{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 模型自动求导\n", "\n", "`CPU` `GPU` `Linux` `入门`\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "神经网络训练常用反向传播算法,根据损失函数对于给定参数的梯度来调整参数(模型权重)。\n", "\n", "LuoJiaNET计算梯度的方法为`luojianet_ms.ops.GradOperation (get_all=False, get_by_list=False, sens_param=False)`,其中`get_all`为`False`时,只会对第一个输入求导,为`True`时,会对所有输入求导;`get_by_list`为`False`时,不会对权重求导,为`True`时,会对权重求导;`sens_param`对网络的输出值做缩放以改变最终梯度。下面用MatMul算子的求导做深入分析。\n", "\n", "首先导入本文档需要的模块和接口,如下所示:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import luojianet_ms.nn as nn\n", "import luojianet_ms.ops as ops\n", "from luojianet_ms import Tensor\n", "from luojianet_ms import ParameterTuple, Parameter\n", "from luojianet_ms import dtype as mstype" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 对输入求一阶导\n", "\n", "如果需要对输入进行求导,首先需要定义一个需要求导的网络,以一个由MatMul算子构成的网络$f(x,y)=z*x*y$为例。\n", "\n", "定义网络结构如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.matmul = ops.MatMul()\n", " self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') #z为输入变量\n", "\n", " def call(self, x, y):\n", " x = x * self.z\n", " out = self.matmul(x, y)\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着定义对输入的求导网络,`__init__`函数中定义需要求导的网络`self.net`和`ops.GradOperation`操作,`call`函数中对`self.net`进行求导。\n", "\n", "求导网络结构如下:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class GradNetWrtX(nn.Module):\n", " def __init__(self, net):\n", " super(GradNetWrtX, self).__init__()\n", " self.net = net\n", " self.grad_op = ops.GradOperation()\n", "\n", " def call(self, x, y):\n", " gradient_function = self.grad_op(self.net)\n", " return gradient_function(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义参数输入并且打印输出:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[4.5099998 2.7 3.6000001]\n", " [4.5099998 2.7 3.6000001]]\n" ] } ], "source": [ "x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)\n", "y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)\n", "output = GradNetWrtX(Net())(x, y)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "若考虑对`x`、`y`输入求导,只需在`GradNetWrtX`中设置`self.grad_op = GradOperation(get_all=True)`。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 对权重求一阶导\n", "\n", "若需要对权重的求导,将`ops.GradOperation`中的`get_by_list`设置为`True`:\n", "\n", "则`GradNetWrtX`结构为:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class GradNetWrtX(nn.Cell):\n", " def __init__(self, net):\n", " super(GradNetWrtX, self).__init__()\n", " self.net = net\n", " self.params = ParameterTuple(net.trainable_params())\n", " self.grad_op = ops.GradOperation(get_by_list=True)\n", "\n", " def construct(self, x, y):\n", " gradient_function = self.grad_op(self.net, self.params)\n", " return gradient_function(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "运行并打印输出:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[1], dtype=Float32, value= [ 2.15359993e+01]),)\n" ] } ], "source": [ "output = GradNetWrtX(Net())(x, y)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "若需要对某些权重不进行求导,则在定义求导网络时,对相应的权重中`requires_grad`设置为`False`。\n", "\n", "```Python\n", "self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z', requires_grad=False)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 停止计算梯度\n", "\n", "我们可以使用`stop_gradient`来禁止网络内的算子对梯度的影响,例如:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[4.5 2.7 3.6]\n", " [4.5 2.7 3.6]]\n" ] } ], "source": [ "import numpy as np\n", "import luojianet_ms.nn as nn\n", "import luojianet_ms.ops as ops\n", "from luojianet_ms import Tensor\n", "from luojianet_ms import ParameterTuple, Parameter\n", "from luojianet_ms import dtype as mstype\n", "from luojianet_ms.ops import stop_gradient\n", "\n", "class MyNet(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.matmul = ops.MatMul()\n", "\n", " def call(self, x, y):\n", " out1 = self.matmul(x, y)\n", " out2 = self.matmul(x, y)\n", " out2 = stop_gradient(out2)\n", " out = out1 + out2\n", " return out\n", "\n", "class GradMyNetWrtX(nn.Module):\n", " def __init__(self, net):\n", " super(GradMyNetWrtX, self).__init__()\n", " self.net = net\n", " self.grad_op = ops.GradOperation()\n", "\n", " def call(self, x, y):\n", " gradient_function = self.grad_op(self.net)\n", " return gradient_function(x, y)\n", "\n", "x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)\n", "y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)\n", "output = GradNetWrtX(Net())(x, y)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在这里我们对`out2`设置了`stop_gradient`, 所以`out2`没有对梯度计算有任何的贡献。 如果我们删除`out2 = stop_gradient(out2)`,那么输出值会变为:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[9.0 5.4 7.2]\n", " [9.0 5.4 7.2]]\n" ] } ], "source": [ "output = GradNetWrtX(Net())(x, y)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在我们不对`out2`设置`stop_gradient`后, `out2`和`out1`会对梯度产生相同的贡献。 所以我们可以看到,结果中每一项的值都变为了原来的两倍。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 4 }