{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 自定义样本加载与处理\n", "\n", "`CPU` `GPU` `Linux` `入门`\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 自定义数据集\n", "\n", "用户可以采用LuoJiaNET提供的`GeneratorDataset`接口自定义数据加载。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import luojianet_ms.dataset as ds\n", "class DatasetGenerator:\n", " def __init__(self):\n", " self.data = np.random.sample((5, 2))\n", " self.label = np.random.sample((5, 1))\n", "\n", " def __getitem__(self, index):\n", " return self.data[index], self.label[index]\n", "\n", " def __len__(self):\n", " return len(self.data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中用户需要自定义的类函数如下:\n", "\n", "- **\\_\\_init\\_\\_**\n", "\n", " 实例化数据集对象时,`__init__`函数被调用,用户可以在此进行数据初始化等操作。\n", "\n", " ```python\n", " def __init__(self):\n", " self.data = np.random.sample((5, 2))\n", " self.label = np.random.sample((5, 1))\n", " ```\n", "\n", "- **\\_\\_getitem\\_\\_**\n", "\n", " 定义数据集类的`__getitem__`函数,使其支持随机访问,能够根据给定的索引值`index`,获取数据集中的数据并返回。\n", "\n", " 其中`__getitem__`函数的返回值,需要是由numpy数组组成的元组(tuple),当返回单个numpy数组时可以写成 `return (np_array_1,)`。\n", "\n", " ```python\n", " def __getitem__(self, index):\n", " return self.data[index], self.label[index]\n", " ```\n", "\n", "- **\\_\\_len\\_\\_**\n", "\n", " 定义数据集类的`__len__`函数,返回数据集的样本数量。\n", "\n", " ```python\n", " def __len__(self):\n", " return len(self.data)\n", " ```\n", "\n", "定义数据集类之后,就可以通过`DatasetGenerator`接口按照用户定义的方式加载并迭代访问数据集样本。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.36510558 0.45120592] [0.78888122]\n", "[0.49606035 0.07562207] [0.38068183]\n", "[0.57176158 0.28963401] [0.16271622]\n", "[0.30880446 0.37487617] [0.54738768]\n", "[0.81585667 0.96883469] [0.77994068]\n" ] } ], "source": [ "dataset_generator = DatasetGenerator()\n", "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label\"], shuffle=False)\n", "\n", "for data in dataset.create_dict_iterator():\n", " print('{}'.format(data[\"data\"]), '{}'.format(data[\"label\"]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 迭代数据集\n", "\n", "用户可以用`create_dict_iterator`创建数据迭代器,迭代访问数据,下面展示了对应图片的形状和标签。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for data in dataset.create_dict_iterator():\n", " print(\"Image shape: {}\".format(data['data'].shape), \", Label: {}\".format(data['label']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据处理及增强\n", "\n", "### 数据处理\n", "\n", "LuoJiaNET提供的数据集接口具备常用的数据处理方法,用户只需调用相应的函数接口即可快速进行数据处理。\n", "\n", "下面的样例先将数据集随机打乱顺序,然后将样本两两组成一个批次。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ds.config.set_seed(30)\n", "\n", "# 随机打乱数据顺序\n", "dataset = ds.shuffle(buffer_size=10)\n", "# 对数据集进行分批\n", "dataset = ds.batch(batch_size=2)\n", "\n", "for data in dataset.create_dict_iterator():\n", " print(\"data: {}\".format(data[\"data\"]))\n", " print(\"label: {}\".format(data[\"label\"]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中,\n", "\n", "`buffer_size`:数据集中进行shuffle操作的缓存区的大小。\n", "\n", "`batch_size`:每组包含的数据个数,示例程序中每组包含2个数据。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 数据增强\n", "\n", "数据量过小或是样本场景单一等问题会影响模型的训练效果,用户可以通过数据增强操作扩充样本多样性,从而提升模型的泛化能力。\n", "\n", "下面的样例使用`luojianet.dataset.vision.c_transforms`模块中的算子对MNIST数据集进行数据增强。\n", "\n", "导入`c_transforms`模块,加载MNIST数据集。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "from luojianet_ms.dataset.vision import Inter\n", "import luojianet_ms.dataset.vision.c_transforms as c_vision\n", "\n", "DATA_DIR = './datasets/MNIST_Data/train'\n", "\n", "mnist_dataset = ds.MnistDataset(DATA_DIR, num_samples=6, shuffle=False)\n", "\n", "# 查看数据原图\n", "mnist_it = mnist_dataset.create_dict_iterator()\n", "data = next(mnist_it)\n", "plt.figure(figsize=(3, 3))\n", "plt.imshow(data['image'].asnumpy().squeeze(), cmap=plt.cm.gray)\n", "plt.title(data['label'].asnumpy(), fontsize=20)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义数据增强算子,对数据集进行`Resize`和`RandomCrop`操作,然后通过`map`映射将其插入数据处理管道。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "resize_op = c_vision.Resize(size=(200, 200), interpolation=Inter.LINEAR)\n", "crop_op = c_vision.RandomCrop(150)\n", "transforms_list = [resize_op, crop_op]\n", "mnist_dataset = mnist_dataset.map(operations=transforms_list, input_columns=[\"image\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "显示数据增强效果。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist_dataset = mnist_dataset.create_dict_iterator()\n", "data = next(mnist_dataset)\n", "plt.figure(figsize=(3, 3))\n", "plt.imshow(data['image'].asnumpy().squeeze(), cmap=plt.cm.gray)\n", "plt.title(data['label'].asnumpy(), fontsize=20)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "想要了解更多可以参考编程指南中[数据增强](https://www.luojianet.cn/docs/programming_guide/zh-CN/master/augmentation.html)章节。" ] } ], "metadata": { "kernelspec": { "display_name": "LuoJiaNET", "language": "python", "name": "luojianet" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }