├── doc ├── 3损失函数.md ├── 4layer的实现.md ├── 2优化器.md └── 1自动求导基础运算实现.md ├── README.md ├── LICENSE ├── test ├── test_optim.py └── test_tensor.py ├── easytorch ├── layer.py ├── functional.py ├── optim.py └── tensor.py └── example ├── Predict.ipynb └── FunctionApproximation.ipynb /doc/3损失函数.md: -------------------------------------------------------------------------------- 1 | # 3. 损失函数 2 | 3 | ## L2损失函数 4 | 5 | 实现简单,不做说明。 6 | 7 | L2 loss的问题是梯度的值与x的值有关,在x特别大时,会有很大的梯度,训练不稳定。 8 | 9 | ## L1损失函数 10 | 11 | L1损失函数的形式为$loss = \sum_i |y_i - pred_i|$,导数为$sign(x)$,在$x = 0$处不可导,可以使用次梯度,取0。 12 | 13 | L1 loss的问题与L2相反,梯度是常数,在x值很小时,梯度依然是1,如果学习率不变的话,很容易发生震荡,难以收敛到更高的精度。 14 | -------------------------------------------------------------------------------- /doc/4layer的实现.md: -------------------------------------------------------------------------------- 1 | # 4. layer的实现 2 | 3 | 具体代码见```layer.py```,先构建抽象基类Layer,规定实现接口,然后由子类实现```forward```方法。 4 | 5 | 目前实现过于简单,之后希望可以按照pytorch的逻辑结构实现一个稍微复杂亿点点的版本。 6 | 7 | ## Linear 8 | 9 | Linear层完成的操作是$x @ W + b$,其中的$W$和$b$为可训练参数。实现时只需要完成上面的正向传播操作,反向传播由自动求导完成,十分简单。 10 | 11 | ## 激活层 12 | 13 | 激活层的特点在于没有可训练参数,实现时与自动求导中相应函数的实现相同。 14 | 15 | ## Sequential 16 | 17 | Sequential层在初始化时会保存所有的层,在正向传播时,按顺序传递数据。 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # easytorch 2 | 3 | 使用Python的numpy实现的简易深度学习框架,API与pytorch基本相同,实现了自动求导、基础优化器、layer等。 4 | 5 | ## 1 文档目录 6 | 7 | [1. 自动求导基础运算实现](./doc/1自动求导基础运算实现.md) 8 | 9 | [2. 优化器实现](./doc/2优化器.md) 10 | 11 | [3. 损失函数](./doc/3损失函数.md) 12 | 13 | [4. layer的实现](./doc/4layer的实现.md) 14 | 15 | ## 2 Quick Start 16 | 17 | ``` python 18 | from easytorch.layer import Linear, Tanh, Sequential 19 | from easytorch.optim import SGD 20 | import easytorch.functional as F 21 | 22 | # Create a model, optimizer, loss function 23 | model = Sequential( 24 | Linear(1, 5), 25 | Tanh(), 26 | Linear(5, 1) 27 | ) 28 | opt = SGD(model.parameters(), lr=3e-4) 29 | loss_fn = F.mse_loss 30 | 31 | # train the model 32 | for epoch in range(epochs): 33 | pred = model(x) 34 | loss = loss_fn(pred, y) 35 | opt.zero_grad() 36 | loss.backward() 37 | opt.step() 38 | ``` 39 | 40 | ## 3 Example 41 | 42 | 1. [使用神经网络近似三角函数](./example/FunctionApproximation.ipynb) 43 | 2. [使用神经网络预测波士顿房价](./example/Predict.ipynb) 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 SongLei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test/test_optim.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from tqdm import tqdm 7 | from easytorch import tensor, optim 8 | 9 | 10 | def generate_data(n=100, f=lambda x: 2 * x - 1): 11 | data = [] 12 | for _ in range(n): 13 | x = np.random.uniform(-3, 3, 3) 14 | y = f(x) + 0.01 * np.random.randn() 15 | data.append([x, y]) 16 | return data 17 | 18 | 19 | def sgd_linear_approximation(): 20 | train_data = generate_data(n=100, f=lambda x: x[0]+2*x[1]+3*x[2]) 21 | x = tensor.Tensor([x for x, y in train_data]) 22 | y = tensor.Tensor([y for x, y in train_data]) 23 | w = tensor.random(3, requires_grad=True) 24 | b = tensor.Tensor(1.0, requires_grad=True) 25 | opt = optim.SGD([w, b], lr=0.01) 26 | loss_list = [] 27 | 28 | for _ in tqdm(range(1000)): 29 | for data_x, data_y in zip(x, y): 30 | pred = data_x @ w + b 31 | loss = ((pred - data_y) * (pred - data_y)).mean() 32 | loss_list.append(loss.data) 33 | opt.zero_grad() 34 | loss.backward() 35 | opt.step() 36 | 37 | plt.plot(loss_list) 38 | plt.show() 39 | 40 | 41 | if __name__ == '__main__': 42 | sgd_linear_approximation() 43 | -------------------------------------------------------------------------------- /doc/2优化器.md: -------------------------------------------------------------------------------- 1 | # 2. 优化器 2 | 3 | 迭代优化算法的基本框架如下: 4 | 5 | 1. 计算目标函数对当前参数的梯度$g_t = \nabla f(\omega _t)$ 6 | 2. 更新历史的一阶动量和二阶动量$m_t$, $V_t$ 7 | 3. 使用$m_t$控制更新的方向,用$V_t$控制更新的步长,计算当前的下降梯度$\eta_t = \alpha \frac{m_t}{\sqrt{V_t}}$ 8 | 4. 使用梯度更新$\omega_{t+1} = \omega_t - \eta_t$ 9 | 10 | 不同的优化器就是在第二步中不同。 11 | 12 | 具体代码见```optim.py```。 13 | 14 | ## SGD 15 | 16 | SGD第二行更新为$m_{t} = g_t$,$V_t = 1$。 17 | 18 | 优化公式为$\omega_{t+1} = \omega_t - \alpha * g_t$ 19 | 20 | ## Adagrad(Adaptive gradient) 21 | 22 | Adagrad第二行更新为$m_{t} = g_{t}$,$V_{t} = V_{t-1} + g_t \odot g_t$加入了自适应的步长,通过累加$V_t$的方式,使得更新梯度$g_t$较大的,更新减慢,而梯度较小的$g_t$,更新加速。 23 | 24 | 优化公式为$\omega_{t+1} = \omega_t - \frac{\alpha}{\sqrt{V_{t}} + \epsilon} \odot g_t$ 25 | 26 | Adagrad的问题是在训练后期,由于$V_t$一直在累加,所以分母会过大,导致后期的学习率过小,基本没有变化。 27 | 28 | ## Moment 29 | 30 | Moment引入了动量,第二行的更新为$m_t = \beta m_{t-1} + 31 | (1 - \beta) g_t$,通过当前梯度和历史梯度的平均,使得在震荡的方向学习减慢,在稳定下降的方向学习加快。 32 | 33 | 优化公式为$\omega_{t+1} = \omega_t + \alpha * m_t$ 34 | 35 | ## RMSprop 36 | 37 | RMSprop一定程度上解决了Adagrad学习率消失的问题,对二阶动量的更新方式$V_{t} = \beta V_{t-1} + (1 - \beta)g_t \odot g_t$。 38 | 39 | 优化公式为$\omega_{t+1} = \omega_t - \frac{\alpha}{\sqrt{V_{t}} + \epsilon} \odot g_t$ 40 | 41 | ## Adam 42 | 43 | Adam是Adaptive moment,将上面两种的思想结合,第二行的更新公式为$m_t = \beta_0 m_{t-1} + (1 - \beta_0) g_t$,$V_{t} = \beta_1 V_{t-1} + (1-\beta_1)g_t \odot g_t$,然后进行bias correction,$m_t = \frac{m_t}{1 - \beta_0^t}$,$V_t = \frac{V_t}{1 - \beta_1^t}$。 44 | 45 | 优化公式为$\omega_{t+1} = \omega_t + \alpha * \frac{m_t}{\sqrt{V_t} + \epsilon}$ 46 | 47 | ## 思考 48 | 49 | 优化器中已经加入了对学习率的衰减,那么再增加学习率衰减还有没有用。 50 | -------------------------------------------------------------------------------- /easytorch/layer.py: -------------------------------------------------------------------------------- 1 | from easytorch import tensor 2 | import easytorch.functional as F 3 | import abc 4 | 5 | 6 | class Layer(metaclass=abc.ABCMeta): 7 | 8 | def __init__(self): 9 | self.params = [] 10 | 11 | @abc.abstractmethod 12 | def forward(self, x): 13 | pass 14 | 15 | def __call__(self, x): 16 | return self.forward(x) 17 | 18 | def parameters(self): 19 | return self.params 20 | 21 | 22 | class Linear(Layer): 23 | 24 | def __init__(self, in_features, out_features, bias=True): 25 | super(Linear, self).__init__() 26 | self.in_features = in_features 27 | self.out_features = out_features 28 | self.weight = tensor.random(in_features, out_features) 29 | self.params.append(self.weight) 30 | if bias: 31 | self.bias = tensor.random(out_features) 32 | self.params.append(self.bias) 33 | else: 34 | self.bias = None 35 | 36 | def forward(self, x): 37 | y = x @ self.weight 38 | if self.bias: 39 | y += self.bias 40 | return y 41 | 42 | 43 | class Sequential(Layer): 44 | 45 | def __init__(self, *layers): 46 | super(Sequential, self).__init__() 47 | self.layers = layers 48 | for layer in layers: 49 | assert isinstance(layer, Layer) 50 | self.params.extend(layer.parameters()) 51 | 52 | def forward(self, x): 53 | for layer in self.layers: 54 | x = layer(x) 55 | return x 56 | 57 | 58 | class ReLU(Layer): 59 | 60 | def __init__(self): 61 | super(ReLU, self).__init__() 62 | 63 | def forward(self, x): 64 | return F.relu(x) 65 | 66 | 67 | class Tanh(Layer): 68 | 69 | def __init__(self): 70 | super(Tanh, self).__init__() 71 | 72 | def forward(self, x): 73 | return F.tanh(x) 74 | -------------------------------------------------------------------------------- /doc/1自动求导基础运算实现.md: -------------------------------------------------------------------------------- 1 | # 1. 自动求导基础运算实现 2 | 3 | 这部分是自动求导的基础运算实现,注意下面的推导都是使用**微分**进行推导,直接使用导数推导会出现很多问题,比如矩阵求导时链式法则不成立、求得的导数结果为四维张量等等,使用微分就比较合适。具体实现代码见```tensor.py```。 4 | 5 | ## 1. 基础运算实现 6 | 7 | ### Add,Sub,Mul,Divide, Pow 8 | 9 | 最基础的四则运算,共同特点是都是逐元素运算,所以求导都十分简单,与标量形式基本相同,麻烦的是要注意考虑broadcast的问题。 10 | 11 | 1. $Z = X + Y$,则$dZ = dX + dY$,数学上梯度为单位矩阵,十分简单。比较复杂的是计算机实现中要考虑到向量的broadcast的问题,broadcast分为两种情况:第一种是其中一个矩阵的维度小于另一个,比如$X=[[1, 2], [3, 4]]$,$Y = [1, 2]$,那么在对$Y$求微分时,需要将多余的维度进行$sum$操作;另一种情况是两个的维度相等,但其中一个的某些维度形状为1,比如$X=[[1, 2], [3, 4]]$,$Y = [[1, 2]]$中,$shape(X) = (2, 2)$,$shape(Y)=(1, 2)$,需要对$shape$为1的维度做$sum$操作。 12 | 13 | 2. Sub可以直接使用Add实现。 14 | 15 | 3. 逐元素乘法,$Z = X\odot Y$。 16 | 17 | 4. 逐元素除法,$Z = X / Y$。 18 | 19 | 5. Pow操作也是逐元素操作,求导比较简单。 20 | 21 | ### Sum, Mean 22 | 23 | 都是将矩阵的某一个维度压缩,所以梯度是将当前的梯度广播到原来的size,二者的差别只在于是否需要乘一个常数项。 24 | 25 | ### Matmul 26 | 27 | 矩阵乘法或叫矩阵内积,$Z = XY$,对$X$的微分为$dZ = dXY$,所以梯度为$Y^T$,而且要注意左乘右乘的顺序。 28 | 29 | 一种特殊情况是两个行向量做内积操作时,会自动将第二个行向量转化为列向量,得到一个标量的结果。比如下面的操作也是合法的,会得到5。但目前我还没有想到怎么处理。 30 | 31 | ``` python 32 | a = Tensor([1., 2.], requires_grad=True) 33 | b = Tensor([1., 2.], requires_grad=True) 34 | c = a @ b 35 | ``` 36 | 37 | ### reshape,__getitem__ 38 | 39 | 这两个操作都是对矩阵的元素的重新排列,在数学上是完全无法求梯度的,反向传播的实现中只需要记录梯度,保持梯度与原数据的位置对应。 40 | 41 | 另外实现要注意这两个操作生成的新向量的数据与原向量的数据是相同的,是浅拷贝。 42 | 43 | ### Tanh, Relu 44 | 45 | 这两个激活函数是逐元素函数$\sigma$,设为$y = \sigma(x)$,$dy = \sigma '(x) \odot dx = diag(\sigma '(x)) dx$,第一个等号是通过逐元素操作计算,实现更加简单,但这种操作数学上好像是不存在的,第二个等号是通过矩阵计算,是数学上的正确的形式,但计算机实现会稍微麻烦一点。 46 | 47 | Tanh求导:$tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$,然后进行求导。 48 | 49 | Relu求导:relu的问题是在$x=0$处不可导,此时需要使用次梯度$c \leq \frac{f(y) - f(x)}{y - x}$,通常取$c=0$(好像一般倾向取0,使得计算方便、带来更多的稀疏性)。 50 | 51 | ### Softmax 52 | 53 | Softmax的公式为$\frac{e^x}{\sum e^x}$,但这个公式在$x$非常小时,会出现下溢的情况,导致分母为0,所以在实现时,通常会减去最大值$m$,即$\frac{e^{x - m}}{\sum e^{x - m}}$,这样在$x_i=m$时,$e^{x_i-m}$为1,分母一定大于等于1,从而避免了下溢出的问题。 54 | 55 | 求导暂时没有解决,和pytorch的梯度不同。 56 | 57 | ### Abs 58 | 59 | 求导为sign函数,在$x=0$处不可导,需要使用次梯度。 60 | 61 | ## 2. 实现总结 62 | 63 | 这部分主要是总结理论上不需要考虑,但实际实现时需要考虑的问题。 64 | 65 | 1. Broadcast问题(详见Add、Sub、Mul、Divide的实现) 66 | 2. 两个操作数为同一个对象时,该怎么处理(见```tensor.py:Tensor/backward```的实现) 67 | 68 | ## 3. 测试 69 | 70 | ### 代码中加入assert 71 | 72 | ```backward```操作后,得到的```grad```和```data```的shape是相同的。 73 | 74 | ### 对拍测试 75 | 76 | **未测试的代码永远是错的**。因为API和pytorch完全相同,所以采用和pytorch对拍的方式测试正确性,在backward后比较叶节点的```grad```。 77 | -------------------------------------------------------------------------------- /easytorch/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | from easytorch import tensor 4 | 5 | 6 | def tanh(inputs): 7 | data = np.tanh(inputs.data) 8 | requires_grad = inputs.requires_grad 9 | t = tensor.Tensor(data, requires_grad) 10 | t.is_leaf = False 11 | 12 | if inputs.requires_grad: 13 | def TanhBackward(grad): 14 | return grad * (1 - np.tanh(inputs.data) ** 2) 15 | t.grad_node.append(tensor.GRAD_NODE_FMT(inputs, TanhBackward)) 16 | 17 | return t 18 | 19 | 20 | def relu(inputs): 21 | data = np.maximum(0, inputs.data) 22 | requires_grad = inputs.requires_grad 23 | t = tensor.Tensor(data, requires_grad) 24 | t.is_leaf = False 25 | 26 | if inputs.requires_grad: 27 | def ReluBackward(grad): 28 | relu_prime = np.zeros_like(inputs.data) 29 | relu_prime[inputs.data > 0] = 1 30 | return grad * relu_prime 31 | 32 | t.grad_node.append(tensor.GRAD_NODE_FMT(inputs, ReluBackward)) 33 | 34 | return t 35 | 36 | 37 | def softmax(inputs, dim=0): 38 | raise NotImplementedError('There is a bug') 39 | def softmax_func(x): 40 | max_v = np.max(x) 41 | return np.e**(x - max_v) / np.sum(np.e**(x - max_v)) 42 | assert inputs.data.ndim == 1 or (inputs.data.ndim == 2 and (inputs.data.shape[0] == 1 or inputs.data.shape[1] == 1)) 43 | # data = np.apply_over_axes(softmax_func, dim, inputs.data) 44 | data = softmax_func(inputs.data) 45 | requires_grad = inputs.requires_grad 46 | t = tensor.Tensor(data, requires_grad) 47 | t.is_leaf = False 48 | 49 | if inputs.requires_grad: 50 | def SoftmaxBackward(grad): 51 | result = softmax_func(inputs.data) 52 | length = inputs.data.reshape(-1).shape[0] 53 | mat = np.zeros((length, length)) 54 | for i in range(length): 55 | for j in range(length): 56 | if i == j: 57 | mat[i][j] = result[i]*(1 - result[i]) 58 | else: 59 | mat[i][j] = result[i] * result[j] 60 | print('mat') 61 | print(mat) 62 | print('grad') 63 | print(grad) 64 | next_grad = mat @ grad 65 | return next_grad 66 | t.grad_node.append(tensor.GRAD_NODE_FMT(inputs, SoftmaxBackward)) 67 | 68 | return t 69 | 70 | 71 | def mse_loss(target_y, y): 72 | if y.shape != target_y.shape: 73 | warnings.warn('mse_loss, target size {} is different from input size {}, ' 74 | 'this will likely lead to incorrect results due to broadcasting'.format(target_y.shape, y.shape)) 75 | return ((y - target_y) * (y - target_y)).mean() 76 | 77 | 78 | def l1_loss(target_y, y): 79 | if y.shape != target_y.shape: 80 | warnings.warn('mse_loss, target size {} is different from input size {}, ' 81 | 'this will likely lead to incorrect results due to broadcasting'.format(target_y.shape, y.shape)) 82 | return (y - target_y).abs().mean() 83 | -------------------------------------------------------------------------------- /easytorch/optim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import abc 3 | 4 | 5 | class Optimizer(metaclass=abc.ABCMeta): 6 | 7 | def __init__(self, params, lr=3e-4): 8 | self.params = params 9 | self.lr = lr 10 | self.V = [] 11 | self.m = [] 12 | for param in self.params: 13 | self.V.append(np.zeros_like(param.data)) 14 | self.m.append(np.zeros_like(param.data)) 15 | 16 | def zero_grad(self): 17 | for param in self.params: 18 | param.grad = 0 19 | 20 | @abc.abstractmethod 21 | def step(self): 22 | pass 23 | 24 | 25 | class SGD(Optimizer): 26 | 27 | def __init__(self, params, lr=3e-4): 28 | super(SGD, self).__init__(params, lr) 29 | 30 | def step(self): 31 | for param in self.params: 32 | param.data -= self.lr * param.grad 33 | 34 | 35 | class Adagrad(Optimizer): 36 | 37 | def __init__(self, params, lr=1e-2, eps=1e-8): 38 | super(Adagrad, self).__init__(params, lr) 39 | self.eps = eps 40 | 41 | def step(self): 42 | for i in range(len(self.params)): 43 | self.V[i] += self.params[i].grad * self.params[i].grad 44 | self.params[i].data -= self.lr * self.params[i].grad / (np.sqrt(self.V[i]) + self.eps) 45 | 46 | 47 | class Moment(Optimizer): 48 | 49 | def __init__(self, params, lr=3e-4, beta=0.9): 50 | super(Moment, self).__init__(params, lr) 51 | self.beta = beta 52 | 53 | def step(self): 54 | for i in range(len(self.params)): 55 | self.m[i] = self.beta * self.m[i] + (1 - self.beta) * self.params[i].grad 56 | self.params[i].data -= self.lr * self.m[i] 57 | 58 | 59 | class RMSprop(Optimizer): 60 | 61 | def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8): 62 | super(RMSprop, self).__init__(params, lr) 63 | self.alpha = alpha 64 | self.eps = eps 65 | 66 | def step(self): 67 | for i in range(len(self.params)): 68 | self.V[i] = self.alpha * self.V[i] + (1 - self.alpha) * (self.params[i].grad * self.params[i].grad) 69 | self.params[i].data -= self.lr * self.params[i].grad / (np.sqrt(self.V[i]) + self.eps) 70 | 71 | 72 | class Adam(Optimizer): 73 | 74 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_delay=0): 75 | super(Adam, self).__init__(params, lr) 76 | self.betas = betas 77 | self.eps = eps 78 | self.beta0_bias_correction = self.betas[0] 79 | self.beta1_bias_correction = self.betas[1] 80 | self.weight_delay = weight_delay 81 | 82 | def step(self): 83 | for i in range(len(self.params)): 84 | self.m[i] = self.betas[0] * self.m[i] + (1-self.betas[0]) * self.params[i].grad 85 | self.m[i] = self.m[i] / (1 - self.beta0_bias_correction) 86 | self.V[i] = self.betas[1] * self.V[i] + (1-self.betas[1]) * (self.params[i].grad * self.params[i].grad) 87 | # 直接这么写似乎容易溢出 88 | # self.V[i] = self.V[i] / (1 - self.beta1_bias_correction) 89 | # self.params[i].data = (1 - self.weight_delay) * self.params[i].data - self.lr * self.m[i] * \ 90 | # / (np.sqrt(self.V[i]) + self.eps) 91 | self.params[i].data = (1 - self.weight_delay) * self.params[i].data - self.lr * self.m[i] * \ 92 | np.sqrt((1 - self.beta1_bias_correction)) / (np.sqrt(self.V[i]) + self.eps) 93 | self.beta0_bias_correction *= self.betas[0] 94 | self.beta1_bias_correction *= self.betas[1] 95 | -------------------------------------------------------------------------------- /test/test_tensor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | import unittest 5 | from easytorch.tensor import Tensor 6 | from torch import tensor as torchTensor 7 | 8 | 9 | def is_tensor_equal(leaves1, leaves2): 10 | ret = True 11 | for t1, t2 in zip(leaves1, leaves2): 12 | val_eq = ((t1.data - t2.detach().numpy()) < 1e-4).all() 13 | grad_eq = ((t1.grad - t2.grad.detach().numpy()) < 1e-4).all() 14 | requires_grad_eq = (t1.requires_grad == t2.requires_grad) 15 | ret = ret and val_eq and grad_eq and requires_grad_eq 16 | return ret 17 | 18 | 19 | def print_leaves(leaves): 20 | print('-------------------------') 21 | for leaf in leaves: 22 | print(leaf.grad) 23 | print('-------------------------') 24 | 25 | 26 | class TestTensor(unittest.TestCase): 27 | 28 | def run_test_case(self, case): 29 | leaves1 = case(Tensor) 30 | leaves2 = case(torchTensor) 31 | self.assertTrue(is_tensor_equal(leaves1, leaves2)) 32 | 33 | def test_ops(self): 34 | def case_add(tensor): 35 | a = tensor([1., 2.], requires_grad=True) 36 | b = tensor([3., 4.], requires_grad=True) 37 | c = tensor([[5., 6.], [7., 8.]], requires_grad=True) 38 | leaves = [a, b, c] 39 | d = a + b + 10 # 相同尺寸的逐元素加法和标量加法 40 | d = d + c + c # broadcast和两个操作数为同一个对象 41 | d = d.mean() 42 | d.backward() 43 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 44 | 45 | return leaves 46 | self.run_test_case(case_add) 47 | 48 | def case_sub(tensor): 49 | a = tensor([1., 2.], requires_grad=True) 50 | b = tensor([3., 4.], requires_grad=True) 51 | c = tensor([[5., 6.], [7., 8.]], requires_grad=True) 52 | leaves = [a, b, c] 53 | d = a - b - 10 # 相同尺寸的逐元素减法和标量减法 54 | d = d - c - c # broadcast和两个操作数为同一个对象 55 | d = 100 - d 56 | d = d.sum() 57 | d.backward() 58 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 59 | 60 | return leaves 61 | self.run_test_case(case_sub) 62 | 63 | def case_mul(tensor): 64 | a = tensor([1., 2.], requires_grad=True) 65 | b = tensor([3., 4.], requires_grad=True) 66 | c = tensor([[5., 6.], [7., 8.]], requires_grad=True) 67 | leaves = [a, b, c] 68 | d = a * a 69 | d = d * c * c + a 70 | d = d.sum() 71 | d.backward() 72 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 73 | 74 | return leaves 75 | self.run_test_case(case_mul) 76 | 77 | def case1(tensor): 78 | a = tensor([1., 2.], requires_grad=True) 79 | b = tensor([3., 4.], requires_grad=True) 80 | c = tensor([[5., 6.], [7., 8.]], requires_grad=True) 81 | leaves = [a, b, c] 82 | 83 | d = 3*a + b + 1 84 | d = d * b 85 | d = d + 5*c / 20 86 | d = d.mean() 87 | d.backward() 88 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 89 | # print_leaves(leaves) 90 | 91 | return leaves 92 | self.run_test_case(case1) 93 | 94 | def test_dot(self): 95 | def case1(tensor): 96 | a = tensor([[1., 2.], [3., 4.]], requires_grad=True) 97 | b = tensor([[5., 6., 7.], [8., 9., 10.]], requires_grad=True) 98 | leaves = [a, b] 99 | c = a @ b 100 | c = c.sum() 101 | c.backward() 102 | return leaves 103 | self.run_test_case(case1) 104 | 105 | def case2(tensor): 106 | a = tensor([[1., 2.]], requires_grad=True) 107 | b = tensor([[5.], [8.]], requires_grad=True) 108 | leaves = [a, b] 109 | c = a @ b 110 | c = c.sum() 111 | c.backward() 112 | return leaves 113 | self.run_test_case(case2) 114 | 115 | def case3(tensor): 116 | a = tensor([1., 2.], requires_grad=True) 117 | b = tensor([3., 4.], requires_grad=True) 118 | leaves = [a, b] 119 | c = a @ b + b 120 | c = c.sum() 121 | c.backward() 122 | return leaves 123 | self.run_test_case(case3) 124 | 125 | def case4(tensor): 126 | a = tensor([1.], requires_grad=True) 127 | b = tensor([3.], requires_grad=True) 128 | leaves = [a, b] 129 | c = a @ b + b 130 | c = c.sum() 131 | c.backward() 132 | return leaves 133 | self.run_test_case(case4) 134 | 135 | def test_reshape(self): 136 | a = Tensor([1, 2]) 137 | b = a.reshape(2, 1) 138 | a[0] = 10 139 | self.assertTrue(a.data[0], b.data[0][0]) 140 | 141 | def case1(tensor): 142 | a = tensor([[1.], [2.]], requires_grad=True) 143 | b = tensor([3., 4.], requires_grad=True) 144 | leaves = [a, b] 145 | c = (2*a + 10).reshape(2) 146 | d = b * (c + 10) 147 | d = d.sum() 148 | d.backward() 149 | return leaves 150 | self.run_test_case(case1) 151 | 152 | def test_activation_func(self): 153 | def case_tanh(tensor): 154 | a = tensor([1., 2.], requires_grad=True) 155 | b = tensor([3., 4.], requires_grad=True) 156 | c = tensor([[5., 6.], [7., 8.]], requires_grad=True) 157 | leaves = [a, b, c] 158 | 159 | d = 3 * a + b + 1 160 | d = (d * b).tanh() 161 | d = (d + 5 * c / 20).tanh() 162 | d = d.mean() 163 | d.backward() 164 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 165 | 166 | return leaves 167 | self.run_test_case(case_tanh) 168 | 169 | def case_relu(tensor): 170 | a = tensor([1., 2.], requires_grad=True) 171 | b = tensor([3., 4.], requires_grad=True) 172 | c = tensor([[5., 6.], [7., 8.]], requires_grad=True) 173 | leaves = [a, b, c] 174 | 175 | d = 3 * a + b + 1 176 | d = (d * b).relu() 177 | d = (d + 5 * c / 20).relu() 178 | d = d.mean() 179 | d.backward() 180 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 181 | 182 | return leaves 183 | self.run_test_case(case_relu) 184 | 185 | def test_pow(self): 186 | def case(tensor): 187 | a = tensor([1., 2.], requires_grad=True) 188 | b = tensor([3., 4.], requires_grad=True) 189 | c = tensor([[5., 6.], [7., 8.]], requires_grad=True) 190 | leaves = [a, b, c] 191 | 192 | d = 3 * a.pow(5) + b + 1 193 | d = (d * b).pow(2).tanh() 194 | d = (d + 5 * c / 20).relu() 195 | d = d.mean() 196 | d.backward() 197 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 198 | 199 | return leaves 200 | self.run_test_case(case) 201 | 202 | def test_select(self): 203 | def case(tensor): 204 | a = tensor([1., 2.], requires_grad=True) 205 | b = tensor([3., 4.], requires_grad=True) 206 | c = tensor([[5., 6.], [7., 8.]], requires_grad=True) 207 | leaves = [a, b, c] 208 | d = a + b + c[0] 209 | d = d.mean() 210 | d.backward() 211 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 212 | 213 | return leaves 214 | self.run_test_case(case) 215 | 216 | # def test_softmax(self): 217 | # def case(tensor): 218 | # a = tensor([1., 2., 3.], requires_grad=True) 219 | # leaves = [a] 220 | # b = a.softmax(dim=0) 221 | # b = b.mean() 222 | # print('b', b) 223 | # b.backward() 224 | # leaves = list(filter(lambda x: x.grad is not None, leaves)) 225 | # print_leaves(leaves) 226 | # 227 | # return leaves 228 | # 229 | # self.run_test_case(case) 230 | 231 | def test_abs(self): 232 | def case(tensor): 233 | a = tensor([1., 2., 0., -10, -20], requires_grad=True) 234 | b = tensor([-2., 4., 0., 0, -20], requires_grad=True) 235 | leaves = [a, b] 236 | c = (a + b).abs() 237 | c = c.sum() 238 | c.backward() 239 | leaves = list(filter(lambda x: x.grad is not None, leaves)) 240 | 241 | return leaves 242 | self.run_test_case(case) 243 | 244 | 245 | if __name__ == '__main__': 246 | unittest.main() 247 | -------------------------------------------------------------------------------- /easytorch/tensor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import namedtuple 3 | import easytorch.functional as F 4 | 5 | 6 | GRAD_NODE_FMT = namedtuple('grad_node', ['tensor', 'grad_fn']) 7 | 8 | 9 | class Tensor: 10 | 11 | def __init__(self, data, requires_grad=False): 12 | self.data = np.asarray(data) 13 | self.requires_grad = requires_grad 14 | if self.data.dtype == np.int and self.requires_grad: 15 | raise RuntimeError('Only Tensors of floating point and complex dtype can require gradients') 16 | self.grad_node = [] 17 | self.grad = None 18 | self.is_leaf = True 19 | 20 | def __repr__(self): 21 | s = 'tensor({}'.format(self.data) 22 | if self.grad_node: 23 | s += ', grad_fn=<{}>)'.format(self.grad_node[0].grad_fn.__name__) 24 | elif self.requires_grad: 25 | s += ', requires_grad=True)' 26 | else: 27 | s += ')' 28 | return s 29 | 30 | def __getitem__(self, item): 31 | data = self.data[item] 32 | requires_grad = self.requires_grad 33 | t = Tensor(data, requires_grad) 34 | t.is_leaf = False 35 | 36 | if self.requires_grad: 37 | def SelectBackward(grad): 38 | next_grad = np.zeros_like(self.data) 39 | next_grad[item] = grad 40 | return next_grad 41 | t.grad_node.append(GRAD_NODE_FMT(self, SelectBackward)) 42 | 43 | return t 44 | 45 | def __setitem__(self, key, value): 46 | self.data[key] = value 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | 51 | @property 52 | def shape(self): 53 | return self.data.shape 54 | 55 | @property 56 | def T(self): 57 | raise NotImplementedError('Transpose is not implemented') 58 | 59 | def reshape(self, *shape): 60 | old_shape = self.data.shape 61 | t = Tensor(self.data.reshape(shape), self.requires_grad) 62 | t.is_leaf = False 63 | 64 | if self.requires_grad: 65 | def ViewBackward(grad): 66 | grad = grad.reshape(old_shape) 67 | return grad 68 | t.grad_node.append(GRAD_NODE_FMT(self, ViewBackward)) 69 | 70 | return t 71 | 72 | def backward(self, gradient=None): 73 | if not self.requires_grad: 74 | raise RuntimeError('tensor does not require grad') 75 | if self.grad is None: 76 | if self.data.shape == () or self.data.shape == (1, ): 77 | self.grad = np.ones(1) 78 | else: 79 | print(self.data.shape) 80 | raise RuntimeError('grad can be implicitly created only for scalar outputs') 81 | 82 | for node in self.grad_node: 83 | if node.tensor.grad is None: 84 | node.tensor.grad = node.grad_fn(self.grad) 85 | else: 86 | node.tensor.grad += node.grad_fn(self.grad) 87 | node.tensor.backward() 88 | if not node.tensor.is_leaf: 89 | node.tensor.grad = None 90 | 91 | def __add__(self, other): 92 | other = Tensor.astensor(other) 93 | data = self.data + other.data 94 | requires_grad = self.requires_grad or other.requires_grad 95 | t = Tensor(data, requires_grad) 96 | t.is_leaf = False 97 | 98 | if self.requires_grad: 99 | def AddBackward(grad): 100 | grad = grad * np.ones_like(self.data) 101 | for _ in range(grad.ndim - self.data.ndim): 102 | grad = grad.sum(axis=0) 103 | for i, d in enumerate(self.data.shape): 104 | if d == 1: 105 | grad = grad.sum(axis=i, keepdims=True) 106 | 107 | assert grad.shape == self.data.shape, 'AddBackward, grad.shape != data.shape' 108 | return grad 109 | t.grad_node.append(GRAD_NODE_FMT(self, AddBackward)) 110 | 111 | if other.requires_grad: 112 | def AddBackward(grad): 113 | grad = grad * np.ones_like(other.data) 114 | for _ in range(grad.ndim - other.data.ndim): 115 | grad = grad.sum(axis=0) 116 | 117 | for i, d in enumerate(other.data.shape): 118 | if d == 1: 119 | grad = grad.sum(axis=i, keepdims=True) 120 | 121 | assert grad.shape == other.data.shape, 'AddBackward, grad.shape != data.shape' 122 | return grad 123 | 124 | t.grad_node.append(GRAD_NODE_FMT(other, AddBackward)) 125 | 126 | return t 127 | 128 | def __radd__(self, other): 129 | return self + other 130 | 131 | def __iadd__(self, other): 132 | return self + other 133 | 134 | def __sub__(self, other): 135 | # TODO: 重新写sub函数,目前的sub记录的grad_fn为AddBackward 136 | return self + (-other) 137 | 138 | def __rsub__(self, other): 139 | return other + (-self) 140 | 141 | def __isub__(self, other): 142 | return self - other 143 | 144 | def __neg__(self): 145 | data = - self.data 146 | requires_grad = self.requires_grad 147 | t = Tensor(data, requires_grad) 148 | t.is_leaf = False 149 | 150 | if requires_grad: 151 | def NegBackward(grad): 152 | return -grad 153 | t.grad_node.append(GRAD_NODE_FMT(self, NegBackward)) 154 | 155 | return t 156 | 157 | def __mul__(self, other): 158 | other = Tensor.astensor(other) 159 | data = self.data * other.data 160 | requires_grad = self.requires_grad or other.requires_grad 161 | t = Tensor(data, requires_grad) 162 | t.is_leaf = False 163 | 164 | if requires_grad: 165 | def MulBackward(grad): 166 | grad = grad * other.data 167 | 168 | for _ in range(grad.ndim - self.data.ndim): 169 | grad = grad.sum(0) 170 | for i, d in enumerate(self.data.shape): 171 | if d == 1: 172 | grad = grad.sum(axis=i, keepdims=True) 173 | 174 | assert grad.shape == self.data.shape, 'MulBackward, grad.shape != data.shape' 175 | return grad 176 | t.grad_node.append(GRAD_NODE_FMT(self, MulBackward)) 177 | 178 | if other.requires_grad: 179 | def MulBackward(grad): 180 | grad = grad * self.data 181 | 182 | for _ in range(grad.ndim - other.data.ndim): 183 | grad = grad.sum(0) 184 | for i, d in enumerate(self.data.shape): 185 | if d == 1: 186 | grad = grad.sum(axis=i, keepdims=True) 187 | 188 | assert grad.shape == other.data.shape, 'MulBackward, grad.shape != data.shape' 189 | return grad 190 | t.grad_node.append(GRAD_NODE_FMT(other, MulBackward)) 191 | 192 | return t 193 | 194 | def __rmul__(self, other): 195 | return self * other 196 | 197 | def __imul__(self, other): 198 | return self * other 199 | 200 | def __truediv__(self, other): 201 | other = Tensor.astensor(other) 202 | data = self.data / other.data 203 | requires_grad = self.requires_grad or other.requires_grad 204 | t = Tensor(data, requires_grad) 205 | t.is_leaf = False 206 | 207 | if self.requires_grad: 208 | def DivBackward(grad): 209 | grad = grad / other.data 210 | 211 | for _ in range(grad.ndim - self.data.ndim): 212 | grad = grad.sum(0) 213 | for i, d in enumerate(self.data.shape): 214 | if d == 1: 215 | grad = grad.sum(axis=i, keepdims=True) 216 | 217 | assert grad.shape == self.data.shape, 'DivBackward, grad.shape != data.shape' 218 | return grad 219 | t.grad_node.append(GRAD_NODE_FMT(self, DivBackward)) 220 | 221 | if other.requires_grad: 222 | def DivBackward(grad): 223 | grad = - (self.data * grad) / (other.data**2) 224 | 225 | for _ in range(grad.ndim - other.data.ndim): 226 | grad = grad.sum(0) 227 | for i, d in enumerate(other.shape): 228 | if d == 1: 229 | grad = grad.sum(axis=i, keepdims=True) 230 | 231 | assert grad.shape == other.data.shape, 'DivBackward, grad.shape != data.shape' 232 | return grad 233 | t.grad_node.append(GRAD_NODE_FMT(other, DivBackward)) 234 | 235 | return t 236 | 237 | def __floordiv__(self, other): 238 | raise NotImplementedError('__floordiv__ not implemented') 239 | 240 | def sum(self, dim=None, keepdim=False): 241 | data = self.data.sum(axis=dim, keepdims=keepdim) 242 | requires_grad = self.requires_grad 243 | t = Tensor(data, requires_grad) 244 | t.is_leaf = False 245 | 246 | if self.requires_grad: 247 | def SumBackward(grad): 248 | grad = grad * np.ones_like(self.data) 249 | return grad 250 | t.grad_node.append(GRAD_NODE_FMT(self, SumBackward)) 251 | 252 | return t 253 | 254 | def mean(self, dim=None, keepdim=False): 255 | data = self.data.mean(axis=dim, keepdims=keepdim) 256 | requires_grad = self.requires_grad 257 | t = Tensor(data, requires_grad) 258 | t.is_leaf = False 259 | 260 | if self.requires_grad: 261 | def MeanBackward(grad): 262 | grad = grad * np.ones_like(self.data) / (self.data.reshape(-1).shape[0] / data.reshape(-1).shape[0]) 263 | return grad 264 | t.grad_node.append(GRAD_NODE_FMT(self, MeanBackward)) 265 | 266 | return t 267 | 268 | def __matmul__(self, other): 269 | other = Tensor.astensor(other) 270 | 271 | if self.data.ndim == 1: 272 | self = self.reshape(1, -1) 273 | if other.data.ndim == 1: 274 | other = other.reshape(-1, 1) 275 | 276 | data = self.data @ other.data 277 | requires_grad = self.requires_grad or other.requires_grad 278 | t = Tensor(data, requires_grad) 279 | t.is_leaf = False 280 | 281 | if self.requires_grad: 282 | def DotBackward(grad): 283 | # d = other.data.reshape(-1, 1).T if other.data.ndim == 1 else other.data.T 284 | grad = grad @ other.data.T 285 | assert grad.shape == self.data.shape, 'DotBackward, grad.shape != data.shape' 286 | return grad 287 | t.grad_node.append(GRAD_NODE_FMT(self, DotBackward)) 288 | 289 | if other.requires_grad: 290 | def DotBackward(grad): 291 | # d = self.data.reshape(-1, 1) if other.data.ndim == 1 else self.data.T 292 | grad = self.data.T @ grad 293 | assert grad.shape == other.data.shape, 'DotBackward, grad.shape != data.shape' 294 | return grad 295 | t.grad_node.append(GRAD_NODE_FMT(other, DotBackward)) 296 | 297 | return t 298 | 299 | def tanh(self): 300 | return F.tanh(self) 301 | 302 | def relu(self): 303 | return F.relu(self) 304 | 305 | def pow(self, n): 306 | data = np.power(self.data, n) 307 | requires_grad = self.requires_grad 308 | t = Tensor(data, requires_grad) 309 | t.is_leaf = False 310 | 311 | if self.requires_grad: 312 | def PowBackward(grad): 313 | return grad * (n * np.power(self.data, n-1)) 314 | t.grad_node.append(GRAD_NODE_FMT(self, PowBackward)) 315 | 316 | return t 317 | 318 | def softmax(self, dim=0): 319 | return F.softmax(self, dim) 320 | 321 | def abs(self): 322 | data = np.abs(self.data) 323 | requires_grad = self.requires_grad 324 | t = Tensor(data, requires_grad) 325 | t.is_leaf = False 326 | 327 | if self.requires_grad: 328 | def AbsBackward(grad): 329 | assert grad.shape == self.data.shape, 'AbsBackward, grad.shape != data.shape' 330 | return grad * np.sign(self.data) 331 | t.grad_node.append(GRAD_NODE_FMT(self, AbsBackward)) 332 | 333 | return t 334 | 335 | @staticmethod 336 | def astensor(data): 337 | if not isinstance(data, Tensor): 338 | data = Tensor(data) 339 | return data 340 | 341 | 342 | def random(*shape, requires_grad=True): 343 | return Tensor(np.random.rand(*shape), requires_grad) 344 | -------------------------------------------------------------------------------- /example/Predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "chemical-wrong", 6 | "metadata": {}, 7 | "source": [ 8 | "# Predict\n", 9 | "\n", 10 | "使用神经网络预测波士顿房价" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "urban-minutes", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import sys\n", 21 | "sys.path.append('..')\n", 22 | "\n", 23 | "import numpy as np\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from sklearn.datasets import load_boston\n", 26 | "from tqdm import tqdm\n", 27 | "from easytorch.layer import Linear, ReLU, Tanh, Sequential\n", 28 | "from easytorch.optim import SGD\n", 29 | "from easytorch.tensor import Tensor\n", 30 | "import easytorch.functional as F" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "published-ridge", 36 | "metadata": {}, 37 | "source": [ 38 | "## 1. 加载数据" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "id": "wrong-player", 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "((506, 13), (506,))" 51 | ] 52 | }, 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "dataset = load_boston()\n", 60 | "data_x = dataset.data\n", 61 | "data_y = dataset.target\n", 62 | "data_name = dataset.feature_names\n", 63 | "data_x = (data_x - data_x.mean(axis=0)) / (data_x.std(axis=0) + 1e-6)\n", 64 | "data_x.shape, data_y.shape" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "affecting-malpractice", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "train_x = Tensor(data_x)\n", 75 | "train_y = Tensor(data_y)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "id": "absolute-edinburgh", 81 | "metadata": {}, 82 | "source": [ 83 | "## 2. 搭建模型及训练" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "id": "agricultural-orleans", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "model = Sequential(\n", 94 | " Linear(13, 10),\n", 95 | " ReLU(),\n", 96 | " Linear(10, 1)\n", 97 | ")\n", 98 | "opt = SGD(model.parameters(), lr=3e-4)\n", 99 | "loss_fn = F.l1_loss\n", 100 | "loss_list = []" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 5, 106 | "id": "german-seven", 107 | "metadata": { 108 | "tags": [] 109 | }, 110 | "outputs": [ 111 | { 112 | "name": "stderr", 113 | "output_type": "stream", 114 | "text": [ 115 | "100%|██████████| 500/500 [00:44<00:00, 11.21it/s]\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "for _ in tqdm(range(500)):\n", 121 | " sum_loss = 0\n", 122 | " for x, y in zip(train_x, train_y):\n", 123 | " pred = model(x)\n", 124 | " loss = loss_fn(pred, y.reshape(1, 1))\n", 125 | " sum_loss += loss.data\n", 126 | " opt.zero_grad()\n", 127 | " loss.backward()\n", 128 | " opt.step()\n", 129 | " loss_list.append(sum_loss / len(train_x))" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "typical-trigger", 135 | "metadata": {}, 136 | "source": [ 137 | "## 3. 结果" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "id": "popular-argentina", 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD6CAYAAACvZ4z8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAdlklEQVR4nO3de5CddZ3n8ff33E/f0t1JJ4QkGAgsDssqYBsBGRa8MMBa6uyyu1BTTtxlNqPrVOleatS1atx1aquc2lldR2ZkWGF0qlzUWWGlFIWIIuN4gQ6XEIhIwEASYrpz7U5fTp/Ld/94npMcmu6k6XM6p/P8Pq+qU+d5fs9zzvP7hebz/M7vuZm7IyIiyZVqdwVERGRxKehFRBJOQS8iknAKehGRhFPQi4gknIJeRCThThn0ZrbOzH5kZs+a2TNm9tG4vN/MtpjZ8/F73xyf3xSv87yZbWp1A0RE5OTsVOfRm9lqYLW7P25m3cBW4P3AB4FD7v5ZM/sE0OfuH5/x2X5gCBgEPP7sW9z98Mm2uWLFCl+/fv2CGiQiEqKtW7cecPeB2ZZlTvVhd98H7Iunx8xsB7AGeB9wTbzaV4GHgY/P+PjvAFvc/RCAmW0BrgfuPtk2169fz9DQ0KmqJiIiMTN7aa5lr2uM3szWA5cCvwBWxTsBgN8Aq2b5yBpgd8P8nrhstu/ebGZDZjY0MjLyeqolIiInMe+gN7Mu4FvAx9x9tHGZR+M/Td1Lwd3vcPdBdx8cGJj114eIiCzAvILezLJEIf81d78nLt4fj9/Xx/GHZ/noXmBdw/zauExERE6T+Zx1Y8CdwA53/1zDovuA+lk0m4Bvz/LxB4DrzKwvPivnurhMREROk/n06N8OfAB4h5k9Gb9uBD4LvNvMngfeFc9jZoNm9mWA+CDsnwKPxa/P1A/MiojI6XHK0yvbYXBw0HXWjYjI/JnZVncfnG2ZrowVEUm4xAR9uVrjrx7eySO/0qmZIiKNEhP0mZRxxyMv8r3t+069sohIQBIT9GbGG8/qZse+sXZXRURkSUlM0AO88awenvvNGLXa0jvALCLSLokK+t9a3c1kucqvD463uyoiIktGooJ+cH0/AL94Uafqi4jUJSroz1vRycruPD978WC7qyIismQkKujNjCs3LOdnLxxkKV4IJiLSDokKeoArNiznwLESO4ePtbsqIiJLQuKC/soNKwD46QsavhERgQQG/br+Dtb1F/nJzgPtroqIyJKQuKAHuOr8AX7+wkEq1Vq7qyIi0nYJDfoVjJUqPLXnaLurIiLSdokM+is3LMcM/kHDNyIiyQz6vs4cF5+9TOP0IiIkNOgB3n7+Cp54+TDjpUq7qyIi0laJDforNyynXHWeePlIu6siItJWiQ36N61dBsDTe3VAVkTCltig7+3Isa6/yHYFvYgELnOqFczsLuA9wLC7XxyXfQO4MF6lFzji7pfM8tldwBhQBSpzPbh2sbxpTS/b9h45nZsUEVly5tOj/wpwfWOBu/9rd78kDvdvAfec5PPXxuue1pAHuHjNMnYfmuTIxPTp3rSIyJJxyqB390eAWW/wbmYG/Cvg7hbXqyX+yZponH773tE210REpH2aHaP/bWC/uz8/x3IHHjSzrWa2+WRfZGabzWzIzIZGRkaarFbk4jU9ADy7T+P0IhKuZoP+Fk7em7/K3S8DbgA+YmZXz7Wiu9/h7oPuPjgwMNBktSK9HTmWd+Z4cUSPFhSRcC046M0sA/xz4BtzrePue+P3YeBeYONCt7dQ5w10KuhFJGjN9OjfBfzS3ffMttDMOs2suz4NXAdsb2J7C3Luik5ePKCHkIhIuE4Z9GZ2N/Az4EIz22Nmt8aLbmbGsI2ZnW1m98ezq4CfmNlTwKPAd939+62r+vycu6KLA8emGZsqn+5Ni4gsCac8j97db5mj/IOzlL0C3BhPvwi8ucn6NW1NXxGAV45MceFZ2TbXRkTk9EvslbF1a3oLALxyZLLNNRERaY8Agr4DgL0KehEJVOKDfqA7TyZlCnoRCVbigz6dMlb1FNinoBeRQCU+6CHq1R8c1/1uRCRMQQT9iq4cB44p6EUkTEEE/fLOPAePldpdDRGRtggj6LtyHBqfxt3bXRURkdMukKDPU6k5o5N6ULiIhCeMoO/MAXBgXMM3IhKeMIK+Kwr6gzogKyIBCiLoe4tR0B+d1I3NRCQ8QQR9TzG6d5uCXkRCFETQLytGd60cVdCLSICCCPruQhT06tGLSIiCCPp0yujOZxjVw0dEJEBBBD1ATzGrHr2IBCmooNcFUyISonCCvpDRwVgRCdJ8Hg5+l5kNm9n2hrL/amZ7zezJ+HXjHJ+93syeM7OdZvaJVlb89eopZjVGLyJBmk+P/ivA9bOUf97dL4lf989caGZp4C+BG4CLgFvM7KJmKtuMnkJWPXoRCdIpg97dHwEOLeC7NwI73f1Fd58Gvg68bwHf0xIduTQT5Wq7Ni8i0jbNjNH/kZlti4d2+mZZvgbY3TC/Jy6blZltNrMhMxsaGRlpolqz68inmZhW0ItIeBYa9F8CNgCXAPuA/9lsRdz9DncfdPfBgYGBZr/uNTpzGaYrNcrVWsu/W0RkKVtQ0Lv7fnevunsN+N9EwzQz7QXWNcyvjcvaoiOXBlCvXkSCs6CgN7PVDbO/C2yfZbXHgAvM7FwzywE3A/ctZHut0JGLbmw2qaAXkcBkTrWCmd0NXAOsMLM9wKeBa8zsEsCBXcAfxuueDXzZ3W9094qZ/RHwAJAG7nL3ZxajEfPRmY969OPTumhKRMJyyqB391tmKb5zjnVfAW5smL8feM2pl+1Q79FPlNSjF5GwBHNl7IkxevXoRSQsAQa9evQiEpZggr4zHw3daIxeREITTNAXs+rRi0iYggn6eo9+oqQevYiEJZigr4/Rj6tHLyKBCSbo85kUZlDSjc1EJDDBBL2ZUcikmVTQi0hgggl6gEI2xVRZNzUTkbAEFfTFrHr0IhKeoIK+kE0zpaAXkcAo6EVEEi6woNcYvYiEJ6igL+Y0Ri8i4Qkq6AsZDd2ISHjCCnr16EUkQGEFfSZNSWP0IhKYoIK+mEupRy8iwQkq6DVGLyIhCiro62fduHu7qyIictqcMujN7C4zGzaz7Q1l/8PMfmlm28zsXjPrneOzu8zsaTN70syGWljvBSlk07jDdFXj9CISjvn06L8CXD+jbAtwsbu/CfgV8MmTfP5ad7/E3QcXVsXWKcRPmZqaVtCLSDhOGfTu/ghwaEbZg+5ef1TTz4G1i1C3litko+ZOVTROLyLhaMUY/b8FvjfHMgceNLOtZrb5ZF9iZpvNbMjMhkZGRlpQrdeqPzd2Uk+ZEpGANBX0ZvYpoAJ8bY5VrnL3y4AbgI+Y2dVzfZe73+Hug+4+ODAw0Ey15nR86EY9ehEJyIKD3sw+CLwH+D2f4zQWd98bvw8D9wIbF7q9VlCPXkRCtKCgN7PrgT8G3uvuE3Os02lm3fVp4Dpg+2zrni75+hi9ro4VkYDM5/TKu4GfARea2R4zuxW4DegGtsSnTt4er3u2md0ff3QV8BMzewp4FPiuu39/UVoxT/UevS6aEpGQZE61grvfMkvxnXOs+wpwYzz9IvDmpmrXYgUFvYgEKKwrY+tj9Ap6EQlIUEF/okevMXoRCUdQQa8evYiEKKigP3HWjYJeRMIRVtBnUphBSUEvIgEJKujNjEJGjxMUkbAEFfQQ3ZNeB2NFJCTBBX0ho8cJikhYwgv6nB4nKCJhCS/o9dxYEQlMcEGvMXoRCU1wQV/IaoxeRMISXNAXsxq6EZGwBBf0+azOoxeRsAQX9MVsmpLG6EUkIMEFvcboRSQ0wQW9xuhFJDTBBX0hHqOf43nmIiKJE2TQu8N0VeP0IhKGeQW9md1lZsNmtr2hrN/MtpjZ8/F73xyf3RSv87yZbWpVxRfq+FOmphX0IhKG+fbovwJcP6PsE8BD7n4B8FA8/ypm1g98GngbsBH49Fw7hNOl/pSpqYrG6UUkDPMKend/BDg0o/h9wFfj6a8C75/lo78DbHH3Q+5+GNjCa3cYp1UhfsrU5LSCXkTC0MwY/Sp33xdP/wZYNcs6a4DdDfN74rLXMLPNZjZkZkMjIyNNVOvk1KMXkdC05GCsR6ewNHUai7vf4e6D7j44MDDQimrNqj5Grx69iISimaDfb2arAeL34VnW2Qusa5hfG5e1zfGDsbo6VkQC0UzQ3wfUz6LZBHx7lnUeAK4zs774IOx1cVnb1MfoddGUiIRivqdX3g38DLjQzPaY2a3AZ4F3m9nzwLviecxs0My+DODuh4A/BR6LX5+Jy9qmmKv36BX0IhKGzHxWcvdb5lj0zlnWHQL+oGH+LuCuBdVuERQy8Ri9gl5EAhHclbEnevQaoxeRMAQX9OrRi0howgv6nA7GikhYggv6XDqFmYJeRMIRXNCbme5JLyJBCS7o4cQ96UVEQhBk0Ec9ep11IyJhCDLo83purIgEJMigL2bTlBT0IhKIIINeY/QiEpIgg15j9CISkiCDvpBN6X70IhKMQIM+rSdMiUgwwg169ehFJBBBBn1HLs24gl5EAhFk0PcUsoxNlYkedSsikmxBBv2yYpaaw7FSpd1VERFZdMEGPcDolIJeRJIvyKDvKUZPUDw6UW5zTUREFt+Cg97MLjSzJxteo2b2sRnrXGNmRxvW+ZOma9wCPXGP/uikgl5Ekm9eDwefjbs/B1wCYGZpYC9w7yyr/r27v2eh21kMPYX60I2CXkSSr1VDN+8EXnD3l1r0fYtqmXr0IhKQVgX9zcDdcyy7wsyeMrPvmdk/btH2mlIfuhlV0ItIAJoOejPLAe8F/m6WxY8Db3D3NwNfBP7fSb5ns5kNmdnQyMhIs9U6qe58BjMFvYiEoRU9+huAx919/8wF7j7q7sfi6fuBrJmtmO1L3P0Odx9098GBgYEWVGtuqZTRnc/o9EoRCUIrgv4W5hi2MbOzzMzi6Y3x9g62YJtNW9aR1Ri9iARhwWfdAJhZJ/Bu4A8byj4E4O63AzcBHzazCjAJ3OxL5L4DPQUFvYiEoamgd/dxYPmMstsbpm8DbmtmG4tlWTGrMXoRCUKQV8aCevQiEo5gg35ZMasLpkQkCMEGfU8xox69iAQh2KBfVswyVa5R0iMFRSThgg56gNFJnUsvIskWbtB35AA4PDHd5pqIiCyuYIN+oCsPwIGxUptrIiKyuMIN+u4o6EeOKehFJNnCDfq4Rz+iHr2IJFywQd9TzJBLp9SjF5HECzbozYyB7rx69CKSeMEGPcAKBb2IBCDooB/oynHgmE6vFJFkCzvo1aMXkQCEHfRdeQ6Nl6jWlsQt8kVEFkXYQd+dp+ZwcFy9ehFJrqCDfsXxq2M1Ti8iyRV00OvqWBEJQdBBv7K7AMD+o1NtromIyOIJOuhX9xZIGew+PNHuqoiILJqmg97MdpnZ02b2pJkNzbLczOwvzGynmW0zs8ua3WarZNMpVi8rsvuQgl5EkivTou+51t0PzLHsBuCC+PU24Evx+5Kwrr/I7sOT7a6GiMiiOR1DN+8D/tYjPwd6zWz1adjuvKzr61CPXkQSrRVB78CDZrbVzDbPsnwNsLthfk9c9ipmttnMhsxsaGRkpAXVmp91/R0Mj5WYKuvZsSKSTK0I+qvc/TKiIZqPmNnVC/kSd7/D3QfdfXBgYKAF1Zqfdf1FAPZo+EZEEqrpoHf3vfH7MHAvsHHGKnuBdQ3za+OyJWFdXwegM29EJLmaCnoz6zSz7vo0cB2wfcZq9wG/H599czlw1N33NbPdVlrXHwX9Ho3Ti0hCNXvWzSrgXjOrf9f/cffvm9mHANz9duB+4EZgJzAB/Jsmt9lSA1158pkUuw4q6EUkmZoKend/EXjzLOW3N0w78JFmtrOYUinjjat72L73aLurIiKyKIK+Mrbu0nW9bNtzlEq11u6qiIi0nIIeuPScXibLVZ7bP9buqoiItJyCHrjsnD4Annj5SHsrIiKyCBT0wNq+Iiu6cgp6EUkkBT1gZrx1fT8//tUI0xWN04tIsijoYzdvPIcDx0p89+lX2l0VEZGWUtDHrr5gBRsGOvmbf9hFdEaoiEgyKOhjZsatV53Htj1H+cGO4XZXR0SkZRT0Df7l4FouWNnFx7+1TbcuFpHEUNA3yKZT/PUH3kKlWmPT3zyqsBeRRFDQz3DeQBd3ffCtHBgr8bt/9VO2vnS43VUSEWmKgn4Wg+v7ueffX0k+k+Km23/Kx//vNkbGSu2ulojIgijo53D+ym6+97Hf5g+uOpdvPb6Ha//8Yb708At6EpWInHEU9CfRU8jyqX92EQ/+h6u5/Lzl/Nn3f8m1f/4wdz/6MmXdAE1EzhAK+nk4b6CLL28a5O5/dzlnLSvwyXue5rrPP8I9j+/RHS9FZMlT0L8OV2xYzj0fvpI7Nw1SyKb5j998ind+7sd887Hd6uGLyJJlS/Eq0MHBQR8aGmp3NU6qVnN+sGM/X/zhTp7ee5Q1vUU2X30e/+Ita+nKN/vgLhGR18fMtrr74KzLFPTNcXce/tUIX3zoeR5/+Qhd+Qw3vWUtH7jiDWwY6Gp39UQkEAr60+Sp3Uf46k938Z1t+5iu1ti4vp+3ntvHpev6uOwNffR35tpdRRFJqEUJejNbB/wt0QPCHbjD3b8wY51rgG8Dv46L7nH3z5zqu8/UoK8bGSvx9Udf5sFn97Nj3yiVWvRvvH55B5ed08fFa5Zx0dk9/NbqHpYVs22urYgkwWIF/Wpgtbs/bmbdwFbg/e7+bMM61wD/2d3f83q++0wP+kaT01We3nuUx18+zOMvHeaJ3UdedfHV2r4iF67q5vyVXWwY6GLDyk42DHTR26Hev4jM38mCfsFHDd19H7Avnh4zsx3AGuDZk34wMMVcmo3n9rPx3P7jZcNjU+zYN8Yzrxzl2VdG2Tl8jL/feeBVDz3p78xxdm+B1cuKrOrJ05XP0l3IsLI7z1nLCqzqKdDbkaUrn6GYTWNm7WieiJwBWnJ6iJmtBy4FfjHL4ivM7CngFaLe/TOt2OaZbGV3gZXdBf7pPxo4XlatOXsOT/DCyDF2Dh/j1wfG2Xd0ipcPTvDYrkNMlKpMz3EKZyGbYqA7z0BXnv7OHL0dOXqLWXo7sizryNHXkWV5Z55lxSw9xQzdhSzd+QyplHYOIiFo+mCsmXUBPwb+u7vfM2NZD1Bz92NmdiPwBXe/YI7v2QxsBjjnnHPe8tJLLzVVrySaKlfZPzrF/tESvxmd4uhkmWNTFQ5PTDM8OsXwWInDE2WOTkxzZLLMxPTct2swg658hp5C9EuhpxDtBI7PFxvLT0x3FzJ05DKYQV9HjlxGl2KILAWLdtaNmWWB7wAPuPvn5rH+LmDQ3Q+cbL0kjdG3U6lS5ehkmcPjZQ4eKzE6VWZ0shK9T1UYnSwzNhXPN0yPTVUYmypTm8efRj6TopBNM9CdZ3lnju5Clp5Chu5ChoHuPH2dOVJmrOktxr8o6suz2kmItNCijNFbNCh8J7BjrpA3s7OA/e7uZraR6Ercgwvdprw++Uyald1pVnYXgO7X9dlazRmfrrwq/Ecny4xORb8U3OHQ+DTjpQqT8S+NwxNl9h6Z5JfxjmN0qnLSbWRSRl9njlrNWREPO3UVMnTnM9H78V8RWQrZFJ35DF31VyE6NpFNp8hlUuTjl45ViLxWM2P0bwc+ADxtZk/GZf8FOAfA3W8HbgI+bGYVYBK42ZfiifvyGqmURWP5hSxnU1zQdxwrVZgoVZiu1njlyBRjU+UTvyomy0yUqxwenyaVMkbGShydKLPn8CTHStGQ1OhUhep8flY0yMXBXw//xvdcOkU+k55lefrV66Ybl0fL0imjvg/pKWaZnK5SzKYp5tKkU0al6uSzKdyjYbFMyihk08d3RpVajVw6RTadImVGqVolm0rhQEcu2ka15lRqTjplpM10DEVappmzbn4CnPQv0d1vA25b6DbkzFbvfQOs7et43Z93dyamq4xOlZkq1xgvVRgvVThWin5pTJWrlKs1SpXoNd3wPl2tUirXmK6+urxUqTIxUXn1+tUapXJ0sLtUqdGOrogZr9lufYeRMiOTsmgHkDLSqRSZlDE+XaEjl6YznyGXTpFJG9l0imwqmk7ZiZ1TfQdUzKYBqMXzPYUstXjDlZpTc493SEYmnWKiVIm+M5OiVK7RU8wc/756HVNmpCx6Qls2HQ3HjZcq5DIndmxmkI53XLWaU8ylyaQMiz9rZpSrNYrZqD1O9N+/Uf0XW6XqVGq16IyzXBr36GSGTMrIZlKUKzWq7mTTqXgb8/n3N3oKmePDldl0Kt7ZRvWO2mjxdNQRSsdlqVT0b7CUd866KYssWWZGZz5D52m8d5B71KueuXOo1vx4+OwfLdHXkaNUqTI5XaVSczJpo1SpYURXD1arzlSlysR0lUo1CqFyrRaHEGTTRrnqpA0mytH3ZFIpshmjFvfs6+9Vd6pVPx7ElVo0X8immJiuMlGuUq7UqNSccrVGOd65ORwPcSMK9/o1HCkzau6MTpaZrjopi35ZpOJfJ/XvyWfS1NyZKke/YMbqw3H1PPO4vbUofMvVaHvFbJpytXb8YsGQpOwUO4d4/sS0Hf/FuKIzzzc/dEXL66SgF2lgZmTjnnFnfvZ1zl/5+o53hKQW74wycc++Fgd9zaMdFkQ7mYnp6vF1PV4vm05xrFShVKmfLRaFX30nNV2pUav31NPGkYky05Xa8SCt1pxy1UmloiG86cr8dzTVmjM6VSaTSlHzaEdX8xPtqbrH09G6NY/La1HbavEOuf6Zary8VmtYxz3+7GvXqTlU3elepE6Ngl5EWiaVMlINI7r1oYwU9qqwWVac/YyrPt0PalHo/DYRkYRT0IuIJJyCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScEvy4eBmNgIs9Ib0K4CT3gY5gdTmMKjNYVhom9/g7gOzLViSQd8MMxua657MSaU2h0FtDsNitFlDNyIiCaegFxFJuCQG/R3trkAbqM1hUJvD0PI2J26MXkREXi2JPXoREWmgoBcRSbjEBL2ZXW9mz5nZTjP7RLvr0ypmdpeZDZvZ9oayfjPbYmbPx+99cbmZ2V/E/wbbzOyy9tV84cxsnZn9yMyeNbNnzOyjcXli221mBTN71Myeitv83+Lyc83sF3HbvmFmubg8H8/vjJevb2sDmmBmaTN7wsy+E88nus1mtsvMnjazJ81sKC5b1L/tRAS9maWBvwRuAC4CbjGzi9pbq5b5CnD9jLJPAA+5+wXAQ/E8RO2/IH5tBr50murYahXgP7n7RcDlwEfi/55JbncJeIe7vxm4BLjezC4H/gz4vLufDxwGbo3XvxU4HJd/Pl7vTPVRYEfDfAhtvtbdL2k4X35x/7bd/Yx/AVcADzTMfxL4ZLvr1cL2rQe2N8w/B6yOp1cDz8XTfw3cMtt6Z/IL+Dbw7lDaDXQAjwNvI7pCMhOXH/87Bx4AroinM/F61u66L6Cta+NgewfwHaJHxCa9zbuAFTPKFvVvOxE9emANsLthfk9cllSr3H1fPP0bYFU8nbh/h/jn+aXAL0h4u+MhjCeBYWAL8AJwxN0r8SqN7Tre5nj5UWD5aa1wa/wv4I+BWjy/nOS32YEHzWyrmW2Oyxb1b1sPBz/DububWSLPkTWzLuBbwMfcfdTsxEOnk9hud68Cl5hZL3Av8Mb21mhxmdl7gGF332pm17S5OqfTVe6+18xWAlvM7JeNCxfjbzspPfq9wLqG+bVxWVLtN7PVAPH7cFyemH8HM8sShfzX3P2euDjx7QZw9yPAj4iGLXrNrN4ha2zX8TbHy5cBB09vTZv2duC9ZrYL+DrR8M0XSHabcfe98fsw0Q59I4v8t52UoH8MuCA+Wp8Dbgbua3OdFtN9wKZ4ehPRGHa9/PfjI/WXA0cbfg6eMSzqut8J7HD3zzUsSmy7zWwg7sljZkWiYxI7iAL/pni1mW2u/1vcBPzQ40HcM4W7f9Ld17r7eqL/Z3/o7r9HgttsZp1m1l2fBq4DtrPYf9vtPjDRwgMcNwK/IhrX/FS769PCdt0N7APKRONztxKNSz4EPA/8AOiP1zWis49eAJ4GBttd/wW2+SqiccxtwJPx68Yktxt4E/BE3ObtwJ/E5ecBjwI7gb8D8nF5IZ7fGS8/r91taLL91wDfSXqb47Y9Fb+eqWfVYv9t6xYIIiIJl5ShGxERmYOCXkQk4RT0IiIJp6AXEUk4Bb2ISMIp6EVEEk5BLyKScP8fw3NmdtJQuqEAAAAASUVORK5CYII=\n", 149 | "text/plain": [ 150 | "
" 151 | ] 152 | }, 153 | "metadata": { 154 | "needs_background": "light" 155 | }, 156 | "output_type": "display_data" 157 | } 158 | ], 159 | "source": [ 160 | "plt.plot(loss_list)\n", 161 | "plt.show()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "id": "comic-salem", 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/plain": [ 173 | "tensor(1.98941591887784, grad_fn=)" 174 | ] 175 | }, 176 | "execution_count": 7, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | } 180 | ], 181 | "source": [ 182 | "pred = model(train_x)\n", 183 | "loss = loss_fn(pred, train_y.reshape(-1, 1)).mean()\n", 184 | "loss" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "interior-virgin", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [] 194 | } 195 | ], 196 | "metadata": { 197 | "kernelspec": { 198 | "display_name": "Python 3", 199 | "language": "python", 200 | "name": "python3" 201 | }, 202 | "language_info": { 203 | "codemirror_mode": { 204 | "name": "ipython", 205 | "version": 3 206 | }, 207 | "file_extension": ".py", 208 | "mimetype": "text/x-python", 209 | "name": "python", 210 | "nbconvert_exporter": "python", 211 | "pygments_lexer": "ipython3", 212 | "version": "3.6.9" 213 | } 214 | }, 215 | "nbformat": 4, 216 | "nbformat_minor": 5 217 | } 218 | -------------------------------------------------------------------------------- /example/FunctionApproximation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "alternative-oxide", 6 | "metadata": {}, 7 | "source": [ 8 | "# Function Approximation\n", 9 | "\n", 10 | "使用单层的神经网络近似三角函数。" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "beneficial-footwear", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import sys\n", 21 | "sys.path.append('..')\n", 22 | "\n", 23 | "import numpy as np\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from tqdm import tqdm\n", 26 | "from easytorch.layer import Linear, ReLU, Tanh, Sequential\n", 27 | "from easytorch.optim import SGD\n", 28 | "from easytorch.tensor import Tensor\n", 29 | "import easytorch.functional as F" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "split-soldier", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "def generate_data(n=100, f=lambda x: 2*np.sin(x) + np.cos(x)):\n", 40 | " data = []\n", 41 | " for _ in range(n):\n", 42 | " x = np.random.uniform(-3, 3)\n", 43 | " y = f(x) + 0.03 * np.random.randn()\n", 44 | " data.append([x, y])\n", 45 | " return data" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "id": "alpha-thanksgiving", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "train_data = generate_data()\n", 56 | "x = Tensor(np.array([x for x, y in train_data]).reshape(-1, 1))\n", 57 | "y = Tensor(np.array([y for x, y in train_data]).reshape(-1, 1))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "id": "saved-breakdown", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "model = Sequential(\n", 68 | " Linear(1, 5),\n", 69 | " Tanh(),\n", 70 | " Linear(5, 1)\n", 71 | ")\n", 72 | "opt = SGD(model.parameters(), lr=3e-3)\n", 73 | "loss_fn = F.mse_loss\n", 74 | "loss_list = []" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "id": "million-immunology", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stderr", 85 | "output_type": "stream", 86 | "text": [ 87 | "100%|██████████| 700/700 [00:16<00:00, 42.21it/s]\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "for epoch in tqdm(range(700)):\n", 93 | " for data_x, data_y in zip(x, y):\n", 94 | " pred = model(data_x)\n", 95 | " loss = loss_fn(pred, data_y.reshape(-1, 1))\n", 96 | " opt.zero_grad()\n", 97 | " loss.backward()\n", 98 | " opt.step()\n", 99 | " loss_list.append(loss.data)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "id": "informative-characteristic", 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAD4CAYAAAATpHZ6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAqyUlEQVR4nO3dfXyU1bXo8d+aZEImIAmSKM2LF87RDyCIvESuPaBVkaoHRLAV3zjV04NYOS09vfdE8eoNIdWSGo+ltrWSIq1etRqPEEMppRW1iq0tCcEIAq1ULAmigCaWZiCTzL5/TCZkkpm8MJN55nlmfT+ffGD2szOzBuPKnv3svbYYY1BKKeUMLqsDUEopFTua1JVSykE0qSullINoUldKKQfRpK6UUg6SasWLZmdnm9GjR1vx0kopZVu1tbVHjTE5vfWxJKmPHj2ampoaK15aKaVsS0Q+6KuPTr8opZSDaFJXSikH0aSulFIOYsmcejg+n4+GhgZOnDhhdSi2l56eTn5+Pm632+pQlFJxljBJvaGhgTPOOIPRo0cjIlaHY1vGGI4dO0ZDQwNjxoyxOhylVJwlzPTLiRMnGDlypCb0KIkII0eO1E88qm/1lfC9iVCSFfizvtLqiFQMJMxIHdCEHiP676h6qK/k5Mb/JM3XDAZaJJ10aSPFtAWuNx+EjcsCf5+0MOJzsLUUmhsgMx9mFUfuqyyTUEldKRVj9ZWw+R6M9xOGBNsEhnICulfd9nkDSTtcoq6vDCR9nzfwuD+/BJQlEmb6xWpNTU089thjcXu90aNHc/To0V77fOc734lTNMpR6is5+eA5mJJMzIt3gPcT+vvZzTQ3hL+wtfRUQg/yeTm5sYjDJefiX5HJ4ZJz2V69JqrQVfQ0qXfoLam3tbXFOZoATepqwOorad9wF0N8zQgw0Jm4j8gO2x4p2ae1NjGKI7gERnGEibX3a2K3mG2TelVdIzPKXmHM8k3MKHuFqrrGqJ5v+fLl7N+/n8mTJ1NUVMRrr73GJZdcwrx58zj//PM5cOAAEydO7Oz/8MMPU1JSAsD+/fu5+uqrmTZtGpdccgl79+7t8fzHjh3ji1/8IhMmTGDx4sV0PXFq/vz5TJs2jQkTJlBRUdEZj9frZfLkydx6660R+ykVYmvpqXnyAWoxaaxqvSHstUjJvvsvDY+0UrCj/LReX8WGLefUq+oauXf9O3h97QA0Nnm5d/07AMyfkndaz1lWVsauXbvYuXMnAK+99ho7duxg165djBkzhgMHDkT83iVLlvD4449z3nnn8Yc//IGlS5fyyiuvhPRZuXIlM2fOpLi4mE2bNvHEE090Xlu3bh1nnnkmXq+Xiy66iC996UuUlZXxwx/+sDOeSP1Gjhx5Wu9XOUS3m5em+WC/p1rajItmk8EI+TuHzEgealtI7fDZYfuuar2BVe61ZEhrZ5sx4T8JnGV6n1ZUg8uWSb18y77OhB7k9bVTvmXfaSf1cKZPn97nWu/jx4/zu9/9jhtuODXCOXnyZI9+r7/+OuvXrwdgzpw5jBgxovPao48+yoYNGwA4ePAgf/7zn8Mm6/72U0kizM1LA70m9eAHxE8ZRonvK1T7Z3Ze87hTWHXV2LDfVzN8Nss/g7tTK8mVYxwyI/FwgpFyvEffjyWbUaf5llT0bJnUDzV5B9R+uoYOHdr599TUVPx+f+fj4Dpwv99PVlZWyIh6IF577TVefvllfv/735ORkcFll10Wdo15f/upJBLm5qUL8BtwdcvsxsAnZhgr204l8hEZbvLSUjnU5CU3y0PRVWMjDoqKrhrLvetbqW499UtgQeqbPCg/CRm9e00aB6cVaVK3kC3n1HOzPANq748zzjiDv/3tbxGvn3322Xz88cccO3aMkydP8otf/AKA4cOHM2bMGF544QUgsKPz7bff7vH9l156Kc8++ywAmzdv5tNPPwWgubmZESNGkJGRwd69e3nrrbc6v8ftduPz+frsp5LH9uo1NJXkY1ZkYpoPRux3zD8MY04l82/6ljKttaIzoXvcKay4dgJvLr+C98vm8ObyK3r9lDt/Sh6rrr+AvCwPAuRlefjCl/6d3dMe4DA5+I1wmBx2TXuAi+bdGeu3rQbAliP1wKjhnZApGI87haIIHx37Y+TIkcyYMYOJEydyzTXXMGfOnJDrbreb4uJipk+fTl5eHuPGjeu89swzz3DXXXfxwAMP4PP5uOmmm7jwwgtDvn/FihXcfPPNTJgwgX/6p3/inHPOAeDqq6/m8ccfZ/z48YwdO5aLL76483uWLFnCpEmTmDp1KuvWrYvYTyWH7dVruLD2/5Ambb3OsRwy2cxsfbTzscedwpem55G390i/RuWRzJ+S1/N7ptwJHUl8VMeXspZ0XYVxWk8gUgA8BZxNYDtDhTHm+719T2Fhoel+SMaePXsYP358v1+3qq6R8i37ovohdbKB/nuqxHe45FxGcaTXPi0mjeW+xdQOn51w/2/s/+mdjPngOaTLxL+M+QLcVm1pXHYiIrXGmMLe+sRipN4G/G9jzA4ROQOoFZHfGGPejcFzRxR21KCUE3WscDnbHIk4QjcGGk125wqWN5dfEd8Y+7D/p3fyDweeC6yW6fIezPu/RZ6cp4k9hqJO6saYD4EPO/7+NxHZA+QBg5rUlUoK9ZX41y/Bhel1I1Fjx5RLbytYrPQ/PqgMG7/QkdjjHpFzxXROXURGA1OAP8TyeZVKNsHpxV+1fJ0zXL1PkZ40KTzUtpC8BJpq6S7F+CPfBzA6nRpLMUvqIjIMeBH4D2PMZ2GuLwGWAJ03CZVSPVXVNbJtw2M8z3MMk557HuDUevNPzDBWcTtX3LCURxM4CbbjIhV/xOux3kyYzGKypFFE3AQS+jPGmPXh+hhjKowxhcaYwpycnFi8rFKOtHNTBaVSQb7raK9TLv9w8lnmZfw/Zi5YmvDJ779lNuHWZBgDb/gnRNxMqAYu6pG6BIp3PwHsMcY8En1ISiW3xa1Pk+Fq7bWPH+H9sjm99kkk6dd9j6deXMa/pLwcMgvzhn8CX/HdxzzXto7dqkc51HHDd2PTzIjPpyKLxUh9BvAvwBUisrPj659j8Ly29tprrzF37tw++1122WV0X97Z3erVq2lpaYlVaCrB5bqO9XrdGHimfVacoomN+VPyyPzSo4xr+zljTj7b+bXmnP/i9mF/pMy9lnzXUVwC+a6jrHY/xsOep6wO25ZisfplG72Xm3CU9vZ2UlJS4vqaq1evZtGiRWRkZMT1dZU1TnhGkeH9sEe7MYG56Wfar2BF21f5igWxRSPSMuSW7/4rGW2hn0xcAtebXwWWc+ohHANiyzIBQMzPVzxw4ADjxo3j1ltvZfz48Xz5y1/uHB2PHj2ae+65h6lTp/LCCy/w61//ms9//vNMnTqVG264gePHA0WNfvWrXzFu3DimTp3aWbyrO6/Xy0033cT48eNZsGABXu+p2h133XUXhYWFTJgwgRUrVgCBIl6HDh3i8ssv5/LLL4/YT9nX9uo1IQdNfHjWF2hLSQ/p02LS+KZvKeeefJoVbV8lL4qSGIkmw3s4bLsALZuL4xuMA9gzqQer0zUfBMypo7WiTOz79u1j6dKl7Nmzh+HDh4ccmjFy5Eh27NjBlVdeyQMPPMDLL7/Mjh07KCws5JFHHuHEiRPccccdbNy4kdraWg4fDv+D+uMf/5iMjAz27NnDypUrqa2t7bz24IMPUlNTQ319Pb/97W+pr69n2bJl5Obm8uqrr/Lqq69G7KfsaXv1GibW3h9y0ETugfV8kD8fMgswCI0mm+W+xSF1W6IpiZFwMvMjXkqPkPBVZPZM6hGO1mJraVRPW1BQwIwZMwBYtGgR27Zt67x24403AvDWW2/x7rvvMmPGDCZPnsyTTz7JBx98wN69exkzZgznnXceIsKiRYvCvsbrr7/eeW3SpElMmjSp81plZSVTp05lypQp7N69m3ffDb9/q7/9VOIr2FGOR0KnHjzSytAPtsK3diElTWyf/zq1w2d3FtJadf0FCb/aZUBmFeOPsBT/kF9LSw+ULQt6EekcxUjt/STd1o91fRwsw2uMYfbs2fz85z8P6Xu6pXeD3n//fR5++GG2b9/OiBEjuP3228OW1u1vP5X4quoamRdh63/XgyYcXxJj0kI2bHiRBf5fhZQMbjFprE1bxGTdmDQg9hypR/q41svHuP7461//yu9//3sAnn32WWbO7Lmk6uKLL+bNN9/kvffeA+Dvf/87f/rTnxg3bhwHDhxg//79AD2SflDXEry7du3qnDr57LPPGDp0KJmZmXz00Uds3ry583u6lgXurZ+yj+AGI3+E/wU/lvDHxzlVyrxHuNt8nQZ/Nn4jNPizKTZLaJv4Ze5d/w6NTV4MpzYmRXt8pZPZM6nPKgZ3txtFbk+gPQpjx47lRz/6EePHj+fTTz/lrrvu6tEnJyeHn/3sZ9x8881MmjSJz3/+8+zdu5f09HQqKiqYM2cOU6dO5ayzzgr7GnfddRfHjx9n/PjxFBcXM23aNAAuvPBCpkyZwrhx47jllls6p4EgUIL36quv5vLLL++1n7KP4AajVOm5y9Jr0jg4tciCqKwzf0oeMxcs5caMn/CPJ5/hxoyfMHPBUl7deyTsxqSVG3dbFGnii7r07umIRend7mczMqs4qqVPBw4cYO7cuezateu0nyORaOndxNZQ/I/ku3qe5dlmXNRNK9ODJjqMXr6JlanrWJSyFReBXHXcDOG+tn/jihu+nnTTMPEqvWuNSQt1/aqyrUgbjFxiNKF3UZr608Au1C5z7WfISf7L/WMe3JTG/CkrrQsuQdlz+mUQjB492jGjdJX4TnjCnxEUqT1Z3ZKyNWz9G7cYFrc+Hf+AbCChRurGmB4rUNTAWTGlpnrXvbTs6vO/xZS3i0ltP7VyqS0lnYxroluW6zQpYe45BPVVTiFZJcxIPT09nWPHjmlCipIxhmPHjpGent53ZxUXnaV0W+5g/5BbeL7lDp6v+St1F5ZCZgEgkFlA6nU/0CnFboxETlH6qSa8hBmp5+fn09DQwJEjvZ/BqPqWnp5Ofn50yztV7ARXumR0bDLKl6OUmgoeql/KRffrlF9vXNP+FVPzRI+l/H5EP9VEkDBJ3e12M2bMGKvDUCrmwpXSzZDWjjlhvdHXq7mPBBJ67U/BdEzFuIfiuna1fqqJIGGSulJOFWnuV+eE+2nuI4Ev1S8JM6eulFPpShcVT5rUlRpkGdeU9iilqytd1GDR6RelYqj70sVA8amFgf/RuuyATo1yB7RSkWhSVypG7q96h2fe+ivBRbnB4lMA86foDujBFv4XanKVEQBN6krFRFVdI8+89VeudW1jRepTnCmB07A+ZRjlL32V+VO+bXGEzta5F4DnyB1ylEMt2azecBOwNOkSu86pKxUD5Vv2ca1rG+XuNYx0HUcEROBMOc5K8xjbq9dYHaKjBfcCdD28ulQq+MNLj1sdWtxpUlcqBg41eSlxP8UQae9xLU3aKNhRbkFUyWNx69Odm7uCMqSVIrMu6Wqva1JXKkpVdY1cl/ImIzgesU/Xk4xU7EVa8z+C4+zcVBHnaKylSV2pKFTVNXLv+nf4z5Tnw1YTDEq2k4ziLdKafxGSrpqjJnWlolC+ZR9eXzu5Enkk3mpSk+4ko3jLuKaUSKUAk23nriZ1paJwqMnLPNe2iGeNtiO8Pe07evDFYJu0kFZ3ZthLybZzV5O6UlG4bdgfKXOvDXvWKG4PKddXaEKPkyHXPqw7d9GkrlRU7nY/32PVBYBfXHDto7rhKJ4mLQzUpE/yGvW6+UipKGR4D4dtdxmTdMkkIejZxTpSVyoqmREOI4nUruJue/UaDpeci39FJodLznX8RjAdqSs1QNur11Cwo5yzzBE+kzM4Q9ykGN+pDm4PzCq2LkDVaXv1GibW3o9HWkFgFEfIrL2f7eDYex06UldqAIJJYhRHcAlk8Tfa/X5OurMIzuPqXHriKNhRHkjoXXik1dE7fHWkrtQAhEsSadLOYZ+bUSVN1gSlIjrLHKHHAacEdvjOKHvFkZUcdaSu1ACcZcIfjK5lABLTx5ITtl0wPN9yB9s2POa42jCa1JUagEhJQssAJKaDU4vwmrQe7dKlkqPTasNoUlcqgqq6RmaUvcKY5ZuYUfYKVXWNYZOE16RpGYAEddG8O9k17QEOk4MJU0cgQ1odVxsmJkldRNaJyMcisisWz6eU1YKFuhqbvBhOnWLUWDC3M0n4jXCYHHZNe8CxKymc4KJ5dzKq5D1MhIprTqsNE6sbpT8Dfgg8FaPnU8pSwUJdXXl97ZRU72bnijuhI4mP6vhSie+EZxQZ3g/Dt1sQz2CJyUjdGPM68EksnkupRHCoyQvAPNc2tqUt4y9DbmFb2jIuPfmq426sJYuMa0qTojaMzqkrFUZulod5rm2UudeGHJFW5l7ruBtrSSNJasOICXf34HSeSGQ08AtjzMQI15cASwDOOeecaR988EFMXlepwVBV18ilVdM7D5DuqsGfTX7pfguiUslORGqNMYW99YnbSN0YU2GMKTTGFObkhF8WplSimJ/yJiPCJHRw3o015Sw6/aJUOFtLw21EBJLv0AVlL7Fa0vhz4PfAWBFpEJF/i8XzKmWZ5oawzQYcd2NNOUtMljQaY26OxfMolTAy86H5YI9m8ZzpuBtrqov6SthaGvilnpkfqLZps//eOv2iVDizigMldLtye+Ca71oTjxp89ZW0vfSNjl/mBpoPBh7XV1od2YBoUlcqnEkLAyV0uyx/05K6ztayuZjU9hMhbantJ2jZbK/a+Fp6V6lI9Gi0pJIe4WjCSO2JSkfqSikFHPKPHFB7otKkrpRSwNq0RbR0q8DZYtJYm7bIoohOjyZ1lfSS7WBiFd7kOUsoNkto8GfjN0KDP5tis4TJc5ZYHdqA6Jy6SmrJeDCxCi9wrN1Sbtwyi0NNXnKzPLY87i5mtV8GorCw0NTU1MT9dZXq7nDJuYyi5xF1h8lhVMl7FkSkVGQJVftFqUSkZ44qp9GkrpKanjmqnEaTukpqeuaochpN6iqpdT2YWM8cVb2qr4TvTYSSrMCfCVo+QG+UKqVUXzrqwnQtI9CWkh73k5P6c6NUlzSqpFFV10j5ln22Xq6mrNGyuZiMCHVhMhKslIROv6ikUFXXyL3r32HaZ7/hjbRlvOFdQOGGS6lc919Wh6ZswE51YTSpq6RQvmUfs9t/2+Mg6bkflOkOUtUnO9WF0aSuksKhJi8rUp8iQ1pD2jOklYId5RZFpezCTnVhNKmrpHDbsD9yZoSDpHWjkeqLnerC6I1SlRTudj+PtIW/9rFko0dJq97YqS6MJnXleNur11Do/TDsNWPg4LQiTeqqT/On5CVkEu9Op1+UowWrMEqE661pWbrRSDmKJnXlaAU7ygNldcNxexhyrd4kVc6iSV051vbqNZwdoQqjMehB0sqRNKkrR+qcdokw7/KR5GhCV46kSV05Um/TLlqFUTmZJnXlSJEOvzAGrcKoHE2TunKkSIdffCQ5mtCVo2lSV46yvXoNh0vO5SxzBH+3qtI67aKSgW4+Uo4RvDnqkVaCC9P9JvDXjySHg9OKdJSu4qO+EraWQnMDZObDrOK43ZjXpK4cI9zNUZfAYXIYVfKe7hpV8VFfCRuXgc8beNx8MPAY4pLYdfpFOUakm6NasEvF1dbSUwk9yOcNtMeBJnXlGJFujn4s2XGORCW15oaBtceYJnXlGAenFuHtVvNab46qeGvxhJ/oi9Qea5rUlWNcNO9Odk17gMPk4DfCYXJ0TbqKu4d8N4Y9UOMh341xeX29Uaoc5aJ5d0JHEh/V8aVUPD15fDqfuFq5O7WSXDnGITOSh9oWsvHkdEri8PoxSeoicjXwfSAFWGuMKYvF8yrVm6q6Rsq37Ev4QwtUcsnN8lDdNJPq1pkh7bcP+yN8755BX+YY9fSLiKQAPwKuAc4HbhaR86N9XqV6U1XXyLYNj/F8yx3sH3ILz7fcwbYNj1FV12h1aCrJFV01Fo87JaTty2m/437zeGB5I+bUMsf6ypi/fizm1KcD7xlj/mKMaQWeA66LwfMqFdHOTRWUSgX5rqO4BPJdRymVCnZuqrA6NJXk5k/JY9X1F5CX5UGAvCwPpUNfJLX9RGjHQVrmGIvplzzgYJfHDcD/7N5JRJYASwDOOeecGLysSmaLW58mwxW60ShDWlnc+jSw0pqglOrQ4+i7ksPhOw7CMse4rX4xxlQYYwqNMYU5OeHXEyvVL/WV5LnCbyjKdR2LczBK9UNm/sDaoxCLpN4IFHR5nN/RplTs1VfCS/8e8czRE3FaC6zUgMwqBrcntM3tCbTHWCyS+nbgPBEZIyJpwE1AdQyeV6meNt8D7eEPv2hLSSfjmvhsxVZqQCYtDByfmFkASODPQTpOMeo5dWNMm4h8HdhCYEnjOmPM7qgjUyoM4/0k7CjdAKnX/UCPqFOJa9LCuPx8xmSdujHml8AvY/FcSp02TegqgcVrX4WWCVC28ol/2IDalUoEVXWN3Lv+HRqbvBigscnLvevfGZR9FZrUla18u/02TprQjR0nTQrfbr/NooiU6lv5ln14fe0hbV5fO+Vb9sX8tTSpK1upap9Bke9OGvzZ+I3Q4M+myHcnVe0zrA5NqYgONXkH1B4NLeilbCUvQl2NvCxPhO9Qynq5WR4awyTw3EH4udWRurKVcHU1PO4Uiq4aa1FESvUtnj+3OlJXthJcLaDVGZWdxPPnVowxMX/SvhQWFpqampq4v66yDy2rq1RPIlJrjCnsrY+O1FXC6Syry3PkDjnKoZZsVm+4CViqiV2pPuicuko4OzdV8KD8OKSs7oPyYy2rq1Q/aFJXCecbrWsZIqFreodIO99oXWtRRErZhyZ1lXDOlOMDaldKnaJJXSWWQTjeS6lkokldJY76Sti4DIlQLL1ZzohvPErZkCZ1lTi2lgbObQyj1aTy56n/N84BKWU/mtRVwjARzms0wNvTvsNF8+6Mb0BK2ZAmdZUwPiI7QnuOJnSl+kmTurJefSV8byJnmyP4u21wbjFprGq9wZq4lLIh3VGqrNVxcxSfFxEQ6Ezsh0w2D7UtpHb4bEtDVMpONKkra4W5OeoSaPBnM7P1UTzuFFZpBUal+k2TurJWhJujuXKMPC3kpdSAaVJXlmrxjCLD+2GP9hMZo3jznissiEgpe9MbpcpSD/lupMWkhbS1mDQe8t1oUURK2ZsmdWWpJ49PZ7lvcciZo8t9i3ny+HSrQ1PKlnT6RVkqV88cVSqmdKSu4m579RoOl5yLf0UmL5y4gwWpb4Zc1zNHlTp9OlJXcbW9eg0Ta+/HI60gkMtRHkz5CempKTx34mI9uk6pKGlSV3FVsKM8kNC7yJBWvsnPWVX2bYuiUso5dPpFxdVZ5kiE9qNxjkQpZ9KkruLqY8mJ0B6+mJdSamA0qau4Oji1CG+3delek8bBqUUWRaSUs2hSV3F10bw72TXtAQ6Tg98Ih8lh17QHtLSuUjEixpi+e8VYYWGhqampifvrKqWUnYlIrTGmsLc+OlJXg6ujVjolWYE/9WBppQaVLmlUg6e+kraXvkFq+4nA4+aDgccAkxZaGZlSjqUjdTVoWjYXn0roHVLbT9CyudiiiJRyvqiSuojcICK7RcQvIr3O86jkk+49PKB2pVT0oh2p7wKuB16PQSzKYQ75Rw6oXSkVvaiSujFmjzFmX6yCUc6yNm1R2Frpa9MWWRSRUs4Xtzl1EVkiIjUiUnPkSPit4spZJs9ZQrFZElIrvdgsYfKcJVaHppRj9bn6RUReBkaFuXSfMeal/r6QMaYCqIDAOvV+R6hspaqukfIt+zjU5CU3y8PlU2/ixr2zOh9rBUalBlefSd0Yc2U8AlH2V1XXyLYNj/E8z5E75CiHWrJZveMmihYs1USuVJzokkYVMzs3VVAqFeS7juISyHcdpVQq2LmpwurQlEoa0S5pXCAiDcDngU0isiU2YSk7Wtz6NBlhaqUvbn3aooiUSj5R7Sg1xmwANsQoFmVzua5jA2pXSsWeTr+omDnhCXc/PXK7Uir2NKmrmMm4ppS2lPSQtraUdDKuKbUoIqWSjyZ1FTuTFpJ63Q8gswAQyCwIPNbiXUrFjVZpVFHZXr2Ggh3lnGWO8LHkcHBqERd9a5fVYSmVtDSpq9O2vXoNE2vvxyOtIDCKI2TW3s920JOMlLKITr+o01awozyQ0LvwSCsFO8otikgpZbuRevdt6Lrt3DpnmSMg4dqPxj8YpRRgo6ReVdfIyo27+bTF19nW2OTl3vXvAGhij6PgL9bnTTb50jOBfyzZYYsFKaUGny2mX6rqGrl3/TshCT3I62unfItW/42XzvouLXeQK0fxdyvN5jVpHJxaZE1wSil7jNTLt+zD62uPeP1QkzeO0SS3YH2XruUA/CYwC/OR5HBwWpHeJFXKQrZI6n0l7dwsT5wiUYtbnybDFXpz1CXQ4M8mf+V7Ou2ilMVsMf3SW9L2uFMoumpsHKNJblrfRanEZoukXnTVWDzulB7tWR43q66/QG+SxkFVXSMzyl6JeL6o1ndRKjHYYvolmLR1KaM1gjeqvb52HnItpMy9NmROXeu7KJU4bJHUIZDYNYlbo3zLPma3/5a70yrJlaM0MQyvP40Rrr/jyswndVax1ndRKkHYYvpFWaeqrpFpn/2GMvfazhONzpTjeKSVb7XeBd/apQldqQSiSV1FFFyT/oj78bAnGt2b9oJFkSmlItGkriIKrklPFX/Y62ej5QCUSjS2mVOnvhK2lkJzA2TmQ4R5XK0NE52u/35vpPVck96VZObHMTKlVH/YY6ReXwkbl0HzQcAE/lx/B3x3TOBah6q6Ror++20am7wYArVhiv77barqGi0L3U6Cq1yC/365Yeq6dHJ7Ar9YlVIJxR5JfWsp+MLsKvV+Ekj2HYl95cbd+NpDi5H42g0rN+6OR5S2170cwyGTHbafX1xw7aN6g1SpBGSPpN7cEPmazxtI+hC24Fdv7SpUY7dyDA+1LaTFpIW0taWk41qwRhO6UgnKHkm9r7nb3pJ+hxllr+g0TB8EWJm6jveGLOL9IbfwiPtxavzn0eDPRs8cVcoe7JHUZxUH5nAj6Uj6WR53xC7B2uua2CMrSV3HV1JeJlX8iECq+LnEtZut/slQ0qRr0pWyAXusfgkmks33BObRu+pyw27uhZ/ji7V3conr1Bz6G/4JfMV3H3Cq9rquhumpqq6RRSlbkW4nGYnAopSt1gSllBowMcb03SvGCgsLTU1Nzel9cy9LG/9QcgnTTX1IYjImNLEHZXnclMybkLQJvuvSxUyPm8t9r/FIyo96JHUAA0hJc9xjVEqFEpFaY0xhb33sMVLvatLCiFMA3RM6BEaal7h2sy1tGQ+1LaTaPxOAJq+PohfeBpLvKLyuBboALj35Kg+614ZN6Eope7HHnHp/RUhKIpDvOspq92O8P+QWtqUtY55rGz6/Scqj8LovXbw7tbJHGYCuxD00HmEppWLAWUm9Dy45leC/736Mp9wP9ljGlwy6nyTV6yYjSYFrVw9uQEqpmHFUUpcxX6C/dwiC0zJPuR9MuhUx3U+S6nWT0YLHdcWLUjbiqKTObdXImC/0u3swsV/30vnwvYkhJQecrOiqsXw57XdsS1vGX4bcgocTtJrQ2yu6yUgpe7Lf6pf+CtaLCVdeoDfuoYHpBicns/pK2l76BqntJzqb2iWVlPTh4P2014JpSinrOHP1S38FE9LWUmg+iDH0b3WH7++BYmF/fQvmPjKoIVpma2lIQgdIMW2QNhTued+ioJRSseCs6ZfuJi0M7IIsaeaPMokBfSipWefc6ZhIZRX6UW5BKZXYnJ3Uu/jwuud400wcQGI3gRG7Tefaq+oamVH2CmOWb+pZ9yZSLR2tj66U7UWV1EWkXET2iki9iGwQkawYxRVz86fkcXRBJSvd/0GDPxu/oX8JvvlgSHlfO+heF71H3ZtwtXS0PrpSjhDtSP03wERjzCTgT8C90Yc0eOZPyaPk/pXkl+7nEs8Gnmq/sn+J3ee11ai9++YiOFX3BghMS137KGQWEKy+qPXRlXKGmK1+EZEFwJeNMbf21Tcuq1/6EDwl6RrzBt9xP8FQTvZ/m7znTLjmuwmVBLvWcjHAPNc27k6tJFeOcshk81DbQjb6Z/J+2RyrQ1VKnaZ4r375KvB8L8EsAZYAnHPOOTF82dMTrPdy3wYXE0/O7EyCeXK07+QePHEJEiKxd6/lMs+1jTL32s6t//lylDL3Ws50pwGa1JVysj5H6iLyMjAqzKX7jDEvdfS5DygErjf9GPonwki9q8krf02TN3A6UveE2KvMgsDqGovNKHslpNzBtrRl5Lt6bv1v8XyOjHv2xjM0pVQMxWSkboy5so8XuR2YC8zqT0JPRCXzJnSOdKv9M8FH/0btwSWAv/hfUPszMF3msTMLBn0DT3DKpbHJGzLdEinkDO/hQYtFKZUYopp+EZGrgbuBLxhjWmITUvwFp2KCCbLaP5Pq1pl9j9oz8wMJveaJnteCq2Ygpom9qq6RlRt3h5y72u9PF7pkUSnHi+pGqYi8BwwBjnU0vWWM+Vpf35do0y9ddZ/KmOfaRon7KUZwPHTU7vYEVoxs+FroCL27GE3RVNU1UlK9O2SaaEXqU5wpx4F+7JYNxpsA9wCUUqenP9MvUS1pNMaca4wpMMZM7vjqM6EnuqKrxuJxp3Q+rvbPZOrJCr7pW9qxvl04RDbbL1gZSJC9JXQI3aVZXxlYFlmSNaDlkcEboU1eH/Nc29g75Da+736Mka7AL5reE7ouWVQqmTi39stp6joVEzzqzdfup7o1MCUT5NmewqqCRuaJC5fxR3y+Fs8oMqBngbHu0zNdjuk76R5Om+8kGeYECHyBYSw3FzM37S3OlOP9X3qZIDdylVLxo0k9jPlT8kKOuJtR9gp/bw2t9hjczHPCzOZGsyVsom0xaTzku5ESCCTs7hUjfd5AO4Qk/CG+ZoZA50lOIzjOV1JeHthxc7pDVKmklDS1X6LR/aSgru3LT9zGU+1X0mZcmI7SA8ZAgz+b5b7FPHl8OgAmQrEs09wQPuF3M6CErtMtSiUtHan3Q26WJ+yxd8H2FW1fZUXbV8N+b17HKUMfkc0ojvS4/hHZnN3cEHEZ4kC0i5uUBY9pMlcqielIvR+63zwF8LhTKLpqLCMy3BG/L9gHYFXrDbSYtJDrLSaNVa038BHhj5PrL2PAJ2ma0JVSmtT7Y/6UPFZdfwF5WR6EwOh71fUXMH9KHiuunYA7pec4O8vj7uwDUDN8Nst9iztX0ASnZ2qGzw6b8LvrvvLUdLS1A38ZfRPuFUc0oSulHHycXRx1LaaVm+Wh6KqxITdag3261meBwEh+1fUXUL5lH9M++03HjtBjfGqGMkTaGErgdKJPzDB+4b+Yq1LfZhRH9bg5pZJUch9nF0fdV8tE6gNETP73rm8NWTLZncedQuZ1F/T5Okqp5KZJPY4iJf9wCf/ycTm8uvdIr6N/pZTqTpN6gujPaF8ppfqiN0qVUspBNKkrpZSDaFJXSikH0aSulFIOokldKaUcxJLNRyJyBPhggN+WDfQ8eNP+nPi+nPiewJnvS9+TfWQDQ40xOb11siSpnw4RqelrJ5UdOfF9OfE9gTPfl74n++jv+9LpF6WUchBN6kop5SB2SuoVVgcwSJz4vpz4nsCZ70vfk330633ZZk5dKaVU3+w0UldKKdUHTepKKeUgtkrqIvJtEakXkZ0i8msRybU6pmiJSLmI7O14XxtEJMvqmGJBRG4Qkd0i4hcRWy8vE5GrRWSfiLwnIsutjicWRGSdiHwsIrusjiVWRKRARF4VkXc7fva+aXVMsSAi6SLyRxF5u+N9rey1v53m1EVkuDHms46/LwPON8Z8zeKwoiIiXwReMca0ich3AYwx91gcVtREZDzgB9YA/2mMseVRVyKSAvwJmA00ANuBm40x71oaWJRE5FLgOPCUMWai1fHEgoh8DvicMWaHiJwB1ALzHfDfSghsOjouIm5gG/BNY8xb4frbaqQeTOgdhhI4qtPWjDG/Nsa0dTx8C8i3Mp5YMcbsMcbsszqOGJgOvGeM+YsxphV4DrjO4piiZox5HfjE6jhiyRjzoTFmR8ff/wbsAWx/SIEJON7x0N3xFTH32SqpA4jIgyJyELgVKLY6nhj7KrDZ6iBUiDzgYJfHDTggUTidiIwGpgB/sDiUmBCRFBHZCXwM/MYYE/F9JVxSF5GXRWRXmK/rAIwx9xljCoBngK9bG23/9PWeOvrcB7QReF+20J/3pVS8icgw4EXgP7p9urctY0y7MWYygU/y00Uk4pRZwh1nZ4y5sp9dnwF+CawYxHBioq/3JCK3A3OBWcZGNzkG8N/KzhqBgi6P8zvaVALqmHN+EXjGGLPe6nhizRjTJCKvAlcDYW9yJ9xIvTcicl6Xh9cBe62KJVZE5GrgbmCeMabF6nhUD9uB80RkjIikATcB1RbHpMLouKH4BLDHGPOI1fHEiojkBFfFiYiHwE37iLnPbqtfXgTGElhV8QHwNWOMrUdNIvIeMAQ41tH0lt1X9ACIyALgB0AO0ATsNMZcZWlQp0lE/hlYDaQA64wxD1obUfRE5OfAZQTKuX4ErDDGPGFpUFESkZnAG8A7BHIEwP8xxvzSuqiiJyKTgCcJ/Py5gEpjTGnE/nZK6koppXpnq+kXpZRSvdOkrpRSDqJJXSmlHESTulJKOYgmdaWUchBN6kop5SCa1JVSykH+P30is10JSjR2AAAAAElFTkSuQmCC\n", 111 | "text/plain": [ 112 | "
" 113 | ] 114 | }, 115 | "metadata": { 116 | "needs_background": "light" 117 | }, 118 | "output_type": "display_data" 119 | } 120 | ], 121 | "source": [ 122 | "# plt.plot(loss_list)\n", 123 | "# plt.show()\n", 124 | "\n", 125 | "plt.scatter(x.data, y.data, label='true data')\n", 126 | "plt.scatter(x.data, model(x).data, label='pred data')\n", 127 | "plt.legend()\n", 128 | "plt.show()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "serious-compiler", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Python 3", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.6.9" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 5 161 | } 162 | --------------------------------------------------------------------------------