自定义样本加载与处理

CPU GPU Linux 入门

自定义数据集

用户可以采用LuoJiaNET提供的GeneratorDataset接口自定义数据加载。

[4]:
import numpy as np
import luojianet_ms.dataset as ds
class DatasetGenerator:
    def __init__(self):
        self.data = np.random.sample((5, 2))
        self.label = np.random.sample((5, 1))

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)

其中用户需要自定义的类函数如下:

  • __init__

    实例化数据集对象时,__init__函数被调用,用户可以在此进行数据初始化等操作。

    def __init__(self):
        self.data = np.random.sample((5, 2))
        self.label = np.random.sample((5, 1))
    
  • __getitem__

    定义数据集类的__getitem__函数,使其支持随机访问,能够根据给定的索引值index,获取数据集中的数据并返回。

    其中__getitem__函数的返回值,需要是由numpy数组组成的元组(tuple),当返回单个numpy数组时可以写成 return (np_array_1,)

    def __getitem__(self, index):
        return self.data[index], self.label[index]
    
  • __len__

    定义数据集类的__len__函数,返回数据集的样本数量。

    def __len__(self):
        return len(self.data)
    

定义数据集类之后,就可以通过DatasetGenerator接口按照用户定义的方式加载并迭代访问数据集样本。

[ ]:
dataset_generator = DatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)

for data in dataset.create_dict_iterator():
    print('{}'.format(data["data"]), '{}'.format(data["label"]))
[0.36510558 0.45120592] [0.78888122]
[0.49606035 0.07562207] [0.38068183]
[0.57176158 0.28963401] [0.16271622]
[0.30880446 0.37487617] [0.54738768]
[0.81585667 0.96883469] [0.77994068]

迭代数据集

用户可以用create_dict_iterator创建数据迭代器,迭代访问数据,下面展示了对应图片的形状和标签。

[ ]:
for data in dataset.create_dict_iterator():
    print("Image shape: {}".format(data['data'].shape), ", Label: {}".format(data['label']))

数据处理及增强

数据处理

LuoJiaNET提供的数据集接口具备常用的数据处理方法,用户只需调用相应的函数接口即可快速进行数据处理。

下面的样例先将数据集随机打乱顺序,然后将样本两两组成一个批次。

[ ]:
ds.config.set_seed(30)

# 随机打乱数据顺序
dataset = ds.shuffle(buffer_size=10)
# 对数据集进行分批
dataset = ds.batch(batch_size=2)

for data in dataset.create_dict_iterator():
    print("data: {}".format(data["data"]))
    print("label: {}".format(data["label"]))

其中,

buffer_size:数据集中进行shuffle操作的缓存区的大小。

batch_size:每组包含的数据个数,示例程序中每组包含2个数据。

数据增强

数据量过小或是样本场景单一等问题会影响模型的训练效果,用户可以通过数据增强操作扩充样本多样性,从而提升模型的泛化能力。

下面的样例使用luojianet.dataset.vision.c_transforms模块中的算子对MNIST数据集进行数据增强。

导入c_transforms模块,加载MNIST数据集。

[ ]:
import matplotlib.pyplot as plt

from luojianet_ms.dataset.vision import Inter
import luojianet_ms.dataset.vision.c_transforms as c_vision

DATA_DIR = './datasets/MNIST_Data/train'

mnist_dataset = ds.MnistDataset(DATA_DIR, num_samples=6, shuffle=False)

# 查看数据原图
mnist_it = mnist_dataset.create_dict_iterator()
data = next(mnist_it)
plt.figure(figsize=(3, 3))
plt.imshow(data['image'].asnumpy().squeeze(), cmap=plt.cm.gray)
plt.title(data['label'].asnumpy(), fontsize=20)
plt.show()

定义数据增强算子,对数据集进行ResizeRandomCrop操作,然后通过map映射将其插入数据处理管道。

[8]:
resize_op = c_vision.Resize(size=(200, 200), interpolation=Inter.LINEAR)
crop_op = c_vision.RandomCrop(150)
transforms_list = [resize_op, crop_op]
mnist_dataset = mnist_dataset.map(operations=transforms_list, input_columns=["image"])

显示数据增强效果。

[ ]:
mnist_dataset = mnist_dataset.create_dict_iterator()
data = next(mnist_dataset)
plt.figure(figsize=(3, 3))
plt.imshow(data['image'].asnumpy().squeeze(), cmap=plt.cm.gray)
plt.title(data['label'].asnumpy(), fontsize=20)
plt.show()

想要了解更多可以参考编程指南中数据增强章节。