{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 自定义模型\n", "\n", "`CPU` `GPU` `Linux` `入门`\n", "\n", "\n", "luojianet_ms.nn提供了神经网络模型构建的基础模块,如nn.Conv2D等基础操作算子。\n", "\n", "自定义模型需要继承自基类nn.Module\n", "\n", "首先导入本文档需要的模块和接口,如下所示:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import luojianet_ms\n", "import luojianet_ms.nn as nn\n", "from luojianet_ms import Tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义模型类\n", "\n", "LuoJiaNET采用nn.Module作为基类,所有网络结构中使用的算子都需要继承自该类,并重写`__init__`方法和`call`方法。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class LeNet5(nn.Module):\n", " \"\"\"\n", " Lenet网络结构\n", " \"\"\"\n", " def __init__(self, num_class=10, num_channel=1):\n", " super(LeNet5, self).__init__()\n", " # 定义所需要的运算\n", " self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n", " self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n", " self.fc1 = nn.Dense(16 * 5 * 5, 120)\n", " self.fc2 = nn.Dense(120, 84)\n", " self.fc3 = nn.Dense(84, num_class)\n", " self.relu = nn.ReLU()\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.flatten = nn.Flatten()\n", "\n", " def call(self, x):\n", " # 使用定义好的运算构建前向网络\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.conv2(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.flatten(x)\n", " x = self.fc1(x)\n", " x = self.relu(x)\n", " x = self.fc2(x)\n", " x = self.relu(x)\n", " x = self.fc3(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 模型参数\n", "\n", "通过调用nn.Conv,nn.ReLU,nn.MaxPool2d,nn.Flatten,nn.Dense等操作搭建网络结构后,使用的权重和偏置参数会在之后训练中进行优化。`nn.Module`中使用`parameters_and_names()`方法访问所有参数。\n", "\n", "下面的示例中,调用parameters_and_names遍历每个参数,并打印网络各层名字和属性。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('conv1.weight', Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True))\n", "('conv2.weight', Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True))\n", "('fc1.weight', Parameter (name=fc1.weight, shape=(120, 400), dtype=Float32, requires_grad=True))\n", "('fc1.bias', Parameter (name=fc1.bias, shape=(120,), dtype=Float32, requires_grad=True))\n", "('fc2.weight', Parameter (name=fc2.weight, shape=(84, 120), dtype=Float32, requires_grad=True))\n", "('fc2.bias', Parameter (name=fc2.bias, shape=(84,), dtype=Float32, requires_grad=True))\n", "('fc3.weight', Parameter (name=fc3.weight, shape=(10, 84), dtype=Float32, requires_grad=True))\n", "('fc3.bias', Parameter (name=fc3.bias, shape=(10,), dtype=Float32, requires_grad=True))\n" ] } ], "source": [ "model = LeNet5()\n", "for m in model.parameters_and_names():\n", " print(m)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }