自定义模型

CPU GPU Linux 入门

luojianet_ms.nn提供了神经网络模型构建的基础模块,如nn.Conv2D等基础操作算子。

自定义模型需要继承自基类nn.Module

首先导入本文档需要的模块和接口,如下所示:

[ ]:
import numpy as np
import luojianet_ms
import luojianet_ms.nn as nn
from luojianet_ms import Tensor

定义模型类

LuoJiaNET采用nn.Module作为基类,所有网络结构中使用的算子都需要继承自该类,并重写__init__方法和call方法。

[2]:
class LeNet5(nn.Module):
    """
    Lenet网络结构
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        # 定义所需要的运算
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, num_class)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def call(self, x):
        # 使用定义好的运算构建前向网络
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

模型参数

通过调用nn.Conv,nn.ReLU,nn.MaxPool2d,nn.Flatten,nn.Dense等操作搭建网络结构后,使用的权重和偏置参数会在之后训练中进行优化。nn.Module中使用parameters_and_names()方法访问所有参数。

下面的示例中,调用parameters_and_names遍历每个参数,并打印网络各层名字和属性。

[8]:
model = LeNet5()
for m in model.parameters_and_names():
    print(m)
('conv1.weight', Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True))
('conv2.weight', Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True))
('fc1.weight', Parameter (name=fc1.weight, shape=(120, 400), dtype=Float32, requires_grad=True))
('fc1.bias', Parameter (name=fc1.bias, shape=(120,), dtype=Float32, requires_grad=True))
('fc2.weight', Parameter (name=fc2.weight, shape=(84, 120), dtype=Float32, requires_grad=True))
('fc2.bias', Parameter (name=fc2.bias, shape=(84,), dtype=Float32, requires_grad=True))
('fc3.weight', Parameter (name=fc3.weight, shape=(10, 84), dtype=Float32, requires_grad=True))
('fc3.bias', Parameter (name=fc3.bias, shape=(10,), dtype=Float32, requires_grad=True))