├── .ipynb_checkpoints ├── 1.梯度下降-checkpoint.ipynb ├── 2.随机梯度下降-checkpoint.ipynb ├── 3.momentum-checkpoint.ipynb ├── 4.ada_grad-checkpoint.ipynb ├── 5.rms_prop-checkpoint.ipynb ├── 6.ada_delta-checkpoint.ipynb └── 7.adam-checkpoint.ipynb ├── 1.梯度下降.ipynb ├── 2.随机梯度下降.ipynb ├── 3.momentum.ipynb ├── 4.ada_grad.ipynb ├── 5.rms_prop.ipynb ├── 6.ada_delta.ipynb ├── 7.adam.ipynb ├── README.md └── 线性数据.csv /.ipynb_checkpoints/1.梯度下降-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "92163201", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "0.6590042695516539" 61 | ] 62 | }, 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "#预测函数\n", 70 | "def predict(x):\n", 71 | " return w.dot(x) + b\n", 72 | "\n", 73 | "\n", 74 | "predict(x[0])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "id": "a7bb7a80", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "7.367867692433937" 87 | ] 88 | }, 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "output_type": "execute_result" 92 | } 93 | ], 94 | "source": [ 95 | "#求loss,MSELoss\n", 96 | "def get_loss():\n", 97 | " loss = 0\n", 98 | " for i in range(N):\n", 99 | " pred = predict(x[i])\n", 100 | " loss += (pred - y[i])**2\n", 101 | " return loss / N\n", 102 | "\n", 103 | "\n", 104 | "get_loss()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "id": "8027d213", 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "(array([2.03668543, 2.38225639, 1.02215384, 2.13526642, 3.22327899]),\n", 117 | " 0.0010000000036924916)" 118 | ] 119 | }, 120 | "execution_count": 5, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "def get_gradient():\n", 127 | " global w\n", 128 | " global b\n", 129 | "\n", 130 | " eps = 1e-3\n", 131 | "\n", 132 | " loss_before = get_loss()\n", 133 | "\n", 134 | " gradient_w = np.empty(M)\n", 135 | " for i in range(M):\n", 136 | " w[i] += eps\n", 137 | " loss_after = get_loss()\n", 138 | " w[i] -= eps\n", 139 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 140 | "\n", 141 | " b += eps\n", 142 | " loss_after = get_loss()\n", 143 | " b -= eps\n", 144 | " gradient_b = (loss_after - loss_before) / eps\n", 145 | "\n", 146 | " return gradient_w, gradient_b\n", 147 | "\n", 148 | "\n", 149 | "get_gradient()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 6, 155 | "id": "c371c6a4", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "0 7.112757670092038\n", 163 | "50 1.7854577366414703\n", 164 | "100 0.8188794143216034\n", 165 | "150 0.5927178198131446\n", 166 | "200 0.5305917673804184\n", 167 | "250 0.5095310094726683\n", 168 | "300 0.5002280440367376\n", 169 | "350 0.4950505954913211\n", 170 | "400 0.49174823787396005\n", 171 | "450 0.48950848092220356\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "plt_x = []\n", 177 | "plt_y = []\n", 178 | "for i in range(500):\n", 179 | " gradient_w, gradient_b = get_gradient()\n", 180 | " w -= gradient_w * 1e-2\n", 181 | " b -= gradient_b * 1e-2\n", 182 | "\n", 183 | " plt_x.append(i)\n", 184 | " plt_y.append(get_loss())\n", 185 | "\n", 186 | " if i % 50 == 0:\n", 187 | " print(i, get_loss())" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 7, 193 | "id": "0471a70d", 194 | "metadata": { 195 | "scrolled": true 196 | }, 197 | "outputs": [ 198 | { 199 | "data": { 200 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZc0lEQVR4nO3dfXRc9X3n8fd3ZqTRw+jBkkayQAI/8mCobYgghoRASUhIwqbbDUlJ02y7S+pDN90lZ3O2J2xPd0939/Rhz7ZNu9umoZC2e5rAJgR2WZYWSCAQHgrINjZgG2yMsWVjSzKSLcnW08x3/5grIQvLHssa3Tszn9c5c3Tn3t/MfH9i+Pin3/zuXHN3REQkumJhFyAiIqenoBYRiTgFtYhIxCmoRUQiTkEtIhJxiUI8aUtLiy9btqwQTy0iUpI2bdrU7+7pUx0rSFAvW7aM7u7uQjy1iEhJMrN35jqmqQ8RkYhTUIuIRJyCWkQk4hTUIiIRp6AWEYk4BbWISMQpqEVEIi4yQe3u/NlPdvH0m31hlyIiEilnDGozu9jMXplxO2ZmX1/oQsyMv3pmD0/t7F3opxYRKWpnPDPR3d8A1gOYWRw4ADxUiGJa6pL0D48V4qlFRIrW2U59fBx4y93nPNXxXDTXViqoRURmOdugvg24rxCFALSkkhwZHi/U04uIFKW8g9rMKoHPAT+c4/hGM+s2s+6+vvl9INic0ohaRGS2sxlRfxrY7O6HT3XQ3e929y5370qnT/lNfWfUkkoycHyCyUx2Xo8XESlFZxPUX6KA0x4ALalKAN47rukPEZEpeQW1mdUCNwEPFrKY5lQSgP4hBbWIyJS8Lhzg7iNAc4FroSUI6iMjmqcWEZkSmTMTIfdhIqAPFEVEZohUUE+PqLVET0RkWqSCur4qQWU8Rp9G1CIi0yIV1GZGc6pSI2oRkRkiFdSgk15ERGaLXFDrNHIRkZNFLqiba/UNeiIiM0UuqFuCOWp3D7sUEZFIiGBQJxnPZBkamwy7FBGRSIhcUE+f9DKk6Q8REYhgUL9/Grk+UBQRgQgGtUbUIiIni1xQp+tyI2qdnSgikhO5oG6uTRIz6D2moBYRgQgGdTxmpOuSHD42GnYpIiKRELmgBmirr+Kw5qhFRICIBnVrXRW9GlGLiABRDer6JL0aUYuIABEN6ra6Kt4bGWd8UlcjFxGJZlDXa4meiMiUSAZ1axDUWvkhIpJnUJtZo5k9YGY7zWyHmV1TyKJa66oA9IGiiAiQyLPdnwL/4O63mlklUFPAmmirD4JaHyiKiJw5qM2sAfgY8GsA7j4OFPQbk5prK4nHTFMfIiLkN/WxHOgD/trMtpjZPWZWO7uRmW00s24z6+7r6zu3omJGOpXksE4jFxHJK6gTwJXAt939CmAE+ObsRu5+t7t3uXtXOp0+58LatJZaRATIL6h7gB53fzG4/wC54C6otM5OFBEB8ghqdz8E7Dezi4NdHwe2F7QqciNqzVGLiOS/6uNfA98LVnzsAf5F4UrKaauvYuD4BGOTGZKJeKFfTkQksvIKand/BegqbCknmz47cWiMjiUFXQ0oIhJpkTwzEaA1WEut6Q8RKXeRDerzGqoBODiooBaR8hbZoG5vzI2o3z16IuRKRETCFdmgrq+qIJVMaEQtImUvskEN0N5QpRG1iJS9aAd1Y7VG1CJS9iId1Oc3akQtIhLpoG5vqKZ/eJyxyUzYpYiIhCbiQZ1b+XHoqKY/RKR8RTqoz2vMraU+MKjpDxEpX0UR1O/qA0URKWORDuqpqQ99oCgi5SzSQV1VEaeptpKDmqMWkTIW6aCG4KQXzVGLSBmLfFCfp5NeRKTMRT+oG6o4qDlqESlj0Q/qxmqGRicZGp0IuxQRkVBEPqinru6y/z2NqkWkPEU+qDubcmup9w8cD7kSEZFwRD6oL2iaGlErqEWkPEU+qBuqK6hLJugZ0NSHiJSnvK5CbmZ7gSEgA0y6+6JdkdzM6GiqYZ9G1CJSpvIK6sDPu3t/wSo5jc4l1bzdPxLGS4uIhC7yUx8AnU019AycwN3DLkVEZNHlG9QOPG5mm8xs46kamNlGM+s2s+6+vr6Fq5DciPrERIb+4fEFfV4RkWKQb1B/1N2vBD4NfM3MPja7gbvf7e5d7t6VTqcXtMjOqZUfWqInImUor6B29wPBz17gIeDqQhY1m5boiUg5O2NQm1mtmdVNbQOfBF4rdGEzTZ2dqCV6IlKO8ln10QY8ZGZT7b/v7v9Q0Kpmqa6M05JKsu+IRtQiUn7OGNTuvgdYtwi1nFZnU7XmqEWkLBXF8jzIzVPrpBcRKUdFE9QXNtdycPAEY5OZsEsREVlURRPUK1pqybpWfohI+SmaoF7WUgvA2/0KahEpL0UT1Mubp4J6OORKREQWV9EEdUNNBU21lRpRi0jZKZqgBljeUqsRtYiUnaIK6mXNtezViFpEykxRBfXylhoOHRvl+Phk2KWIiCyaIgvqFIBG1SJSVooqqJe15L6cSVd7EZFyUlxBHSzR23tEQS0i5aOogro2maCtPsmePgW1iJSPogpqgJXpFHu0RE9EykjRBfXq1hS7Dw/rQrciUjaKLqhXtdUxNDbJoWOjYZciIrIoii6oV7fmlujtOqzpDxEpD8Ub1L0KahEpD0UX1M2pJE21lezuHQq7FBGRRVF0QQ25UfWbmvoQkTJRnEHdlmLX4SGt/BCRspB3UJtZ3My2mNkjhSwoH6tb6zg2Oknf0FjYpYiIFNzZjKjvBHYUqpCzoQ8URaSc5BXUZtYBfBa4p7Dl5GdV29QSPX2gKCKlL98R9beA3wKyczUws41m1m1m3X19fQtR25zSqSRLairYeUhBLSKl74xBbWa3AL3uvul07dz9bnfvcveudDq9YAXOUROXttez491jBX0dEZEoyGdE/RHgc2a2F7gfuNHM/q6gVeXh0vZ63jg8RCarlR8iUtrOGNTufpe7d7j7MuA24El3/5WCV3YGl7bXMzqR1UUERKTkFeU6aoA17fUAbNf0h4iUuLMKanf/qbvfUqhizsaq1hQVcdM8tYiUvKIdUVcmYqxqrWP7QQW1iJS2og1qgEvb6zSiFpGSV9RBvaa9nt6hMfqHdSq5iJSuog9qQKNqESlpxR3U5+WC+tUDR0OuRESkcIo6qBtrKrmwuYZt+xXUIlK6ijqoAdZ2NLKtZzDsMkRECqbog3pdRwMHj47qu6lFpGQVfVCv7WgE0KhaREpW0Qf15efXEzPY2qN5ahEpTUUf1DWVCVa31mlELSIlq+iDGmBtRwPbeo7qYrciUpJKI6g7G3lvZJyegRNhlyIisuBKIqiv6GwEYPO+gXALEREpgJII6kuW1lFbGad7r4JaREpPSQR1Ih7jyguX0P2OglpESk9JBDXAhy5cws5Dxzg2OhF2KSIiC6pkgrrrwibcYcu+wbBLERFZUCUT1OsvaCQeMzbtfS/sUkREFlTJBHUqmeDS9jpe1geKIlJiSiaoITf98cr+QSYy2bBLERFZMGcMajOrMrOXzGyrmb1uZr+7GIXNx4YVTZyYyOh0chEpKfmMqMeAG919HbAeuNnMNhS0qnn68PJmzOD53UfCLkVEZMGcMag9Zzi4WxHcIvmlGktqK1nTXs9zb/WHXYqIyILJa47azOJm9grQCzzh7i+eos1GM+s2s+6+vr4FLjN/165sZvM7g4xOZEKrQURkIeUV1O6ecff1QAdwtZldfoo2d7t7l7t3pdPpBS4zf9eubGE8k2WTzlIUkRJxVqs+3H0QeAq4uSDVLICrljcRjxnPa/pDREpEPqs+0mbWGGxXAzcBOwtc17ylkgnWdTTwnD5QFJESkc+Iuh14ysy2AS+Tm6N+pLBlnZuPrk6zrWeQgZHxsEsRETln+az62ObuV7j7Wne/3N3/02IUdi5uuDhN1uGZXeF9qCkislBK6szEKes6GllSU8HTbyioRaT4lWRQx2PG9RelefrNPrLZSC75FhHJW0kGNcANF7dyZGScVw8cDbsUEZFzUrJB/bGL0pjBU2/0hl2KiMg5Kdmgbqqt5IrORn6843DYpYiInJOSDWqAT162lNcOHKNn4HjYpYiIzFtJB/WnLlsKwGOva1QtIsWrpIN6eUstlyyt47HXD4VdiojIvJV0UENu+uPlve/RPzwWdikiIvNS8kF982VLcYcntmv6Q0SKU8kH9aXtdSxvqeX/bj0YdikiIvNS8kFtZnxu3Xm8sOcIh46Ohl2OiMhZK/mgBviF9efhDo9s06haRIpPWQT1inSKtR0N/O9XDoRdiojIWSuLoAb4hfXn89qBY+zuHQq7FBGRs1I2Qf1P1rUTjxk/3NQTdikiImelbIK6ta6KT1zaygPdPYxPZsMuR0Qkb2UT1ABfuvoCjoyMa021iBSVsgrq61anOb+xmvte2hd2KSIieSuroI7HjNuu6uTZ3f28c2Qk7HJERPJSVkEN8IWuTuIx4/6X94ddiohIXs4Y1GbWaWZPmdl2M3vdzO5cjMIKZWlDFR+/pJX7X9rHifFM2OWIiJxRPiPqSeAb7r4G2AB8zczWFLaswvrqdSsYOD7BA5u1VE9Eou+MQe3u77r75mB7CNgBnF/owgrpqmVLWN/ZyD0/20NGVykXkYg7qzlqM1sGXAG8eIpjG82s28y6+/r6Fqi8wjAzNn5sBe8cOc4T23VRARGJtryD2sxSwI+Ar7v7sdnH3f1ud+9y9650Or2QNRbEpy5bygVNNXznmT24a1QtItGVV1CbWQW5kP6euz9Y2JIWRzxmfPW65WzZN8g/7nkv7HJEROaUz6oPA+4Fdrj7Hxe+pMXzxa5O2uqT/NHjb2hULSKRlc+I+iPAV4AbzeyV4PaZAte1KKoq4vzmjavpfmeAn74Z7Xl1ESlf+az6eNbdzd3Xuvv64PboYhS3GH6pq5OOJdUaVYtIZJXdmYmzVSZifP0TF/HagWM89rpWgIhI9JR9UAP84hXns6o1xR/8/U7GJnW2oohEi4Ka3AqQ37llDXuPHOfeZ98OuxwRkZMoqAPXX5TmpjVt/I8nd+tq5SISKQrqGX7ns2uYzDq/9+iOsEsREZmmoJ7hguYa7rh+JQ9vPchP3+gNuxwREUBB/QH/6oaVrGpNcdeDr3JsdCLsckREFNSzVVXE+W9fWMfhY6P8l0e2h12OiIiC+lTWdzZyx/Ur+UF3D0/u1IVwRSRcCuo53PmJ1VyytI5v/GArBwdPhF2OiJQxBfUckok4f/HlK5nIOL/5/c1MZLJhlyQiZUpBfRor0in+4PM/x+Z9g/z+ozvDLkdEylQi7AKi7pa159G9d4DvPvc2lyyt44tXdYZdkoiUGY2o8/Dbn72U61a38O8fepVnd/WHXY6IlBkFdR4q4jH+4stXsqo1xW/83SZ2HvrAlchERApGQZ2nuqoK7v21q6hJxvmVe15kd+9w2CWJSJlQUJ+F8xur+f6vbwCMX/6rf+Tt/pGwSxKRMqCgPksr0ynu+/UPk8k6v/SdFzQNIiIFp6Ceh9Vtddy3cQNm8IW/fIGX3tZVzEWkcBTU83RRWx0/+o1rSdcl+cq9L/Lw1oNhlyQiJeqMQW1m3zWzXjN7bTEKKiYdS2p44I5rWdvRwL+5bwu///c7yGR1gVwRWVj5jKj/Bri5wHUUrabaSr731Q18ZcOFfOfpPXzl3hd1hRgRWVBnDGp3fwbQJOxpVCZi/Od/ejn/9da1bNk3yKe+9Qz/b9u7YZclIiVCc9QL6ItdnTx653Usa6nla9/fzJ33b6FvaCzsskSkyC1YUJvZRjPrNrPuvr6+hXraorO8pZYH7riGOz++mkdffZcb/+in/O3zezV3LSLzZu5nDhAzWwY84u6X5/OkXV1d3t3dfY6lFb89fcP8x4df52e7+rlkaR3/7lMXc+MlrZhZ2KWJSMSY2SZ37zrVMU19FNCKdIr/+S+v5s9/+UpOTGS4/W+7+fy3n+e53f3k8w+kiAjktzzvPuAF4GIz6zGz2wtfVukwMz67tp0f/9vr+b1f/DkODJ7gy/e8yC3//Vke2tKjCxKIyBnlNfVxtjT1MbfRiQwPbTnAPT/bw1t9Iyytr+K2qzv5/JUddDbVhF2eiITkdFMfCuqQZLPO02/2ce+zb/Ps7tx3XF+zoplbP9TBJ9a00VBdEXKFIrKYFNQR1zNwnAc3H+CBTT3se+84iZhxzcpmPnnZUm66tI2lDVVhlygiBaagLhLuzuZ9gzy+/RCPv354+mtUL1lax7UrW7h2ZTNXr2iivkqjbZFSo6AuQu7O7t5hHt9+mOff6qd77wBjk1liBmvOq2dtRyPrOxpZ19nIqtYU8ZiW/IkUMwV1CRidyLBl3yAvvNXP5n2DbO0ZZGh0EoCayjgXL63jotY6VrelWN1Wx0VtKZbWV2nNtkiRUFCXoGzWefvICFv3D7Kt5yg73j3G7t5hjoyMT7dJJRN0NtXQuaR6+ucFzTV0LKmhra6K+uqEglwkIk4X1InFLkYWRixmrEynWJlO8c+u7Jjef2R4jF29w+w6PMTu3mH2D5zg7f4RntnVx+jEyWu2k4kY6bokrXVJWuuqaK3PbTfVJllSU0FDTQWN1ZU01lTQWFNBdUVcwS4SAgV1iWlOJWlOJdmwovmk/e5O3/AY+987Qc/AcfqGxugdGqP32Ci9Q2Ps6h3iubf6p6dTTqUyEaOxuoKG6gpqkwlqk3FqKxOkkglqknFqkwlSlQlqkglSyTg1lQmqK+IkK2IkE3GSiRhVFbmfs/dpjl1kbgrqMmFmuVFzXRUfunDJnO1GJzIMHB9n8PhEcBtn8ESwfWKcwZEJjp6YYGR8kpGxSY4MjzM8Nsnx8QzDY5OMT87vTMtEzIIAj1MZj5GIGxXxGImYkYjHqIgb8ZhREcsdS8RjVMRs1nauXWKqTcyIxYy4GTF7fzsey/0+4sF9M3LbMcvtD9rEgsfFg+eJGbnnik3tP7mNGRi5dhY8b+7fn6ltw2C6ndn727HY+/vmekws+Gtmev8pnucDrzPd7vSPseA9kntlpl+HGfslPApqOUlVRZz2hmraG6rn9fiJTJbjY5npIB+dyDI2mWFsMvdz+v5ElrHJLKMTHzw2PpllMuNMZJ3JTJaJjDOZze2bzGYZn8wyMp5hMjPVLjiWyU4/JtfWybiTzTpZd/QFhuduOryn738w3I2TG53q2GmfZ65jdvrXnXrUB597+sicrzuzfzPbz1XrjJc76VhzbZIf3HENC01BLQuqIh6joSZGQ0301np7ENZZdzIzwjuTfT/Mc8HOSQGfazvrcUGbTNbxGW0cB2d624PHOQT7c/t8Rj1Mtzv5MbmaT35MNthwgteb8RgPXmf2Y3x6e47HBO2mXi+oaNb99++cru1cxzjp2Cmem/frmnls5n+7fGp8/3Gzjp3iuU/5fB849sE+z3UMh7qqwkSqglrKRm5aA+IYFfGwqxHJn77mVEQk4hTUIiIRp6AWEYk4BbWISMQpqEVEIk5BLSIScQpqEZGIU1CLiERcQb7m1Mz6gHfm+fAWoH8ByykG6nN5UJ/Lw3z7fKG7p091oCBBfS7MrHuu72QtVepzeVCfy0Mh+qypDxGRiFNQi4hEXBSD+u6wCwiB+lwe1OfysOB9jtwctYiInCyKI2oREZlBQS0iEnGRCWozu9nM3jCz3Wb2zbDrWShm9l0z6zWz12bsazKzJ8xsV/BzSbDfzOzPgt/BNjO7MrzK58/MOs3sKTPbbmavm9mdwf6S7beZVZnZS2a2Nejz7wb7l5vZi0Hf/peZVQb7k8H93cHxZaF24ByYWdzMtpjZI8H9ku6zme01s1fN7BUz6w72FfS9HYmgNrM48OfAp4E1wJfMbE24VS2YvwFunrXvm8BP3H018JPgPuT6vzq4bQS+vUg1LrRJ4BvuvgbYAHwt+O9Zyv0eA25093XAeuBmM9sA/CHwJ+6+ChgAbg/a3w4MBPv/JGhXrO4Edsy4Xw59/nl3Xz9jvXRh39u5a6aFewOuAR6bcf8u4K6w61rA/i0DXptx/w2gPdhuB94Itr8DfOlU7Yr5Bvwf4KZy6TdQA2wGPkzuDLVEsH/6fQ48BlwTbCeCdhZ27fPoa0cQTDcCj5C7zmup93kv0DJrX0Hf25EYUQPnA/tn3O8J9pWqNnd/N9g+BLQF2yX3ewj+vL0CeJES73cwBfAK0As8AbwFDLr7ZNBkZr+m+xwcPwo0L2rBC+NbwG8B2eB+M6XfZwceN7NNZrYx2FfQ97Yubhsyd3czK8k1kmaWAn4EfN3dj5nZ9LFS7Le7Z4D1ZtYIPARcEm5FhWVmtwC97r7JzG4IuZzF9FF3P2BmrcATZrZz5sFCvLejMqI+AHTOuN8R7CtVh82sHSD42RvsL5nfg5lVkAvp77n7g8Huku83gLsPAk+R+7O/0cymBkQz+zXd5+B4A3BkcSs9Zx8BPmdme4H7yU1//Cml3Wfc/UDws5fcP8hXU+D3dlSC+mVgdfBpcSVwG/BwyDUV0sPArwbbv0puDndq/z8PPineAByd8edU0bDc0PleYIe7//GMQyXbbzNLByNpzKya3Jz8DnKBfWvQbHafp34XtwJPejCJWSzc/S5373D3ZeT+n33S3b9MCffZzGrNrG5qG/gk8BqFfm+HPTE/Y5L9M8Cb5Ob1fjvsehawX/cB7wIT5Oanbic3L/cTYBfwY6ApaGvkVr+8BbwKdIVd/zz7/FFy83jbgFeC22dKud/AWmBL0OfXgP8Q7F8BvATsBn4IJIP9VcH93cHxFWH34Rz7fwPwSKn3Oejb1uD2+lRWFfq9rVPIRUQiLipTHyIiMgcFtYhIxCmoRUQiTkEtIhJxCmoRkYhTUIuIRJyCWkQk4v4/y+3mmReXR0gAAAAASUVORK5CYII=\n", 201 | "text/plain": [ 202 | "
" 203 | ] 204 | }, 205 | "metadata": { 206 | "needs_background": "light" 207 | }, 208 | "output_type": "display_data" 209 | } 210 | ], 211 | "source": [ 212 | "from matplotlib import pyplot as plt\n", 213 | "%matplotlib inline\n", 214 | "\n", 215 | "plt.plot(plt_x, plt_y)\n", 216 | "plt.show()" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3 (ipykernel)", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.8.11" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 5 241 | } 242 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/2.随机梯度下降-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "92163201", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "0.6590042695516539" 61 | ] 62 | }, 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "#预测函数\n", 70 | "def predict(x):\n", 71 | " return w.dot(x) + b\n", 72 | "\n", 73 | "\n", 74 | "predict(x[0])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "id": "a7bb7a80", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "0.21258140154187247" 87 | ] 88 | }, 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "output_type": "execute_result" 92 | } 93 | ], 94 | "source": [ 95 | "#求loss,MSELoss\n", 96 | "def get_loss(x, y):\n", 97 | " pred = predict(x)\n", 98 | " loss = (pred - y)**2\n", 99 | " return loss\n", 100 | "\n", 101 | "\n", 102 | "get_loss(x[0], y[0])" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "id": "8027d213", 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/plain": [ 114 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 115 | " 0.923131013558981)" 116 | ] 117 | }, 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | } 122 | ], 123 | "source": [ 124 | "def get_gradient(x, y):\n", 125 | " global w\n", 126 | " global b\n", 127 | "\n", 128 | " eps = 1e-3\n", 129 | "\n", 130 | " loss_before = get_loss(x, y)\n", 131 | "\n", 132 | " gradient_w = np.empty(M)\n", 133 | " for i in range(M):\n", 134 | " w[i] += eps\n", 135 | " loss_after = get_loss(x, y)\n", 136 | " w[i] -= eps\n", 137 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 138 | "\n", 139 | " b += eps\n", 140 | " loss_after = get_loss(x, y)\n", 141 | " b -= eps\n", 142 | " gradient_b = (loss_after - loss_before) / eps\n", 143 | "\n", 144 | " return gradient_w, gradient_b\n", 145 | "\n", 146 | "\n", 147 | "get_gradient(x[0], y[0])" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 6, 153 | "id": "f39e0125", 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "11073.905141728206" 160 | ] 161 | }, 162 | "execution_count": 6, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "def total_loss():\n", 169 | " loss = 0\n", 170 | " for i in range(N):\n", 171 | " loss += get_loss(x[i], y[i])\n", 172 | " return loss\n", 173 | "\n", 174 | "\n", 175 | "total_loss()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 7, 181 | "id": "c371c6a4", 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "0 11038.895201527894\n", 189 | "150 6696.3736283721655\n", 190 | "300 4354.119709336485\n", 191 | "450 3004.4953025777063\n", 192 | "600 2310.68634531403\n", 193 | "750 1839.580107951638\n", 194 | "900 1549.8047186884628\n", 195 | "1050 1278.2079624059054\n", 196 | "1200 1099.1810250366634\n", 197 | "1350 986.6025037947752\n", 198 | "1500 921.4757031198328\n", 199 | "1650 879.159069825457\n", 200 | "1800 853.2767252227716\n", 201 | "1950 835.2496534941863\n", 202 | "2100 812.3758750332744\n", 203 | "2250 794.1165878394305\n", 204 | "2400 786.9647280480957\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "plt_x = []\n", 210 | "plt_y = []\n", 211 | "for epoch in range(2500):\n", 212 | " i = np.random.randint(N)\n", 213 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 214 | " w -= gradient_w * 1e-3\n", 215 | " b -= gradient_b * 1e-3\n", 216 | "\n", 217 | " plt_x.append(epoch)\n", 218 | " plt_y.append(total_loss())\n", 219 | "\n", 220 | " if epoch % 150 == 0:\n", 221 | " print(epoch, total_loss())" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 8, 227 | "id": "0471a70d", 228 | "metadata": { 229 | "scrolled": true 230 | }, 231 | "outputs": [ 232 | { 233 | "data": { 234 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgw0lEQVR4nO3deZhcdZ3v8fe3qrp6Te+dpNPp7CEkJGBCE3ZRQAiBaxh1MM69EoWRq4KD8sxVGJxHR+Ze9Y4Ogt5hnihoUBQQosQRgQz7IAl0QnbopBOydGfpTnpJ791V9bt/1OnQxE5IejtdVZ/X89RT5/zOOVXfX1clnzq7OecQEZHUFvC7ABER8Z/CQEREFAYiIqIwEBERFAYiIgKE/C5goIqLi92UKVP8LkNEJGGsW7fusHOupL9pCRsGU6ZMobKy0u8yREQShpntOdE0bSYSERGFgYiIKAxERASFgYiIoDAQEREUBiIigsJARERIwTB4eXs9W2qb/S5DRGRUSakwaG7vYdlDb7B0+Rq/SxERGVVSKgzystK4as44WrsitHdH/C5HRGTUSKkwAPj4hyYA8O7hNp8rEREZPVIuDGaX5gKwcZ/2G4iI9Eq5MJhWnE1BVho/eK7K71JEREaNlAsDM+NzF02loa2btw8c9bscEZFRIeXCAOCzF07GDFZvO+R3KSIio0JKhkFhdpgJeZlU17X6XYqIyKiQkmEAcNaEXFZt3M/h1i6/SxER8V3KhsFVZ40HdIipiAikcBiU5WcC0BON+VyJiIj/UjYMwiEDoCfqfK5ERMR/KRsGacF41yNaMxARSd0wCAXiXe/sURiIiKRsGIzJCAGwZb8uSyEikrJhUF6YRVrQaOns8bsUERHfpWwYAJwxbgw1jR1+lyEi4rsPDAMze8jM6sxsS5+2QjNbbWY7vOcCr93M7H4zqzazTWa2oM8yy7z5d5jZsj7t55rZZm+Z+83MhrqTJzKxIJNahYGIyCmtGfwCWHRc253A8865mcDz3jjANcBM73EL8ADEwwP4FnA+sBD4Vm+AePN8oc9yx7/XsJlYkEVNYwfO6fBSEUltHxgGzrlXgIbjmpcAK7zhFcD1fdofdnFrgHwzKwWuBlY75xqcc43AamCRNy3XObfGxf9HfrjPaw27KUVZdPRE+ePmAyP1liIio9JA9xmMc871/g96EBjnDZcB+/rMV+O1nay9pp/2fpnZLWZWaWaV9fX1Ayz9PR//UBmTi7K47ddvDfq1REQS2aB3IHu/6EdkO4tzbrlzrsI5V1FSUjLo18vLTOOauaUA7DjUMujXExFJVAMNg0PeJh685zqvvRYo7zPfRK/tZO0T+2kfMR+bE1+p2X2kfSTfVkRkVBloGKwCeo8IWgY81af9Ru+ooguAZm9z0rPAVWZW4O04vgp41pt21Mwu8I4iurHPa42IacXZAGzc1zSSbysiMqqcyqGlvwFeB2aZWY2Z3Qx8D/iYme0ArvTGAZ4GdgHVwE+BLwM45xqAe4A3vcd3vDa8eX7mLbMT+NPQdO3UFGSHGZ+bwdPaiSwiKSz0QTM45z5zgklX9DOvA249wes8BDzUT3slMPeD6hhOs8aPYet+3Q9ZRFJXSp+B3GteWR6N7d3EYjrfQERSk8IAKMoJE405mjt0nSIRSU0KA6AoJx2Ag0c7fa5ERMQfCgNgfG4GAK9VH/a5EhERfygMgPmT8gF4btshfwsREfGJwoD4LTDPm1JA1cEWenQbTBFJQQoDz00XT6W5o4dNNbrzmYikHoWBZ1xefL/BUd35TERSkMLAkxUOAtDRHfW5EhGRkacw8GSlxU/GblcYiEgKUhh4Mr01g/buiM+ViIiMPIWBp3czUWuXwkBEUo/CwJMVDpKZFuSNd4+/w6eISPJTGHjMjPysNP6884jfpYiIjDiFQR/nTi6gOxLTEUUiknIUBn1cPKMYgNVv67IUIpJaFAZ9XP+hMgD+adVWnysRERlZCoM+MsNBFkzK11nIIpJyFAbHuWL2OHqijjrd20BEUojC4DhzJuQCsFaHmIpIClEYHOfi6cUEDHYcavG7FBGREaMwOE44FKA0L5M9De1+lyIiMmIUBv0oL8ykprHD7zJEREaMwqAfRdnpNLZ3+12GiMiIURj0IzczxNEOXbBORFKHwqAfuRlptOhcAxFJIQqDfuSkh+iKxKht0n4DEUkNCoN+fGhSPgBrdAVTEUkRCoN+XDy9mNyMEM9tO+h3KSIiI0Jh0I9AwJhSnM26PY1+lyIiMiIGFQZm9jUz22pmW8zsN2aWYWZTzWytmVWb2WNmFvbmTffGq73pU/q8zl1ee5WZXT3IPg2JS2YU09DWTSzm/C5FRGTYDTgMzKwM+Dugwjk3FwgCS4HvA/c652YAjcDN3iI3A41e+73efJjZHG+5s4BFwL+ZWXCgdQ2V4px0Yg6aOnRUkYgkv8FuJgoBmWYWArKAA8DlwBPe9BXA9d7wEm8cb/oVZmZe+6POuS7n3LtANbBwkHUNWlFOGICGti6fKxERGX4DDgPnXC3wA2Av8RBoBtYBTc653jO2aoAyb7gM2OctG/HmL+rb3s8y72Nmt5hZpZlV1tfXD7T0U1KY3RsGWjMQkeQ3mM1EBcR/1U8FJgDZxDfzDBvn3HLnXIVzrqKkpGQ43+pYGOzTBetEJAUMZjPRlcC7zrl651wPsBK4GMj3NhsBTARqveFaoBzAm54HHOnb3s8yvinLzwRg7bs610BEkt9gwmAvcIGZZXnb/q8AtgEvAp/y5lkGPOUNr/LG8aa/4JxzXvtS72ijqcBM4I1B1DUk8rPChIMBHq+sYXNNs9/liIgMq8HsM1hLfEfwemCz91rLgW8Ad5hZNfF9Ag96izwIFHntdwB3eq+zFXiceJA8A9zqnIsOtK6hdNvlMwB4ZO0enysRERleFv9xnngqKipcZWXlsL/Poh+9wjsHW9j+z9cQDukcPRFJXGa2zjlX0d80/e/2ARZOLQTgma26NIWIJC+FwQf40kemA3D3ys0+VyIiMnwUBh+gNC+ThVMLaemKsLO+1e9yRESGhcLgFHzF25H8rDYViUiSUhicglnjxgDw6Bv7PmBOEZHEpDA4BWNzM7ihYiJtXbovsogkJ4XBKRqXm0FjezdRXdJaRJKQwuAU9V7SuqGt2+9SRESGnMLgFE0pzgagcneDz5WIiAw9hcEpOl8nn4lIElMYnKKMtCCXzCjmqQ37efC/3vW7HBGRIaUwOA3f/cQ8AO75j210R2I+VyMiMnQUBqehvDCL7yw5C4BXtg/vndZEREaSwuA0XTuvFIBtB476XImIyNBRGJymopx0Jhdl8fDre0jUy3+LiBxPYTAAl51RwuHWLo7onAMRSRIKgwE4e2I+gC5PISJJQ2EwANnhIAC7j7T7XImIyNBQGAzA3LI8AP79pZ0+VyIiMjQUBgNQXpjFZxaW8/quI3zh4Uqa2rXvQEQSm8JggP7p43NZPG88q7cd4g+bDvhdjojIoCgMBigcCnDf0vkAvFxV53M1IiKDozAYhLRggHMnF7C5ttnvUkREBkVhMEiXnVHCoaNd7NWRRSKSwBQGg3T2xPiRRdfe/6rOOxCRhKUwGKSPzBrLPUvOoqUrwsu6eJ2IJCiFwRC49uwJALykHckikqAUBkOgMDvM9JJsaps6/C5FRGRAFAZDZFxuBp09uuGNiCSmQYWBmeWb2RNm9o6ZvW1mF5pZoZmtNrMd3nOBN6+Z2f1mVm1mm8xsQZ/XWebNv8PMlg22U37ICgfZfqiFls4ev0sRETltg10zuA94xjl3JnAO8DZwJ/C8c24m8Lw3DnANMNN73AI8AGBmhcC3gPOBhcC3egMkkXxk1lhaOiNc9L0X2LZfN74RkcQy4DAwszzgw8CDAM65budcE7AEWOHNtgK43hteAjzs4tYA+WZWClwNrHbONTjnGoHVwKKB1uWX/3HBZH5580JauyLc9pv1xGK68Y2IJI7BrBlMBeqBn5vZW2b2MzPLBsY553ov1nMQGOcNlwH7+ixf47WdqP0vmNktZlZpZpX19aPvMM5LZ5Zw9+LZ7KpvY0+DTkITkcQxmDAIAQuAB5xz84E23tskBICL3xdyyH4iO+eWO+cqnHMVJSUlQ/WyQ2r+pHwAlr+iy1uLSOIYTBjUADXOubXe+BPEw+GQt/kH77n34PtaoLzP8hO9thO1J6QFkwoYkx7iN2/so+5op9/liIickgGHgXPuILDPzGZ5TVcA24BVQO8RQcuAp7zhVcCN3lFFFwDN3uakZ4GrzKzA23F8ldeWkMyMH/9N/Gqm//C7LT5XIyJyakKDXP4rwCNmFgZ2AZ8nHjCPm9nNwB7gBm/ep4HFQDXQ7s2Lc67BzO4B3vTm+45zrmGQdfnqI7PGsnjeeJ7efJC7Vm7mu5+Y53dJIiInZfHN+omnoqLCVVZW+l3GCTW1d/O1xzbwYlU9v/3ihZw3pdDvkkQkxZnZOudcRX/TdAbyMMnPCnPfZ+Kbi9bsPOJzNSIiJ6cwGEa5GWmU5Wey4vU9RHXegYiMYgqDYXbd2aUcbu3iU//+Z7oiUb/LERHpl8JgmN3y4WlMLMjkrb1NvLr9sN/liIj0S2EwzIpy0nnySxcBcKhF5x2IyOikMBgBGWlBADq6tZlIREYnhcEIyPTCYGd9q8+ViIj0T2EwAtKCxpiM+CUqtHYgIqORwmAEmBm3fnQGAGve1TkHIjL6KAxGyF+fOxGAdw60+FyJiMhfUhiMkKKcdM4pz+exN/eSqJcAEZHkpTAYQUvOmcDuI+0caNYhpiIyuigMRtDFM4oBuPXX632uRETk/RQGI2jW+DFkhYO8tbeJ6jodZioio4fCYISt/PJFpAWNn7yww+9SRESOURiMsDPH5zKnNJfth7RmICKjh8LAB6V5mWw7cJTXdZ8DERklFAY++OZ1sxmfm8Hf/3ajLmstIqOCwsAHEwuy+PqiWdQ2dfDW3ia/yxERURj4Zc6EXABtKhKRUUFh4JPygiwA3tzd4HMlIiIKA99kp4f44mXT+fPOI+xraPe7HBFJcQoDH31yQRmgTUUi4j+FgY+ml+QAUNPU4XMlIpLqFAY+CgSM7HCQ+pYuv0sRkRSnMPDZmaW5vFRVR3ck5ncpIpLCFAY+u/2KmRxo7mT5Kzv9LkVEUpjCwGeXziymYnIBP3huO09tqPW7HBFJUQoDn5kZP7zhHABuf3QDf9x0wOeKRCQVKQxGgclF2Tzyt+cTDgW49dfreefgUb9LEpEUM+gwMLOgmb1lZv/hjU81s7VmVm1mj5lZ2GtP98arvelT+rzGXV57lZldPdiaEtHFM4r53ZcvAuDHL1T7XI2IpJqhWDO4HXi7z/j3gXudczOARuBmr/1moNFrv9ebDzObAywFzgIWAf9mZsEhqCvhnDUhj2vnlfLHTQf45eu7/S5HRFLIoMLAzCYC1wI/88YNuBx4wptlBXC9N7zEG8ebfoU3/xLgUedcl3PuXaAaWDiYuhLZ//nEPHLSQ/zjU1v5X7/dSCSqQ05FZPgNds3gR8DXgd7/sYqAJudcxBuvAcq84TJgH4A3vdmb/1h7P8u8j5ndYmaVZlZZX18/yNJHp7zMNP50+6VcOXssv11XwyNr9/pdkoikgAGHgZldB9Q559YNYT0n5Zxb7pyrcM5VlJSUjNTbjrjywix+tuw8JuRlsGFfk9/liEgKCA1i2YuBj5vZYiADyAXuA/LNLOT9+p8I9B48XwuUAzVmFgLygCN92nv1XSalTS3Jpupgi99liEgKGPCagXPuLufcROfcFOI7gF9wzv134EXgU95sy4CnvOFV3jje9Becc85rX+odbTQVmAm8MdC6ksncsjy2HTjKD56t8rsUEUlyw3GewTeAO8ysmvg+gQe99geBIq/9DuBOAOfcVuBxYBvwDHCrc043BiZ+qYpgwPjJi9XaXCQiw8riP84TT0VFhausrPS7jGFX09jOdT/+L7LDIV6783K/yxGRBGZm65xzFf1N0xnIo9zEgiyunVdKbVMH7x5u87scEUlSCoMEcOOFU0gLGt/8/WYSdU1OREY3hUECmDV+DH81v4zXqo/w1cc2+F2OiCQhhUGCuOf6udxQMZGnNuzXzmQRGXIKgwSRHgry9UVnAvDq9uQ8+1pE/KMwSCDFOekU56Tz5Poav0sRkSSjMEgwkwoz2X2k3e8yRCTJKAwSzDVzSwH45Zo9PlciIslEYZBgPn/xFKaXZHPv6u3UtXT6XY6IJAmFQYIJBQPct3Q+je3dXPTdF3jhnUN+lyQiSUBhkIDmluXxxBcvJBgwbvpFJas27ve7JBFJcAqDBHXu5EL+847LmFOayx2PbeDGh95gS22z32WJSIJSGCSw8sIsfvW35/PRM8fyyvZ6PvfzN3ipqo5YTJesEJHTozBIcIXZYX56YwW//sL5BAPG537+Jp9e/jqvVR9WKIjIKdMlrJNIZ0+U366r4d7V22lo62Z8bgbf/eQ8PjprrN+licgocLJLWCsMklBnT5RVG/bzzd9voTsaY3ZpLj/863OYMyHX79JExEe6n0GKyUgLcsN55Tx9+yV84dKpVB08yk2/eNPvskRkFFMYJLEZY8dw97VzuO3ymRw82sn3n3nH75JEZJRSGKSAL1w6lYumF/HASztZqYvciUg/FAYpYExGGituWsi8sjzueHwj19z3Kt9etZV3Dh71uzQRGSUUBikiLRjgVzefz5c/Mp38zDR+8efdLPrRq/zs1V1EojG/yxMRn+loohS1pbaZf3xqC2/tbWJCXgaP3nIhk4qy/C5LRIaRjiaSvzC3LI+VX7qIn/zNfA63dfPZh9bS0R31uywR8YnCIIWZGdedPYGbLp7KniPt/OvqKr9LEhGfKAyEO6+J31v5xap6jnb2+FyNiPhBYSAA3H7FTKrrWrnihy/z2QfXsm5Po98licgIUhgIAF/72Bn86ubzmV+ez7o9jSxd/jrfXrWVxrZuv0sTkRGgo4nkLxxo7uBfnq3i92/VEnMwuzSXs8vymFaSzTVzS3XUkUiC0oXqZEDePnCU32+opXJ3I3uOtHG4Nb6WsHjeeH706fmEQ1qxFEkkJwuD0CBetBx4GBgHOGC5c+4+MysEHgOmALuBG5xzjWZmwH3AYqAd+Jxzbr33WsuAb3ov/c/OuRUDrUuGzuzSXGaXvnel06qDLdz//A7+uPkAm2pe4oaKcpaeV87Y3AwfqxSRoTDgNQMzKwVKnXPrzWwMsA64Hvgc0OCc+56Z3QkUOOe+YWaLga8QD4Pzgfucc+d74VEJVBAPlXXAuc65k+7B1JqBP5xzLH9lFw+8vJOm9h7K8jP5yuUz+NiccRTlpPtdnoicxLCcdOacO9D7y9451wK8DZQBS4DeX/YriAcEXvvDLm4NkO8FytXAaudcgxcAq4FFA61LhpeZ8T8vm07l3Vfy0xsr6OyJcufKzVz43Rf4xWvvEtXd1UQS0pBs9DWzKcB8YC0wzjl3wJt0kPhmJIgHxb4+i9V4bSdq7+99bjGzSjOrrK+vH4rSZYBCwQAfmzOON+6+knuWnMXkoiy+/Ydtum+CSIIadBiYWQ7wJPBV59z7LoPp4tughuynonNuuXOuwjlXUVJSMlQvK4MQDBifvXAKz33tw+Skh3h5ez1rdx3xuywROU2DCgMzSyMeBI8451Z6zYe8zT+9+xXqvPZaoLzP4hO9thO1SwIxM57+u0vJzQjx6eVruGvlJn65Zg9b9zeTqEesiaSSwexANuL7BBqcc1/t0/4vwJE+O5ALnXNfN7Nrgdt4bwfy/c65hd4O5HXAAu8l1hPfgdxwsvfXDuTRqepgC7c/+hb7Gtpp8y58VzImnStnj+W2y2dSlp/pc4UiqWtYzjMws0uAV4HNQO8F8f+B+H6Dx4FJwB7ih5Y2eOHxE+I7h9uBzzvnKr3XuslbFuB/O+d+/kHvrzAY3aIxx5u7G1i7q4GXttexpbaZnqhjekk2/+2cCXxywUTKC3XymshI0kln4ru3DxzlDxv3s7m2mVd3HMYMrp4zntmluZw7uYCKKQVkpAX9LlMkqQ3LSWcip6PvCWzVdS08sa6WJ9bV8Oy2gzgH4VCA86YU8OGZJSw9bxJ5WWk+VyySWrRmIL5q64rwxu4GXttxmJe317OjrpVwMMBZZbl8ZuEkJhVmMX9SPukhrTWIDJY2E0nCWLenkee2HuTJ9TXHroVUlB1m0dzx3PrRGUzQDmiRAVMYSMJp64qws76VQ0e7WLm+hj9tOQjAd5acxdLzJukieSIDoDCQhPf6ziP88LkqKvc0Upafyfc+OY9LZ+rEQ5HToTCQpOCc44l1NXz9yU04B7kZIcbmZlCSk860kmwunlHMtJJszhg7hkDA/C5XZNRRGEhS2bb/KC9W1VF3tJNDR7vY19jO9kMt9ETj3+UxGSFmjM3h0xXlnDe1kLL8TB22KoIOLZUkM2dCLnMm5L6vrbMnSnVdK28fOMrGmib+XH2EO1duPjZ9YkEm8ycVMKUoi0tmFHP+tKKRLltkVNOagSSlnmiMDfua2NfQzr6GDrbsb6bqYAs1je3EHCycWsi180q5fn4ZeZk6p0FSgzYTiXia23t4Yn0ND7++mz1H2gEozkknLzNEWjBAelqQ3IwQuZlpTC3K5pKZxZw3pZCg9kFIElAYiPRj3Z5GXt95mJrGDpo7euiJOroiUVo6IzR39LDnSBsxB4XZYa6cPZainHSmFmWzYHI+04pztJNaEo72GYj049zJBZw7ueCE01u7IrxUVcczWw6yauN+uiMxem/kFg4FGJ+bwRnjxjBjbA5njMthekkOpfkZFGSFSQvqPAhJLAoDkRPISQ9x3dkTuO7sCTjncA52HW7jrb2NVNe1UtvUQdXBFl7ZXk93NHZsuXAowKTCLHLSQ8wuzWViQSblhVlMK85mWkk2WWH9s5PRR99KkVNgZpjBjLE5zBib875pkWiMnfVt7D7SxqGjnexraGf3kXaOtHbxpy0HaGrved/85YWZhAIBAgZZ4RBjMkKMHZNOZjhEXmYa+Vlp5GWmMSE/k7L8THIzQ+Skh8hMCxK/ErzI0FMYiAxSKBhg1vgxzBo/pt/p7d0R9ja0s6u+je2HWthV34YDorEYzR09tHdHWb+3ieaOHlq7IkRj/e/HCxhkh0PkZITIzwqTlxkiPzNMZjhIXmYaJWPSmVOaS1lBJmnBAEb8tqS9j4AZoYARCMSfs8IKF3mPwkBkmGWFQ5w5Ppczx+eyeF7pSed1ztHRE6WpvYd9De0caO6kpStCm/do7YrQ0hmhqb2b5o4edh1upak9HiitXZHTrCtIeihAKBggPRSgICtMZlqQcCgQfwQD7w174+mh+BFXY9JDpAWNjLTgsUd6KEAgYATNCAQgaPEQMu85HAzE3zMtQMDMe3Bsnt7hgLcWFg4GFFYjSGEgMoqYGVnhEFnh0GlfobWhrZtd9a3sb+4kFnPEnCMSc8RijqiLP0dijqj3XN/SRXckRiQWo7MnRkNbN509Udq6IzS2x+iOxOiOes/eo8sbHwm9ARIK2rFNZxmh3vAJkB6KB0tmWpDs9BDpoQBBb60nGAi8by0o2Pc5+N703vb3zxffhBcIWDz8QvH36w283sAKHgu+EwdW776m3nW9gDFqA05hIJIkCrPDFGYXDvv7xGKOtu5IPBwiMTp7onT2xOiKRIk5RzQWv+1pfDj+HHOO7kiM9u4oXZGY18ax0PqLYedo996jJ+qob+2ivStCVyRGe3eEhrYYnZEoXT3x92/titATfe9or5HW+5+8cw4HnOiIfTPICAVJCxrhUJBQwIg6h8GxNSXz1qxisXiY9P49Yi4+XpgdZvUdlw15HxQGInJaAgFjTMboPGu7dy2od+0nGnVEYjGiXnsk2mda7L1pvePRWHye3iDrisRDrssLu67Ie6/Vu6YVi8UDIOYcRnwTlwGYYfGnY+2RaIzOyHtrXZFojGAgADhiMd4XhmbxTW0BLxx6N6XlpA/P315hICJJIxAwAhi6LuHp05kxIiKiMBAREYWBiIigMBARERQGIiKCwkBERFAYiIgICgMRESGB73RmZvXAngEuXgwcHsJyEoH6nPxSrb+gPp+uyc65kv4mJGwYDIaZVZ7o1m/JSn1OfqnWX1Cfh5I2E4mIiMJARERSNwyW+12AD9Tn5Jdq/QX1ecik5D4DERF5v1RdMxARkT4UBiIiklphYGaLzKzKzKrN7E6/6xlKZrbbzDab2QYzq/TaCs1stZnt8J4LvHYzs/u9v8MmM1vgb/WnxsweMrM6M9vSp+20+2hmy7z5d5jZMj/6cqpO0Odvm1mt91lvMLPFfabd5fW5ysyu7tOeMN99Mys3sxfNbJuZbTWz2732pPysT9Lfkf2c4zdsTv4HEAR2AtOAMLARmON3XUPYv91A8XFt/xe40xu+E/i+N7wY+BPxu/NdAKz1u/5T7OOHgQXAloH2ESgEdnnPBd5wgd99O80+fxv4+37mneN9r9OBqd73PZho332gFFjgDY8Btnt9S8rP+iT9HdHPOZXWDBYC1c65Xc65buBRYInPNQ23JcAKb3gFcH2f9odd3Bog38xKfajvtDjnXgEajms+3T5eDax2zjU45xqB1cCiYS9+gE7Q5xNZAjzqnOtyzr0LVBP/3ifUd985d8A5t94bbgHeBspI0s/6JP09kWH5nFMpDMqAfX3Gazj5HzzROOA5M1tnZrd4beOccwe84YPAOG84mf4Wp9vHZOn7bd4mkYd6N5eQhH02synAfGAtKfBZH9dfGMHPOZXCINld4pxbAFwD3GpmH+470cXXL5P6OOJU6KPnAWA68CHgAPBDX6sZJmaWAzwJfNU5d7TvtGT8rPvp74h+zqkUBrVAeZ/xiV5bUnDO1XrPdcDviK8yHurd/OM913mzJ9Pf4nT7mPB9d84dcs5FnXMx4KfEP2tIoj6bWRrx/xgfcc6t9JqT9rPur78j/TmnUhi8Ccw0s6lmFgaWAqt8rmlImFm2mY3pHQauArYQ71/vERTLgKe84VXAjd5RGBcAzX1WvxPN6fbxWeAqMyvwVruv8toSxnH7d/6K+GcN8T4vNbN0M5sKzATeIMG++2ZmwIPA2865f+0zKSk/6xP1d8Q/Z7/3pI/kg/hRB9uJ73G/2+96hrBf04gfObAR2NrbN6AIeB7YAfwnUOi1G/D/vL/DZqDC7z6cYj9/Q3x1uYf49tCbB9JH4CbiO92qgc/73a8B9PmXXp82ef/YS/vMf7fX5yrgmj7tCfPdBy4hvgloE7DBeyxO1s/6JP0d0c9Zl6MQEZGU2kwkIiInoDAQERGFgYiIKAxERASFgYiIoDAQEREUBiIiAvx/Ytwm8HresDsAAAAASUVORK5CYII=\n", 235 | "text/plain": [ 236 | "
" 237 | ] 238 | }, 239 | "metadata": { 240 | "needs_background": "light" 241 | }, 242 | "output_type": "display_data" 243 | } 244 | ], 245 | "source": [ 246 | "from matplotlib import pyplot as plt\n", 247 | "%matplotlib inline\n", 248 | "\n", 249 | "plt.plot(plt_x, plt_y)\n", 250 | "plt.show()" 251 | ] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "Python 3 (ipykernel)", 257 | "language": "python", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.8.11" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 5 275 | } 276 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/3.momentum-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0\n", 49 | "\n", 50 | "#动量都初始化为0\n", 51 | "momentum_w = np.zeros(M)\n", 52 | "momentum_b = 0" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "id": "92163201", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "0.6590042695516539" 65 | ] 66 | }, 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "#预测函数\n", 74 | "def predict(x):\n", 75 | " return w.dot(x) + b\n", 76 | "\n", 77 | "\n", 78 | "predict(x[0])" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "id": "a7bb7a80", 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "0.21258140154187247" 91 | ] 92 | }, 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "#求loss,MSELoss\n", 100 | "def get_loss(x, y):\n", 101 | " pred = predict(x)\n", 102 | " loss = (pred - y)**2\n", 103 | " return loss\n", 104 | "\n", 105 | "\n", 106 | "get_loss(x[0], y[0])" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 5, 112 | "id": "8027d213", 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "text/plain": [ 118 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 119 | " 0.923131013558981)" 120 | ] 121 | }, 122 | "execution_count": 5, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "def get_gradient(x, y):\n", 129 | " global w\n", 130 | " global b\n", 131 | "\n", 132 | " eps = 1e-3\n", 133 | "\n", 134 | " loss_before = get_loss(x, y)\n", 135 | "\n", 136 | " gradient_w = np.empty(M)\n", 137 | " for i in range(M):\n", 138 | " w[i] += eps\n", 139 | " loss_after = get_loss(x, y)\n", 140 | " w[i] -= eps\n", 141 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 142 | "\n", 143 | " b += eps\n", 144 | " loss_after = get_loss(x, y)\n", 145 | " b -= eps\n", 146 | " gradient_b = (loss_after - loss_before) / eps\n", 147 | "\n", 148 | " return gradient_w, gradient_b\n", 149 | "\n", 150 | "\n", 151 | "get_gradient(x[0], y[0])" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 6, 157 | "id": "f39e0125", 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "11073.905141728206" 164 | ] 165 | }, 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "def total_loss():\n", 173 | " loss = 0\n", 174 | " for i in range(N):\n", 175 | " loss += get_loss(x[i], y[i])\n", 176 | " return loss\n", 177 | "\n", 178 | "\n", 179 | "total_loss()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 7, 185 | "id": "c371c6a4", 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "0 [-4.35683357 4.61375434 1.50384418 2.91083043 2.96213566] -4.022402214293841 11044.88570074484\n", 193 | "150 [ 1.26339579 2.13768158 0.49321428 -1.51848376 3.5785042 ] -1.7176170960084935 1967.1697066194426\n", 194 | "300 [ 5.80970624e-03 3.55685572e+00 3.82822053e+00 -2.58323564e-01\n", 195 | " 1.07002133e+01] 2.0649671248152166 947.1121093652372\n", 196 | "450 [-2.26148992 3.8055673 -3.18437755 -3.87469838 -0.61715981] 0.7950908166330994 815.4883875626289\n", 197 | "600 [1.55949644 0.79231053 3.14007815 3.71824435 0.98449926] -0.5781588479929312 783.2820118211715\n", 198 | "750 [ 3.59572027 -0.74530488 0.72892403 -0.15447051 -4.33429161] -2.0781779575632005 753.5637854055667\n", 199 | "900 [ 0.3379927 2.54425898 -4.14378216 0.92339395 0.76440147] 0.8774161483117863 769.3244441314099\n", 200 | "1050 [-1.00006534 -5.83176015 -5.28599293 -5.19888383 -3.62367306] 2.5986887635740827 742.322126662018\n", 201 | "1200 [ 1.32568154 -8.98944316 8.03060836 -1.84855209 -1.93601743] -4.04468557258265 753.4751326083475\n", 202 | "1350 [ 1.82402337 -1.95342977 1.07956753 2.00744178 -4.75462661] -2.284159092719058 734.9285345602206\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "plt_x = []\n", 208 | "plt_y = []\n", 209 | "for epoch in range(1500):\n", 210 | " i = np.random.randint(N)\n", 211 | "\n", 212 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 213 | "\n", 214 | " #这是更新动量的数学公式,0.8是过去动量的权重\n", 215 | " momentum_w = 0.8 * momentum_w + gradient_w\n", 216 | " momentum_b = 0.8 * momentum_b + gradient_b\n", 217 | "\n", 218 | " #这里更新参数不再使用梯度,而是使用动量\n", 219 | " w -= momentum_w * 1e-3\n", 220 | " b -= momentum_b * 1e-3\n", 221 | "\n", 222 | " #思考一下,在时刻0,动量都是0.此时更新动量,动量就等于梯度.\n", 223 | " #也就是说,再时刻0,其实就是再用梯度下降.\n", 224 | " #时刻1,是上一个时刻的梯度乘以0.8,再加上当前时刻的梯度\n", 225 | " #所以在时刻1,差不多可以认为是梯度乘以了1.8.不过这里面两部分的梯度是在两个不同的点上评估出来的.\n", 226 | " #在时刻2,差不多等同于时刻1.往后都差不多.\n", 227 | "\n", 228 | " plt_x.append(epoch)\n", 229 | " plt_y.append(total_loss())\n", 230 | "\n", 231 | " if epoch % 150 == 0:\n", 232 | " print(epoch, momentum_w, momentum_b, total_loss())" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 8, 238 | "id": "0471a70d", 239 | "metadata": { 240 | "scrolled": true 241 | }, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhEElEQVR4nO3deZRc5Xnn8e9Ta3dV7+pGaEMSWIBlbAy0McSO4xiz2HEMOWM7JI6tOJ5wZuJMvM1xIMkJmcTJiROfsc1MjEOMPThDMJg4gfESQjAxnJmwNMYIhBASQiC1JNTqVi/qtZZn/rhvN43S2nrRre77+5xTp+99762qp1+p+lfvXc3dERGRZEvFXYCIiMRPYSAiIgoDERFRGIiICAoDEREBMnEXMFvt7e2+bt26uMsQEVk0nnjiiYPu3jHTskUbBuvWraOrqyvuMkREFg0ze+loy7SZSEREFAYiIqIwEBERFAYiIoLCQEREUBiIiAgKAxERIWFh4O78jwe28+Pne+IuRUSkpiQqDMyMWx7ayYPPHYi7FBGRmpKoMABoLmQZGC3FXYaISE1JXBi0FnIcGpmIuwwRkZqSuDBoKWQ5NKKRgYjIdAkMgxwDGhmIiLxG4sKgVSMDEZF/J3Fh0JDPMDxext3jLkVEpGYkLgyK+QzlqjNRqcZdiohIzUhcGNRn0wCMjFdirkREpHYkLgyK+RAGJYWBiMikxIVBIRfd6XNkvBxzJSIitSNxYTA5Mhie0MhARGRS4sKgPquRgYjIkRIXBlP7DDQyEBGZkrgwmNxnMDyhkYGIyKQEhoFGBiIiR0pcGBQnjyZSGIiITDluGJjZN8zsgJk9M62tzczuN7Pt4WdraDczu8nMdpjZZjO7cNpzNoX1t5vZpmntF5nZ0+E5N5mZzfcvOV395MhAO5BFRKacyMjgfwFXHdF2PfCAu28AHgjzAO8BNoTHdcDNEIUHcCPwVuBi4MbJAAnr/Oa05x35XvMql0mRS6d0aKmIyDTHDQN3fwjoO6L5auC2MH0bcM209m955BGgxcxWAFcC97t7n7sfAu4HrgrLmtz9EY+uHPetaa+1YAr5NMMaGYiITJntPoPl7r4vTO8HlofpVcDuaevtCW3Hat8zQ/uMzOw6M+sys66entnf1L6xLsPQmC5jLSIyac47kMM3+lNyPWh3v8XdO929s6OjY9av05jPclgjAxGRKbMNg1fCJh7CzwOhvRtYM2291aHtWO2rZ2hfUA11GQbHFAYiIpNmGwb3ApNHBG0C7pnW/tFwVNElwEDYnHQfcIWZtYYdx1cA94Vlg2Z2STiK6KPTXmvBNNVlOKwwEBGZkjneCmZ2B/BOoN3M9hAdFfTnwF1m9nHgJeBDYfUfAO8FdgAjwMcA3L3PzP4EeDys98fuPrlT+reIjliqB34YHguqsS7L0PjQQr+NiMiicdwwcPdfOcqiy2ZY14FPHOV1vgF8Y4b2LuC849UxnxryGYY0MhARmZK4M5AhOpro8JjugywiMimhYZClXHXGSroPsogIJDYMoq1jOtdARCSS6DDQ4aUiIpFEhkFzfRaAgVGNDEREIKFh0FLIATAwOhFzJSIitSGZYRBGBv0jGhmIiEBSw6CgMBARmS6RYdBYl8UM+rXPQEQESGgYpFNGc32W/hHtMxARgYSGAUT7DXQ0kYhIJLFh0FzIaZ+BiEiQ2DBoqc9qn4GISJDcMChkGdA+AxERIMlhoJGBiMiUxIZBU9iBrMtYi4gkOAyK+Qzu6DLWIiIkOQxyaQAOj+vKpSIiyQ2DfHQZ65EJhYGISGLDoJCLwkAjAxGRBIdBMR9tJhqZqMRciYhI/BIcBtHIYFgjAxGRBIdBbjIMNDIQEUluGITNRMPagSwikuAwyGkzkYjIpOSGwdShpdpMJCKS2DDIZVJk06ZDS0VESHAYQDQ6GFEYiIgkPAxyGQ7raCIRkbmFgZl92sy2mNkzZnaHmdWZ2Xoze9TMdpjZnWaWC+vmw/yOsHzdtNe5IbRvM7Mr5/g7nbBiPq3LUYiIMIcwMLNVwO8Ane5+HpAGrgW+AHzJ3V8HHAI+Hp7yceBQaP9SWA8z2xie9wbgKuCrZpaebV0no5DLaJ+BiAhz30yUAerNLAMUgH3Au4C7w/LbgGvC9NVhnrD8MjOz0P5tdx939xeBHcDFc6zrhDTkMzqaSESEOYSBu3cDXwReJgqBAeAJoN/dJ79u7wFWhelVwO7w3HJYf9n09hmes6AKubTOMxARYW6biVqJvtWvB1YCRaLNPAvGzK4zsy4z6+rp6Znz6zXkMzoDWUSEuW0mejfworv3uHsJ+C7wNqAlbDYCWA10h+luYA1AWN4M9E5vn+E5r+Hut7h7p7t3dnR0zKH0SCGf1rWJRESYWxi8DFxiZoWw7f8y4FngQeADYZ1NwD1h+t4wT1j+I49uQHwvcG042mg9sAF4bA51nbBiPqPNRCIiRDuAZ8XdHzWzu4GfAGXgSeAW4PvAt83s86Ht1vCUW4G/NbMdQB/REUS4+xYzu4soSMrAJ9z9lHxdL+YyjJerlCtVMulEn3IhIgk36zAAcPcbgRuPaN7JDEcDufsY8MGjvM6fAn86l1pmo5CbvHJpheZ6hYGIJFei/wI26AY3IiJAwsOgMHXlUoWBiCRbosOgYfIGNzqiSEQSLtFhUNANbkREgISHwdQ+A12SQkQSLtFhMHU0kUYGIpJwiQ6DhrpoZDCkMBCRhEt0GDTVZQEYHC3FXImISLwSHQZ12TT12TSHhifiLkVEJFaJDgOA1kKWfo0MRCThEh8GzYUc/SMKAxFJtsSHQUt9loFRbSYSkWRTGBSyHNLIQEQSTmFQyGozkYgkXuLDoLk+x8DoBNF9dkREkinxYdDekKNUcY0ORCTREh8Gq1rqAejuH425EhGR+CQ+DFaGMNg3MBZzJSIi8Ul8GKxoqQNgr0YGIpJgiQ+D9mKeXDrF3gGFgYgkV+LDIJUyTm+uY2+/NhOJSHIlPgwAVrbUaTORiCSawgA4o63AS70jcZchIhIbhQFwVkcDBw+PM6BzDUQkoRQGRGEA8MLBwzFXIiISD4UBcNZpIQwOKAxEJJkUBsCa1nqyaeOFnuG4SxERiYXCAMikU5zV0cBz+wfjLkVEJBYKg+CNq5rZvGdAVy8VkURSGATnr2mhb3iCPYd0voGIJM+cwsDMWszsbjN7zsy2mtmlZtZmZveb2fbwszWsa2Z2k5ntMLPNZnbhtNfZFNbfbmab5vpLzcZFa1sBeGRnbxxvLyISq7mODL4C/JO7nwucD2wFrgcecPcNwANhHuA9wIbwuA64GcDM2oAbgbcCFwM3TgbIqXTu6Y10NOZ5aPvBU/3WIiKxm3UYmFkz8A7gVgB3n3D3fuBq4Law2m3ANWH6auBbHnkEaDGzFcCVwP3u3ufuh4D7gatmW9dsmRlvXd/G4y/2ab+BiCTOXEYG64Ee4Jtm9qSZfd3MisByd98X1tkPLA/Tq4Dd056/J7Qdrf3fMbPrzKzLzLp6enrmUPrMLl7fxv7BMe03EJHEmUsYZIALgZvd/QJgmFc3CQHg0Vfsefua7e63uHunu3d2dHTM18tOecu6NgAee7Fv3l9bRKSWzSUM9gB73P3RMH83UTi8Ejb/EH4eCMu7gTXTnr86tB2t/ZQ7Z3kjTXUZHt+lMBCRZJl1GLj7fmC3mZ0Tmi4DngXuBSaPCNoE3BOm7wU+Go4qugQYCJuT7gOuMLPWsOP4itB2yqVSRue6Nh7efpBypRpHCSIisZjr0UT/BbjdzDYDbwb+DPhz4HIz2w68O8wD/ADYCewA/gb4LQB37wP+BHg8PP44tMXiQ51r6O4f5Z6f7o2rBBGRU84W65EznZ2d3tXVNe+v6+5c9eWHMYMffvJnMbN5fw8RkTiY2RPu3jnTMp2BfAQz49ffto7n9g/x5O7+uMsRETklFAYz+MXzV9KQz3DTA9t1zoGIJILCYAYN+Qyfvvxs/nVbD9/bvO/4TxARWeQUBkex6dK1nL+6mRvv3cLIRDnuckREFpTC4Cgy6RS//a4N9A1P8Ey37nMgIkubwuAYzlvVBKCb3ojIkqcwOIbTm+poKWTZuk9hICJLm8LgGMyMc09vZOu+obhLERFZUAqD43j9iia27R+iUtUhpiKydCkMjuPs5Y2Mlirs7ddlrUVk6VIYHEdrIQfA4Fgp5kpERBaOwuA4muoyABwe07kGIrJ0KQyOoyGEwZDCQESWMIXBcTTkw8hgXGEgIkuXwuA4GuuyAAxpn4GILGEKg+NonNxMpJGBiCxhCoPjyGdSZNPG4KjCQESWLoXBcZgZbcUcfcPjcZciIrJgFAYnoKMxz8HDE3GXISKyYBQGJ6CjIU/PkEYGIrJ0KQxOQEdjngNDY3GXISKyYBQGJ2BFcz09Q+OMlytxlyIisiAUBifgzI4iVYeXekfiLkVEZEEoDE7A+vYiADt7DsdciYjIwlAYnIDJMHihZzjmSkREFobC4AQ01mVZ0VzHtv2645mILE0KgxP0xlXNPN09EHcZIiILQmFwgt60upkXDw7rJjcisiQpDE7QBWe0AtC1qy/mSkRE5t+cw8DM0mb2pJl9L8yvN7NHzWyHmd1pZrnQng/zO8LyddNe44bQvs3MrpxrTQvhorWt5DMpHnr+YNyliIjMu/kYGXwS2Dpt/gvAl9z9dcAh4OOh/ePAodD+pbAeZrYRuBZ4A3AV8FUzS89DXfOqLpvmZ85axj9v2U+l6nGXIyIyr+YUBma2GvgF4Oth3oB3AXeHVW4DrgnTV4d5wvLLwvpXA99293F3fxHYAVw8l7oWyoc617B3YIz7n90fdykiIvNqriODLwOfA6phfhnQ7+6TF//fA6wK06uA3QBh+UBYf6p9hue8hpldZ2ZdZtbV09Mzx9JP3uUbl3NmR5G/uG8bpUr1+E8QEVkkZh0GZvY+4IC7PzGP9RyTu9/i7p3u3tnR0XGq3nZKJp3ic1eey86eYf55yyun/P1FRBbKXEYGbwPeb2a7gG8TbR76CtBiZpmwzmqgO0x3A2sAwvJmoHd6+wzPqTmXb1zOyuY67uzaffyVRUQWiVmHgbvf4O6r3X0d0Q7gH7n7h4EHgQ+E1TYB94Tpe8M8YfmP3N1D+7XhaKP1wAbgsdnWtdDSKeODnWt4eHsPew7pwnUisjQsxHkGvwt8xsx2EO0TuDW03wosC+2fAa4HcPctwF3As8A/AZ9w95q+VvQHO1cDcFfXnpgrERGZHxZ9OV98Ojs7vaurK7b3/8itj7KzZ5iHP/fzpFIWWx0iIifKzJ5w986ZlukM5Fn6YOcauvtHeWRnb9yliIjMmcJglq7YuJzGugzfeUKbikRk8VMYzFJdNs3lG5fz0PM9LNZNbSIikxQGc/CmVc30Dk/wyuB43KWIiMyJwmAOzlvVDMAzus+BiCxyCoM5eP2KJszgmb0KAxFZ3BQGc1DMZ1i/rMizewfjLkVEZE4UBnP0+pVNPLtPYSAii5vCYI7esLKJPYdGGRjV7TBFZPFSGMzROcsbAdhxYCjmSkREZk9hMEdnhzDYtv9wzJWIiMyewmCOVrXU01rI0rWrL+5SRERmTWEwR6mU8fYNHfz4+R7d/UxEFi2FwTz4xTetoHd4gsdf1OhARBYnhcE8uHBtKwBb92snsogsTgqDedDekKetmGPbfp1vICKLk8Jgnpy/upmuXYfiLkNEZFYUBvPkneecxs6Dw7ponYgsSgqDefJLF66iuT7LH/zjM5R1VJGILDIKg3nSVJfl89ecx0939/PVf30h7nJERE6KwmAe/eL5K7n6zSv5ygPb2bynP+5yREROmMJgnv3x+8+jvSHHB772b9z5+MtxlyMickIUBvOsuZDlzusu5ezlDfzRvc+yu28k7pJERI5LYbAA1rUXueUjnTjO3zy8M+5yRESOS2GwQFa21NO5to2f7u6PuxQRkeNSGCyg9e1FXjw4jLvHXYqIyDEpDBbQ2mUFhsbKHBrRXdBEpLYpDBbQOadHN755Wmcli0iNUxgsoIvWtpJNGz/e1hN3KSIixzTrMDCzNWb2oJk9a2ZbzOyTob3NzO43s+3hZ2toNzO7ycx2mNlmM7tw2mttCutvN7NNc/+1akMhl+FnN3Two+deibsUEZFjmsvIoAx81t03ApcAnzCzjcD1wAPuvgF4IMwDvAfYEB7XATdDFB7AjcBbgYuBGycDZCl46/o2dvWO0DM0HncpIiJHNeswcPd97v6TMD0EbAVWAVcDt4XVbgOuCdNXA9/yyCNAi5mtAK4E7nf3Pnc/BNwPXDXbumrNO885DYA7HtPZyCJSuzLz8SJmtg64AHgUWO7u+8Ki/cDyML0K2D3taXtC29Hal4RzTm/kPeedzpf/5Xn2HBrhjatb6GjIcelZ7TTXZ+MuT0QEmIcwMLMG4O+BT7n7oJlNLXN3N7N5O8jezK4j2sTEGWecMV8vu+C++MHzyWdS3PvUXu7q2jPVfu7pjfzee1/PO87uiLE6EZE5hoGZZYmC4HZ3/25ofsXMVrj7vrAZ6EBo7wbWTHv66tDWDbzziPZ/nen93P0W4BaAzs7ORXMmVzGf4cvXXkCl6uztH2Vv/yiPvdjHd5/s5qPfeIyr37ySKzaezmWvP426bDruckUkgWy2Z8daNAS4Dehz909Na/9LoNfd/9zMrgfa3P1zZvYLwG8D7yXaWXyTu18cdiA/AUweXfQT4CJ37zvW+3d2dnpXV9esaq8VY6UKX31wB1/78U4mKlXetLqZL/3ymzmroyHu0kRkCTKzJ9y9c8ZlcwiDtwMPA08Dk7f2+j2i/QZ3AWcALwEfcve+EB7/k2jn8AjwMXfvCq/1G+G5AH/q7t883vsvhTCYNDpR4Xub9/L5729ltFThk5dt4Dd/9kxyGZ0GIiLzZ0HCIG5LKQwmHRga48Z7tvDDZ/ZzzvJGvr6pkzVthbjLEpEl4lhhoK+eNeS0xjpu/rWL+PpHO+nuH+UP73km7pJEJCEUBjXo3RuX86l3b+DBbT38wT8+TbW6OEdvIrJ4zMt5BjL/Pva29WzdN8T/fuRlRieq3PDec2lvyMddlogsURoZ1Kh0yvjiB9/Eb//86/iHJ/fwc3/xIPc/q2scicjCUBjUMDPjv155Dvd/5uc467QGfueOJ3VPZRFZEAqDReCsjgb+6lcvZLRU4ftP7zv+E0RETpLCYJFY01bgvFVN/ODpfbqNpojMO4XBIvLLbzmDzXsGuOene+MuRUSWGIXBInLtW9Zw0dpWPvudp7hvy/64yxGRJURhsIhk0ylu+42LOW9VM5+96yl2HDgcd0kiskQoDBaZhnyGmz98IblMimv+6v/yO3c8yfc379OJaSIyJwqDRWhlSz13/6dLueINy/l/L/Tyib/7Cb9266M67FREZk0XqlvkqlXn24/v5s9+sJWqO7971bl85JK1pFJ2/CeLSKLoQnVLWCpl/Opbz+C+T7+DznVt3HjvFn7164/w0PM9DI6VXnMY6lipwvZXhtiyd4DxciXGqkWk1mhksIS4O9/p2sMf/Z8tjExEf+zNIJtKsaq1nu5Do0xUoltP1GVTvGVdGxee0cqmn1lHWzEXZ+knrffwOJWqc1pTXdyliCwaup9BwgyOlXhsZx+7eocZHC0xVq6ys2eYszqKbFzZRDpldO06xCM7e3lu/xAAhVyalS31vOvc0/jM5WfX3O03p/8//fufdHPDdzdTqTofuWQtn/j51ykURE6AwkCOasveAR7efpCeoXFePDjMg9sOsKyYZ+PKJvKZFAZUHdqKWZrqsgCUq046ZbQ35OlozHNaY/SzozFPXTbNnkMj7Do4wst9w7zUO8JLvSNk0kYxl+HQyASNdRky6RQpM5rqMrQ35GktZHHg+VeGGBork8ukqM+mGRwr88SuPvYOjGEG+UyKsVKVt6xr5XWnNXDn47tJmXHNBav45besYU1rgdMa89pnIjIDhYGcsH97oZfbH32Jl/tGKFV86hv5oZEJBkfLpAwy6RSlSnVqU9SxtBSynNFWYLxUpVSp0lLIcni8TLnquMPAaIm+4Ymp9Zvrs7QVc0yUq4yWKuQzKS5c28qZ7UUAxstVzmwv8h8uWk02nWLXwWFu+7dd3P7Iy1ObwHLpFCta6mgr5miqy9JYl6GpPsuyYo628KjPpnHg5d4Rtu4b5IWDw6xqqeONq1poqMvg7lSrjgPFfIamugyN4bXqs2nymTR12RT5TJp8Ntr19nLfCAcPj2MY+WyKukya+lya+myaTNooV5yJcpXhiTJPdw/w0PM9PP/KELt6R1jbVuBnzlrGOac3ce6KRta2FShXncPjZQZHSwyNlRkci/qqu380GvGVqq/p62I+TUM+y5a9A2zdN8jQWJn17UV+7uwOljXkaMhnKebTFHLR71CXTU2NANMpI5s2MqkUmbThDqOlCqVKlbpsOnpkUjTWZadux1qtOuWqU65WKZWd8UqF8VKV8XJ16vfsPTxBpepEPRmNQNcuK9JayFHIpcmlU1PBPTpRoXd4nL7hCXqHJ+g9PEGpUuW0xjz1uaiGfCbFeLnK6ESF4fEyIxMVth+I+rAuk+byjcs59/RGOhrzFPOvXqG/XInqGi9XGStVpl6jf2SC8XKVctUp5tKs7yiyrJgnnTJSFl0scrbcnb7hCUYmKkxUoj6ZKEefg4lyFbPQ5+kUmZRRl03TWJchm05RrlapVJ1yxRkvVxidqDJWrjBWqmAYb9/QPquaFAayIIbHyxw8PM6BoXF6wmNkosKatnrWthU5Y1mB5vrscV+nXKkyOFbG3Wkr5mb1ATx4eJzNe/rp7h+j+9Ao3f2jHBqeYHAs/CEdLXFoZIKZTsdY3pRnfXuR7v5RdveNnvR7z9bpTXW8YWUT69uLPLtvkKd29zN8AgGbS6doLmTJZ6LRFYDjHB4rMzRWZl17kc61rRTzGbp29fHsvkFKldr9nGfTRqXqM/7bnIh0yjijrUDv4XEGx8pT7XXZqH/Gy9Ef1tlIGSEYjHRq2sOMVPiZThmpcCiOe/SoutM/UmK0NP8HarQ35On6g3fP6rnHCgPd3EZmrZjPUMxnWLusOKfXyaRTc96B3d6Q513nLj/mOpWqMzhaond4grFSBXdY2VLHsmk3DRocKzFWqkQffjOcKPQmQ2VorDz1zXK8/Oo34ao7a9oKdDTkcTz6BjpRCd/mom+DmVRqavPX+vYir1/R+Jrgq1ad7v5Rnts/xJ5DI+QyKRryGZrqsjTVRz9bCjmWFXMntRmsWnVGShWGxkoMj5cZnYhGXWPhMdk3papTrkTfkg0o5DKkUzb1e45MRL9/qVKF0D+ZdPTHcPKbbT4T/Y75TJpCLk1bMUc2Hf2lNIOhsRIv940wMFJipPRq/2VSRiGfDqO3PG3FHO0NOTLpFD1D44yGvhwvVclnUxSyaYr5DPW5NCub66nPpRkrVXi6e4CXeqMR2sGh8bBpMaorPzmSmzbdWshRl02RSaUYHCuxs+cwA6MlKlWohNHh1M8wXQnT1anp6I8/gBEd4WdAU32W1a31FPOZqF/SUd/kMimy6RRVj775l6tVJsrRCODweJlSuTo1WkilLIziopFcfTYa2S0EjQxERBJC5xmIiMgxKQxERERhICIiCgMREUFhICIiKAxERASFgYiIoDAQEREW8UlnZtYDvDTLp7cDB+exnPlW6/WBapwPtV4f1H6NtV4f1FaNa929Y6YFizYM5sLMuo52Fl4tqPX6QDXOh1qvD2q/xlqvDxZHjaDNRCIigsJARERIbhjcEncBx1Hr9YFqnA+1Xh/Ufo21Xh8sjhqTuc9AREReK6kjAxERmUZhICIiyQoDM7vKzLaZ2Q4zuz7GOtaY2YNm9qyZbTGzT4b2NjO738y2h5+tod3M7KZQ92Yzu/AU1Zk2syfN7Hthfr2ZPRrquNPMcqE9H+Z3hOXrTlF9LWZ2t5k9Z2ZbzezSWupDM/t0+Pd9xszuMLO6uPvQzL5hZgfM7JlpbSfdZ2a2Kay/3cw2nYIa/zL8O282s38ws5Zpy24INW4zsyuntS/Y532mGqct+6yZuZm1h/lY+vGkuXsiHkAaeAE4E8gBTwEbY6plBXBhmG4Engc2An8BXB/arwe+EKbfC/yQ6K56lwCPnqI6PwP8HfC9MH8XcG2Y/hrwn8P0bwFfC9PXAneeovpuA/5jmM4BLbXSh8Aq4EWgflrf/XrcfQi8A7gQeGZa20n1GdAG7Aw/W8N06wLXeAWQCdNfmFbjxvBZzgPrw2c8vdCf95lqDO1rgPuITohtj7MfT/p3iuuNT/kvCpcC902bvwG4Ie66Qi33AJcD24AVoW0FsC1M/zXwK9PWn1pvAWtaDTwAvAv4XviPfHDaB3KqP8N//kvDdCasZwtcX3P4Y2tHtNdEHxKFwe7wQc+EPryyFvoQWHfEH9qT6jPgV4C/ntb+mvUWosYjlv0ScHuYfs3neLIfT8XnfaYagbuB84FdvBoGsfXjyTyStJlo8sM5aU9oi1XYHHAB8Ciw3N33hUX7gck7vMdR+5eBzwHVML8M6Hf38gw1TNUXlg+E9RfSeqAH+GbYlPV1MytSI33o7t3AF4GXgX1EffIEtdWHk062z+L+LP0G0TdtjlHLKa/RzK4Gut39qSMW1UyNx5KkMKg5ZtYA/D3wKXcfnL7Mo68KsRz3a2bvAw64+xNxvP8JyhAN02929wuAYaJNHFNi7sNW4Gqi0FoJFIGr4qjlZMTZZyfCzH4fKAO3x13LdGZWAH4P+MO4a5mtJIVBN9H2vEmrQ1sszCxLFAS3u/t3Q/MrZrYiLF8BHAjtp7r2twHvN7NdwLeJNhV9BWgxs8wMNUzVF5Y3A70LWB9E36L2uPujYf5uonColT58N/Ciu/e4ewn4LlG/1lIfTjrZPovls2Rmvw68D/hwCK1aqvEsouB/KnxuVgM/MbPTa6jGY0pSGDwObAhHc+SIdtLdG0chZmbArcBWd//v0xbdC0weUbCJaF/CZPtHw1EJlwAD04b1887db3D31e6+jqiffuTuHwYeBD5wlPom6/5AWH9Bv126+35gt5mdE5ouA56lRvqQaPPQJWZWCP/ek/XVTB9Oc7J9dh9whZm1hhHQFaFtwZjZVUSbLd/v7iNH1H5tOBprPbABeIxT/Hl396fd/TR3Xxc+N3uIDhLZTw314zHFtbMijgfRXv3niY4y+P0Y63g70VB8M/DT8Hgv0TbiB4DtwL8AbWF9A/4q1P000HkKa30nrx5NdCbRB20H8B0gH9rrwvyOsPzMU1Tbm4Gu0I//SHRERs30IfDfgOeAZ4C/JTriJdY+BO4g2odRIvqD9fHZ9BnRdvsd4fGxU1DjDqLt65Ofl69NW//3Q43bgPdMa1+wz/tMNR6xfBev7kCOpR9P9qHLUYiISKI2E4mIyFEoDERERGEgIiIKAxERQWEgIiIoDEREBIWBiIgA/x+8OMovNC6pQwAAAABJRU5ErkJggg==\n", 246 | "text/plain": [ 247 | "
" 248 | ] 249 | }, 250 | "metadata": { 251 | "needs_background": "light" 252 | }, 253 | "output_type": "display_data" 254 | } 255 | ], 256 | "source": [ 257 | "from matplotlib import pyplot as plt\n", 258 | "%matplotlib inline\n", 259 | "\n", 260 | "plt.plot(plt_x, plt_y)\n", 261 | "plt.show()" 262 | ] 263 | } 264 | ], 265 | "metadata": { 266 | "kernelspec": { 267 | "display_name": "Python 3 (ipykernel)", 268 | "language": "python", 269 | "name": "python3" 270 | }, 271 | "language_info": { 272 | "codemirror_mode": { 273 | "name": "ipython", 274 | "version": 3 275 | }, 276 | "file_extension": ".py", 277 | "mimetype": "text/x-python", 278 | "name": "python", 279 | "nbconvert_exporter": "python", 280 | "pygments_lexer": "ipython3", 281 | "version": "3.8.11" 282 | } 283 | }, 284 | "nbformat": 4, 285 | "nbformat_minor": 5 286 | } 287 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/4.ada_grad-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "array([0., 0., 0., 0., 0.])" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#常量\n", 55 | "N, M = x.shape\n", 56 | "\n", 57 | "#变量\n", 58 | "w = np.ones(M)\n", 59 | "b = 0\n", 60 | "\n", 61 | "#初始化S为全0\n", 62 | "S_w = np.zeros(M)\n", 63 | "S_b = 0\n", 64 | "\n", 65 | "S_w" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "id": "92163201", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "0.6590042695516539" 78 | ] 79 | }, 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "#预测函数\n", 87 | "def predict(x):\n", 88 | " return w.dot(x) + b\n", 89 | "\n", 90 | "\n", 91 | "predict(x[0])" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "id": "a7bb7a80", 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "0.21258140154187247" 104 | ] 105 | }, 106 | "execution_count": 4, 107 | "metadata": {}, 108 | "output_type": "execute_result" 109 | } 110 | ], 111 | "source": [ 112 | "#求loss,MSELoss\n", 113 | "def get_loss(x, y):\n", 114 | " pred = predict(x)\n", 115 | " loss = (pred - y)**2\n", 116 | " return loss\n", 117 | "\n", 118 | "\n", 119 | "get_loss(x[0], y[0])" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 5, 125 | "id": "8027d213", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 132 | " 0.923131013558981)" 133 | ] 134 | }, 135 | "execution_count": 5, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "def get_gradient(x, y):\n", 142 | " global w\n", 143 | " global b\n", 144 | "\n", 145 | " eps = 1e-3\n", 146 | "\n", 147 | " loss_before = get_loss(x, y)\n", 148 | "\n", 149 | " gradient_w = np.empty(M)\n", 150 | " for i in range(M):\n", 151 | " w[i] += eps\n", 152 | " loss_after = get_loss(x, y)\n", 153 | " w[i] -= eps\n", 154 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 155 | "\n", 156 | " b += eps\n", 157 | " loss_after = get_loss(x, y)\n", 158 | " b -= eps\n", 159 | " gradient_b = (loss_after - loss_before) / eps\n", 160 | "\n", 161 | " return gradient_w, gradient_b\n", 162 | "\n", 163 | "\n", 164 | "get_gradient(x[0], y[0])" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 6, 170 | "id": "f39e0125", 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "11073.905141728206" 177 | ] 178 | }, 179 | "execution_count": 6, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | } 183 | ], 184 | "source": [ 185 | "def total_loss():\n", 186 | " loss = 0\n", 187 | " for i in range(N):\n", 188 | " loss += get_loss(x[i], y[i])\n", 189 | " return loss\n", 190 | "\n", 191 | "\n", 192 | "total_loss()" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 7, 198 | "id": "c371c6a4", 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "0 [0.02348118 0.05140723 0.0153264 0.01141825 0.03465209] 0.014058087368505013 10246.765269340094\n", 206 | "150 [0.00215905 0.00213966 0.0025248 0.0023185 0.00215242] 0.0024356945540773096 2658.3900946275558\n", 207 | "300 [0.00159518 0.00180842 0.0019631 0.00191729 0.00173008] 0.0019876188630781272 1600.8100316915763\n", 208 | "450 [0.00148661 0.00166541 0.00178354 0.00173962 0.00157404] 0.0017842433083993126 1186.465991584585\n", 209 | "600 [0.00139184 0.00156629 0.00167125 0.00164534 0.00149001] 0.0016826253755351525 1018.0084411998928\n", 210 | "750 [0.00132517 0.00150278 0.00158741 0.00157468 0.00142465] 0.001594262705724423 927.9759431415143\n", 211 | "900 [0.0012664 0.00144711 0.00152095 0.00150801 0.00135752] 0.0015236251237426357 861.6745940332113\n", 212 | "1050 [0.00123801 0.00141408 0.0014759 0.00146659 0.00133255] 0.001476298163704934 824.5775032016687\n", 213 | "1200 [0.00121116 0.00136314 0.001434 0.00142441 0.00128068] 0.001429332532480186 807.1329761191264\n", 214 | "1350 [0.00116843 0.0013307 0.00139645 0.00138838 0.00124704] 0.0013899172750459162 791.1949098503449\n", 215 | "1500 [0.00115129 0.00130284 0.0013622 0.00136147 0.00122016] 0.0013561117340229432 788.5470016064282\n", 216 | "1650 [0.00111086 0.00126146 0.00132726 0.00132571 0.00119364] 0.0013234756456931064 783.241050040833\n", 217 | "1800 [0.00105912 0.00119525 0.00128706 0.00128536 0.00116119] 0.0012821836202760327 776.5490903516861\n", 218 | "1950 [0.00101795 0.00115106 0.00125265 0.00124849 0.00112686] 0.0012443142251112867 767.7869135624381\n", 219 | "2100 [0.00099775 0.00112372 0.00122065 0.00122 0.00111091] 0.0012130252246838496 762.3245228610506\n", 220 | "2250 [0.00098704 0.00109155 0.00119765 0.00120011 0.00109044] 0.00119230318143276 757.1151773657198\n", 221 | "2400 [0.00096106 0.00106878 0.00117239 0.00118274 0.00108084] 0.0011689278828959015 754.8130624984661\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "plt_x = []\n", 227 | "plt_y = []\n", 228 | "for epoch in range(2500):\n", 229 | " i = np.random.randint(N)\n", 230 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 231 | "\n", 232 | " #adagrad的特点是每个变量都有属于自己的lr\n", 233 | " #要计算各个变量的lr,先要计算S\n", 234 | " #这是S的计算公式\n", 235 | " S_w = S_w + gradient_w**2\n", 236 | " S_b = S_b + gradient_b**2\n", 237 | "\n", 238 | " #计算lr的公式,其中的1e-1是原本的lr,1e-6是防止除0的\n", 239 | " lr_w = 1e-1 / ((S_w + 1e-6)**0.5)\n", 240 | " lr_b = 1e-1 / ((S_b + 1e-6)**0.5)\n", 241 | "\n", 242 | " #所以在时刻0,lr就等于梯度的倒数\n", 243 | " #梯度大的变量会有小lr,梯度小的变量会有大lr\n", 244 | " #往后的每一个时刻,都是类似动量法,考虑上一步的梯度\n", 245 | "\n", 246 | " w -= gradient_w * lr_w\n", 247 | " b -= gradient_b * lr_b\n", 248 | "\n", 249 | " plt_x.append(epoch)\n", 250 | " plt_y.append(total_loss())\n", 251 | "\n", 252 | " if epoch % 150 == 0:\n", 253 | " print(epoch, lr_w, lr_b, total_loss())" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 8, 259 | "id": "0471a70d", 260 | "metadata": { 261 | "scrolled": true 262 | }, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAcdklEQVR4nO3dfXRc9X3n8fd3ZvT8ZNmSjW0ZZEABHAIBFHACSUPM8hS6ZrtJlj274KbsctqlDTSb3YV2t+QkzTbpNiGhm9DSQArdNIQNyeKzTcKaZ9oEE0EcwDi2hcHYxtiyLduyZT3NfPeP+xt5ZEt+0Ei6Gt3P6xydufd37535/jTCH3730dwdERFJtlTcBYiISPwUBiIiojAQERGFgYiIoDAQEREgE3cB49XU1OStra1xlyEiUjJeeumlXe7ePNqykg2D1tZWOjo64i5DRKRkmNnmsZZpN5GIiCgMREREYSAiIigMREQEhYGIiKAwEBERFAYiIkICw+DFN/ewYUdP3GWIiEwrJXvR2Xh96q9/DsBbX/54zJWIiEwfiRsZiIjI0RQGIiKiMBAREYWBiIigMBARERJ4NtFFpzVSWaYMFBEpdNx/Fc3sATPbaWavFbTNNrNVZrYxvDaGdjOze8ys08xeMbMLC7ZZEdbfaGYrCtovMrNXwzb3mJlNdCcLpc3I5nwyP0JEpOScyP8i/y1w9RFtdwBPunsb8GSYB7gGaAs/twD3QhQewF3AJcDFwF35AAnr/PuC7Y78rAmVSkEuN5mfICJSeo4bBu7+HLDniOblwINh+kHg+oL2hzzyAjDLzOYDVwGr3H2Pu3cDq4Crw7J6d3/B3R14qOC9JkXKjJxrZCAiUmi8O8/nufv2MP0uMC9MLwS2FKy3NbQdq33rKO2jMrNbzKzDzDq6urrGVXg6ZWQVBiIiIxR9JDX8H/2U/Ovq7ve5e7u7tzc3j/pM5+Oqq8ywr3dwgisTESlt4w2DHWEXD+F1Z2jfBiwqWK8ltB2rvWWU9klz6uwatnT36iCyiEiB8YbBSiB/RtAK4LGC9pvCWUVLgX1hd9LjwJVm1hgOHF8JPB6W7TezpeEsopsK3mtSNFaXMZh1+gazk/kxIiIl5bjXGZjZ94CPAk1mtpXorKAvA4+Y2c3AZuBTYfUfA9cCnUAv8GkAd99jZl8EfhHW+4K75w9K/weiM5aqgJ+En0lTnonyb2AoR03FZH6SiEjpOG4YuPu/HmPRslHWdeDWMd7nAeCBUdo7gHOPV8dEGQ6DrM4vFRHJS9yluOXpwyMDERGJJC8MwsigX2EgIjIseWGgkYGIyFGSFwZhZDCoYwYiIsMSFwb56wtefrs75kpERKaPxIXB+nd7APjyT34dcyUiItNH4sLgty6KLnj+zLK2mCsREZk+EhcGteXRpRVVZemYKxERmT4SFwbpdPTsnCE91EBEZFjiwiCTyoeBblQnIpKX2DDIZhUGIiJ5iQuDtEYGIiJHSVwYmFn0tDOFgYjIsMSFAUSjA40MREQOS2QYZFLGkG5HISIyLJFhoJGBiMhIiQyDsnRKxwxERAokMgw0MhARGSmRYZBJGVldgSwiMiyRYaCRgYjISIkMg4yuMxARGSGRYZBOGUO6HYWIyLBEhkFZOqW7loqIFEhkGOh2FCIiIyUyDDI6gCwiMkIiw0AjAxGRkRIZBplUSgeQRUQKJDIMousMdABZRCQvkWGQSeuYgYhIoWSGgY4ZiIiMkMgwSOuYgYjICIkMA40MRERGKioMzOwPzWytmb1mZt8zs0ozW2xmq82s08y+b2blYd2KMN8ZlrcWvM+doX29mV1VZJ+OK53WAWQRkULjDgMzWwh8Bmh393OBNHAD8BXgbnc/E+gGbg6b3Ax0h/a7w3qY2ZKw3XuBq4FvmVl6vHWdCI0MRERGKnY3UQaoMrMMUA1sBz4G/CAsfxC4PkwvD/OE5cvMzEL7w+7e7+5vAp3AxUXWdUzplDEwpJGBiEjeuMPA3bcBfwG8TRQC+4CXgL3uPhRW2wosDNMLgS1h26Gw/pzC9lG2GcHMbjGzDjPr6OrqGm/pzK2r5J19feQ0OhARAYrbTdRI9H/1i4EFQA3Rbp5J4+73uXu7u7c3NzeP+32qy6O9UIoCEZFIMbuJrgDedPcudx8EfghcCswKu40AWoBtYXobsAggLG8Adhe2j7LNpEhZ9JpzxYGICBQXBm8DS82sOuz7Xwa8DjwNfCKsswJ4LEyvDPOE5U+5u4f2G8LZRouBNuDFIuo6rqhchYGISF7m+KuMzt1Xm9kPgJeBIeCXwH3APwAPm9mfhrb7wyb3A39nZp3AHqIziHD3tWb2CFGQDAG3unt2vHWdiJAFKAtERCLjDgMAd78LuOuI5k2McjaQu/cBnxzjfb4EfKmYWk5GKqSBwkBEJJLIK5B1zEBEZKSEhoGOGYiIFEpkGBw+gBxzISIi00QiwyA1fABZaSAiAokNA40MREQKJTIMTAeQRURGSGgY6NRSEZFCiQwDHTMQERkpoWGgYwYiIoUSGgbRq55pICISSWQY7DowAMDKX03qzVFFREpGIsPgX17YAmhkICKSl8gwOKWhkopMin6FgYgIkNAwAKityHCgf+j4K4qIJEBiw6C6Ik3vwKQ+NkFEpGQkNgxqyjMc1MhARARIcBhUl2tkICKSl9gwqMik6R9SGIiIQILDoDyT0qmlIiJBosOgR8cMREQAyMRdQFxWvb4j7hJERKaNxI4MLj+rGYC+QR03EBFJbBi8f1EjAL/asjfeQkREpoHEhsHHzp4LwM837Y65EhGR+CU2DM6eXwdAOv8MTBGRBEtsGGRShhkMZHV6qYhIYsPAzChPpxQGIiIkOAyAKAx04ZmISMLDQFchi4gACQ+DsnSKQe0mEhFJdhhoZCAiEikqDMxslpn9wMx+bWbrzOyDZjbbzFaZ2cbw2hjWNTO7x8w6zewVM7uw4H1WhPU3mtmKYjt1osrSxmDWp+rjRESmrWJHBt8AfuruZwPnA+uAO4An3b0NeDLMA1wDtIWfW4B7AcxsNnAXcAlwMXBXPkAmW1k6xYtv7ZmKjxIRmdbGHQZm1gB8BLgfwN0H3H0vsBx4MKz2IHB9mF4OPOSRF4BZZjYfuApY5e573L0bWAVcPd66TkZP3xBdPf1T8VEiItNaMSODxUAX8B0z+6WZfdvMaoB57r49rPMuMC9MLwS2FGy/NbSN1X4UM7vFzDrMrKOrq6uI0iM3fGARoJvViYgUEwYZ4ELgXne/ADjI4V1CALi7AxO2U97d73P3dndvb25uLvr9GqrLgGiEICKSZMWEwVZgq7uvDvM/IAqHHWH3D+F1Z1i+DVhUsH1LaBurfdLVVUaPc+jpG5yKjxMRmbbGHQbu/i6wxczOCk3LgNeBlUD+jKAVwGNheiVwUziraCmwL+xOehy40swaw4HjK0PbpKuv1MhARASKf9LZHwDfNbNyYBPwaaKAecTMbgY2A58K6/4YuBboBHrDurj7HjP7IvCLsN4X3H1KTvGpC2GwXyMDEUm4osLA3dcA7aMsWjbKug7cOsb7PAA8UEwt41Ffld9NpJGBiCRboq9AHh4ZHNLIQESSLdFh0BjOJtqxX9caiEiyJToMqsuj3UR3P7FBN6wTkURLdBgAfLitCYDOnQdirkREJD6JD4NbLz8TgO6DAzFXIiISn8SHweyacgD29CoMRCS5Eh8Gs8JBZI0MRCTJEh8GjdXRyGDXAYWBiCRX4sOgLJ1iQUMlG3f2xF2KiEhsEh8GAKfOqdZzDUQk0RQGwJzaCnZrN5GIJJjCAJhdXU63ziYSkQRTGBA916Cnb4joXnoiIsmjMCC6Yd1Qzukb1C0pRCSZFAboiWciIgoDDofBfj3XQEQSSmHA4cdfvrP3UMyViIjEQ2EALJhVBcB9z22KuRIRkXgoDICzTqkjZVCWtrhLERGJhcIguKytmT26WZ2IJJTCIGiqKdfN6kQksRQGwSkNlby7v49cTheeiUjyKAyCxupysjnnwIBOLxWR5FEYBIcvPFMYiEjyKAyCunCtga5CFpEkUhgEjeHxl292HYy5EhGRqacwCNpbZ1NfmeH3vvuyTjEVkcRRGATlmRTXnb8AgB/9clvM1YiITC2FQYE/uW4JAJt3a1eRiCSLwqBAZVmahbOqeOjnm/WgGxFJFIXBES45fTYAL23ujrkSEZGpozA4wm3L2gBYv6Mn5kpERKZO0WFgZmkz+6WZ/d8wv9jMVptZp5l938zKQ3tFmO8My1sL3uPO0L7ezK4qtqZitDRWkzL43x1b4yxDRGRKTcTI4DZgXcH8V4C73f1MoBu4ObTfDHSH9rvDepjZEuAG4L3A1cC3zCw9AXWNSzplXHRaI2u27GVnT19cZYiITKmiwsDMWoCPA98O8wZ8DPhBWOVB4PowvTzME5YvC+svBx529353fxPoBC4upq5ifWH5uQA8uW5nnGWIiEyZYkcGXwf+M5AL83OAve6ev8HPVmBhmF4IbAEIy/eF9YfbR9lmBDO7xcw6zKyjq6uryNLH1jqnBoDuXl18JiLJMO4wMLPrgJ3u/tIE1nNM7n6fu7e7e3tzc/OkfU5lWfRr+fOfrtctrUUkEYoZGVwK/HMzewt4mGj30DeAWWaWCeu0APnLebcBiwDC8gZgd2H7KNvEwsw4r6UBgLf39MZZiojIlBh3GLj7ne7e4u6tRAeAn3L3fwM8DXwirLYCeCxMrwzzhOVPeXRl10rghnC20WKgDXhxvHVNlM98LDrFdMd+HUQWkZlvMq4z+C/AZ82sk+iYwP2h/X5gTmj/LHAHgLuvBR4BXgd+Ctzq7tlJqOukzKuvBODxtTtirkREZPJljr/K8bn7M8AzYXoTo5wN5O59wCfH2P5LwJcmopaJcu7CegB69eQzEUkAXYE8BjPjvQvqtZtIRBJBYXAMTbUVPL1+8k5hFRGZLhQGx1BbEe1F232gP+ZKREQml8LgGG784GkA/P3qt3VLaxGZ0RQGx/C+hQ1kUsZXV23gCd2aQkRmMIXBMdRUZHjis78BwOvv7I+5GhGRyaMwOI7WphrK0kbH5j1xlyIiMmkUBifg9KZant+4i4P9uuZARGYmhcEJWH7BAgDe2Xso5kpERCaHwuAEfKA1ei7y69t13EBEZiaFwQk4v2UWZWljzZa9cZciIjIpFAYnoDyTYsn8er7zT2/FXYqIyKRQGJygM+bWAtDTNxhzJSIiE09hcIKufu8pALpXkYjMSAqDE3TJ6XNoqCrjv/7oVYayueNvICJSQhQGJ6ihqoybL1vM/r4hdh0YiLscEZEJpTA4CfkH3ryzT9cbiMjMojA4Ca1zagD41tOd5HK6i6mIzBwKg5Nw2pwaasrTPLFuJ7d/f03c5YiITBiFwUlIp4yf3bGM3zx/ASt/9Q6PrdkWd0kiIhNCYXCSGqrL+ItPnkdZ2vjhywoDEZkZFAbjUJFJs/z9C3l2Qxe/flf3KxKR0qcwGKePntUMwNVff56t3b0xVyMiUhyFwTh9/H3z+cRFLQA8u0FXJYtIaVMYjJOZ8afXnwvAU3o+soiUOIVBESrL0gAM6poDESlxCoMiffSsZp7b0IW7AkFESpfCoEinza4G4KXN3TFXIiIyfgqDIv27D58OwKO65kBESpjCoEiLZlfTVFvB9158m217dQM7ESlNCoMJ8J3f/gAAl375Kb7xxEayOqAsIiVm3GFgZovM7Gkze93M1prZbaF9tpmtMrON4bUxtJuZ3WNmnWb2ipldWPBeK8L6G81sRfHdmlrva2ngP111FgB3P7GBD3/lKd7erQvRRKR0FDMyGAL+o7svAZYCt5rZEuAO4El3bwOeDPMA1wBt4ecW4F6IwgO4C7gEuBi4Kx8gpeTWy8/kjf9+LVecM4939vXxu//rJd3mWkRKxrjDwN23u/vLYboHWAcsBJYDD4bVHgSuD9PLgYc88gIwy8zmA1cBq9x9j7t3A6uAq8dbV5zSKePbK9q56zeX8Pr2/Vxx97Ps7OmLuywRkeOakGMGZtYKXACsBua5+/aw6F1gXpheCGwp2GxraBurvWTduPQ0blvWxqaug9z56KtxlyMiclxFh4GZ1QKPAre7+4hbeHp0JdaE7Ssxs1vMrMPMOrq6pu/9gDLpFH/4z97DRac18syGLroP6pnJIjK9FRUGZlZGFATfdfcfhuYdYfcP4TV/455twKKCzVtC21jtR3H3+9y93d3bm5ubiyl9Sny4rYlszrngi6toveMf+MeNu+IuSURkVMWcTWTA/cA6d/9awaKVQP6MoBXAYwXtN4WzipYC+8LupMeBK82sMRw4vjK0lbw/+Fgb969o59bLzwDgs4+s0UFlEZmWMkVseylwI/Cqma0JbX8EfBl4xMxuBjYDnwrLfgxcC3QCvcCnAdx9j5l9EfhFWO8L7r6niLqmjXTKWHbOPJadM4+2uXXc/v01/Pi17Vx33oK4SxMRGcFK9QZr7e3t3tHREXcZJ6ynb5CP/o9n2NM7wEff08ztV7yH8xfNirssEUkQM3vJ3dtHW6YrkKdIXWUZj/3+pfzWBS08vb6L5d/8J/YdGoy7LBERQGEwpVoaq/nqp87nlo9EN7e74mvP8mc/Wce67XqOsojES2EQgz+69hy+ccP7KU+n+OtnN3HNN57nur98XrfBFpHY6JhBzDp39vDwi1t4pGML+/uGaJtbyw0Xn8rvXNpKdMKWiMjEONYxA4XBNLG/b5C/X/023129mS17DvHbH2rlT65bQiqlQBCRiaEDyCWgvrKM3/2NM3jmc5dzXksDf/uzt7jkz55kw46euEsTkQRQGEwz6ZRx77+9iA+dMYeunn4+fs/z3Hj/aj6/ci0H+4fiLk9EZijtJprGNnUd4JtPv8HP39jFO/uiu5+e19LAbcvaWHbOvONsLSIyko4ZzAD/55fb+PoTG9h3aJDu3kE+ft58PnflWSxuqom7NBEpEQqDGaR/KMtfPtnJ/3y6E4B59RX8q/ZF3H7Fe3SwWUSOSWEwA73RdYAH/vFNfvzqdrp7Bzl1djVnn1LHmXNrWXr6HD7c1qRTU0VkBIXBDJbLOX/13Bs8u76LrgP9bOo6CMAZzTW8b2ED/+26JdRUZKgsS8dcqYjETWGQIPv7Brnz0Vd5fmMX+/sOn33UflojHzqzicvObKIsbZRnUpxzSr12LYkkiMIgoV7YtJu17+xnZ08fP+vczavb9o1Y3tJYxamzq6mvLMNxLjqtkSXzG0gZ7Ds0SCad4rIzm6gq16hCZCY4VhgU8zwDmeaWnj6HpafPGZ7f2t3L+nd7GMw6XT19PLthF929A2zY0cO+Q4M8vnbHUe9RkUlx5txa6ivLWDCrirn1FSyYVcXZp9Sxq6efnENzXQUtjVU011VQltalKyKlSGGQIC2N1bQ0Vg/P3/jB1uFpd2ftO/s5NJgdfhpb70CW5zfuorPrAPsPDfL8xi52Hxwge4yntb13QT1nnVLHqbOraaqtoLYiw3vm1TG/oZL6qjLS2i0lMi0pDAQAM+PchQ1HtV9+9twR8+7O5t29vLX7IHWVZdRWZHhz10F2Heinc+cB1r/bw8/f2M0PXz76MdYpg/kNVTTVljOrupyGqjLqqzLUVZZRV5mhvrKM05tqaKwpp7o8TcqMirIUlWVpqsrSGnWITCKFgZwUM6O1qYbWgovdzjql7qj1BoZy7O0dYGdP//BuqN0HBtja3cue3kG6ewd4a/dBevqG6OkbZDB7/GNX6ZRRU56mtiJDOm2UpVKkUkYmZVSVR4FRXZ4eDo/q8jSVBe1VZdGy6vIMVeUpKjNpMukU6fAe6ZSRSeenU4fbhl9TpNM2ol2n78pMoTCQSVGeSTG3vpK59ZWjjjgKuTv9Qzl2Hxxg866DbOnuJRX+ke0bytE/mKVvMMuhwSw9fUP0DmQZyuYYyjnuMJDNRcsHsuztHRxet3cgeh0Yyk1aP1NGFBL50EgfER4FYZIPm9GDpqC94D3SZmTSKcrTUfDk3MnmnJw7uRykUpCyaN3876x/KMdQNkdP3xD9Q1kqMmnKMykqMinKM9HoygwMC69Qlk5RU5EZc3nhfJ6ZFSyL5lMGNRWZ8LsxUqEP6RSkUynSBfWmw/JU2M6G+xGtU1WejgK34HPNjvzcwzUStrOjaotWKKzzqO2P6m/BOgkJfIWBxM7MqCxLs3BWFQtnVU34+2dzzqEQFn0FIdE3mCWbc4ZyTjaXYyjrBfMF7fn57BjtI5aP0j7m+0dtI+s4/DqYzYXXaN1szo/4RxRyHl1rkg8JgPJM9I9ofVWGikyagaEc/UPZ8JrDDNwh544TTQ8M5Tg0mJ3w3/1MMlZQjBZM+XYK548ITcYI2/w6w59ZsCxlxuyach79vQ9NeP8UBjLjpVNGbUWG2gr9uR9LPqDyZ5u7g+PhNRrB5cMjWuHo5Vl3evujUMl5PqQYHs3kg65whJMPpmzu8PRg1ukfyjKU9RGfffRnjlKjj6w13z56nwrmfez3ZcT7HbGe+8jPPmI9OLKefD+jDkXvMXpdjPK7r62cnL9j/dchIgBhV9YEXFNy9CEkKQE6PUNERBQGIiKiMBARERQGIiKCwkBERFAYiIgICgMREUFhICIilPDDbcysC9g8zs2bgF0TWE4pUJ9nvqT1F9Tnk3WauzePtqBkw6AYZtYx1tN+Zir1eeZLWn9BfZ5I2k0kIiIKAxERSW4Y3Bd3ATFQn2e+pPUX1OcJk8hjBiIiMlJSRwYiIlJAYSAiIskKAzO72szWm1mnmd0Rdz0TyczeMrNXzWyNmXWEttlmtsrMNobXxtBuZnZP+D28YmYXxlv9iTGzB8xsp5m9VtB20n00sxVh/Y1mtiKOvpyoMfr8eTPbFr7rNWZ2bcGyO0Of15vZVQXtJfO3b2aLzOxpM3vdzNaa2W2hfUZ+18fo79R+z9Gj2mb+D5AG3gBOB8qBXwFL4q5rAvv3FtB0RNufA3eE6TuAr4Tpa4GfED1udSmwOu76T7CPHwEuBF4bbx+B2cCm8NoYphvj7ttJ9vnzwOdGWXdJ+LuuABaHv/d0qf3tA/OBC8N0HbAh9G1GftfH6O+Ufs9JGhlcDHS6+yZ3HwAeBpbHXNNkWw48GKYfBK4vaH/IIy8As8xsfgz1nRR3fw7Yc0TzyfbxKmCVu+9x925gFXD1pBc/TmP0eSzLgYfdvd/d3wQ6if7uS+pv3923u/vLYboHWAcsZIZ+18fo71gm5XtOUhgsBLYUzG/l2L/wUuPA/zOzl8zsltA2z923h+l3gXlheib9Lk62jzOl778fdok8kN9dwgzss5m1AhcAq0nAd31Ef2EKv+ckhcFMd5m7XwhcA9xqZh8pXOjR+HJGn0echD4G9wJnAO8HtgNfjbWaSWJmtcCjwO3uvr9w2Uz8rkfp75R+z0kKg23AooL5ltA2I7j7tvC6E/gR0ZBxR373T3jdGVafSb+Lk+1jyffd3Xe4e9bdc8DfEH3XMIP6bGZlRP8wftfdfxiaZ+x3PVp/p/p7TlIY/AJoM7PFZlYO3ACsjLmmCWFmNWZWl58GrgReI+pf/gyKFcBjYXolcFM4C2MpsK9g+F1qTraPjwNXmlljGHZfGdpKxhHHd/4F0XcNUZ9vMLMKM1sMtAEvUmJ/+2ZmwP3AOnf/WsGiGfldj9XfKf+e4z6SPpU/RGcdbCA64v7Hcdczgf06nejMgV8Ba/N9A+YATwIbgSeA2aHdgG+G38OrQHvcfTjBfn6PaLg8SLQ/9Obx9BH4HaKDbp3Ap+Pu1zj6/HehT6+E/9jnF6z/x6HP64FrCtpL5m8fuIxoF9ArwJrwc+1M/a6P0d8p/Z51OwoREUnUbiIRERmDwkBERBQGIiKiMBARERQGIiKCwkBERFAYiIgI8P8B1M/TYWHrGh0AAAAASUVORK5CYII=\n", 267 | "text/plain": [ 268 | "
" 269 | ] 270 | }, 271 | "metadata": { 272 | "needs_background": "light" 273 | }, 274 | "output_type": "display_data" 275 | } 276 | ], 277 | "source": [ 278 | "from matplotlib import pyplot as plt\n", 279 | "%matplotlib inline\n", 280 | "\n", 281 | "plt.plot(plt_x, plt_y)\n", 282 | "plt.show()" 283 | ] 284 | } 285 | ], 286 | "metadata": { 287 | "kernelspec": { 288 | "display_name": "Python 3 (ipykernel)", 289 | "language": "python", 290 | "name": "python3" 291 | }, 292 | "language_info": { 293 | "codemirror_mode": { 294 | "name": "ipython", 295 | "version": 3 296 | }, 297 | "file_extension": ".py", 298 | "mimetype": "text/x-python", 299 | "name": "python", 300 | "nbconvert_exporter": "python", 301 | "pygments_lexer": "ipython3", 302 | "version": "3.8.11" 303 | } 304 | }, 305 | "nbformat": 4, 306 | "nbformat_minor": 5 307 | } 308 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/6.ada_delta-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0\n", 49 | "\n", 50 | "#初始化S为全0\n", 51 | "S_w = np.zeros(M)\n", 52 | "S_b = 0\n", 53 | "\n", 54 | "#初始化delta为全0\n", 55 | "delta_w = np.zeros(M)\n", 56 | "delta_b = 0" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "id": "92163201", 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "0.6590042695516539" 69 | ] 70 | }, 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "#预测函数\n", 78 | "def predict(x):\n", 79 | " return w.dot(x) + b\n", 80 | "\n", 81 | "\n", 82 | "predict(x[0])" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "id": "a7bb7a80", 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "0.21258140154187247" 95 | ] 96 | }, 97 | "execution_count": 4, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "#求loss,MSELoss\n", 104 | "def get_loss(x, y):\n", 105 | " pred = predict(x)\n", 106 | " loss = (pred - y)**2\n", 107 | " return loss\n", 108 | "\n", 109 | "\n", 110 | "get_loss(x[0], y[0])" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "id": "8027d213", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 123 | " 0.923131013558981)" 124 | ] 125 | }, 126 | "execution_count": 5, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "def get_gradient(x, y):\n", 133 | " global w\n", 134 | " global b\n", 135 | "\n", 136 | " eps = 1e-3\n", 137 | "\n", 138 | " loss_before = get_loss(x, y)\n", 139 | "\n", 140 | " gradient_w = np.empty(M)\n", 141 | " for i in range(M):\n", 142 | " w[i] += eps\n", 143 | " loss_after = get_loss(x, y)\n", 144 | " w[i] -= eps\n", 145 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 146 | "\n", 147 | " b += eps\n", 148 | " loss_after = get_loss(x, y)\n", 149 | " b -= eps\n", 150 | " gradient_b = (loss_after - loss_before) / eps\n", 151 | "\n", 152 | " return gradient_w, gradient_b\n", 153 | "\n", 154 | "\n", 155 | "get_gradient(x[0], y[0])" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 6, 161 | "id": "f39e0125", 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "data": { 166 | "text/plain": [ 167 | "11073.905141728206" 168 | ] 169 | }, 170 | "execution_count": 6, 171 | "metadata": {}, 172 | "output_type": "execute_result" 173 | } 174 | ], 175 | "source": [ 176 | "def total_loss():\n", 177 | " loss = 0\n", 178 | " for i in range(N):\n", 179 | " loss += get_loss(x[i], y[i])\n", 180 | " return loss\n", 181 | "\n", 182 | "\n", 183 | "total_loss()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 7, 189 | "id": "c371c6a4", 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "0 [9.99999848e-07 9.99999845e-07] 9.999998984693441e-07 11059.213325989294\n", 197 | "500 [2.66451164e-06 2.89372373e-06] 3.595841972186342e-06 7623.402923852171\n", 198 | "1000 [1.49058551e-06 1.77039119e-06] 2.3569819378456833e-06 5116.6673548382705\n", 199 | "1500 [1.54060667e-07 1.52269411e-07] 4.854100489197519e-07 3453.7618700421162\n", 200 | "2000 [4.5399257e-06 3.6700749e-06] 4.182369642302336e-06 2259.06518603289\n", 201 | "2500 [1.65797772e-06 1.05985930e-06] 1.3383495167537935e-06 1438.5906667372406\n", 202 | "3000 [2.96914906e-07 6.56168224e-07] 6.131919613839703e-07 1051.9695231637406\n", 203 | "3500 [1.48431059e-06 1.85664992e-06] 3.2354643896326765e-06 886.751528580343\n", 204 | "4000 [3.06717189e-06 1.53765326e-06] 7.968217077828009e-06 826.4649373166238\n", 205 | "4500 [1.83228209e-06 1.98744616e-06] 3.0032937384269872e-06 794.0081885344441\n", 206 | "5000 [5.05157589e-07 6.08365156e-07] 6.43608532358165e-07 798.5606607279586\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "plt_x = []\n", 212 | "plt_y = []\n", 213 | "\n", 214 | "for epoch in range(5500):\n", 215 | " i = np.random.randint(N)\n", 216 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 217 | "\n", 218 | " #ada_delta算法不需要设定超参数lr\n", 219 | " #他需要维持两个变量,delta和S\n", 220 | "\n", 221 | " #S的计算和rmsprop完全一致\n", 222 | " S_w = 0.2 * S_w + 0.8 * gradient_w**2\n", 223 | " S_b = 0.2 * S_b + 0.8 * gradient_b**2\n", 224 | "\n", 225 | " #计算lr的公式,这里的1e-6是为了防止除0\n", 226 | " lr = (delta_w + 1e-6) / (S_w + 1e-6)\n", 227 | " gradient_w = lr**0.5 * gradient_w\n", 228 | "\n", 229 | " lr = (delta_b + 1e-6) / (S_b + 1e-6)\n", 230 | " gradient_b = lr**0.5 * gradient_b\n", 231 | "\n", 232 | " #更新参数\n", 233 | " w -= gradient_w\n", 234 | " b -= gradient_b\n", 235 | "\n", 236 | " #更新delta,这里的两个系数和计算S时用的要一样\n", 237 | " delta_w = 0.2 * delta_w + 0.8 * gradient_w**2\n", 238 | " delta_b = 0.2 * delta_b + 0.8 * gradient_b**2\n", 239 | "\n", 240 | " #思考一下,在时刻0,S就是梯度的平方乘以0.8\n", 241 | " #所以在一开始的时候,S是比较大的.但delta还是0\n", 242 | " #所以一开始的时候lr是比较大的.\n", 243 | " #delta更新为变量更新量的平方*0.8\n", 244 | " #所以delta当中差不多相当于存储了变量更新量的历史信息\n", 245 | " #所以最后的lr,应该是取两者的比值\n", 246 | "\n", 247 | " plt_x.append(epoch)\n", 248 | " plt_y.append(total_loss())\n", 249 | "\n", 250 | " if epoch % 500 == 0:\n", 251 | " print(epoch, delta_w[:2], delta_b, total_loss())" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 8, 257 | "id": "0471a70d", 258 | "metadata": { 259 | "scrolled": true 260 | }, 261 | "outputs": [ 262 | { 263 | "data": { 264 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjAklEQVR4nO3deZxU1Z338c+vqnpfaWh2ZBEEcUPoNIgbigoaR5xRJ24jMT6SiSZjYp5MNJmJGSfJM0kmUTNGHYMLZuI2JkZcoiJxV8DGDQG1mx1soKHphqa36qrz/FEH0xIQ6O12VX3fr1e96t5zl/4dLfrbdztlzjlERCS9hYIuQEREgqcwEBERhYGIiCgMREQEhYGIiACRoAvoqH79+rkRI0YEXYaISNJYunTpNudc6b6WJW0YjBgxgoqKiqDLEBFJGma2bn/LdJpIREQUBiIiojAQEREUBiIigsJARERQGIiICAoDEREhDcPgZ89+yKMVG4IuQ0SkV0nah846YmdzlDteWgXAzKMHUpidEXBFIiK9Q1odGRRmZ3DHZRMB+O2b+30QT0Qk7aRVGACceHg/AH7+3EcBVyIi0nukXRgU5f7l1FDV1oYAKxER6T3SLgwA7vSniv5nkU4ViYhAmobB2ccMAuD+N9YGW4iISC+RlmEAMGvCYAA27mgMuBIRkeClbRh89ZTDAVi8ujbgSkREgpe2YTBuYAFFORksXrM96FJERAKXtmEQChnlI0tYvEZHBiIiaRsGAJNHlrBueyPV9U1BlyIiEqi0DoMpo/oCum4gIpLWYXDkoEIKsiO6biAiaS+twyAcMkryMnloyQZicRd0OSIigUnrMACYcdRAAN7bWBdsISIiAUr7MLjyxBEAvLu+LtA6RESClPZhMKgoh1Gledz3xhqao7GgyxERCcQBw8DM7jWzrWb2Qbu2EjNbYGaV/r2Pbzcz+5WZVZnZ+2Y2sd02s/36lWY2u137JDNb5rf5lZlZV3fyQH5w7ng21Dbxw/nLe/pHi4j0CgdzZHA/MHOvthuAhc65McBCPw9wNjDGv+YAd0IiPICbgMlAOXDTngDx61zdbru9f1a3mza2PwDLP9nZ0z9aRKRXOGAYOOdeAfa+EX8WMM9PzwPOb9f+gEtYBBSb2SBgBrDAOVfrnNsBLABm+mWFzrlFzjkHPNBuXz3qkvJhrK9tJK67ikQkDXX0msEA51y1n94MDPDTQ4D23za/0bd9XvvGfbT3uLLhJdQ3RXlh5ZYgfryISKA6fQHZ/0XfI39Om9kcM6sws4qampou3feMoxO3mD6zrPoAa4qIpJ6OhsEWf4oH/77Vt28ChrVbb6hv+7z2ofto3yfn3N3OuTLnXFlpaWkHS9+3/KwI5x47iNdXbae1Ld6l+xYR6e06GgbzgT13BM0GnmjXfoW/q2gKUO9PJz0HnGVmffyF47OA5/yynWY2xd9FdEW7ffW4v5s4hJpdLfzwSd1VJCLp5WBuLX0IeBMYa2Ybzewq4D+AM82sEjjDzwM8A6wGqoDfANcAOOdqgX8H3vKvm30bfp25fptVwJ+6pmuH7vRxA+hfkMUzy6pJnP0SEUkPkQOt4Jy7ZD+Lpu9jXQdcu5/93Avcu4/2CuDoA9XRU647Ywzff/wD1m5vZGS/vKDLERHpEWn/BPLeThlTihk8tGR90KWIiPQYhcFehpXkcuoRpbz8UdferSQi0pspDPbh2KHFfLRlF+9uqAu6FBGRHqEw2IeLJiXudj3/16/TFtNtpiKS+hQG+zCsJJfykSUA3P/G2mCLERHpAQqD/bjvy18AYMvO5oArERHpfgqD/cjLilA+soRXK7cFXYqISLdTGHyOaWNL+XDzLrbu0tGBiKQ2hcHnmD4uMRjr0+9r8DoRSW0Kg88xdmABxw4t4u5XVuuuIhFJaQqDA/iHKcOprm/mO4+9H3QpIiLdRmFwABdMTDxz8J4eQBORFKYwOIBQyLik/DC2NbToVJGIpCyFwUE4eUw/dja3ccdLq4IuRUSkWygMDsLp4/oDcOsLH+s2UxFJSQqDg5CdEebOyyYSd3Dqz14iFtcX34hIalEYHKSzjxlEXmaYpmiMf/njB0GXIyLSpRQGh+DF70wD9MU3IpJ6FAaHoH9BNpeUD8MMtjW0BF2OiEiXURgcoqtOGoVzUPajF1hV0xB0OSIiXUJhcIhG98+nfETiuw6m/+Jl4rqYLCIpQGHQAY/+4wnMmjAYgPc21gVbjIhIF1AYdNANZ48D4F+f0J1FIpL8FAYdNKgoh375mXywaaeGuBaRpKcw6ISHrp4CwE3zdXQgIslNYdAJYwYUMH5QIdsaWjVMhYgkNYVBJ31n5lgAHl6yIeBKREQ6TmHQSdOOKGXCsGJufeFjGlragi5HRKRDFAadZGb888yxxB088paODkQkOSkMusAJo/oyqCibh5esxzk9hCYiyadTYWBm3zKz5Wb2gZk9ZGbZZjbSzBabWZWZPWJmmX7dLD9f5ZePaLefG337R2Y2o5N96nFmxjXTDqdyawPraxuDLkdE5JB1OAzMbAjwT0CZc+5oIAxcDPwUuMU5NxrYAVzlN7kK2OHbb/HrYWbj/XZHATOBO8ws3NG6glLmh6j47ZvrAq5EROTQdfY0UQTIMbMIkAtUA6cDj/nl84Dz/fQsP49fPt3MzLc/7Jxrcc6tAaqA8k7W1ePGDSxg3MAC5r62huZoLOhyREQOSYfDwDm3CfhPYD2JEKgHlgJ1zrk9t9VsBIb46SHABr9tm1+/b/v2fWzzGWY2x8wqzKyipqamo6V3CzPjq6eOAuCOF6sCrkZE5NB05jRRHxJ/1Y8EBgN5JE7zdBvn3N3OuTLnXFlpaWl3/qgOOffYwQwryeHpZRqeQkSSS2dOE50BrHHO1TjnosAfgBOBYn/aCGAosMlPbwKGAfjlRcD29u372CapZIRDXD55OKtqdlNd3xR0OSIiB60zYbAemGJmuf7c/3RgBfAicKFfZzbwhJ+e7+fxy//sEvdhzgcu9ncbjQTGAEs6UVegTh2bOGJ5SE8ki0gS6cw1g8UkLgS/DSzz+7ob+C5wvZlVkbgmcI/f5B6gr2+/HrjB72c58CiJIHkWuNY5l7RXYMcNLOTE0X351cJKrn6gIuhyREQOiiXrQ1JlZWWuoqJ3/rJ9o2obl85dDMADXynnlCN63/UNEUk/ZrbUOVe2r2V6ArkbTB3dj5e/Mw2AK+5dQl1ja7AFiYgcgMKgmwzvm8dFk4YCcN/ra4MtRkTkABQG3eim844C4LaFlTS2akRTEem9FAbdKD8rwt9NTDw/d/8ba4MtRkTkcygMutnPLjg28f7sR0Rj8YCrERHZN4VBN4uEQ9x28QQA/vvlVcEWIyKyHwqDHnDecYMZUJjF7xavp6UtaR+hEJEUpjDoAWbGzy88jur6Zp5+X+MWiUjvozDoISeP6cfAwmyeW7456FJERP6KwqCHmBlnHTWA55Zv4bXKbUGXIyLyGQqDHnTVSSMBuPyexbp2ICK9isKgBw3vm8ctXzoOgFsWVAZcjYjIXygMetj5E4ZQlJPBwpVbgi5FRORTCoMeZmbMnjqCyq0NGqJCRHoNhUEAJo8sAeBhfQGOiPQSCoMATD28L7mZYW5+aoUuJItIr6AwCICZcaEf3nqeBrATkV5AYRCQH/5NYnjrnzzzIZ/UNQVcjYikO4VBQEIh467LJwFw/aPvBluMiKQ9hUGAZh49kJlHDWTR6lreXLU96HJEJI0pDAJ2w9njALjkN4v4cPPOgKsRkXSlMAjYiH55zL2iDICbn1yBcy7gikQkHSkMeoEzxg/gB+eO541V23l+hZ5MFpGepzDoJa44YTjDSnJ44M21QZciImlIYdBLRMIh/vb4obyxajtrt+0OuhwRSTMKg17k8smHETLjR0+vDLoUEUkzCoNepH9hNlMP78trVTU0tWqYChHpOQqDXuaqk0bSHI3zgoa4FpEepDDoZaYe3o+SvEy+8dA71DdGgy5HRNJEp8LAzIrN7DEz+9DMVprZCWZWYmYLzKzSv/fx65qZ/crMqszsfTOb2G4/s/36lWY2u7OdSmaZkRD/7++OAeCuV1YFXI2IpIvOHhncBjzrnBsHHAesBG4AFjrnxgAL/TzA2cAY/5oD3AlgZiXATcBkoBy4aU+ApKsZRw0kZHDnS6s0xLWI9IgOh4GZFQGnAPcAOOdanXN1wCxgnl9tHnC+n54FPOASFgHFZjYImAEscM7VOud2AAuAmR2tK1X8+G8TRwcvrNgacCUikg46c2QwEqgB7jOzd8xsrpnlAQOcc9V+nc3AAD89BGj/1V4bfdv+2tPaRZOGUpAV4doH36a1LR50OSKS4joTBhFgInCnc+54YDd/OSUEgEsMtNNlg+2Y2RwzqzCzipqamq7aba8UCYe44ZzEIHYPLl4XcDUikuo6EwYbgY3OucV+/jES4bDFn/7Bv+85z7EJGNZu+6G+bX/tf8U5d7dzrsw5V1ZaWtqJ0pPDpeWHUT6yhP9+ZTWxuAawE5Hu0+EwcM5tBjaY2VjfNB1YAcwH9twRNBt4wk/PB67wdxVNAer96aTngLPMrI+/cHyWb0t7ZsaVU0dQXd/Ml+9bEnQ5IpLCIp3c/hvA78wsE1gNXEkiYB41s6uAdcDf+3WfAc4BqoBGvy7OuVoz+3fgLb/ezc652k7WlTLOGD+AoX1yeLVyG0++9wl/c9zgoEsSkRRkyTp+fllZmauoqAi6jB5R3xTluH97nqxIiI9+dHbQ5YhIkjKzpc65sn0t0xPISaAoJ4PzjhtMS1ucJWt00CQiXU9hkCS+MyNxaeZbj7yrW01FpMspDJLEsJJcvjbtcDbVNXH8zc/r6zFFpEspDJLIt888glOOKGV3a4yl63YEXY6IpBCFQRKJhEP854XHAnDhXW/SHNW4RSLSNRQGSaZ/YTYnj+kHwMxbXwm4GhFJFQqDJDTvynIA1m5v5FuPvBtsMSKSEhQGSSgUMt644XQAHn9nEz9+ekXAFYlIslMYJKnBxTks+d50AO59fS3rtzcGXJGIJDOFQRLrX5jNk18/iVjcccrPX2TpOj2QJiIdozBIcscMLeL2S48H4II73+T1qm0BVyQiyUhhkALOPXYw879+IgCXzV1MXMNdi8ghUhikiGOHFvON00cD8NZanS4SkUOjMEgh/3jq4QBc8ptFbG9oCbgaEUkmCoMUkpcVYc4po4g7OO/21zV+kYgcNIVBivneOUfy5akj2FTXxG0LK4MuR0SShMIgBd30N+OZNLwPt75QSeWWXUGXIyJJQGGQgsyMG88eB8CX73uLaEzffyAin09hkKLKRpRwSflhbKpr4prfvR10OSLSyykMUthP/vZoCrIiLFixhVc+rgm6HBHpxRQGKczMePW7pwEw57cVCgQR2S+FQYorzs3kT9edTHM0zhX3LuHCO9+gqVVfiiMin6UwSANHDirkj9cmhquoWLeDnz77YcAViUhvozBIExOGFbP2P77IBROH8r8VG3R0ICKfoTBIMxeXD2N3a0xHByLyGQqDNPOFESWcPKYf97+xli/995s6QhARQGGQlubOLmPS8D4sXlPL7HuXaAwjEVEYpKOsSJjff20q508YzJK1tRz5g2dZVdMQdFkiEiCFQRq75UsTmDa2lOZonC/ft4Q2DVshkrYUBmnMzLj/ynK+feYRbKht4puPvEtM35ImkpY6HQZmFjazd8zsKT8/0swWm1mVmT1iZpm+PcvPV/nlI9rt40bf/pGZzehsTXJovjF9DNdMO5yn3q/mD29vDLocEQlAVxwZXAesbDf/U+AW59xoYAdwlW+/Ctjh22/x62Fm44GLgaOAmcAdZhbugrrkEHxnxliOGVLED55YzoOL1wddjoj0sE6FgZkNBb4IzPXzBpwOPOZXmQec76dn+Xn88ul+/VnAw865FufcGqAKKO9MXXLozIzbLz2ekMH3Hl/G61Xbgi5JRHpQZ48MbgX+Gdhz5bEvUOeca/PzG4EhfnoIsAHAL6/363/avo9tPsPM5phZhZlV1NRo0LWuNrxvHk98PTFsxXUPv8vWnc0BVyQiPaXDYWBm5wJbnXNLu7Cez+Wcu9s5V+acKystLe2pH5tWRvcv4KlvnERdYytXzavQMwgiaaIzRwYnAueZ2VrgYRKnh24Dis0s4tcZCmzy05uAYQB+eRGwvX37PraRABw9pIhvnzWWZZvqueLeJcR1h5FIyutwGDjnbnTODXXOjSBxAfjPzrnLgBeBC/1qs4En/PR8P49f/meX+LNzPnCxv9toJDAGWNLRuqRrXH3ySEb1y+PVym08/NaGA28gIkmtO54z+C5wvZlVkbgmcI9vvwfo69uvB24AcM4tBx4FVgDPAtc65zRgTsAi4RALv30qk0eWcPNTy7nntTU6QhBJYZas54TLyspcRUVF0GWkvO0NLVxw5xus3d7ICaP68uDVk0ncBCYiycbMljrnyva1TE8gy+fqm5/Fn647hdPGlvLm6u1M/Y8/s1rjGImkHIWBHFBOZph7Zn+BsuF9qK5v5vRfvMxT738SdFki0oUUBnJQQiHjwauncM4xAwH4+oPv8MGm+oCrEpGuojCQg5YZCXHHZZN4ZM4UAM79r9eoa2wNuCoR6QoKAzlkk0f15Z7ZiWtQ//bkCj2YJpICFAbSIdOPHMC3zjiCx9/ZxE3zlwddjoh0ksJAOuwbp4/m5DH9eODNdTy2VENfiyQzhYF0WChk3DP7C5SPKOH//u97/OaV1UGXJCIdpDCQTsmMhLj90uMB+PEzK/njOxpWSiQZKQyk0/oXZlPxL2cwrCSHbz7yLr9+sSrokkTkECkMpEv0y8/iya+fxMDCbO54sYqqrXpKWSSZKAykyxTnZvL4tVOJxh1n/PJlnl++OeiSROQgKQykSw0qyuHGs8cBMOe3S/nd4nV6DkEkCSgMpMtdeeJIFt04nfGDCvn+4x9w+i9e1tAVIr2cwkC6xcCibP547YnMPGoga7bt5tz/eo1VGu1UpNdSGEi3yYyEuOsfJvH4NVPJz4rwD3MXU7tbYxmJ9EYKA+l2xx/Wh//5P5PZtruVc257lYq1tUGXJCJ7URhIj5gwrJiHrp5MS1uMS3+zmC07m4MuSUTaURhIj5k0vITvnXMkrbE4k3+ykMvmLtIQ2CK9hMJAetRFZcP4wzVTOXJQIa9Xbee821/n4y27gi5LJO0pDKTHTTysD3+67mR+cdFxrK9t5JK7F7Fw5RZicT2PIBIUhYEE5oJJQ3n2mydTmJPBVfMqmHnrK+xuaQu6LJG0pDCQQI0bWMiz3zyZ6888gsqtDZx3+2u8t6Eu6LJE0o7CQAKXFQnzT9PHcN+VX2B3S4xLfrOInzyzkgYdJYj0GIWB9Bqnje3PvK+Uc/SQIu5+ZTXTfv4S33t8Gc3RWNCliaQ8hYH0KmMHFvDoV0/g91+byuGleTy4eD3j/vVZ3lm/I+jSRFKawkB6pUnD+/DwnCn89IJjALh87mKWrtOTyyLdRWEgvZaZ8aUvHMZdl08kEg5xwZ1vcv6vX6e+KRp0aSIpR2Egvd7Mowfx2ndP48tTR/DuhjqunldBS5uuI4h0pQ6HgZkNM7MXzWyFmS03s+t8e4mZLTCzSv/ex7ebmf3KzKrM7H0zm9huX7P9+pVmNrvz3ZJUU5CdwQ/PO4p/PXc8S9bWUv7jhbxRtS3oskRSRmeODNqAbzvnxgNTgGvNbDxwA7DQOTcGWOjnAc4GxvjXHOBOSIQHcBMwGSgHbtoTICJ7u+qkkfziouOob4py6dzFXPvg29Q36rSRSGd1OAycc9XOubf99C5gJTAEmAXM86vNA87307OAB1zCIqDYzAYBM4AFzrla59wOYAEws6N1Seq7YNJQVtw8g386fTTPL9/Mube/yqLV24MuSySpdck1AzMbARwPLAYGOOeq/aLNwAA/PQTY0G6zjb5tf+0i+5WbGeH6s8Yy78pyNtQ2cfHdi/j6g2/T1KprCSId0ekwMLN84PfAN51zO9svc4lvQu+y0cfMbI6ZVZhZRU1NTVftVpLY1NH9WHTjdM4+eiBPvV/Nebe/pi/PEemAToWBmWWQCILfOef+4Ju3+NM/+Petvn0TMKzd5kN92/7a/4pz7m7nXJlzrqy0tLQzpUsKGViUzZ2XT+L+K79AUzTGl+5exH8trCSuUVBFDlpn7iYy4B5gpXPul+0WzQf23BE0G3iiXfsV/q6iKUC9P530HHCWmfXxF47P8m0ih2Ta2P786bqTOfPIAfxiwcec8vMXmfvqatZu2x10aSK9niXO5HRgQ7OTgFeBZUDcN3+PxHWDR4HDgHXA3zvnan143E7i4nAjcKVzrsLv6yt+W4AfO+fuO9DPLysrcxUVFR2qXVJbPO548v1PuGn+cuoao2RGQvz8wmOZNUGXoiS9mdlS51zZPpd1NAyCpjCQA4nFHSurd/LD+cupWLeDAYVZnDa2P3NOGcWo0vygyxPpcQoDSWvRWJwHF6/n6WXVLFmTuLhcNrwP/3nRcYzolxdwdSI9R2Eg4q2uaeDRio389s21APzySxOYcdTAYIsS6SEKA5G9rN22m3/8n6V8uHkXY/rnc8XUEZx33GCKcjKCLk2k2ygMRPahvinK1Q9UfHrqKDMS4ovHDOKyyYdx/GF9CIcs4ApFutbnhUGkp4sR6S2KcjJ49Ksn4Jxj2aZ6Hlu6kceWbuTxdzYxsl8e1595BKeP609elv6ZSOrTkYFIO7uaozy3fAu3LPiYTXVNAJx6RCnHDSvmihOG0y8/K+AKRTpOp4lEDlEs7nhrbS0/enoF1XXNbN/dCsC4gQVMGFbMGUcO4JQjSsmM6CtBJHkoDEQ66Z31O3h+xRbe21DH0nU7aGmLk50R4pppo7mk/DBKC3TEIL2fwkCkC7W2xXn54xoerdjAghVbABjZL48TR/flwknDmDCsONgCRfZDYSDSTZauq+Wlj2p4b2M9r3ycGEl33MACzpswmHOOHsTwvrkkRmIRCZ7CQKQHbG9o4cn3PuGxtzfywabEaO6Di7IZWJTNuEGFHDOkiPKRJYzom6fbViUQCgORHrZ2225e+mgrr1Vtp2ZXM5VbG2j0X7wzqCibw0vzOXJQAaP75zO6fwHjBhboFlbpdgoDkYA55/hw8y5erazhhZVbWV3TwK7mNlraEgP+ZkZCnDa2lPGDihjRL5fhffMY0z9fASFdSmEg0gvF4o5NO5r4eEsiJBas2MIn9c2fWeewklyOGFBAS1uM4txMjhtaxJGDConFHaNK88iMhOhfkB1QDyTZKAxEkkRzNMbHW3ZRXd9M5ZZdrKjeSdXWBnIyI2zd2Uz1XmEBUFqQxfCSXI4YWMBhJbmELNE2pn8BQ4pzKM7N0EVsATQchUjSyM4Ic+zQYo4dyj5HU926q5nln+xkdc1u8rPC1DVGWVXTQNXWBp5ZVk1dY/SvtskIGyV5mTgH4VBiOj8rQlZGmNyMMJA4TdUnN4PD++dTmp9FblaEvnmZFOVk0Dc/k9xM/apIdfo/LJJE+hdk039sNqeN3ffyXc1R4g4+qWtizbbdVNc3U7OrhdrdLYTMiMYcNQ0tNLW2Ud8UpbquiVjc0RZ3bG9oYbe/yL234twMBhRk0xSNkRkJEQkZ4ZBRnJtBYXYGeVkRSguyCJsRd47MSIim1hgZ4RA7m6OEzOibl0lJfiY5GWHCIaOhpe3T9vzsCP3yswgZ5GRGCBnkZkYoysmgLRYnGnNkhBPbtMUdORlhsjPChAzMjNa2OCGDkBlmiUEI4z78IiEjEjYioRD1TVFa2mIUZGcQCRlZkdCn122isTgfb9lFXWOUTXVN1DdGKcrNYEBhNq1tcaKxOKUFWQwszPb/TRJhmRG2/R55xeOOlrY49U1R399E4OdkhMnJDBMyIxZ3NEdjmMHmnc3U7m4lEgrR0hbjk7omdjRGqdzSQEtbjHDIKMrJ4OZZR3f+w7QXhYFICinITgzBXZSTwZGDCg9p23g8ERTbG1rZ1RxlR2OUXc1Rtu5qobq+ic31LeRkhonF44kAiTl2NLaydWcLDS1t1OxqIe4cITPa4olf3tGYozA7gnOwq6XtkPuTnRGiORr/3HXCocQvVDNwDnIywjRF9x1q3SEcMnIzwxRmZ9AWj9Palni1tMVpi3fNafj+BVnkZ0WIO0dxbmaX7HNvCgMRASAUMgYUZjOgsPMXpKOxOBnhEPG4I+SfqWiOxqhvitLYGqMtFic/O0LcQW1DKzubo9TubiXuHI2tsUR4NEfZvruV7EiIrIww0Vj807/om6MxmqNx2uJx4s6RHQnTFnc452hoiTG4OJuMcIhoLP7pkU80FqcoJ4PsjMTptbhznx4VZEVCZIQT/R/aJ5chxTkU5kTY0RilZlcLeZlhIuEQ1XVNbNnVTDgUorYhEYLN0TgNLW3sbI6SGQ6RGQn95d2/CrMzyM+KEIs7WmNxmqMxmqKxT//7ZEfCOBLXevrlZxKPQyRsvo6MHvmeDYWBiHS5jHBiAL9Qu4frsv2pnb0NKc7psboOVW5m5DP1jUzhr0nVkIsiIqIwEBERhYGIiKAwEBERFAYiIoLCQEREUBiIiAgKAxERIYlHLTWzGmBdBzfvB2zrwnJ6E/UteaVy/1K5b5A8/RvunCvd14KkDYPOMLOK/Q3jmuzUt+SVyv1L5b5BavRPp4lERERhICIi6RsGdwddQDdS35JXKvcvlfsGKdC/tLxmICIin5WuRwYiItKOwkBERNIrDMxsppl9ZGZVZnZD0PUcLDO718y2mtkH7dpKzGyBmVX69z6+3czsV76P75vZxHbbzPbrV5rZ7CD6sjczG2ZmL5rZCjNbbmbX+fak75+ZZZvZEjN7z/ft33z7SDNb7PvwiJll+vYsP1/ll49ot68bfftHZjYjoC79FTMLm9k7ZvaUn0+lvq01s2Vm9q6ZVfi2pP9c7pdzLi1eQBhYBYwCMoH3gPFB13WQtZ8CTAQ+aNf2M+AGP30D8FM/fQ7wJ8CAKcBi314CrPbvffx0n17Qt0HARD9dAHwMjE+F/vka8/10BrDY1/wocLFvvwv4mp++BrjLT18MPOKnx/vPaxYw0n+Ow0H/v/O1XQ88CDzl51Opb2uBfnu1Jf3ncn+vdDoyKAeqnHOrnXOtwMPArIBrOijOuVeA2r2aZwHz/PQ84Px27Q+4hEVAsZkNAmYAC5xztc65HcACYGa3F38Azrlq59zbfnoXsBIYQgr0z9fY4Gcz/MsBpwOP+fa9+7anz48B083MfPvDzrkW59waoIrE5zlQZjYU+CIw188bKdK3z5H0n8v9SacwGAJsaDe/0bclqwHOuWo/vRkY4Kf3189e339/6uB4En9Bp0T//GmUd4GtJH4RrALqnHNtfpX2dX7aB7+8HuhLL+0bcCvwz0Dcz/cldfoGieB+3syWmtkc35YSn8t9iQRdgHSec86ZWVLfI2xm+cDvgW8653Ym/mhMSOb+OediwAQzKwYeB8YFW1HXMLNzga3OuaVmNi3gcrrLSc65TWbWH1hgZh+2X5jMn8t9Sacjg03AsHbzQ31bstriD0Px71t9+/762Wv7b2YZJILgd865P/jmlOkfgHOuDngROIHEKYQ9f4i1r/PTPvjlRcB2emffTgTOM7O1JE65ng7cRmr0DQDn3Cb/vpVEkJeTYp/L9tIpDN4Cxvi7HTJJXMSaH3BNnTEf2HNnwmzgiXbtV/i7G6YA9f6w9jngLDPr4++AOMu3BcqfN74HWOmc+2W7RUnfPzMr9UcEmFkOcCaJayIvAhf61fbu254+Xwj82SWuQs4HLvZ35IwExgBLeqQT++Gcu9E5N9Q5N4LEv6U/O+cuIwX6BmBmeWZWsGeaxOfpA1Lgc7lfQV/B7skXiSv+H5M4b/v9oOs5hLofAqqBKIlzjleRON+6EKgEXgBK/LoG/Nr3cRlQ1m4/XyFxga4KuDLofvmaTiJxbvZ94F3/OicV+gccC7zj+/YB8APfPorEL7wq4H+BLN+e7eer/PJR7fb1fd/nj4Czg+7bXv2cxl/uJkqJvvl+vOdfy/f8vkiFz+X+XhqOQkRE0uo0kYiI7IfCQEREFAYiIqIwEBERFAYiIoLCQEREUBiIiAjw/wEZxi+isV885QAAAABJRU5ErkJggg==\n", 265 | "text/plain": [ 266 | "
" 267 | ] 268 | }, 269 | "metadata": { 270 | "needs_background": "light" 271 | }, 272 | "output_type": "display_data" 273 | } 274 | ], 275 | "source": [ 276 | "from matplotlib import pyplot as plt\n", 277 | "%matplotlib inline\n", 278 | "\n", 279 | "plt.plot(plt_x, plt_y)\n", 280 | "plt.show()" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "Python 3 (ipykernel)", 287 | "language": "python", 288 | "name": "python3" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.8.11" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 5 305 | } 306 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/7.adam-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0\n", 49 | "\n", 50 | "#初始化S为全0\n", 51 | "S_w = np.zeros(M)\n", 52 | "S_b = 0\n", 53 | "\n", 54 | "v_w = np.zeros(M)\n", 55 | "v_b = 0" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "id": "92163201", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "0.6590042695516539" 68 | ] 69 | }, 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "#预测函数\n", 77 | "def predict(x):\n", 78 | " return w.dot(x) + b\n", 79 | "\n", 80 | "\n", 81 | "predict(x[0])" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "id": "a7bb7a80", 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "text/plain": [ 93 | "0.21258140154187247" 94 | ] 95 | }, 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "#求loss,MSELoss\n", 103 | "def get_loss(x, y):\n", 104 | " pred = predict(x)\n", 105 | " loss = (pred - y)**2\n", 106 | " return loss\n", 107 | "\n", 108 | "\n", 109 | "get_loss(x[0], y[0])" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "id": "8027d213", 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "data": { 120 | "text/plain": [ 121 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 122 | " 0.923131013558981)" 123 | ] 124 | }, 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "def get_gradient(x, y):\n", 132 | " global w\n", 133 | " global b\n", 134 | "\n", 135 | " eps = 1e-3\n", 136 | "\n", 137 | " loss_before = get_loss(x, y)\n", 138 | "\n", 139 | " gradient_w = np.empty(M)\n", 140 | " for i in range(M):\n", 141 | " w[i] += eps\n", 142 | " loss_after = get_loss(x, y)\n", 143 | " w[i] -= eps\n", 144 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 145 | "\n", 146 | " b += eps\n", 147 | " loss_after = get_loss(x, y)\n", 148 | " b -= eps\n", 149 | " gradient_b = (loss_after - loss_before) / eps\n", 150 | "\n", 151 | " return gradient_w, gradient_b\n", 152 | "\n", 153 | "\n", 154 | "get_gradient(x[0], y[0])" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 6, 160 | "id": "f39e0125", 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "data": { 165 | "text/plain": [ 166 | "11073.905141728206" 167 | ] 168 | }, 169 | "execution_count": 6, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "def total_loss():\n", 176 | " loss = 0\n", 177 | " for i in range(N):\n", 178 | " loss += get_loss(x[i], y[i])\n", 179 | " return loss\n", 180 | "\n", 181 | "\n", 182 | "total_loss()" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 7, 188 | "id": "c371c6a4", 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "500 [0.08904127 0.71062389] [16.68083241 9.3870837 ] 1261.5118603011992\n", 196 | "1000 [-0.90368761 0.83072632] [8.3751223 4.76494681] 749.9516889645644\n", 197 | "1500 [-0.30210063 0.02406904] [5.98086602 3.792364 ] 745.5930763832396\n", 198 | "2000 [ 0.05851974 -0.52552324] [5.03242601 3.45097686] 734.1425192541457\n", 199 | "2500 [ 0.19785822 -0.16807166] [4.55122551 2.92449296] 763.0359728232618\n", 200 | "3000 [ 0.11518987 -0.08617003] [4.08077263 2.93377002] 746.4590779281906\n", 201 | "3500 [ 0.14053857 -0.11983124] [3.52452827 2.59132747] 783.7272638890388\n", 202 | "4000 [0.03610707 0.36194797] [3.26520532 2.42821712] 766.6407173851062\n", 203 | "4500 [-0.24672945 0.18166294] [3.41205737 3.01430322] 791.864659000636\n", 204 | "5000 [ 0.1558802 -0.19912752] [3.24213769 2.86521236] 788.53176983857\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "plt_x = []\n", 210 | "plt_y = []\n", 211 | "\n", 212 | "for t in range(1, 5500):\n", 213 | " i = np.random.randint(N)\n", 214 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 215 | "\n", 216 | " v_w = 0.9 * v_w + 0.1 * gradient_w\n", 217 | " v_b = 0.9 * v_b + 0.1 * gradient_b\n", 218 | "\n", 219 | " #S的计算和rmsprop完全一致\n", 220 | " S_w = 0.999 * S_w + 0.001 * gradient_w**2\n", 221 | " S_b = 0.999 * S_b + 0.001 * gradient_b**2\n", 222 | "\n", 223 | " #根据以上公式,在时刻0\n", 224 | " #v = [0.1 * gradient_0]\n", 225 | "\n", 226 | " #这可能太过于小,为了消除这个影响,需要做偏差修正,也就是除以系数\n", 227 | " #v = 0.1 * sigma[0.9**(t-i) * gradient_i]\n", 228 | " #S = 0.001 * sigma[0.999**(t-i) * gradient_i**2]\n", 229 | " \n", 230 | " #将梯度的系数部分整理得到\n", 231 | " #0.1 * sigma[0.9**(t-i)] = 1-0.9**t\n", 232 | "\n", 233 | " #偏差修正\n", 234 | " v_hat_w = v_w / (1 - 0.9**t)\n", 235 | " v_hat_b = v_b / (1 - 0.9**t)\n", 236 | " S_hat_w = S_w / (1 - 0.999**t)\n", 237 | " S_hat_b = S_b / (1 - 0.999**t)\n", 238 | "\n", 239 | " #下面是adam参数更新的公式\n", 240 | " #这里的1e-2是超参数lr\n", 241 | " gradient_w = (1e-2 * v_hat_w) / (S_hat_w**0.5 + 1e-6)\n", 242 | " gradient_b = (1e-2 * v_hat_b) / (S_hat_b**0.5 + 1e-6)\n", 243 | "\n", 244 | " #更新参数\n", 245 | " w -= gradient_w\n", 246 | " b -= gradient_b\n", 247 | "\n", 248 | " plt_x.append(t)\n", 249 | " plt_y.append(total_loss())\n", 250 | "\n", 251 | " if t % 500 == 0:\n", 252 | " print(t, v_hat_w[:2], S_hat_w[:2], total_loss())" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 8, 258 | "id": "0471a70d", 259 | "metadata": { 260 | "scrolled": true 261 | }, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgqElEQVR4nO3de3Sc9X3n8fd37pJGd41kWTbxFYOhEIgLhJA2gQQITWKym3bpphtvyymnbbabNnu2JafnhG672dN0t0mT05aUDXRJTrYJJeniNhfiAIGWBBMDxviCbdlgbFmWZN2vM5qZ3/4xPzmDkHzRSB5Jz+d1zpx5nt/zzOj7s8b6zO/3PM+MOecQEZFgC5W7ABERKT+FgYiIKAxERERhICIiKAxERASIlLuAuWpqanJr1qwpdxkiIkvGCy+8cNo5l5pp25INgzVr1rBr165ylyEismSY2bHZtmmaSEREFAYiIqIwEBERFAYiIoLCQEREUBiIiAgKAxERIWBhkMs77v/REXYfHyh3KSIii0qgwmA0k+Xvnn2Nv/jBwXKXIiKyqAQqDGoSUW7a0MTRntFylyIisqgEKgwA1jRV0TEwzngmV+5SREQWjcCFQXN1HIC+sUyZKxERWTwCFwZ1lTEABhQGIiJnBDAMogAMjE2WuRIRkcUjcGFQ70cG/RoZiIicEcAwKIwM+jUyEBE5I3BhcOaYwahGBiIiUwIXBrFIiKpYWCMDEZEigQsDKIwOBsY1MhARmRLIMKiviupsIhGRIoEMg7qKmM4mEhEpEswwqNTIQESkWCDDoL5SIwMRkWIBDYMog+OT5PKu3KWIiCwKgQyDusoYzsHwhKaKRETgPMLAzB4ys24z21vU1mBmO8zssL+v9+1mZl8ys3Yz22Nm1xY9Zpvf/7CZbStqf4eZveIf8yUzs/nu5HT1VboKWUSk2PmMDP4PcPu0tnuBJ5xzG4En/DrAB4CN/nYPcD8UwgO4D7geuA64bypA/D6/WfS46T9r3tXp84lERN7knGHgnHsG6JvWvBV42C8/DNxZ1P5VV/AcUGdmrcBtwA7nXJ9zrh/YAdzut9U4555zzjngq0XPtWBqEhEARiayC/2jRESWhLkeM2hxznX65VNAi19uA44X7XfCt52t/cQM7TMys3vMbJeZ7erp6Zlj6RCPhAGYmNS3nYmIwDwcQPbv6C/KaTnOuQecc1ucc1tSqdScnycRLYTBuMJARASYexh0+Ske/H23b+8AVhftt8q3na191QztCyoRLXQ7PZlf6B8lIrIkzDUMtgNTZwRtAx4rav+4P6voBmDQTyc9DtxqZvX+wPGtwON+25CZ3eDPIvp40XMtmKmRwURWIwMREYDIuXYws78H3gM0mdkJCmcF/RnwiJndDRwDfsXv/l3gDqAdGAN+HcA512dmfwr81O/3J865qYPSv0PhjKUK4Hv+tqDOhIGmiUREgPMIA+fcr86y6ZYZ9nXAJ2Z5noeAh2Zo3wVcea465lMiUhgQTWiaSEQECOgVyJFwiEjINDIQEfECGQZQmCrSyEBEpCDAYRDSAWQRES+wYRCPhDVNJCLiBTYMKmJhXWcgIuIFNgwS0ZCuQBYR8YIbBpomEhE5I7hhEFUYiIhMCXAYhHRqqYiIF9gwiEfDOrVURMQLbBgkIjqbSERkSnDDIBrSMQMRES/AYaADyCIiUwIcBiEmsnkKH7QqIhJswQ2DSJhc3jGZUxiIiAQ3DPRtZyIiZwQ4DKa+4EZhICIS4DAojAx0eqmIiMJAIwMRERQG+kgKERECHQb+mIEOIIuIBDkMNE0kIjIluGEQ0TSRiMiU4IaBnybSt52JiAQ6DDRNJCIyJbBhEPcjg7TCQEQkuGGgU0tFRH4msGFQ4cNgLKORgYhIYMMgGg6RiIYYzWTLXYqISNmVFAZm9vtmts/M9prZ35tZwszWmtlOM2s3s2+aWczvG/fr7X77mqLn+bRvP2hmt5XYp/OWjEcZnlAYiIjMOQzMrA34z8AW59yVQBi4C/gc8AXn3AagH7jbP+RuoN+3f8Hvh5lt9o+7Argd+BszC8+1rgtRnYgwklYYiIiUOk0UASrMLAJUAp3AzcCjfvvDwJ1+eatfx2+/xczMt3/DOZd2zr0GtAPXlVjXeUnGI4xMTF6MHyUisqjNOQyccx3A/wLeoBACg8ALwIBzburt9gmgzS+3Acf9Y7N+/8bi9hkes6CScY0MRESgtGmiegrv6tcCK4EqCtM8C8bM7jGzXWa2q6enp+Tnq05EdMxARITSponeB7zmnOtxzk0C3wbeBdT5aSOAVUCHX+4AVgP47bVAb3H7DI95E+fcA865Lc65LalUqoTSC5IKAxERoLQweAO4wcwq/dz/LcB+4Cngo36fbcBjfnm7X8dvf9I553z7Xf5so7XARuD5Euo6b9WaJhIRAQoHgOfEObfTzB4FXgSywEvAA8B3gG+Y2X/3bQ/6hzwIfM3M2oE+CmcQ4ZzbZ2aPUAiSLPAJ59xFuRIs6c8mcs5RyDMRkWCacxgAOOfuA+6b1nyUGc4Gcs5NAL88y/N8FvhsKbXMRTIeJZd3TEzmqYhdlLNZRUQWpcBegQyFA8gAw2mdXioiwaYwAIbGddxARIIt0GFQVxkDYHA8U+ZKRETKK9hhUBEFoH9U00QiEmyBDoN6PzLoH9PIQESCLdBhUFdVGBkMjmtkICLBFugwqI5HCIdMIwMRCbxAh4GZUVcRpX9MIwMRCbZAhwFAXWWUAY0MRCTgAh8G9ZUxBjQyEJGAC3wY1FVqmkhERGFQGdM0kYgEXuDDoL4yqrOJRCTwAh8GdZUxJibzTExelE/NFhFZlBQGlYULz3QQWUSCLPBhoI+kEBFRGJwZGfSPKgxEJLgCHwZNyTgAvQoDEQmwwIdBY1Vhmqh3JF3mSkREyifwYVBfGSNkcHpEIwMRCa7Ah0EoZDRUxekd1chARIIr8GEA0JSMaWQgIoGmMKBwEPm0jhmISIApDCiMDHo1MhCRAFMYAI0aGYhIwCkMgMZkjLFMjvGMPp9IRIJJYUDRtQY6o0hEAkphADRUFa5C7tNVyCISUAoDoOHMyEBhICLBpDCg+CMpFAYiEkwlhYGZ1ZnZo2b2qpkdMLN3mlmDme0ws8P+vt7va2b2JTNrN7M9ZnZt0fNs8/sfNrNtpXbqQjUmC2HQp2MGIhJQpY4Mvgh83zl3GXA1cAC4F3jCObcReMKvA3wA2Ohv9wD3A5hZA3AfcD1wHXDfVIBcLMl4hFg4pGkiEQmsOYeBmdUCvwA8COCcyzjnBoCtwMN+t4eBO/3yVuCrruA5oM7MWoHbgB3OuT7nXD+wA7h9rnXNhZnRUBWjT9NEIhJQpYwM1gI9wN+Z2Utm9hUzqwJanHOdfp9TQItfbgOOFz3+hG+brf0tzOweM9tlZrt6enpKKP2tGqpiOptIRAKrlDCIANcC9zvnrgFG+dmUEADOOQe4En7GmzjnHnDObXHObUmlUvP1tEDhuMFphYGIBFQpYXACOOGc2+nXH6UQDl1++gd/3+23dwCrix6/yrfN1n5RNSXjnB7WAWQRCaY5h4Fz7hRw3Mw2+aZbgP3AdmDqjKBtwGN+eTvwcX9W0Q3AoJ9Oehy41czq/YHjW33bRdVcHadnJE1hMCMiEiyREh//u8DXzSwGHAV+nULAPGJmdwPHgF/x+34XuANoB8b8vjjn+szsT4Gf+v3+xDnXV2JdFyxVHSeTzTM0nqW2Mnqxf7yISFmVFAbOud3Alhk23TLDvg74xCzP8xDwUCm1lCpVXfhIip6RCYWBiASOrkD2psKgW8cNRCSAFAZe89TIQGEgIgGkMPBS1QlAYSAiwaQw8GoSEWKRkMJARAJJYeCZGalkXMcMRCSQFAZFmmviGhmISCApDIqkkgoDEQkmhUGRlL8KWUQkaBQGRZqrE/SNZshk8+UuRUTkolIYFJm68KxX33gmIgGjMCiS0oVnIhJQCoMiU1chdw8pDEQkWBQGRX72YXUKAxEJFoVBkaakpolEJJgUBkVikRD1lVG6hyfKXYqIyEWlMJgmVa0Lz0QkeBQG0ygMRCSIFAbTNFcn9GF1IhI4CoNppkYGhW/pFBEJBoXBNKlknHQ2z3A6W+5SREQuGoXBNM01uvBMRIJHYTBNStcaiEgAKQymafJXIZ/WVcgiEiAKg2mmrkJWGIhIkCgMpqmriBIJmaaJRCRQFAbThEJGYzKmkYGIBIrCYAYraivoGBgvdxkiIheNwmAGG5uTHOoaKXcZIiIXjcJgBptaqukZTtM/mil3KSIiF0XJYWBmYTN7ycz+2a+vNbOdZtZuZt80s5hvj/v1dr99TdFzfNq3HzSz20qtqVSXtVYDcODUUJkrERG5OOZjZPBJ4EDR+ueALzjnNgD9wN2+/W6g37d/we+HmW0G7gKuAG4H/sbMwvNQ15xd3loDwP6TCgMRCYaSwsDMVgG/BHzFrxtwM/Co3+Vh4E6/vNWv47ff4vffCnzDOZd2zr0GtAPXlVJXqZqScVLVcV49NVzOMkRELppSRwZ/CfwBkPfrjcCAc27qU95OAG1+uQ04DuC3D/r9z7TP8Jg3MbN7zGyXme3q6ekpsfSz25BK0t6tg8giEgxzDgMz+yDQ7Zx7YR7rOSvn3APOuS3OuS2pVGpBf9aG5iRHukf0UdYiEgiREh77LuDDZnYHkABqgC8CdWYW8e/+VwEdfv8OYDVwwswiQC3QW9Q+pfgxZbOhOclwOkv3cJqWmkS5yxERWVBzHhk45z7tnFvlnFtD4QDwk865jwFPAR/1u20DHvPL2/06fvuTrvC2eztwlz/baC2wEXh+rnXNlw3NSQBNFYlIICzEdQZ/CHzKzNopHBN40Lc/CDT69k8B9wI45/YBjwD7ge8Dn3DO5RagrguiMBCRICllmugM59yPgB/55aPMcDaQc24C+OVZHv9Z4LPzUct8aa6OUx2PKAxEJBB0BfIszIz1zTqjSESCQWFwFhuak7T3KAxEZPlTGJzFulQVPcNpRtLZc+8sIrKEKQzOYlV9JQAd/fo4axFZ3hQGZ9FWVwFAx8BYmSsREVlYCoOzWF1fCIMTGhmIyDKnMDiLpmScWDikaSIRWfYUBmcRChlt9RUaGYjIsqcwOIe2ugpO6PuQRWSZUxicw6r6Ck0TiciypzA4h1X1FZweSTMxWfaPSxIRWTAKg3No0xlFIhIACoNzWJ8qfHrpQX0FpogsYwqDc7i8tYZENMSuY33lLkVEZMEoDM4hGg5x9ao6XjzWX+5SREQWjMLgPFy3toG9J4foH82UuxQRkQWhMDgP79/cQi7vePLV7nKXIiKyIBQG5+Hn2mpprU3w/X2nyl2KiMiCUBicBzPjA1e28vTBHgbHJstdjojIvFMYnKePXNNGJpfnn/acLHcpIiLzTmFwnq5sq2Fzaw1f+8kxnHPlLkdEZF4pDM6TmfEfb1zDwa5hXnxjoNzliIjMK4XBBXjvZc0AvPSGrjkQkeVFYXABUtVxVtYm2H18oNyliIjMK4XBBbrmknpe0jSRiCwzCoMLtGVNPR0D4xzvGyt3KSIi80ZhcIHevTEFwL8cPl3mSkRE5o/C4AKtT1WxoibBs0cUBiKyfCgMLpCZceP6Rp470qvrDURk2ZhzGJjZajN7ysz2m9k+M/ukb28wsx1mdtjf1/t2M7MvmVm7me0xs2uLnmub3/+wmW0rvVsL64b1jfSOZjjUNVLuUkRE5kUpI4Ms8F+cc5uBG4BPmNlm4F7gCefcRuAJvw7wAWCjv90D3A+F8ADuA64HrgPumwqQxerG9Y0APNuuqSIRWR7mHAbOuU7n3It+eRg4ALQBW4GH/W4PA3f65a3AV13Bc0CdmbUCtwE7nHN9zrl+YAdw+1zruhhW1VeyqaWa/7e7o9yliIjMi3k5ZmBma4BrgJ1Ai3Ou0286BbT45TbgeNHDTvi22doXtV+9bjV7Tgyyt2Ow3KWIiJSs5DAwsyTwLeD3nHNDxdtc4QjrvB1lNbN7zGyXme3q6emZr6edk49cu4pENMTDP369rHWIiMyHksLAzKIUguDrzrlv++YuP/2Dv5/6erAOYHXRw1f5ttna38I594BzbotzbksqlSql9JLVVkS56+cv4VsvnuBQ13BZaxERKVUpZxMZ8CBwwDn3+aJN24GpM4K2AY8VtX/cn1V0AzDop5MeB241s3p/4PhW37boffKWjSTjET77nQPlLkVEpCSljAzeBfwH4GYz2+1vdwB/BrzfzA4D7/PrAN8FjgLtwP8GfgfAOdcH/CnwU3/7E9+26NVXxfjdmzfy9KEedh7tLXc5IiJzZkv1wqktW7a4Xbt2lbsMJiZzXPfZH/Ley5r54l3XlLscEZFZmdkLzrktM23TFcglSkTDfOSaNr6zp5P9J4fO/QARkUVIYTAPfu99l1JXGeP3v7mbXH5pjrREJNgUBvOgvirGZz60mYNdwzp2ICJLksJgnrxnUwoz2Pnakjj2LSLyJgqDeVKTiHJVWy3feaVTU0UisuQoDObRb/3ietq7R/jy00fKXYqIyAVRGMyj269cwYeuXsnndxzij7fvo280U+6SRETOi8JgHpkZ/+MjV7L16pV8fecxbv3CM+w+PlDuskREzklhMM+qE1E+/+/ezvb/dBOJaIiP3v9j/us/vMxrp0fLXZqIyKwUBgvk8tYavv3bN/JrN7yN7S+f5Ja/+BGfemQ3I+lsuUsTEXkLhcECaq5J8McfvoJ//cOb+c13r+Ox3Sf58F/9K8f7xspdmojImygMLoJUdZxP33E5X7v7OnpHMvzSl/6Fz33/Vfp1gFlEFgmFwUV04/om/uG33slNG5v426eP8LGv7GR4YrLcZYmIKAwutktbqvmbj72DB7f9PAe7htn20PP0DKfLXZaIBJzCoEzee1kzf/3vr2FvxxDv+/zTPLLrOEv148RFZOlTGJTR7Ve28t1PvptLW5L8waN7+NBf/Svfe6VToSCB5JzTa7+MFAZltqE5yTfveSd//m+vYmIyz29//UW2/vWzfO25Y+ztGCSdzZW7RJkHQxOTfGdPJxOTS/f3mc87/ts/7eP+Hx056+syn3cc6hpmYjLHeOb8+nugc4ir/vgH/NqDO8nP8tleswVFLu/Ysb+LX/nyT/jo/T/m1ODErM+xEJxzDIxlGMss7dPG9U1ni0g2l+dbL57gy08fPXORWkU0zLpUFStqEkTCRmMyTlNVjNrKGPWVUcIhYzLnyOXzJKJhqmIRvvrcMV461s/qhkresynFphXVrKqvJBo2DneNUFMRZTSdJRI2WmsraK1NAPDDA12MZ3LcdsUKGpMxEtEwvSMZJiZztNQkiEcK7x0m83k6+scZn8yRyzu6htLknaMpGWc0nfXPHSISMtLZPIPjGVLVceoqY1TGwrTVVVCdiL6l78MTWfZ3DnF6JM1YJsdrp0epSUS4YmUtADduaCQeCZPPOyayOTr6x9nfOcRVq+pY21RF70iaw90jHO4e4Uj3CC+90U/fWIY7rmyltjLK4Ngk772smaZknL0dg2Ryea5f20A27zg5ME40HGJFTYIVtQmefLWbkwPjrKyr4G2NlRzvG2PX64V/0w9dvZJDXcP86GAPsUiIG9Y1sLK2giM9IwBUxSNUxsKsqEnw8olBnjjQxQvH+jncPcK6piruvKaNxmSMfN7xzvWNrKqvpH8sQ99oBsNoSsZIJiL85EgvY5kcPznay/G+MfacGOQ9m1J85oObiYRCDI5P8o8vdXDL5c1csbKG9u4R/vaZo7x+epQNzUl+8dIUN21sOvM7Gklnyeby/Mvh0zx1sJtNK6rZ0JzkHZfUs6E5SXUiSiwSYjSdZdexfh7+8es0V8f58NUr2bKmgYeefY0/+96rAKxPVfHBq1YSi4S4tKWaTS3VRCNGejLPfdv38fShHgDikRDv39zCppZqrl/XyE+O9BZex1UxekczhEPG5tYa/ufjB3mlYxCAP7rjcsygbzRz5v9Bx8A4BzqHuOWyFm6+rJlMLs8bfWOsqEnwxKtdPNv+5o+Ob66O8xs3rWVDKknfaIZI2Li0pZrqRIS2ugo6BycwK1wkmp7M8cMD3cQiITY2J5nM5UkmIlREw5wcmCDvHLm8o38sQywcYjST45lDPYyks1TGwvz09T66htJUxyNc3loDBh+6qpVrLqlndUMlI+ksIxNZRtKTNFbFaaqO++ce54cHunjuaC/hkHH3TeuIR0KsaaqiZzjN0Z4RQiGja3Ci0KeaOCPpHFesrGFdUxWFr6G/MGf7pjOFwSLknOONvjFe6Rhk1+v9HOsd5dRQmmwuz+mRNAPjk5zt11YZC/PBq1p5vXeMF471L8pPUa2tiBIySGfzpLP5GWuMRUJksvkz6/FIiJAZ4zO8u46EjGzRc1TFwrytsYqWmjjPHD5NLu/ess+FioYLwVv8Mx2c97/v21fXkXeOPScGL+jnVkTDrG6ooK4yxvOzfET6VG0V0TBXttVw8NQwQxOzv1O95pI63ugdo3fa6c3xSIi0/zdvrIqRyeYZTmcJGeQdvO/yFq5f28DfPfsaJ/0fqenM4J53ryMUMjoHxnl8X9eMv7NiyXiEz3xwM3/7zBGO9Pzsav3VDRUYRltdBZc0VPK9vZ1n+jX1+4xHQnzmQ5v55Xes5nt7OznQOcwzh3rY3znzNw9O9aUUFdEwlzRUMj6Z422NlVy9qo7j/WMcPDVMx8A4w2f5t5/ODKKhEJlc/tw7U/i/s/sz71cYTFnOYXAuubxjeGKS/rFJcnlHLBwiEjbGMjkGxzNsWlFDMh4BCt/R/EbfGCf6x8hkHetTVYxP5qiMRcjm83QOTHBqaALn4PLWamorovz09T6GxrOMT+aoq4xSFYvQNTzBxGSekEHYjNa6CpLxCCGDlpoEZtA7kjnzjmpiMkc4ZOTyjubqBL2jaQbGJhnNZDnRP05H/zhmEAuHiEdDxCNhKmNhLm2pprU2QSJaGEEc6xujo3+cSf+ONhyCiljhnXdDZYzLW2vYfbyfU0MT1FfGuLSlmo0tSVbUJM78Z+n37wzDIeMH+7rI5R2bV9aQyeY52DVMNGysqKkgl3ecGpqgc2CcVQ0V/MLGFJ2DExzrHcPhuO2KFbx8fIDnX+9jQyrJjRuacM7x3NE+Tg2Os3llLeOZHEMThbDuH8uwuqGS69c2cHJgnLX+3Vz/aIZMLs9YJscP93cxmc9TVxGjoSqKc9AzkqZraIIb1jUSC4e4rLWG2orCSOpA5xBPvtpNRTRMVTzM6vpK9p0conc0Q0tNnK1vb6OhKkY2l+eFY/08/1ofFbEwqeo41YkI8UiYpmScTSuqcc5xon+cl44P0D+aYXhikuGJLFXxCDWJCP/mHauIhIxnDvWw+/ggaxor2fr2NipiYaAwmktn8+w5MciJ/jGyeUfYjMtba/i5VbVves2+0TvGc6/18u6NTdRWROkdKYwWe0czHDg5xPXrGqhORBnP5HilY5C2+gqakjHikfCbnmdiMkfPcBrnCkHR3j1CfVWMpmT8Tftlc3l+sL+LkMH6VJKhiSynR9L0j2Y41jfGJQ2VhIwzf7Tfub6RimiYw90jxCOFUddIOsvaxioi4RBmUF8Z9aNwx+bWGkKhmf8YT/27/uRIL0MTkyTjEaoTUSrjYToHJhhNZxnL5AiH4P2bV9Bal2A0neXZ9l4mc3mGxiepr4qxPpUkl3e01MTPjC6SiQjdw2neu6l5Tn87FAYiInLWMNABZBERURiIiIjCQEREUBiIiAgKAxERQWEgIiIoDEREBIWBiIiwhC86M7Me4NgcH94EnJ7HchYT9W3pWs79W859g6XTv7c551IzbViyYVAKM9s121V4S536tnQt5/4t577B8uifpolERERhICIiwQ2DB8pdwAJS35au5dy/5dw3WAb9C+QxAxERebOgjgxERKSIwkBERIIVBmZ2u5kdNLN2M7u33PWcLzN7yMy6zWxvUVuDme0ws8P+vt63m5l9yfdxj5ldW/SYbX7/w2a2rRx9mc7MVpvZU2a238z2mdknffuS75+ZJczseTN72fftv/n2tWa20/fhm2YW8+1xv97ut68peq5P+/aDZnZbmbr0FmYWNrOXzOyf/fpy6tvrZvaKme02s12+bcm/LmflnAvEDQgDR4B1QAx4Gdhc7rrOs/ZfAK4F9ha1/Tlwr1++F/icX74D+B5gwA3ATt/eABz19/V+uX4R9K0VuNYvVwOHgM3LoX++xqRfjgI7fc2PAHf59i8Dv+2Xfwf4sl++C/imX97sX69xYK1/HYfL/bvztX0K+L/AP/v15dS314GmaW1L/nU52y1II4PrgHbn3FHnXAb4BrC1zDWdF+fcM8D0b0LfCjzslx8G7ixq/6oreA6oM7NW4DZgh3OuzznXD+wAbl/w4s/BOdfpnHvRLw8DB4A2lkH/fI0jfjXqbw64GXjUt0/v21SfHwVuscIXOW8FvuGcSzvnXgPaKbyey8rMVgG/BHzFrxvLpG9nseRfl7MJUhi0AceL1k/4tqWqxTnX6ZdPAS1+ebZ+Lvr++6mDayi8g14W/fPTKLuBbgp/CI4AA865rN+luM4zffDbB4FGFmnfgL8E/gDI+/VGlk/foBDcPzCzF8zsHt+2LF6XM4mUuwApnXPOmdmSPkfYzJLAt4Dfc84NFd40Fizl/jnncsDbzawO+EfgsvJWND/M7INAt3PuBTN7T5nLWSg3Oec6zKwZ2GFmrxZvXMqvy5kEaWTQAawuWl/l25aqLj8Mxd93+/bZ+rlo+29mUQpB8HXn3Ld987LpH4BzbgB4CngnhSmEqTdixXWe6YPfXgv0sjj79i7gw2b2OoUp15uBL7I8+gaAc67D33dTCPLrWGavy2JBCoOfAhv92Q4xCgextpe5plJsB6bOTNgGPFbU/nF/dsMNwKAf1j4O3Gpm9f4MiFt9W1n5eeMHgQPOuc8XbVry/TOzlB8RYGYVwPspHBN5Cvio321636b6/FHgSVc4CrkduMufkbMW2Ag8f1E6MQvn3Kedc6ucc2so/F960jn3MZZB3wDMrMrMqqeWKbye9rIMXpezKvcR7It5o3DE/xCFeds/Knc9F1D33wOdwCSFOce7Kcy3PgEcBn4INPh9Dfhr38dXgC1Fz/MbFA7QtQO/Xu5++ZpuojA3uwfY7W93LIf+AVcBL/m+7QU+49vXUfiD1w78AxD37Qm/3u63ryt6rj/yfT4IfKDcfZvWz/fws7OJlkXffD9e9rd9U38vlsPrcrabPo5CREQCNU0kIiKzUBiIiIjCQEREFAYiIoLCQEREUBiIiAgKAxERAf4/Vqzj2EHm6PMAAAAASUVORK5CYII=\n", 266 | "text/plain": [ 267 | "
" 268 | ] 269 | }, 270 | "metadata": { 271 | "needs_background": "light" 272 | }, 273 | "output_type": "display_data" 274 | } 275 | ], 276 | "source": [ 277 | "from matplotlib import pyplot as plt\n", 278 | "%matplotlib inline\n", 279 | "\n", 280 | "plt.plot(plt_x, plt_y)\n", 281 | "plt.show()" 282 | ] 283 | } 284 | ], 285 | "metadata": { 286 | "kernelspec": { 287 | "display_name": "Python 3 (ipykernel)", 288 | "language": "python", 289 | "name": "python3" 290 | }, 291 | "language_info": { 292 | "codemirror_mode": { 293 | "name": "ipython", 294 | "version": 3 295 | }, 296 | "file_extension": ".py", 297 | "mimetype": "text/x-python", 298 | "name": "python", 299 | "nbconvert_exporter": "python", 300 | "pygments_lexer": "ipython3", 301 | "version": "3.8.11" 302 | } 303 | }, 304 | "nbformat": 4, 305 | "nbformat_minor": 5 306 | } 307 | -------------------------------------------------------------------------------- /1.梯度下降.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "92163201", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "0.6590042695516539" 61 | ] 62 | }, 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "#预测函数\n", 70 | "def predict(x):\n", 71 | " return w.dot(x) + b\n", 72 | "\n", 73 | "\n", 74 | "predict(x[0])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "id": "a7bb7a80", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "7.367867692433937" 87 | ] 88 | }, 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "output_type": "execute_result" 92 | } 93 | ], 94 | "source": [ 95 | "#求loss,MSELoss\n", 96 | "def get_loss():\n", 97 | " loss = 0\n", 98 | " for i in range(N):\n", 99 | " pred = predict(x[i])\n", 100 | " loss += (pred - y[i])**2\n", 101 | " return loss / N\n", 102 | "\n", 103 | "\n", 104 | "get_loss()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "id": "8027d213", 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "(array([2.03668543, 2.38225639, 1.02215384, 2.13526642, 3.22327899]),\n", 117 | " 0.0010000000036924916)" 118 | ] 119 | }, 120 | "execution_count": 5, 121 | "metadata": {}, 122 | "output_type": "execute_result" 123 | } 124 | ], 125 | "source": [ 126 | "def get_gradient():\n", 127 | " global w\n", 128 | " global b\n", 129 | "\n", 130 | " eps = 1e-3\n", 131 | "\n", 132 | " loss_before = get_loss()\n", 133 | "\n", 134 | " gradient_w = np.empty(M)\n", 135 | " for i in range(M):\n", 136 | " w[i] += eps\n", 137 | " loss_after = get_loss()\n", 138 | " w[i] -= eps\n", 139 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 140 | "\n", 141 | " b += eps\n", 142 | " loss_after = get_loss()\n", 143 | " b -= eps\n", 144 | " gradient_b = (loss_after - loss_before) / eps\n", 145 | "\n", 146 | " return gradient_w, gradient_b\n", 147 | "\n", 148 | "\n", 149 | "get_gradient()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 6, 155 | "id": "c371c6a4", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "0 7.112757670092038\n", 163 | "50 1.7854577366414703\n", 164 | "100 0.8188794143216034\n", 165 | "150 0.5927178198131446\n", 166 | "200 0.5305917673804184\n", 167 | "250 0.5095310094726683\n", 168 | "300 0.5002280440367376\n", 169 | "350 0.4950505954913211\n", 170 | "400 0.49174823787396005\n", 171 | "450 0.48950848092220356\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "plt_x = []\n", 177 | "plt_y = []\n", 178 | "for i in range(500):\n", 179 | " gradient_w, gradient_b = get_gradient()\n", 180 | " w -= gradient_w * 1e-2\n", 181 | " b -= gradient_b * 1e-2\n", 182 | "\n", 183 | " plt_x.append(i)\n", 184 | " plt_y.append(get_loss())\n", 185 | "\n", 186 | " if i % 50 == 0:\n", 187 | " print(i, get_loss())" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 7, 193 | "id": "0471a70d", 194 | "metadata": { 195 | "scrolled": true 196 | }, 197 | "outputs": [ 198 | { 199 | "data": { 200 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZc0lEQVR4nO3dfXRc9X3n8fd3ZqTRw+jBkkayQAI/8mCobYgghoRASUhIwqbbDUlJ02y7S+pDN90lZ3O2J2xPd0939/Rhz7ZNu9umoZC2e5rAJgR2WZYWSCAQHgrINjZgG2yMsWVjSzKSLcnW08x3/5grIQvLHssa3Tszn9c5c3Tn3t/MfH9i+Pin3/zuXHN3REQkumJhFyAiIqenoBYRiTgFtYhIxCmoRUQiTkEtIhJxiUI8aUtLiy9btqwQTy0iUpI2bdrU7+7pUx0rSFAvW7aM7u7uQjy1iEhJMrN35jqmqQ8RkYhTUIuIRJyCWkQk4hTUIiIRp6AWEYk4BbWISMQpqEVEIi4yQe3u/NlPdvH0m31hlyIiEilnDGozu9jMXplxO2ZmX1/oQsyMv3pmD0/t7F3opxYRKWpnPDPR3d8A1gOYWRw4ADxUiGJa6pL0D48V4qlFRIrW2U59fBx4y93nPNXxXDTXViqoRURmOdugvg24rxCFALSkkhwZHi/U04uIFKW8g9rMKoHPAT+c4/hGM+s2s+6+vvl9INic0ohaRGS2sxlRfxrY7O6HT3XQ3e929y5370qnT/lNfWfUkkoycHyCyUx2Xo8XESlFZxPUX6KA0x4ALalKAN47rukPEZEpeQW1mdUCNwEPFrKY5lQSgP4hBbWIyJS8Lhzg7iNAc4FroSUI6iMjmqcWEZkSmTMTIfdhIqAPFEVEZohUUE+PqLVET0RkWqSCur4qQWU8Rp9G1CIi0yIV1GZGc6pSI2oRkRkiFdSgk15ERGaLXFDrNHIRkZNFLqiba/UNeiIiM0UuqFuCOWp3D7sUEZFIiGBQJxnPZBkamwy7FBGRSIhcUE+f9DKk6Q8REYhgUL9/Grk+UBQRgQgGtUbUIiIni1xQp+tyI2qdnSgikhO5oG6uTRIz6D2moBYRgQgGdTxmpOuSHD42GnYpIiKRELmgBmirr+Kw5qhFRICIBnVrXRW9GlGLiABRDer6JL0aUYuIABEN6ra6Kt4bGWd8UlcjFxGJZlDXa4meiMiUSAZ1axDUWvkhIpJnUJtZo5k9YGY7zWyHmV1TyKJa66oA9IGiiAiQyLPdnwL/4O63mlklUFPAmmirD4JaHyiKiJw5qM2sAfgY8GsA7j4OFPQbk5prK4nHTFMfIiLkN/WxHOgD/trMtpjZPWZWO7uRmW00s24z6+7r6zu3omJGOpXksE4jFxHJK6gTwJXAt939CmAE+ObsRu5+t7t3uXtXOp0+58LatJZaRATIL6h7gB53fzG4/wC54C6otM5OFBEB8ghqdz8E7Dezi4NdHwe2F7QqciNqzVGLiOS/6uNfA98LVnzsAf5F4UrKaauvYuD4BGOTGZKJeKFfTkQksvIKand/BegqbCknmz47cWiMjiUFXQ0oIhJpkTwzEaA1WEut6Q8RKXeRDerzGqoBODiooBaR8hbZoG5vzI2o3z16IuRKRETCFdmgrq+qIJVMaEQtImUvskEN0N5QpRG1iJS9aAd1Y7VG1CJS9iId1Oc3akQtIhLpoG5vqKZ/eJyxyUzYpYiIhCbiQZ1b+XHoqKY/RKR8RTqoz2vMraU+MKjpDxEpX0UR1O/qA0URKWORDuqpqQ99oCgi5SzSQV1VEaeptpKDmqMWkTIW6aCG4KQXzVGLSBmLfFCfp5NeRKTMRT+oG6o4qDlqESlj0Q/qxmqGRicZGp0IuxQRkVBEPqinru6y/z2NqkWkPEU+qDubcmup9w8cD7kSEZFwRD6oL2iaGlErqEWkPEU+qBuqK6hLJugZ0NSHiJSnvK5CbmZ7gSEgA0y6+6JdkdzM6GiqYZ9G1CJSpvIK6sDPu3t/wSo5jc4l1bzdPxLGS4uIhC7yUx8AnU019AycwN3DLkVEZNHlG9QOPG5mm8xs46kamNlGM+s2s+6+vr6Fq5DciPrERIb+4fEFfV4RkWKQb1B/1N2vBD4NfM3MPja7gbvf7e5d7t6VTqcXtMjOqZUfWqInImUor6B29wPBz17gIeDqQhY1m5boiUg5O2NQm1mtmdVNbQOfBF4rdGEzTZ2dqCV6IlKO8ln10QY8ZGZT7b/v7v9Q0Kpmqa6M05JKsu+IRtQiUn7OGNTuvgdYtwi1nFZnU7XmqEWkLBXF8jzIzVPrpBcRKUdFE9QXNtdycPAEY5OZsEsREVlURRPUK1pqybpWfohI+SmaoF7WUgvA2/0KahEpL0UT1Mubp4J6OORKREQWV9EEdUNNBU21lRpRi0jZKZqgBljeUqsRtYiUnaIK6mXNtezViFpEykxRBfXylhoOHRvl+Phk2KWIiCyaIgvqFIBG1SJSVooqqJe15L6cSVd7EZFyUlxBHSzR23tEQS0i5aOogro2maCtPsmePgW1iJSPogpqgJXpFHu0RE9EykjRBfXq1hS7Dw/rQrciUjaKLqhXtdUxNDbJoWOjYZciIrIoii6oV7fmlujtOqzpDxEpD8Ub1L0KahEpD0UX1M2pJE21lezuHQq7FBGRRVF0QQ25UfWbmvoQkTJRnEHdlmLX4SGt/BCRspB3UJtZ3My2mNkjhSwoH6tb6zg2Oknf0FjYpYiIFNzZjKjvBHYUqpCzoQ8URaSc5BXUZtYBfBa4p7Dl5GdV29QSPX2gKCKlL98R9beA3wKyczUws41m1m1m3X19fQtR25zSqSRLairYeUhBLSKl74xBbWa3AL3uvul07dz9bnfvcveudDq9YAXOUROXttez491jBX0dEZEoyGdE/RHgc2a2F7gfuNHM/q6gVeXh0vZ63jg8RCarlR8iUtrOGNTufpe7d7j7MuA24El3/5WCV3YGl7bXMzqR1UUERKTkFeU6aoA17fUAbNf0h4iUuLMKanf/qbvfUqhizsaq1hQVcdM8tYiUvKIdUVcmYqxqrWP7QQW1iJS2og1qgEvb6zSiFpGSV9RBvaa9nt6hMfqHdSq5iJSuog9qQKNqESlpxR3U5+WC+tUDR0OuRESkcIo6qBtrKrmwuYZt+xXUIlK6ijqoAdZ2NLKtZzDsMkRECqbog3pdRwMHj47qu6lFpGQVfVCv7WgE0KhaREpW0Qf15efXEzPY2qN5ahEpTUUf1DWVCVa31mlELSIlq+iDGmBtRwPbeo7qYrciUpJKI6g7G3lvZJyegRNhlyIisuBKIqiv6GwEYPO+gXALEREpgJII6kuW1lFbGad7r4JaREpPSQR1Ih7jyguX0P2OglpESk9JBDXAhy5cws5Dxzg2OhF2KSIiC6pkgrrrwibcYcu+wbBLERFZUCUT1OsvaCQeMzbtfS/sUkREFlTJBHUqmeDS9jpe1geKIlJiSiaoITf98cr+QSYy2bBLERFZMGcMajOrMrOXzGyrmb1uZr+7GIXNx4YVTZyYyOh0chEpKfmMqMeAG919HbAeuNnMNhS0qnn68PJmzOD53UfCLkVEZMGcMag9Zzi4WxHcIvmlGktqK1nTXs9zb/WHXYqIyILJa47azOJm9grQCzzh7i+eos1GM+s2s+6+vr4FLjN/165sZvM7g4xOZEKrQURkIeUV1O6ecff1QAdwtZldfoo2d7t7l7t3pdPpBS4zf9eubGE8k2WTzlIUkRJxVqs+3H0QeAq4uSDVLICrljcRjxnPa/pDREpEPqs+0mbWGGxXAzcBOwtc17ylkgnWdTTwnD5QFJESkc+Iuh14ysy2AS+Tm6N+pLBlnZuPrk6zrWeQgZHxsEsRETln+az62ObuV7j7Wne/3N3/02IUdi5uuDhN1uGZXeF9qCkislBK6szEKes6GllSU8HTbyioRaT4lWRQx2PG9RelefrNPrLZSC75FhHJW0kGNcANF7dyZGScVw8cDbsUEZFzUrJB/bGL0pjBU2/0hl2KiMg5Kdmgbqqt5IrORn6843DYpYiInJOSDWqAT162lNcOHKNn4HjYpYiIzFtJB/WnLlsKwGOva1QtIsWrpIN6eUstlyyt47HXD4VdiojIvJV0UENu+uPlve/RPzwWdikiIvNS8kF982VLcYcntmv6Q0SKU8kH9aXtdSxvqeX/bj0YdikiIvNS8kFtZnxu3Xm8sOcIh46Ohl2OiMhZK/mgBviF9efhDo9s06haRIpPWQT1inSKtR0N/O9XDoRdiojIWSuLoAb4hfXn89qBY+zuHQq7FBGRs1I2Qf1P1rUTjxk/3NQTdikiImelbIK6ta6KT1zaygPdPYxPZsMuR0Qkb2UT1ABfuvoCjoyMa021iBSVsgrq61anOb+xmvte2hd2KSIieSuroI7HjNuu6uTZ3f28c2Qk7HJERPJSVkEN8IWuTuIx4/6X94ddiohIXs4Y1GbWaWZPmdl2M3vdzO5cjMIKZWlDFR+/pJX7X9rHifFM2OWIiJxRPiPqSeAb7r4G2AB8zczWFLaswvrqdSsYOD7BA5u1VE9Eou+MQe3u77r75mB7CNgBnF/owgrpqmVLWN/ZyD0/20NGVykXkYg7qzlqM1sGXAG8eIpjG82s28y6+/r6Fqi8wjAzNn5sBe8cOc4T23VRARGJtryD2sxSwI+Ar7v7sdnH3f1ud+9y9650Or2QNRbEpy5bygVNNXznmT24a1QtItGVV1CbWQW5kP6euz9Y2JIWRzxmfPW65WzZN8g/7nkv7HJEROaUz6oPA+4Fdrj7Hxe+pMXzxa5O2uqT/NHjb2hULSKRlc+I+iPAV4AbzeyV4PaZAte1KKoq4vzmjavpfmeAn74Z7Xl1ESlf+az6eNbdzd3Xuvv64PboYhS3GH6pq5OOJdUaVYtIZJXdmYmzVSZifP0TF/HagWM89rpWgIhI9JR9UAP84hXns6o1xR/8/U7GJnW2oohEi4Ka3AqQ37llDXuPHOfeZ98OuxwRkZMoqAPXX5TmpjVt/I8nd+tq5SISKQrqGX7ns2uYzDq/9+iOsEsREZmmoJ7hguYa7rh+JQ9vPchP3+gNuxwREUBB/QH/6oaVrGpNcdeDr3JsdCLsckREFNSzVVXE+W9fWMfhY6P8l0e2h12OiIiC+lTWdzZyx/Ur+UF3D0/u1IVwRSRcCuo53PmJ1VyytI5v/GArBwdPhF2OiJQxBfUckok4f/HlK5nIOL/5/c1MZLJhlyQiZUpBfRor0in+4PM/x+Z9g/z+ozvDLkdEylQi7AKi7pa159G9d4DvPvc2lyyt44tXdYZdkoiUGY2o8/Dbn72U61a38O8fepVnd/WHXY6IlBkFdR4q4jH+4stXsqo1xW/83SZ2HvrAlchERApGQZ2nuqoK7v21q6hJxvmVe15kd+9w2CWJSJlQUJ+F8xur+f6vbwCMX/6rf+Tt/pGwSxKRMqCgPksr0ynu+/UPk8k6v/SdFzQNIiIFp6Ceh9Vtddy3cQNm8IW/fIGX3tZVzEWkcBTU83RRWx0/+o1rSdcl+cq9L/Lw1oNhlyQiJeqMQW1m3zWzXjN7bTEKKiYdS2p44I5rWdvRwL+5bwu///c7yGR1gVwRWVj5jKj/Bri5wHUUrabaSr731Q18ZcOFfOfpPXzl3hd1hRgRWVBnDGp3fwbQJOxpVCZi/Od/ejn/9da1bNk3yKe+9Qz/b9u7YZclIiVCc9QL6ItdnTx653Usa6nla9/fzJ33b6FvaCzsskSkyC1YUJvZRjPrNrPuvr6+hXraorO8pZYH7riGOz++mkdffZcb/+in/O3zezV3LSLzZu5nDhAzWwY84u6X5/OkXV1d3t3dfY6lFb89fcP8x4df52e7+rlkaR3/7lMXc+MlrZhZ2KWJSMSY2SZ37zrVMU19FNCKdIr/+S+v5s9/+UpOTGS4/W+7+fy3n+e53f3k8w+kiAjktzzvPuAF4GIz6zGz2wtfVukwMz67tp0f/9vr+b1f/DkODJ7gy/e8yC3//Vke2tKjCxKIyBnlNfVxtjT1MbfRiQwPbTnAPT/bw1t9Iyytr+K2qzv5/JUddDbVhF2eiITkdFMfCuqQZLPO02/2ce+zb/Ps7tx3XF+zoplbP9TBJ9a00VBdEXKFIrKYFNQR1zNwnAc3H+CBTT3se+84iZhxzcpmPnnZUm66tI2lDVVhlygiBaagLhLuzuZ9gzy+/RCPv354+mtUL1lax7UrW7h2ZTNXr2iivkqjbZFSo6AuQu7O7t5hHt9+mOff6qd77wBjk1liBmvOq2dtRyPrOxpZ19nIqtYU8ZiW/IkUMwV1CRidyLBl3yAvvNXP5n2DbO0ZZGh0EoCayjgXL63jotY6VrelWN1Wx0VtKZbWV2nNtkiRUFCXoGzWefvICFv3D7Kt5yg73j3G7t5hjoyMT7dJJRN0NtXQuaR6+ucFzTV0LKmhra6K+uqEglwkIk4X1InFLkYWRixmrEynWJlO8c+u7Jjef2R4jF29w+w6PMTu3mH2D5zg7f4RntnVx+jEyWu2k4kY6bokrXVJWuuqaK3PbTfVJllSU0FDTQWN1ZU01lTQWFNBdUVcwS4SAgV1iWlOJWlOJdmwovmk/e5O3/AY+987Qc/AcfqGxugdGqP32Ci9Q2Ps6h3iubf6p6dTTqUyEaOxuoKG6gpqkwlqk3FqKxOkkglqknFqkwlSlQlqkglSyTg1lQmqK+IkK2IkE3GSiRhVFbmfs/dpjl1kbgrqMmFmuVFzXRUfunDJnO1GJzIMHB9n8PhEcBtn8ESwfWKcwZEJjp6YYGR8kpGxSY4MjzM8Nsnx8QzDY5OMT87vTMtEzIIAj1MZj5GIGxXxGImYkYjHqIgb8ZhREcsdS8RjVMRs1nauXWKqTcyIxYy4GTF7fzsey/0+4sF9M3LbMcvtD9rEgsfFg+eJGbnnik3tP7mNGRi5dhY8b+7fn6ltw2C6ndn727HY+/vmekws+Gtmev8pnucDrzPd7vSPseA9kntlpl+HGfslPApqOUlVRZz2hmraG6rn9fiJTJbjY5npIB+dyDI2mWFsMvdz+v5ElrHJLKMTHzw2PpllMuNMZJ3JTJaJjDOZze2bzGYZn8wyMp5hMjPVLjiWyU4/JtfWybiTzTpZd/QFhuduOryn738w3I2TG53q2GmfZ65jdvrXnXrUB597+sicrzuzfzPbz1XrjJc76VhzbZIf3HENC01BLQuqIh6joSZGQ0301np7ENZZdzIzwjuTfT/Mc8HOSQGfazvrcUGbTNbxGW0cB2d624PHOQT7c/t8Rj1Mtzv5MbmaT35MNthwgteb8RgPXmf2Y3x6e47HBO2mXi+oaNb99++cru1cxzjp2Cmem/frmnls5n+7fGp8/3Gzjp3iuU/5fB849sE+z3UMh7qqwkSqglrKRm5aA+IYFfGwqxHJn77mVEQk4hTUIiIRp6AWEYk4BbWISMQpqEVEIk5BLSIScQpqEZGIU1CLiERcQb7m1Mz6gHfm+fAWoH8ByykG6nN5UJ/Lw3z7fKG7p091oCBBfS7MrHuu72QtVepzeVCfy0Mh+qypDxGRiFNQi4hEXBSD+u6wCwiB+lwe1OfysOB9jtwctYiInCyKI2oREZlBQS0iEnGRCWozu9nM3jCz3Wb2zbDrWShm9l0z6zWz12bsazKzJ8xsV/BzSbDfzOzPgt/BNjO7MrzK58/MOs3sKTPbbmavm9mdwf6S7beZVZnZS2a2Nejz7wb7l5vZi0Hf/peZVQb7k8H93cHxZaF24ByYWdzMtpjZI8H9ku6zme01s1fN7BUz6w72FfS9HYmgNrM48OfAp4E1wJfMbE24VS2YvwFunrXvm8BP3H018JPgPuT6vzq4bQS+vUg1LrRJ4BvuvgbYAHwt+O9Zyv0eA25093XAeuBmM9sA/CHwJ+6+ChgAbg/a3w4MBPv/JGhXrO4Edsy4Xw59/nl3Xz9jvXRh39u5a6aFewOuAR6bcf8u4K6w61rA/i0DXptx/w2gPdhuB94Itr8DfOlU7Yr5Bvwf4KZy6TdQA2wGPkzuDLVEsH/6fQ48BlwTbCeCdhZ27fPoa0cQTDcCj5C7zmup93kv0DJrX0Hf25EYUQPnA/tn3O8J9pWqNnd/N9g+BLQF2yX3ewj+vL0CeJES73cwBfAK0As8AbwFDLr7ZNBkZr+m+xwcPwo0L2rBC+NbwG8B2eB+M6XfZwceN7NNZrYx2FfQ97Yubhsyd3czK8k1kmaWAn4EfN3dj5nZ9LFS7Le7Z4D1ZtYIPARcEm5FhWVmtwC97r7JzG4IuZzF9FF3P2BmrcATZrZz5sFCvLejMqI+AHTOuN8R7CtVh82sHSD42RvsL5nfg5lVkAvp77n7g8Huku83gLsPAk+R+7O/0cymBkQz+zXd5+B4A3BkcSs9Zx8BPmdme4H7yU1//Cml3Wfc/UDws5fcP8hXU+D3dlSC+mVgdfBpcSVwG/BwyDUV0sPArwbbv0puDndq/z8PPineAByd8edU0bDc0PleYIe7//GMQyXbbzNLByNpzKya3Jz8DnKBfWvQbHafp34XtwJPejCJWSzc/S5373D3ZeT+n33S3b9MCffZzGrNrG5qG/gk8BqFfm+HPTE/Y5L9M8Cb5Ob1fjvsehawX/cB7wIT5Oanbic3L/cTYBfwY6ApaGvkVr+8BbwKdIVd/zz7/FFy83jbgFeC22dKud/AWmBL0OfXgP8Q7F8BvATsBn4IJIP9VcH93cHxFWH34Rz7fwPwSKn3Oejb1uD2+lRWFfq9rVPIRUQiLipTHyIiMgcFtYhIxCmoRUQiTkEtIhJxCmoRkYhTUIuIRJyCWkQk4v4/y+3mmReXR0gAAAAASUVORK5CYII=\n", 201 | "text/plain": [ 202 | "
" 203 | ] 204 | }, 205 | "metadata": { 206 | "needs_background": "light" 207 | }, 208 | "output_type": "display_data" 209 | } 210 | ], 211 | "source": [ 212 | "from matplotlib import pyplot as plt\n", 213 | "%matplotlib inline\n", 214 | "\n", 215 | "plt.plot(plt_x, plt_y)\n", 216 | "plt.show()" 217 | ] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3 (ipykernel)", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.8.11" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 5 241 | } 242 | -------------------------------------------------------------------------------- /2.随机梯度下降.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "id": "92163201", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "0.6590042695516539" 61 | ] 62 | }, 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "#预测函数\n", 70 | "def predict(x):\n", 71 | " return w.dot(x) + b\n", 72 | "\n", 73 | "\n", 74 | "predict(x[0])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "id": "a7bb7a80", 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/plain": [ 86 | "0.21258140154187247" 87 | ] 88 | }, 89 | "execution_count": 4, 90 | "metadata": {}, 91 | "output_type": "execute_result" 92 | } 93 | ], 94 | "source": [ 95 | "#求loss,MSELoss\n", 96 | "def get_loss(x, y):\n", 97 | " pred = predict(x)\n", 98 | " loss = (pred - y)**2\n", 99 | " return loss\n", 100 | "\n", 101 | "\n", 102 | "get_loss(x[0], y[0])" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 5, 108 | "id": "8027d213", 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/plain": [ 114 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 115 | " 0.923131013558981)" 116 | ] 117 | }, 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | } 122 | ], 123 | "source": [ 124 | "def get_gradient(x, y):\n", 125 | " global w\n", 126 | " global b\n", 127 | "\n", 128 | " eps = 1e-3\n", 129 | "\n", 130 | " loss_before = get_loss(x, y)\n", 131 | "\n", 132 | " gradient_w = np.empty(M)\n", 133 | " for i in range(M):\n", 134 | " w[i] += eps\n", 135 | " loss_after = get_loss(x, y)\n", 136 | " w[i] -= eps\n", 137 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 138 | "\n", 139 | " b += eps\n", 140 | " loss_after = get_loss(x, y)\n", 141 | " b -= eps\n", 142 | " gradient_b = (loss_after - loss_before) / eps\n", 143 | "\n", 144 | " return gradient_w, gradient_b\n", 145 | "\n", 146 | "\n", 147 | "get_gradient(x[0], y[0])" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 6, 153 | "id": "f39e0125", 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "11073.905141728206" 160 | ] 161 | }, 162 | "execution_count": 6, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "def total_loss():\n", 169 | " loss = 0\n", 170 | " for i in range(N):\n", 171 | " loss += get_loss(x[i], y[i])\n", 172 | " return loss\n", 173 | "\n", 174 | "\n", 175 | "total_loss()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 7, 181 | "id": "c371c6a4", 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "0 11038.895201527894\n", 189 | "150 6696.3736283721655\n", 190 | "300 4354.119709336485\n", 191 | "450 3004.4953025777063\n", 192 | "600 2310.68634531403\n", 193 | "750 1839.580107951638\n", 194 | "900 1549.8047186884628\n", 195 | "1050 1278.2079624059054\n", 196 | "1200 1099.1810250366634\n", 197 | "1350 986.6025037947752\n", 198 | "1500 921.4757031198328\n", 199 | "1650 879.159069825457\n", 200 | "1800 853.2767252227716\n", 201 | "1950 835.2496534941863\n", 202 | "2100 812.3758750332744\n", 203 | "2250 794.1165878394305\n", 204 | "2400 786.9647280480957\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "plt_x = []\n", 210 | "plt_y = []\n", 211 | "for epoch in range(2500):\n", 212 | " i = np.random.randint(N)\n", 213 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 214 | " w -= gradient_w * 1e-3\n", 215 | " b -= gradient_b * 1e-3\n", 216 | "\n", 217 | " plt_x.append(epoch)\n", 218 | " plt_y.append(total_loss())\n", 219 | "\n", 220 | " if epoch % 150 == 0:\n", 221 | " print(epoch, total_loss())" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 8, 227 | "id": "0471a70d", 228 | "metadata": { 229 | "scrolled": true 230 | }, 231 | "outputs": [ 232 | { 233 | "data": { 234 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgw0lEQVR4nO3deZhcdZ3v8fe3qrp6Te+dpNPp7CEkJGBCE3ZRQAiBaxh1MM69EoWRq4KD8sxVGJxHR+Ze9Y4Ogt5hnihoUBQQosQRgQz7IAl0QnbopBOydGfpTnpJ791V9bt/1OnQxE5IejtdVZ/X89RT5/zOOVXfX1clnzq7OecQEZHUFvC7ABER8Z/CQEREFAYiIqIwEBERFAYiIgKE/C5goIqLi92UKVP8LkNEJGGsW7fusHOupL9pCRsGU6ZMobKy0u8yREQShpntOdE0bSYSERGFgYiIKAxERASFgYiIoDAQEREUBiIigsJARERIwTB4eXs9W2qb/S5DRGRUSakwaG7vYdlDb7B0+Rq/SxERGVVSKgzystK4as44WrsitHdH/C5HRGTUSKkwAPj4hyYA8O7hNp8rEREZPVIuDGaX5gKwcZ/2G4iI9Eq5MJhWnE1BVho/eK7K71JEREaNlAsDM+NzF02loa2btw8c9bscEZFRIeXCAOCzF07GDFZvO+R3KSIio0JKhkFhdpgJeZlU17X6XYqIyKiQkmEAcNaEXFZt3M/h1i6/SxER8V3KhsFVZ40HdIipiAikcBiU5WcC0BON+VyJiIj/UjYMwiEDoCfqfK5ERMR/KRsGacF41yNaMxARSd0wCAXiXe/sURiIiKRsGIzJCAGwZb8uSyEikrJhUF6YRVrQaOns8bsUERHfpWwYAJwxbgw1jR1+lyEi4rsPDAMze8jM6sxsS5+2QjNbbWY7vOcCr93M7H4zqzazTWa2oM8yy7z5d5jZsj7t55rZZm+Z+83MhrqTJzKxIJNahYGIyCmtGfwCWHRc253A8865mcDz3jjANcBM73EL8ADEwwP4FnA+sBD4Vm+AePN8oc9yx7/XsJlYkEVNYwfO6fBSEUltHxgGzrlXgIbjmpcAK7zhFcD1fdofdnFrgHwzKwWuBlY75xqcc43AamCRNy3XObfGxf9HfrjPaw27KUVZdPRE+ePmAyP1liIio9JA9xmMc871/g96EBjnDZcB+/rMV+O1nay9pp/2fpnZLWZWaWaV9fX1Ayz9PR//UBmTi7K47ddvDfq1REQS2aB3IHu/6EdkO4tzbrlzrsI5V1FSUjLo18vLTOOauaUA7DjUMujXExFJVAMNg0PeJh685zqvvRYo7zPfRK/tZO0T+2kfMR+bE1+p2X2kfSTfVkRkVBloGKwCeo8IWgY81af9Ru+ooguAZm9z0rPAVWZW4O04vgp41pt21Mwu8I4iurHPa42IacXZAGzc1zSSbysiMqqcyqGlvwFeB2aZWY2Z3Qx8D/iYme0ArvTGAZ4GdgHVwE+BLwM45xqAe4A3vcd3vDa8eX7mLbMT+NPQdO3UFGSHGZ+bwdPaiSwiKSz0QTM45z5zgklX9DOvA249wes8BDzUT3slMPeD6hhOs8aPYet+3Q9ZRFJXSp+B3GteWR6N7d3EYjrfQERSk8IAKMoJE405mjt0nSIRSU0KA6AoJx2Ag0c7fa5ERMQfCgNgfG4GAK9VH/a5EhERfygMgPmT8gF4btshfwsREfGJwoD4LTDPm1JA1cEWenQbTBFJQQoDz00XT6W5o4dNNbrzmYikHoWBZ1xefL/BUd35TERSkMLAkxUOAtDRHfW5EhGRkacw8GSlxU/GblcYiEgKUhh4Mr01g/buiM+ViIiMPIWBp3czUWuXwkBEUo/CwJMVDpKZFuSNd4+/w6eISPJTGHjMjPysNP6884jfpYiIjDiFQR/nTi6gOxLTEUUiknIUBn1cPKMYgNVv67IUIpJaFAZ9XP+hMgD+adVWnysRERlZCoM+MsNBFkzK11nIIpJyFAbHuWL2OHqijjrd20BEUojC4DhzJuQCsFaHmIpIClEYHOfi6cUEDHYcavG7FBGREaMwOE44FKA0L5M9De1+lyIiMmIUBv0oL8ykprHD7zJEREaMwqAfRdnpNLZ3+12GiMiIURj0IzczxNEOXbBORFKHwqAfuRlptOhcAxFJIQqDfuSkh+iKxKht0n4DEUkNCoN+fGhSPgBrdAVTEUkRCoN+XDy9mNyMEM9tO+h3KSIiI0Jh0I9AwJhSnM26PY1+lyIiMiIGFQZm9jUz22pmW8zsN2aWYWZTzWytmVWb2WNmFvbmTffGq73pU/q8zl1ee5WZXT3IPg2JS2YU09DWTSzm/C5FRGTYDTgMzKwM+Dugwjk3FwgCS4HvA/c652YAjcDN3iI3A41e+73efJjZHG+5s4BFwL+ZWXCgdQ2V4px0Yg6aOnRUkYgkv8FuJgoBmWYWArKAA8DlwBPe9BXA9d7wEm8cb/oVZmZe+6POuS7n3LtANbBwkHUNWlFOGICGti6fKxERGX4DDgPnXC3wA2Av8RBoBtYBTc653jO2aoAyb7gM2OctG/HmL+rb3s8y72Nmt5hZpZlV1tfXD7T0U1KY3RsGWjMQkeQ3mM1EBcR/1U8FJgDZxDfzDBvn3HLnXIVzrqKkpGQ43+pYGOzTBetEJAUMZjPRlcC7zrl651wPsBK4GMj3NhsBTARqveFaoBzAm54HHOnb3s8yvinLzwRg7bs610BEkt9gwmAvcIGZZXnb/q8AtgEvAp/y5lkGPOUNr/LG8aa/4JxzXvtS72ijqcBM4I1B1DUk8rPChIMBHq+sYXNNs9/liIgMq8HsM1hLfEfwemCz91rLgW8Ad5hZNfF9Ag96izwIFHntdwB3eq+zFXiceJA8A9zqnIsOtK6hdNvlMwB4ZO0enysRERleFv9xnngqKipcZWXlsL/Poh+9wjsHW9j+z9cQDukcPRFJXGa2zjlX0d80/e/2ARZOLQTgma26NIWIJC+FwQf40kemA3D3ys0+VyIiMnwUBh+gNC+ThVMLaemKsLO+1e9yRESGhcLgFHzF25H8rDYViUiSUhicglnjxgDw6Bv7PmBOEZHEpDA4BWNzM7ihYiJtXbovsogkJ4XBKRqXm0FjezdRXdJaRJKQwuAU9V7SuqGt2+9SRESGnMLgFE0pzgagcneDz5WIiAw9hcEpOl8nn4lIElMYnKKMtCCXzCjmqQ37efC/3vW7HBGRIaUwOA3f/cQ8AO75j210R2I+VyMiMnQUBqehvDCL7yw5C4BXtg/vndZEREaSwuA0XTuvFIBtB476XImIyNBRGJymopx0Jhdl8fDre0jUy3+LiBxPYTAAl51RwuHWLo7onAMRSRIKgwE4e2I+gC5PISJJQ2EwANnhIAC7j7T7XImIyNBQGAzA3LI8AP79pZ0+VyIiMjQUBgNQXpjFZxaW8/quI3zh4Uqa2rXvQEQSm8JggP7p43NZPG88q7cd4g+bDvhdjojIoCgMBigcCnDf0vkAvFxV53M1IiKDozAYhLRggHMnF7C5ttnvUkREBkVhMEiXnVHCoaNd7NWRRSKSwBQGg3T2xPiRRdfe/6rOOxCRhKUwGKSPzBrLPUvOoqUrwsu6eJ2IJCiFwRC49uwJALykHckikqAUBkOgMDvM9JJsaps6/C5FRGRAFAZDZFxuBp09uuGNiCSmQYWBmeWb2RNm9o6ZvW1mF5pZoZmtNrMd3nOBN6+Z2f1mVm1mm8xsQZ/XWebNv8PMlg22U37ICgfZfqiFls4ev0sRETltg10zuA94xjl3JnAO8DZwJ/C8c24m8Lw3DnANMNN73AI8AGBmhcC3gPOBhcC3egMkkXxk1lhaOiNc9L0X2LZfN74RkcQy4DAwszzgw8CDAM65budcE7AEWOHNtgK43hteAjzs4tYA+WZWClwNrHbONTjnGoHVwKKB1uWX/3HBZH5580JauyLc9pv1xGK68Y2IJI7BrBlMBeqBn5vZW2b2MzPLBsY553ov1nMQGOcNlwH7+ixf47WdqP0vmNktZlZpZpX19aPvMM5LZ5Zw9+LZ7KpvY0+DTkITkcQxmDAIAQuAB5xz84E23tskBICL3xdyyH4iO+eWO+cqnHMVJSUlQ/WyQ2r+pHwAlr+iy1uLSOIYTBjUADXOubXe+BPEw+GQt/kH77n34PtaoLzP8hO9thO1J6QFkwoYkx7iN2/so+5op9/liIickgGHgXPuILDPzGZ5TVcA24BVQO8RQcuAp7zhVcCN3lFFFwDN3uakZ4GrzKzA23F8ldeWkMyMH/9N/Gqm//C7LT5XIyJyakKDXP4rwCNmFgZ2AZ8nHjCPm9nNwB7gBm/ep4HFQDXQ7s2Lc67BzO4B3vTm+45zrmGQdfnqI7PGsnjeeJ7efJC7Vm7mu5+Y53dJIiInZfHN+omnoqLCVVZW+l3GCTW1d/O1xzbwYlU9v/3ihZw3pdDvkkQkxZnZOudcRX/TdAbyMMnPCnPfZ+Kbi9bsPOJzNSIiJ6cwGEa5GWmU5Wey4vU9RHXegYiMYgqDYXbd2aUcbu3iU//+Z7oiUb/LERHpl8JgmN3y4WlMLMjkrb1NvLr9sN/liIj0S2EwzIpy0nnySxcBcKhF5x2IyOikMBgBGWlBADq6tZlIREYnhcEIyPTCYGd9q8+ViIj0T2EwAtKCxpiM+CUqtHYgIqORwmAEmBm3fnQGAGve1TkHIjL6KAxGyF+fOxGAdw60+FyJiMhfUhiMkKKcdM4pz+exN/eSqJcAEZHkpTAYQUvOmcDuI+0caNYhpiIyuigMRtDFM4oBuPXX632uRETk/RQGI2jW+DFkhYO8tbeJ6jodZioio4fCYISt/PJFpAWNn7yww+9SRESOURiMsDPH5zKnNJfth7RmICKjh8LAB6V5mWw7cJTXdZ8DERklFAY++OZ1sxmfm8Hf/3ajLmstIqOCwsAHEwuy+PqiWdQ2dfDW3ia/yxERURj4Zc6EXABtKhKRUUFh4JPygiwA3tzd4HMlIiIKA99kp4f44mXT+fPOI+xraPe7HBFJcQoDH31yQRmgTUUi4j+FgY+ml+QAUNPU4XMlIpLqFAY+CgSM7HCQ+pYuv0sRkRSnMPDZmaW5vFRVR3ck5ncpIpLCFAY+u/2KmRxo7mT5Kzv9LkVEUpjCwGeXziymYnIBP3huO09tqPW7HBFJUQoDn5kZP7zhHABuf3QDf9x0wOeKRCQVKQxGgclF2Tzyt+cTDgW49dfreefgUb9LEpEUM+gwMLOgmb1lZv/hjU81s7VmVm1mj5lZ2GtP98arvelT+rzGXV57lZldPdiaEtHFM4r53ZcvAuDHL1T7XI2IpJqhWDO4HXi7z/j3gXudczOARuBmr/1moNFrv9ebDzObAywFzgIWAf9mZsEhqCvhnDUhj2vnlfLHTQf45eu7/S5HRFLIoMLAzCYC1wI/88YNuBx4wptlBXC9N7zEG8ebfoU3/xLgUedcl3PuXaAaWDiYuhLZ//nEPHLSQ/zjU1v5X7/dSCSqQ05FZPgNds3gR8DXgd7/sYqAJudcxBuvAcq84TJgH4A3vdmb/1h7P8u8j5ndYmaVZlZZX18/yNJHp7zMNP50+6VcOXssv11XwyNr9/pdkoikgAGHgZldB9Q559YNYT0n5Zxb7pyrcM5VlJSUjNTbjrjywix+tuw8JuRlsGFfk9/liEgKCA1i2YuBj5vZYiADyAXuA/LNLOT9+p8I9B48XwuUAzVmFgLygCN92nv1XSalTS3Jpupgi99liEgKGPCagXPuLufcROfcFOI7gF9wzv134EXgU95sy4CnvOFV3jje9Becc85rX+odbTQVmAm8MdC6ksncsjy2HTjKD56t8rsUEUlyw3GewTeAO8ysmvg+gQe99geBIq/9DuBOAOfcVuBxYBvwDHCrc043BiZ+qYpgwPjJi9XaXCQiw8riP84TT0VFhausrPS7jGFX09jOdT/+L7LDIV6783K/yxGRBGZm65xzFf1N0xnIo9zEgiyunVdKbVMH7x5u87scEUlSCoMEcOOFU0gLGt/8/WYSdU1OREY3hUECmDV+DH81v4zXqo/w1cc2+F2OiCQhhUGCuOf6udxQMZGnNuzXzmQRGXIKgwSRHgry9UVnAvDq9uQ8+1pE/KMwSCDFOekU56Tz5Poav0sRkSSjMEgwkwoz2X2k3e8yRCTJKAwSzDVzSwH45Zo9PlciIslEYZBgPn/xFKaXZHPv6u3UtXT6XY6IJAmFQYIJBQPct3Q+je3dXPTdF3jhnUN+lyQiSUBhkIDmluXxxBcvJBgwbvpFJas27ve7JBFJcAqDBHXu5EL+847LmFOayx2PbeDGh95gS22z32WJSIJSGCSw8sIsfvW35/PRM8fyyvZ6PvfzN3ipqo5YTJesEJHTozBIcIXZYX56YwW//sL5BAPG537+Jp9e/jqvVR9WKIjIKdMlrJNIZ0+U366r4d7V22lo62Z8bgbf/eQ8PjprrN+licgocLJLWCsMklBnT5RVG/bzzd9voTsaY3ZpLj/863OYMyHX79JExEe6n0GKyUgLcsN55Tx9+yV84dKpVB08yk2/eNPvskRkFFMYJLEZY8dw97VzuO3ymRw82sn3n3nH75JEZJRSGKSAL1w6lYumF/HASztZqYvciUg/FAYpYExGGituWsi8sjzueHwj19z3Kt9etZV3Dh71uzQRGSUUBikiLRjgVzefz5c/Mp38zDR+8efdLPrRq/zs1V1EojG/yxMRn+loohS1pbaZf3xqC2/tbWJCXgaP3nIhk4qy/C5LRIaRjiaSvzC3LI+VX7qIn/zNfA63dfPZh9bS0R31uywR8YnCIIWZGdedPYGbLp7KniPt/OvqKr9LEhGfKAyEO6+J31v5xap6jnb2+FyNiPhBYSAA3H7FTKrrWrnihy/z2QfXsm5Po98licgIUhgIAF/72Bn86ubzmV+ez7o9jSxd/jrfXrWVxrZuv0sTkRGgo4nkLxxo7uBfnq3i92/VEnMwuzSXs8vymFaSzTVzS3XUkUiC0oXqZEDePnCU32+opXJ3I3uOtHG4Nb6WsHjeeH706fmEQ1qxFEkkJwuD0CBetBx4GBgHOGC5c+4+MysEHgOmALuBG5xzjWZmwH3AYqAd+Jxzbr33WsuAb3ov/c/OuRUDrUuGzuzSXGaXvnel06qDLdz//A7+uPkAm2pe4oaKcpaeV87Y3AwfqxSRoTDgNQMzKwVKnXPrzWwMsA64Hvgc0OCc+56Z3QkUOOe+YWaLga8QD4Pzgfucc+d74VEJVBAPlXXAuc65k+7B1JqBP5xzLH9lFw+8vJOm9h7K8jP5yuUz+NiccRTlpPtdnoicxLCcdOacO9D7y9451wK8DZQBS4DeX/YriAcEXvvDLm4NkO8FytXAaudcgxcAq4FFA61LhpeZ8T8vm07l3Vfy0xsr6OyJcufKzVz43Rf4xWvvEtXd1UQS0pBs9DWzKcB8YC0wzjl3wJt0kPhmJIgHxb4+i9V4bSdq7+99bjGzSjOrrK+vH4rSZYBCwQAfmzOON+6+knuWnMXkoiy+/Ydtum+CSIIadBiYWQ7wJPBV59z7LoPp4tughuynonNuuXOuwjlXUVJSMlQvK4MQDBifvXAKz33tw+Skh3h5ez1rdx3xuywROU2DCgMzSyMeBI8451Z6zYe8zT+9+xXqvPZaoLzP4hO9thO1SwIxM57+u0vJzQjx6eVruGvlJn65Zg9b9zeTqEesiaSSwexANuL7BBqcc1/t0/4vwJE+O5ALnXNfN7Nrgdt4bwfy/c65hd4O5HXAAu8l1hPfgdxwsvfXDuTRqepgC7c/+hb7Gtpp8y58VzImnStnj+W2y2dSlp/pc4UiqWtYzjMws0uAV4HNQO8F8f+B+H6Dx4FJwB7ih5Y2eOHxE+I7h9uBzzvnKr3XuslbFuB/O+d+/kHvrzAY3aIxx5u7G1i7q4GXttexpbaZnqhjekk2/+2cCXxywUTKC3XymshI0kln4ru3DxzlDxv3s7m2mVd3HMYMrp4zntmluZw7uYCKKQVkpAX9LlMkqQ3LSWcip6PvCWzVdS08sa6WJ9bV8Oy2gzgH4VCA86YU8OGZJSw9bxJ5WWk+VyySWrRmIL5q64rwxu4GXttxmJe317OjrpVwMMBZZbl8ZuEkJhVmMX9SPukhrTWIDJY2E0nCWLenkee2HuTJ9TXHroVUlB1m0dzx3PrRGUzQDmiRAVMYSMJp64qws76VQ0e7WLm+hj9tOQjAd5acxdLzJukieSIDoDCQhPf6ziP88LkqKvc0Upafyfc+OY9LZ+rEQ5HToTCQpOCc44l1NXz9yU04B7kZIcbmZlCSk860kmwunlHMtJJszhg7hkDA/C5XZNRRGEhS2bb/KC9W1VF3tJNDR7vY19jO9kMt9ETj3+UxGSFmjM3h0xXlnDe1kLL8TB22KoIOLZUkM2dCLnMm5L6vrbMnSnVdK28fOMrGmib+XH2EO1duPjZ9YkEm8ycVMKUoi0tmFHP+tKKRLltkVNOagSSlnmiMDfua2NfQzr6GDrbsb6bqYAs1je3EHCycWsi180q5fn4ZeZk6p0FSgzYTiXia23t4Yn0ND7++mz1H2gEozkknLzNEWjBAelqQ3IwQuZlpTC3K5pKZxZw3pZCg9kFIElAYiPRj3Z5GXt95mJrGDpo7euiJOroiUVo6IzR39LDnSBsxB4XZYa6cPZainHSmFmWzYHI+04pztJNaEo72GYj049zJBZw7ueCE01u7IrxUVcczWw6yauN+uiMxem/kFg4FGJ+bwRnjxjBjbA5njMthekkOpfkZFGSFSQvqPAhJLAoDkRPISQ9x3dkTuO7sCTjncA52HW7jrb2NVNe1UtvUQdXBFl7ZXk93NHZsuXAowKTCLHLSQ8wuzWViQSblhVlMK85mWkk2WWH9s5PRR99KkVNgZpjBjLE5zBib875pkWiMnfVt7D7SxqGjnexraGf3kXaOtHbxpy0HaGrved/85YWZhAIBAgZZ4RBjMkKMHZNOZjhEXmYa+Vlp5GWmMSE/k7L8THIzQ+Skh8hMCxK/ErzI0FMYiAxSKBhg1vgxzBo/pt/p7d0R9ja0s6u+je2HWthV34YDorEYzR09tHdHWb+3ieaOHlq7IkRj/e/HCxhkh0PkZITIzwqTlxkiPzNMZjhIXmYaJWPSmVOaS1lBJmnBAEb8tqS9j4AZoYARCMSfs8IKF3mPwkBkmGWFQ5w5Ppczx+eyeF7pSed1ztHRE6WpvYd9De0caO6kpStCm/do7YrQ0hmhqb2b5o4edh1upak9HiitXZHTrCtIeihAKBggPRSgICtMZlqQcCgQfwQD7w174+mh+BFXY9JDpAWNjLTgsUd6KEAgYATNCAQgaPEQMu85HAzE3zMtQMDMe3Bsnt7hgLcWFg4GFFYjSGEgMoqYGVnhEFnh0GlfobWhrZtd9a3sb+4kFnPEnCMSc8RijqiLP0dijqj3XN/SRXckRiQWo7MnRkNbN509Udq6IzS2x+iOxOiOes/eo8sbHwm9ARIK2rFNZxmh3vAJkB6KB0tmWpDs9BDpoQBBb60nGAi8by0o2Pc5+N703vb3zxffhBcIWDz8QvH36w283sAKHgu+EwdW776m3nW9gDFqA05hIJIkCrPDFGYXDvv7xGKOtu5IPBwiMTp7onT2xOiKRIk5RzQWv+1pfDj+HHOO7kiM9u4oXZGY18ax0PqLYedo996jJ+qob+2ivStCVyRGe3eEhrYYnZEoXT3x92/titATfe9or5HW+5+8cw4HnOiIfTPICAVJCxrhUJBQwIg6h8GxNSXz1qxisXiY9P49Yi4+XpgdZvUdlw15HxQGInJaAgFjTMboPGu7dy2od+0nGnVEYjGiXnsk2mda7L1pvePRWHye3iDrisRDrssLu67Ie6/Vu6YVi8UDIOYcRnwTlwGYYfGnY+2RaIzOyHtrXZFojGAgADhiMd4XhmbxTW0BLxx6N6XlpA/P315hICJJIxAwAhi6LuHp05kxIiKiMBAREYWBiIigMBARERQGIiKCwkBERFAYiIgICgMRESGB73RmZvXAngEuXgwcHsJyEoH6nPxSrb+gPp+uyc65kv4mJGwYDIaZVZ7o1m/JSn1OfqnWX1Cfh5I2E4mIiMJARERSNwyW+12AD9Tn5Jdq/QX1ecik5D4DERF5v1RdMxARkT4UBiIiklphYGaLzKzKzKrN7E6/6xlKZrbbzDab2QYzq/TaCs1stZnt8J4LvHYzs/u9v8MmM1vgb/WnxsweMrM6M9vSp+20+2hmy7z5d5jZMj/6cqpO0Odvm1mt91lvMLPFfabd5fW5ysyu7tOeMN99Mys3sxfNbJuZbTWz2732pPysT9Lfkf2c4zdsTv4HEAR2AtOAMLARmON3XUPYv91A8XFt/xe40xu+E/i+N7wY+BPxu/NdAKz1u/5T7OOHgQXAloH2ESgEdnnPBd5wgd99O80+fxv4+37mneN9r9OBqd73PZho332gFFjgDY8Btnt9S8rP+iT9HdHPOZXWDBYC1c65Xc65buBRYInPNQ23JcAKb3gFcH2f9odd3Bog38xKfajvtDjnXgEajms+3T5eDax2zjU45xqB1cCiYS9+gE7Q5xNZAjzqnOtyzr0LVBP/3ifUd985d8A5t94bbgHeBspI0s/6JP09kWH5nFMpDMqAfX3Gazj5HzzROOA5M1tnZrd4beOccwe84YPAOG84mf4Wp9vHZOn7bd4mkYd6N5eQhH02synAfGAtKfBZH9dfGMHPOZXCINld4pxbAFwD3GpmH+470cXXL5P6OOJU6KPnAWA68CHgAPBDX6sZJmaWAzwJfNU5d7TvtGT8rPvp74h+zqkUBrVAeZ/xiV5bUnDO1XrPdcDviK8yHurd/OM913mzJ9Pf4nT7mPB9d84dcs5FnXMx4KfEP2tIoj6bWRrx/xgfcc6t9JqT9rPur78j/TmnUhi8Ccw0s6lmFgaWAqt8rmlImFm2mY3pHQauArYQ71/vERTLgKe84VXAjd5RGBcAzX1WvxPN6fbxWeAqMyvwVruv8toSxnH7d/6K+GcN8T4vNbN0M5sKzATeIMG++2ZmwIPA2865f+0zKSk/6xP1d8Q/Z7/3pI/kg/hRB9uJ73G/2+96hrBf04gfObAR2NrbN6AIeB7YAfwnUOi1G/D/vL/DZqDC7z6cYj9/Q3x1uYf49tCbB9JH4CbiO92qgc/73a8B9PmXXp82ef/YS/vMf7fX5yrgmj7tCfPdBy4hvgloE7DBeyxO1s/6JP0d0c9Zl6MQEZGU2kwkIiInoDAQERGFgYiIKAxERASFgYiIoDAQEREUBiIiAvx/Ytwm8HresDsAAAAASUVORK5CYII=\n", 235 | "text/plain": [ 236 | "
" 237 | ] 238 | }, 239 | "metadata": { 240 | "needs_background": "light" 241 | }, 242 | "output_type": "display_data" 243 | } 244 | ], 245 | "source": [ 246 | "from matplotlib import pyplot as plt\n", 247 | "%matplotlib inline\n", 248 | "\n", 249 | "plt.plot(plt_x, plt_y)\n", 250 | "plt.show()" 251 | ] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "Python 3 (ipykernel)", 257 | "language": "python", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.8.11" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 5 275 | } 276 | -------------------------------------------------------------------------------- /3.momentum.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0\n", 49 | "\n", 50 | "#动量都初始化为0\n", 51 | "momentum_w = np.zeros(M)\n", 52 | "momentum_b = 0" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "id": "92163201", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "0.6590042695516539" 65 | ] 66 | }, 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "#预测函数\n", 74 | "def predict(x):\n", 75 | " return w.dot(x) + b\n", 76 | "\n", 77 | "\n", 78 | "predict(x[0])" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "id": "a7bb7a80", 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "0.21258140154187247" 91 | ] 92 | }, 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "#求loss,MSELoss\n", 100 | "def get_loss(x, y):\n", 101 | " pred = predict(x)\n", 102 | " loss = (pred - y)**2\n", 103 | " return loss\n", 104 | "\n", 105 | "\n", 106 | "get_loss(x[0], y[0])" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 5, 112 | "id": "8027d213", 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "text/plain": [ 118 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 119 | " 0.923131013558981)" 120 | ] 121 | }, 122 | "execution_count": 5, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "def get_gradient(x, y):\n", 129 | " global w\n", 130 | " global b\n", 131 | "\n", 132 | " eps = 1e-3\n", 133 | "\n", 134 | " loss_before = get_loss(x, y)\n", 135 | "\n", 136 | " gradient_w = np.empty(M)\n", 137 | " for i in range(M):\n", 138 | " w[i] += eps\n", 139 | " loss_after = get_loss(x, y)\n", 140 | " w[i] -= eps\n", 141 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 142 | "\n", 143 | " b += eps\n", 144 | " loss_after = get_loss(x, y)\n", 145 | " b -= eps\n", 146 | " gradient_b = (loss_after - loss_before) / eps\n", 147 | "\n", 148 | " return gradient_w, gradient_b\n", 149 | "\n", 150 | "\n", 151 | "get_gradient(x[0], y[0])" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 6, 157 | "id": "f39e0125", 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "11073.905141728206" 164 | ] 165 | }, 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "def total_loss():\n", 173 | " loss = 0\n", 174 | " for i in range(N):\n", 175 | " loss += get_loss(x[i], y[i])\n", 176 | " return loss\n", 177 | "\n", 178 | "\n", 179 | "total_loss()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 7, 185 | "id": "c371c6a4", 186 | "metadata": {}, 187 | "outputs": [ 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "0 [-4.35683357 4.61375434 1.50384418 2.91083043 2.96213566] -4.022402214293841 11044.88570074484\n", 193 | "150 [ 1.26339579 2.13768158 0.49321428 -1.51848376 3.5785042 ] -1.7176170960084935 1967.1697066194426\n", 194 | "300 [ 5.80970624e-03 3.55685572e+00 3.82822053e+00 -2.58323564e-01\n", 195 | " 1.07002133e+01] 2.0649671248152166 947.1121093652372\n", 196 | "450 [-2.26148992 3.8055673 -3.18437755 -3.87469838 -0.61715981] 0.7950908166330994 815.4883875626289\n", 197 | "600 [1.55949644 0.79231053 3.14007815 3.71824435 0.98449926] -0.5781588479929312 783.2820118211715\n", 198 | "750 [ 3.59572027 -0.74530488 0.72892403 -0.15447051 -4.33429161] -2.0781779575632005 753.5637854055667\n", 199 | "900 [ 0.3379927 2.54425898 -4.14378216 0.92339395 0.76440147] 0.8774161483117863 769.3244441314099\n", 200 | "1050 [-1.00006534 -5.83176015 -5.28599293 -5.19888383 -3.62367306] 2.5986887635740827 742.322126662018\n", 201 | "1200 [ 1.32568154 -8.98944316 8.03060836 -1.84855209 -1.93601743] -4.04468557258265 753.4751326083475\n", 202 | "1350 [ 1.82402337 -1.95342977 1.07956753 2.00744178 -4.75462661] -2.284159092719058 734.9285345602206\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "plt_x = []\n", 208 | "plt_y = []\n", 209 | "for epoch in range(1500):\n", 210 | " i = np.random.randint(N)\n", 211 | "\n", 212 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 213 | "\n", 214 | " #这是更新动量的数学公式,0.8是过去动量的权重\n", 215 | " momentum_w = 0.8 * momentum_w + gradient_w\n", 216 | " momentum_b = 0.8 * momentum_b + gradient_b\n", 217 | "\n", 218 | " #这里更新参数不再使用梯度,而是使用动量\n", 219 | " w -= momentum_w * 1e-3\n", 220 | " b -= momentum_b * 1e-3\n", 221 | "\n", 222 | " #思考一下,在时刻0,动量都是0.此时更新动量,动量就等于梯度.\n", 223 | " #也就是说,再时刻0,其实就是再用梯度下降.\n", 224 | " #时刻1,是上一个时刻的梯度乘以0.8,再加上当前时刻的梯度\n", 225 | " #所以在时刻1,差不多可以认为是梯度乘以了1.8.不过这里面两部分的梯度是在两个不同的点上评估出来的.\n", 226 | " #在时刻2,差不多等同于时刻1.往后都差不多.\n", 227 | "\n", 228 | " plt_x.append(epoch)\n", 229 | " plt_y.append(total_loss())\n", 230 | "\n", 231 | " if epoch % 150 == 0:\n", 232 | " print(epoch, momentum_w, momentum_b, total_loss())" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 8, 238 | "id": "0471a70d", 239 | "metadata": { 240 | "scrolled": true 241 | }, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhEElEQVR4nO3deZRc5Xnn8e9Ta3dV7+pGaEMSWIBlbAy0McSO4xiz2HEMOWM7JI6tOJ5wZuJMvM1xIMkJmcTJiROfsc1MjEOMPThDMJg4gfESQjAxnJmwNMYIhBASQiC1JNTqVi/qtZZn/rhvN43S2nrRre77+5xTp+99762qp1+p+lfvXc3dERGRZEvFXYCIiMRPYSAiIgoDERFRGIiICAoDEREBMnEXMFvt7e2+bt26uMsQEVk0nnjiiYPu3jHTskUbBuvWraOrqyvuMkREFg0ze+loy7SZSEREFAYiIqIwEBERFAYiIoLCQEREUBiIiAgKAxERIWFh4O78jwe28+Pne+IuRUSkpiQqDMyMWx7ayYPPHYi7FBGRmpKoMABoLmQZGC3FXYaISE1JXBi0FnIcGpmIuwwRkZqSuDBoKWQ5NKKRgYjIdAkMgxwDGhmIiLxG4sKgVSMDEZF/J3Fh0JDPMDxext3jLkVEpGYkLgyK+QzlqjNRqcZdiohIzUhcGNRn0wCMjFdirkREpHYkLgyK+RAGJYWBiMikxIVBIRfd6XNkvBxzJSIitSNxYTA5Mhie0MhARGRS4sKgPquRgYjIkRIXBlP7DDQyEBGZkrgwmNxnMDyhkYGIyKQEhoFGBiIiR0pcGBQnjyZSGIiITDluGJjZN8zsgJk9M62tzczuN7Pt4WdraDczu8nMdpjZZjO7cNpzNoX1t5vZpmntF5nZ0+E5N5mZzfcvOV395MhAO5BFRKacyMjgfwFXHdF2PfCAu28AHgjzAO8BNoTHdcDNEIUHcCPwVuBi4MbJAAnr/Oa05x35XvMql0mRS6d0aKmIyDTHDQN3fwjoO6L5auC2MH0bcM209m955BGgxcxWAFcC97t7n7sfAu4HrgrLmtz9EY+uHPetaa+1YAr5NMMaGYiITJntPoPl7r4vTO8HlofpVcDuaevtCW3Hat8zQ/uMzOw6M+sys66entnf1L6xLsPQmC5jLSIyac47kMM3+lNyPWh3v8XdO929s6OjY9av05jPclgjAxGRKbMNg1fCJh7CzwOhvRtYM2291aHtWO2rZ2hfUA11GQbHFAYiIpNmGwb3ApNHBG0C7pnW/tFwVNElwEDYnHQfcIWZtYYdx1cA94Vlg2Z2STiK6KPTXmvBNNVlOKwwEBGZkjneCmZ2B/BOoN3M9hAdFfTnwF1m9nHgJeBDYfUfAO8FdgAjwMcA3L3PzP4EeDys98fuPrlT+reIjliqB34YHguqsS7L0PjQQr+NiMiicdwwcPdfOcqiy2ZY14FPHOV1vgF8Y4b2LuC849UxnxryGYY0MhARmZK4M5AhOpro8JjugywiMimhYZClXHXGSroPsogIJDYMoq1jOtdARCSS6DDQ4aUiIpFEhkFzfRaAgVGNDEREIKFh0FLIATAwOhFzJSIitSGZYRBGBv0jGhmIiEBSw6CgMBARmS6RYdBYl8UM+rXPQEQESGgYpFNGc32W/hHtMxARgYSGAUT7DXQ0kYhIJLFh0FzIaZ+BiEiQ2DBoqc9qn4GISJDcMChkGdA+AxERIMlhoJGBiMiUxIZBU9iBrMtYi4gkOAyK+Qzu6DLWIiIkOQxyaQAOj+vKpSIiyQ2DfHQZ65EJhYGISGLDoJCLwkAjAxGRBIdBMR9tJhqZqMRciYhI/BIcBtHIYFgjAxGRBIdBbjIMNDIQEUluGITNRMPagSwikuAwyGkzkYjIpOSGwdShpdpMJCKS2DDIZVJk06ZDS0VESHAYQDQ6GFEYiIgkPAxyGQ7raCIRkbmFgZl92sy2mNkzZnaHmdWZ2Xoze9TMdpjZnWaWC+vmw/yOsHzdtNe5IbRvM7Mr5/g7nbBiPq3LUYiIMIcwMLNVwO8Ane5+HpAGrgW+AHzJ3V8HHAI+Hp7yceBQaP9SWA8z2xie9wbgKuCrZpaebV0no5DLaJ+BiAhz30yUAerNLAMUgH3Au4C7w/LbgGvC9NVhnrD8MjOz0P5tdx939xeBHcDFc6zrhDTkMzqaSESEOYSBu3cDXwReJgqBAeAJoN/dJ79u7wFWhelVwO7w3HJYf9n09hmes6AKubTOMxARYW6biVqJvtWvB1YCRaLNPAvGzK4zsy4z6+rp6Znz6zXkMzoDWUSEuW0mejfworv3uHsJ+C7wNqAlbDYCWA10h+luYA1AWN4M9E5vn+E5r+Hut7h7p7t3dnR0zKH0SCGf1rWJRESYWxi8DFxiZoWw7f8y4FngQeADYZ1NwD1h+t4wT1j+I49uQHwvcG042mg9sAF4bA51nbBiPqPNRCIiRDuAZ8XdHzWzu4GfAGXgSeAW4PvAt83s86Ht1vCUW4G/NbMdQB/REUS4+xYzu4soSMrAJ9z9lHxdL+YyjJerlCtVMulEn3IhIgk36zAAcPcbgRuPaN7JDEcDufsY8MGjvM6fAn86l1pmo5CbvHJpheZ6hYGIJFei/wI26AY3IiJAwsOgMHXlUoWBiCRbosOgYfIGNzqiSEQSLtFhUNANbkREgISHwdQ+A12SQkQSLtFhMHU0kUYGIpJwiQ6DhrpoZDCkMBCRhEt0GDTVZQEYHC3FXImISLwSHQZ12TT12TSHhifiLkVEJFaJDgOA1kKWfo0MRCThEh8GzYUc/SMKAxFJtsSHQUt9loFRbSYSkWRTGBSyHNLIQEQSTmFQyGozkYgkXuLDoLk+x8DoBNF9dkREkinxYdDekKNUcY0ORCTREh8Gq1rqAejuH425EhGR+CQ+DFaGMNg3MBZzJSIi8Ul8GKxoqQNgr0YGIpJgiQ+D9mKeXDrF3gGFgYgkV+LDIJUyTm+uY2+/NhOJSHIlPgwAVrbUaTORiCSawgA4o63AS70jcZchIhIbhQFwVkcDBw+PM6BzDUQkoRQGRGEA8MLBwzFXIiISD4UBcNZpIQwOKAxEJJkUBsCa1nqyaeOFnuG4SxERiYXCAMikU5zV0cBz+wfjLkVEJBYKg+CNq5rZvGdAVy8VkURSGATnr2mhb3iCPYd0voGIJM+cwsDMWszsbjN7zsy2mtmlZtZmZveb2fbwszWsa2Z2k5ntMLPNZnbhtNfZFNbfbmab5vpLzcZFa1sBeGRnbxxvLyISq7mODL4C/JO7nwucD2wFrgcecPcNwANhHuA9wIbwuA64GcDM2oAbgbcCFwM3TgbIqXTu6Y10NOZ5aPvBU/3WIiKxm3UYmFkz8A7gVgB3n3D3fuBq4Law2m3ANWH6auBbHnkEaDGzFcCVwP3u3ufuh4D7gatmW9dsmRlvXd/G4y/2ab+BiCTOXEYG64Ee4Jtm9qSZfd3MisByd98X1tkPLA/Tq4Dd056/J7Qdrf3fMbPrzKzLzLp6enrmUPrMLl7fxv7BMe03EJHEmUsYZIALgZvd/QJgmFc3CQHg0Vfsefua7e63uHunu3d2dHTM18tOecu6NgAee7Fv3l9bRKSWzSUM9gB73P3RMH83UTi8Ejb/EH4eCMu7gTXTnr86tB2t/ZQ7Z3kjTXUZHt+lMBCRZJl1GLj7fmC3mZ0Tmi4DngXuBSaPCNoE3BOm7wU+Go4qugQYCJuT7gOuMLPWsOP4itB2yqVSRue6Nh7efpBypRpHCSIisZjr0UT/BbjdzDYDbwb+DPhz4HIz2w68O8wD/ADYCewA/gb4LQB37wP+BHg8PP44tMXiQ51r6O4f5Z6f7o2rBBGRU84W65EznZ2d3tXVNe+v6+5c9eWHMYMffvJnMbN5fw8RkTiY2RPu3jnTMp2BfAQz49ffto7n9g/x5O7+uMsRETklFAYz+MXzV9KQz3DTA9t1zoGIJILCYAYN+Qyfvvxs/nVbD9/bvO/4TxARWeQUBkex6dK1nL+6mRvv3cLIRDnuckREFpTC4Cgy6RS//a4N9A1P8Ey37nMgIkubwuAYzlvVBKCb3ojIkqcwOIbTm+poKWTZuk9hICJLm8LgGMyMc09vZOu+obhLERFZUAqD43j9iia27R+iUtUhpiKydCkMjuPs5Y2Mlirs7ddlrUVk6VIYHEdrIQfA4Fgp5kpERBaOwuA4muoyABwe07kGIrJ0KQyOoyGEwZDCQESWMIXBcTTkw8hgXGEgIkuXwuA4GuuyAAxpn4GILGEKg+NonNxMpJGBiCxhCoPjyGdSZNPG4KjCQESWLoXBcZgZbcUcfcPjcZciIrJgFAYnoKMxz8HDE3GXISKyYBQGJ6CjIU/PkEYGIrJ0KQxOQEdjngNDY3GXISKyYBQGJ2BFcz09Q+OMlytxlyIisiAUBifgzI4iVYeXekfiLkVEZEEoDE7A+vYiADt7DsdciYjIwlAYnIDJMHihZzjmSkREFobC4AQ01mVZ0VzHtv2645mILE0KgxP0xlXNPN09EHcZIiILQmFwgt60upkXDw7rJjcisiQpDE7QBWe0AtC1qy/mSkRE5t+cw8DM0mb2pJl9L8yvN7NHzWyHmd1pZrnQng/zO8LyddNe44bQvs3MrpxrTQvhorWt5DMpHnr+YNyliIjMu/kYGXwS2Dpt/gvAl9z9dcAh4OOh/ePAodD+pbAeZrYRuBZ4A3AV8FUzS89DXfOqLpvmZ85axj9v2U+l6nGXIyIyr+YUBma2GvgF4Oth3oB3AXeHVW4DrgnTV4d5wvLLwvpXA99293F3fxHYAVw8l7oWyoc617B3YIz7n90fdykiIvNqriODLwOfA6phfhnQ7+6TF//fA6wK06uA3QBh+UBYf6p9hue8hpldZ2ZdZtbV09Mzx9JP3uUbl3NmR5G/uG8bpUr1+E8QEVkkZh0GZvY+4IC7PzGP9RyTu9/i7p3u3tnR0XGq3nZKJp3ic1eey86eYf55yyun/P1FRBbKXEYGbwPeb2a7gG8TbR76CtBiZpmwzmqgO0x3A2sAwvJmoHd6+wzPqTmXb1zOyuY67uzaffyVRUQWiVmHgbvf4O6r3X0d0Q7gH7n7h4EHgQ+E1TYB94Tpe8M8YfmP3N1D+7XhaKP1wAbgsdnWtdDSKeODnWt4eHsPew7pwnUisjQsxHkGvwt8xsx2EO0TuDW03wosC+2fAa4HcPctwF3As8A/AZ9w95q+VvQHO1cDcFfXnpgrERGZHxZ9OV98Ojs7vaurK7b3/8itj7KzZ5iHP/fzpFIWWx0iIifKzJ5w986ZlukM5Fn6YOcauvtHeWRnb9yliIjMmcJglq7YuJzGugzfeUKbikRk8VMYzFJdNs3lG5fz0PM9LNZNbSIikxQGc/CmVc30Dk/wyuB43KWIiMyJwmAOzlvVDMAzus+BiCxyCoM5eP2KJszgmb0KAxFZ3BQGc1DMZ1i/rMizewfjLkVEZE4UBnP0+pVNPLtPYSAii5vCYI7esLKJPYdGGRjV7TBFZPFSGMzROcsbAdhxYCjmSkREZk9hMEdnhzDYtv9wzJWIiMyewmCOVrXU01rI0rWrL+5SRERmTWEwR6mU8fYNHfz4+R7d/UxEFi2FwTz4xTetoHd4gsdf1OhARBYnhcE8uHBtKwBb92snsogsTgqDedDekKetmGPbfp1vICKLk8Jgnpy/upmuXYfiLkNEZFYUBvPkneecxs6Dw7ponYgsSgqDefJLF66iuT7LH/zjM5R1VJGILDIKg3nSVJfl89ecx0939/PVf30h7nJERE6KwmAe/eL5K7n6zSv5ygPb2bynP+5yREROmMJgnv3x+8+jvSHHB772b9z5+MtxlyMickIUBvOsuZDlzusu5ezlDfzRvc+yu28k7pJERI5LYbAA1rUXueUjnTjO3zy8M+5yRESOS2GwQFa21NO5to2f7u6PuxQRkeNSGCyg9e1FXjw4jLvHXYqIyDEpDBbQ2mUFhsbKHBrRXdBEpLYpDBbQOadHN755Wmcli0iNUxgsoIvWtpJNGz/e1hN3KSIixzTrMDCzNWb2oJk9a2ZbzOyTob3NzO43s+3hZ2toNzO7ycx2mNlmM7tw2mttCutvN7NNc/+1akMhl+FnN3Two+deibsUEZFjmsvIoAx81t03ApcAnzCzjcD1wAPuvgF4IMwDvAfYEB7XATdDFB7AjcBbgYuBGycDZCl46/o2dvWO0DM0HncpIiJHNeswcPd97v6TMD0EbAVWAVcDt4XVbgOuCdNXA9/yyCNAi5mtAK4E7nf3Pnc/BNwPXDXbumrNO885DYA7HtPZyCJSuzLz8SJmtg64AHgUWO7u+8Ki/cDyML0K2D3taXtC29Hal4RzTm/kPeedzpf/5Xn2HBrhjatb6GjIcelZ7TTXZ+MuT0QEmIcwMLMG4O+BT7n7oJlNLXN3N7N5O8jezK4j2sTEGWecMV8vu+C++MHzyWdS3PvUXu7q2jPVfu7pjfzee1/PO87uiLE6EZE5hoGZZYmC4HZ3/25ofsXMVrj7vrAZ6EBo7wbWTHv66tDWDbzziPZ/nen93P0W4BaAzs7ORXMmVzGf4cvXXkCl6uztH2Vv/yiPvdjHd5/s5qPfeIyr37ySKzaezmWvP426bDruckUkgWy2Z8daNAS4Dehz909Na/9LoNfd/9zMrgfa3P1zZvYLwG8D7yXaWXyTu18cdiA/AUweXfQT4CJ37zvW+3d2dnpXV9esaq8VY6UKX31wB1/78U4mKlXetLqZL/3ymzmroyHu0kRkCTKzJ9y9c8ZlcwiDtwMPA08Dk7f2+j2i/QZ3AWcALwEfcve+EB7/k2jn8AjwMXfvCq/1G+G5AH/q7t883vsvhTCYNDpR4Xub9/L5729ltFThk5dt4Dd/9kxyGZ0GIiLzZ0HCIG5LKQwmHRga48Z7tvDDZ/ZzzvJGvr6pkzVthbjLEpEl4lhhoK+eNeS0xjpu/rWL+PpHO+nuH+UP73km7pJEJCEUBjXo3RuX86l3b+DBbT38wT8+TbW6OEdvIrJ4zMt5BjL/Pva29WzdN8T/fuRlRieq3PDec2lvyMddlogsURoZ1Kh0yvjiB9/Eb//86/iHJ/fwc3/xIPc/q2scicjCUBjUMDPjv155Dvd/5uc467QGfueOJ3VPZRFZEAqDReCsjgb+6lcvZLRU4ftP7zv+E0RETpLCYJFY01bgvFVN/ODpfbqNpojMO4XBIvLLbzmDzXsGuOene+MuRUSWGIXBInLtW9Zw0dpWPvudp7hvy/64yxGRJURhsIhk0ylu+42LOW9VM5+96yl2HDgcd0kiskQoDBaZhnyGmz98IblMimv+6v/yO3c8yfc379OJaSIyJwqDRWhlSz13/6dLueINy/l/L/Tyib/7Cb9266M67FREZk0XqlvkqlXn24/v5s9+sJWqO7971bl85JK1pFJ2/CeLSKLoQnVLWCpl/Opbz+C+T7+DznVt3HjvFn7164/w0PM9DI6VXnMY6lipwvZXhtiyd4DxciXGqkWk1mhksIS4O9/p2sMf/Z8tjExEf+zNIJtKsaq1nu5Do0xUoltP1GVTvGVdGxee0cqmn1lHWzEXZ+knrffwOJWqc1pTXdyliCwaup9BwgyOlXhsZx+7eocZHC0xVq6ys2eYszqKbFzZRDpldO06xCM7e3lu/xAAhVyalS31vOvc0/jM5WfX3O03p/8//fufdHPDdzdTqTofuWQtn/j51ykURE6AwkCOasveAR7efpCeoXFePDjMg9sOsKyYZ+PKJvKZFAZUHdqKWZrqsgCUq046ZbQ35OlozHNaY/SzozFPXTbNnkMj7Do4wst9w7zUO8JLvSNk0kYxl+HQyASNdRky6RQpM5rqMrQ35GktZHHg+VeGGBork8ukqM+mGRwr88SuPvYOjGEG+UyKsVKVt6xr5XWnNXDn47tJmXHNBav45besYU1rgdMa89pnIjIDhYGcsH97oZfbH32Jl/tGKFV86hv5oZEJBkfLpAwy6RSlSnVqU9SxtBSynNFWYLxUpVSp0lLIcni8TLnquMPAaIm+4Ymp9Zvrs7QVc0yUq4yWKuQzKS5c28qZ7UUAxstVzmwv8h8uWk02nWLXwWFu+7dd3P7Iy1ObwHLpFCta6mgr5miqy9JYl6GpPsuyYo628KjPpnHg5d4Rtu4b5IWDw6xqqeONq1poqMvg7lSrjgPFfIamugyN4bXqs2nymTR12RT5TJp8Ntr19nLfCAcPj2MY+WyKukya+lya+myaTNooV5yJcpXhiTJPdw/w0PM9PP/KELt6R1jbVuBnzlrGOac3ce6KRta2FShXncPjZQZHSwyNlRkci/qqu380GvGVqq/p62I+TUM+y5a9A2zdN8jQWJn17UV+7uwOljXkaMhnKebTFHLR71CXTU2NANMpI5s2MqkUmbThDqOlCqVKlbpsOnpkUjTWZadux1qtOuWqU65WKZWd8UqF8VKV8XJ16vfsPTxBpepEPRmNQNcuK9JayFHIpcmlU1PBPTpRoXd4nL7hCXqHJ+g9PEGpUuW0xjz1uaiGfCbFeLnK6ESF4fEyIxMVth+I+rAuk+byjcs59/RGOhrzFPOvXqG/XInqGi9XGStVpl6jf2SC8XKVctUp5tKs7yiyrJgnnTJSFl0scrbcnb7hCUYmKkxUoj6ZKEefg4lyFbPQ5+kUmZRRl03TWJchm05RrlapVJ1yxRkvVxidqDJWrjBWqmAYb9/QPquaFAayIIbHyxw8PM6BoXF6wmNkosKatnrWthU5Y1mB5vrscV+nXKkyOFbG3Wkr5mb1ATx4eJzNe/rp7h+j+9Ao3f2jHBqeYHAs/CEdLXFoZIKZTsdY3pRnfXuR7v5RdveNnvR7z9bpTXW8YWUT69uLPLtvkKd29zN8AgGbS6doLmTJZ6LRFYDjHB4rMzRWZl17kc61rRTzGbp29fHsvkFKldr9nGfTRqXqM/7bnIh0yjijrUDv4XEGx8pT7XXZqH/Gy9Ef1tlIGSEYjHRq2sOMVPiZThmpcCiOe/SoutM/UmK0NP8HarQ35On6g3fP6rnHCgPd3EZmrZjPUMxnWLusOKfXyaRTc96B3d6Q513nLj/mOpWqMzhaond4grFSBXdY2VLHsmk3DRocKzFWqkQffjOcKPQmQ2VorDz1zXK8/Oo34ao7a9oKdDTkcTz6BjpRCd/mom+DmVRqavPX+vYir1/R+Jrgq1ad7v5Rnts/xJ5DI+QyKRryGZrqsjTVRz9bCjmWFXMntRmsWnVGShWGxkoMj5cZnYhGXWPhMdk3papTrkTfkg0o5DKkUzb1e45MRL9/qVKF0D+ZdPTHcPKbbT4T/Y75TJpCLk1bMUc2Hf2lNIOhsRIv940wMFJipPRq/2VSRiGfDqO3PG3FHO0NOTLpFD1D44yGvhwvVclnUxSyaYr5DPW5NCub66nPpRkrVXi6e4CXeqMR2sGh8bBpMaorPzmSmzbdWshRl02RSaUYHCuxs+cwA6MlKlWohNHh1M8wXQnT1anp6I8/gBEd4WdAU32W1a31FPOZqF/SUd/kMimy6RRVj775l6tVJsrRCODweJlSuTo1WkilLIziopFcfTYa2S0EjQxERBJC5xmIiMgxKQxERERhICIiCgMREUFhICIiKAxERASFgYiIoDAQEREW8UlnZtYDvDTLp7cDB+exnPlW6/WBapwPtV4f1H6NtV4f1FaNa929Y6YFizYM5sLMuo52Fl4tqPX6QDXOh1qvD2q/xlqvDxZHjaDNRCIigsJARERIbhjcEncBx1Hr9YFqnA+1Xh/Ufo21Xh8sjhqTuc9AREReK6kjAxERmUZhICIiyQoDM7vKzLaZ2Q4zuz7GOtaY2YNm9qyZbTGzT4b2NjO738y2h5+tod3M7KZQ92Yzu/AU1Zk2syfN7Hthfr2ZPRrquNPMcqE9H+Z3hOXrTlF9LWZ2t5k9Z2ZbzezSWupDM/t0+Pd9xszuMLO6uPvQzL5hZgfM7JlpbSfdZ2a2Kay/3cw2nYIa/zL8O282s38ws5Zpy24INW4zsyuntS/Y532mGqct+6yZuZm1h/lY+vGkuXsiHkAaeAE4E8gBTwEbY6plBXBhmG4Engc2An8BXB/arwe+EKbfC/yQ6K56lwCPnqI6PwP8HfC9MH8XcG2Y/hrwn8P0bwFfC9PXAneeovpuA/5jmM4BLbXSh8Aq4EWgflrf/XrcfQi8A7gQeGZa20n1GdAG7Aw/W8N06wLXeAWQCdNfmFbjxvBZzgPrw2c8vdCf95lqDO1rgPuITohtj7MfT/p3iuuNT/kvCpcC902bvwG4Ie66Qi33AJcD24AVoW0FsC1M/zXwK9PWn1pvAWtaDTwAvAv4XviPfHDaB3KqP8N//kvDdCasZwtcX3P4Y2tHtNdEHxKFwe7wQc+EPryyFvoQWHfEH9qT6jPgV4C/ntb+mvUWosYjlv0ScHuYfs3neLIfT8XnfaYagbuB84FdvBoGsfXjyTyStJlo8sM5aU9oi1XYHHAB8Ciw3N33hUX7gck7vMdR+5eBzwHVML8M6Hf38gw1TNUXlg+E9RfSeqAH+GbYlPV1MytSI33o7t3AF4GXgX1EffIEtdWHk062z+L+LP0G0TdtjlHLKa/RzK4Gut39qSMW1UyNx5KkMKg5ZtYA/D3wKXcfnL7Mo68KsRz3a2bvAw64+xNxvP8JyhAN02929wuAYaJNHFNi7sNW4Gqi0FoJFIGr4qjlZMTZZyfCzH4fKAO3x13LdGZWAH4P+MO4a5mtJIVBN9H2vEmrQ1sszCxLFAS3u/t3Q/MrZrYiLF8BHAjtp7r2twHvN7NdwLeJNhV9BWgxs8wMNUzVF5Y3A70LWB9E36L2uPujYf5uonColT58N/Ciu/e4ewn4LlG/1lIfTjrZPovls2Rmvw68D/hwCK1aqvEsouB/KnxuVgM/MbPTa6jGY0pSGDwObAhHc+SIdtLdG0chZmbArcBWd//v0xbdC0weUbCJaF/CZPtHw1EJlwAD04b1887db3D31e6+jqiffuTuHwYeBD5wlPom6/5AWH9Bv126+35gt5mdE5ouA56lRvqQaPPQJWZWCP/ek/XVTB9Oc7J9dh9whZm1hhHQFaFtwZjZVUSbLd/v7iNH1H5tOBprPbABeIxT/Hl396fd/TR3Xxc+N3uIDhLZTw314zHFtbMijgfRXv3niY4y+P0Y63g70VB8M/DT8Hgv0TbiB4DtwL8AbWF9A/4q1P000HkKa30nrx5NdCbRB20H8B0gH9rrwvyOsPzMU1Tbm4Gu0I//SHRERs30IfDfgOeAZ4C/JTriJdY+BO4g2odRIvqD9fHZ9BnRdvsd4fGxU1DjDqLt65Ofl69NW//3Q43bgPdMa1+wz/tMNR6xfBev7kCOpR9P9qHLUYiISKI2E4mIyFEoDERERGEgIiIKAxERQWEgIiIoDEREBIWBiIgA/x+8OMovNC6pQwAAAABJRU5ErkJggg==\n", 246 | "text/plain": [ 247 | "
" 248 | ] 249 | }, 250 | "metadata": { 251 | "needs_background": "light" 252 | }, 253 | "output_type": "display_data" 254 | } 255 | ], 256 | "source": [ 257 | "from matplotlib import pyplot as plt\n", 258 | "%matplotlib inline\n", 259 | "\n", 260 | "plt.plot(plt_x, plt_y)\n", 261 | "plt.show()" 262 | ] 263 | } 264 | ], 265 | "metadata": { 266 | "kernelspec": { 267 | "display_name": "Python 3 (ipykernel)", 268 | "language": "python", 269 | "name": "python3" 270 | }, 271 | "language_info": { 272 | "codemirror_mode": { 273 | "name": "ipython", 274 | "version": 3 275 | }, 276 | "file_extension": ".py", 277 | "mimetype": "text/x-python", 278 | "name": "python", 279 | "nbconvert_exporter": "python", 280 | "pygments_lexer": "ipython3", 281 | "version": "3.8.11" 282 | } 283 | }, 284 | "nbformat": 4, 285 | "nbformat_minor": 5 286 | } 287 | -------------------------------------------------------------------------------- /4.ada_grad.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "array([0., 0., 0., 0., 0.])" 46 | ] 47 | }, 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "#常量\n", 55 | "N, M = x.shape\n", 56 | "\n", 57 | "#变量\n", 58 | "w = np.ones(M)\n", 59 | "b = 0\n", 60 | "\n", 61 | "#初始化S为全0\n", 62 | "S_w = np.zeros(M)\n", 63 | "S_b = 0\n", 64 | "\n", 65 | "S_w" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "id": "92163201", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "0.6590042695516539" 78 | ] 79 | }, 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "#预测函数\n", 87 | "def predict(x):\n", 88 | " return w.dot(x) + b\n", 89 | "\n", 90 | "\n", 91 | "predict(x[0])" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "id": "a7bb7a80", 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "0.21258140154187247" 104 | ] 105 | }, 106 | "execution_count": 4, 107 | "metadata": {}, 108 | "output_type": "execute_result" 109 | } 110 | ], 111 | "source": [ 112 | "#求loss,MSELoss\n", 113 | "def get_loss(x, y):\n", 114 | " pred = predict(x)\n", 115 | " loss = (pred - y)**2\n", 116 | " return loss\n", 117 | "\n", 118 | "\n", 119 | "get_loss(x[0], y[0])" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 5, 125 | "id": "8027d213", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 132 | " 0.923131013558981)" 133 | ] 134 | }, 135 | "execution_count": 5, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "def get_gradient(x, y):\n", 142 | " global w\n", 143 | " global b\n", 144 | "\n", 145 | " eps = 1e-3\n", 146 | "\n", 147 | " loss_before = get_loss(x, y)\n", 148 | "\n", 149 | " gradient_w = np.empty(M)\n", 150 | " for i in range(M):\n", 151 | " w[i] += eps\n", 152 | " loss_after = get_loss(x, y)\n", 153 | " w[i] -= eps\n", 154 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 155 | "\n", 156 | " b += eps\n", 157 | " loss_after = get_loss(x, y)\n", 158 | " b -= eps\n", 159 | " gradient_b = (loss_after - loss_before) / eps\n", 160 | "\n", 161 | " return gradient_w, gradient_b\n", 162 | "\n", 163 | "\n", 164 | "get_gradient(x[0], y[0])" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 6, 170 | "id": "f39e0125", 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "11073.905141728206" 177 | ] 178 | }, 179 | "execution_count": 6, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | } 183 | ], 184 | "source": [ 185 | "def total_loss():\n", 186 | " loss = 0\n", 187 | " for i in range(N):\n", 188 | " loss += get_loss(x[i], y[i])\n", 189 | " return loss\n", 190 | "\n", 191 | "\n", 192 | "total_loss()" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 7, 198 | "id": "c371c6a4", 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "0 [0.02348118 0.05140723 0.0153264 0.01141825 0.03465209] 0.014058087368505013 10246.765269340094\n", 206 | "150 [0.00215905 0.00213966 0.0025248 0.0023185 0.00215242] 0.0024356945540773096 2658.3900946275558\n", 207 | "300 [0.00159518 0.00180842 0.0019631 0.00191729 0.00173008] 0.0019876188630781272 1600.8100316915763\n", 208 | "450 [0.00148661 0.00166541 0.00178354 0.00173962 0.00157404] 0.0017842433083993126 1186.465991584585\n", 209 | "600 [0.00139184 0.00156629 0.00167125 0.00164534 0.00149001] 0.0016826253755351525 1018.0084411998928\n", 210 | "750 [0.00132517 0.00150278 0.00158741 0.00157468 0.00142465] 0.001594262705724423 927.9759431415143\n", 211 | "900 [0.0012664 0.00144711 0.00152095 0.00150801 0.00135752] 0.0015236251237426357 861.6745940332113\n", 212 | "1050 [0.00123801 0.00141408 0.0014759 0.00146659 0.00133255] 0.001476298163704934 824.5775032016687\n", 213 | "1200 [0.00121116 0.00136314 0.001434 0.00142441 0.00128068] 0.001429332532480186 807.1329761191264\n", 214 | "1350 [0.00116843 0.0013307 0.00139645 0.00138838 0.00124704] 0.0013899172750459162 791.1949098503449\n", 215 | "1500 [0.00115129 0.00130284 0.0013622 0.00136147 0.00122016] 0.0013561117340229432 788.5470016064282\n", 216 | "1650 [0.00111086 0.00126146 0.00132726 0.00132571 0.00119364] 0.0013234756456931064 783.241050040833\n", 217 | "1800 [0.00105912 0.00119525 0.00128706 0.00128536 0.00116119] 0.0012821836202760327 776.5490903516861\n", 218 | "1950 [0.00101795 0.00115106 0.00125265 0.00124849 0.00112686] 0.0012443142251112867 767.7869135624381\n", 219 | "2100 [0.00099775 0.00112372 0.00122065 0.00122 0.00111091] 0.0012130252246838496 762.3245228610506\n", 220 | "2250 [0.00098704 0.00109155 0.00119765 0.00120011 0.00109044] 0.00119230318143276 757.1151773657198\n", 221 | "2400 [0.00096106 0.00106878 0.00117239 0.00118274 0.00108084] 0.0011689278828959015 754.8130624984661\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "plt_x = []\n", 227 | "plt_y = []\n", 228 | "for epoch in range(2500):\n", 229 | " i = np.random.randint(N)\n", 230 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 231 | "\n", 232 | " #adagrad的特点是每个变量都有属于自己的lr\n", 233 | " #要计算各个变量的lr,先要计算S\n", 234 | " #这是S的计算公式\n", 235 | " S_w = S_w + gradient_w**2\n", 236 | " S_b = S_b + gradient_b**2\n", 237 | "\n", 238 | " #计算lr的公式,其中的1e-1是原本的lr,1e-6是防止除0的\n", 239 | " lr_w = 1e-1 / ((S_w + 1e-6)**0.5)\n", 240 | " lr_b = 1e-1 / ((S_b + 1e-6)**0.5)\n", 241 | "\n", 242 | " #所以在时刻0,lr就等于梯度的倒数\n", 243 | " #梯度大的变量会有小lr,梯度小的变量会有大lr\n", 244 | " #往后的每一个时刻,都是类似动量法,考虑上一步的梯度\n", 245 | "\n", 246 | " w -= gradient_w * lr_w\n", 247 | " b -= gradient_b * lr_b\n", 248 | "\n", 249 | " plt_x.append(epoch)\n", 250 | " plt_y.append(total_loss())\n", 251 | "\n", 252 | " if epoch % 150 == 0:\n", 253 | " print(epoch, lr_w, lr_b, total_loss())" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 8, 259 | "id": "0471a70d", 260 | "metadata": { 261 | "scrolled": true 262 | }, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAcdklEQVR4nO3dfXRc9X3n8fd3ZvT8ZNmSjW0ZZEABHAIBFHACSUPM8hS6ZrtJlj274KbsctqlDTSb3YV2t+QkzTbpNiGhm9DSQArdNIQNyeKzTcKaZ9oEE0EcwDi2hcHYxtiyLduyZT3NfPeP+xt5ZEt+0Ei6Gt3P6xydufd37535/jTCH3730dwdERFJtlTcBYiISPwUBiIiojAQERGFgYiIoDAQEREgE3cB49XU1OStra1xlyEiUjJeeumlXe7ePNqykg2D1tZWOjo64i5DRKRkmNnmsZZpN5GIiCgMREREYSAiIigMREQEhYGIiKAwEBERFAYiIkICw+DFN/ewYUdP3GWIiEwrJXvR2Xh96q9/DsBbX/54zJWIiEwfiRsZiIjI0RQGIiKiMBAREYWBiIigMBARERJ4NtFFpzVSWaYMFBEpdNx/Fc3sATPbaWavFbTNNrNVZrYxvDaGdjOze8ys08xeMbMLC7ZZEdbfaGYrCtovMrNXwzb3mJlNdCcLpc3I5nwyP0JEpOScyP8i/y1w9RFtdwBPunsb8GSYB7gGaAs/twD3QhQewF3AJcDFwF35AAnr/PuC7Y78rAmVSkEuN5mfICJSeo4bBu7+HLDniOblwINh+kHg+oL2hzzyAjDLzOYDVwGr3H2Pu3cDq4Crw7J6d3/B3R14qOC9JkXKjJxrZCAiUmi8O8/nufv2MP0uMC9MLwS2FKy3NbQdq33rKO2jMrNbzKzDzDq6urrGVXg6ZWQVBiIiIxR9JDX8H/2U/Ovq7ve5e7u7tzc3j/pM5+Oqq8ywr3dwgisTESlt4w2DHWEXD+F1Z2jfBiwqWK8ltB2rvWWU9klz6uwatnT36iCyiEiB8YbBSiB/RtAK4LGC9pvCWUVLgX1hd9LjwJVm1hgOHF8JPB6W7TezpeEsopsK3mtSNFaXMZh1+gazk/kxIiIl5bjXGZjZ94CPAk1mtpXorKAvA4+Y2c3AZuBTYfUfA9cCnUAv8GkAd99jZl8EfhHW+4K75w9K/weiM5aqgJ+En0lTnonyb2AoR03FZH6SiEjpOG4YuPu/HmPRslHWdeDWMd7nAeCBUdo7gHOPV8dEGQ6DrM4vFRHJS9yluOXpwyMDERGJJC8MwsigX2EgIjIseWGgkYGIyFGSFwZhZDCoYwYiIsMSFwb56wtefrs75kpERKaPxIXB+nd7APjyT34dcyUiItNH4sLgty6KLnj+zLK2mCsREZk+EhcGteXRpRVVZemYKxERmT4SFwbpdPTsnCE91EBEZFjiwiCTyoeBblQnIpKX2DDIZhUGIiJ5iQuDtEYGIiJHSVwYmFn0tDOFgYjIsMSFAUSjA40MREQOS2QYZFLGkG5HISIyLJFhoJGBiMhIiQyDsnRKxwxERAokMgw0MhARGSmRYZBJGVldgSwiMiyRYaCRgYjISIkMg4yuMxARGSGRYZBOGUO6HYWIyLBEhkFZOqW7loqIFEhkGOh2FCIiIyUyDDI6gCwiMkIiw0AjAxGRkRIZBplUSgeQRUQKJDIMousMdABZRCQvkWGQSeuYgYhIoWSGgY4ZiIiMkMgwSOuYgYjICIkMA40MRERGKioMzOwPzWytmb1mZt8zs0ozW2xmq82s08y+b2blYd2KMN8ZlrcWvM+doX29mV1VZJ+OK53WAWQRkULjDgMzWwh8Bmh393OBNHAD8BXgbnc/E+gGbg6b3Ax0h/a7w3qY2ZKw3XuBq4FvmVl6vHWdCI0MRERGKnY3UQaoMrMMUA1sBz4G/CAsfxC4PkwvD/OE5cvMzEL7w+7e7+5vAp3AxUXWdUzplDEwpJGBiEjeuMPA3bcBfwG8TRQC+4CXgL3uPhRW2wosDNMLgS1h26Gw/pzC9lG2GcHMbjGzDjPr6OrqGm/pzK2r5J19feQ0OhARAYrbTdRI9H/1i4EFQA3Rbp5J4+73uXu7u7c3NzeP+32qy6O9UIoCEZFIMbuJrgDedPcudx8EfghcCswKu40AWoBtYXobsAggLG8Adhe2j7LNpEhZ9JpzxYGICBQXBm8DS82sOuz7Xwa8DjwNfCKsswJ4LEyvDPOE5U+5u4f2G8LZRouBNuDFIuo6rqhchYGISF7m+KuMzt1Xm9kPgJeBIeCXwH3APwAPm9mfhrb7wyb3A39nZp3AHqIziHD3tWb2CFGQDAG3unt2vHWdiJAFKAtERCLjDgMAd78LuOuI5k2McjaQu/cBnxzjfb4EfKmYWk5GKqSBwkBEJJLIK5B1zEBEZKSEhoGOGYiIFEpkGBw+gBxzISIi00QiwyA1fABZaSAiAokNA40MREQKJTIMTAeQRURGSGgY6NRSEZFCiQwDHTMQERkpoWGgYwYiIoUSGgbRq55pICISSWQY7DowAMDKX03qzVFFREpGIsPgX17YAmhkICKSl8gwOKWhkopMin6FgYgIkNAwAKityHCgf+j4K4qIJEBiw6C6Ik3vwKQ+NkFEpGQkNgxqyjMc1MhARARIcBhUl2tkICKSl9gwqMik6R9SGIiIQILDoDyT0qmlIiJBosOgR8cMREQAyMRdQFxWvb4j7hJERKaNxI4MLj+rGYC+QR03EBFJbBi8f1EjAL/asjfeQkREpoHEhsHHzp4LwM837Y65EhGR+CU2DM6eXwdAOv8MTBGRBEtsGGRShhkMZHV6qYhIYsPAzChPpxQGIiIkOAyAKAx04ZmISMLDQFchi4gACQ+DsnSKQe0mEhFJdhhoZCAiEikqDMxslpn9wMx+bWbrzOyDZjbbzFaZ2cbw2hjWNTO7x8w6zewVM7uw4H1WhPU3mtmKYjt1osrSxmDWp+rjRESmrWJHBt8AfuruZwPnA+uAO4An3b0NeDLMA1wDtIWfW4B7AcxsNnAXcAlwMXBXPkAmW1k6xYtv7ZmKjxIRmdbGHQZm1gB8BLgfwN0H3H0vsBx4MKz2IHB9mF4OPOSRF4BZZjYfuApY5e573L0bWAVcPd66TkZP3xBdPf1T8VEiItNaMSODxUAX8B0z+6WZfdvMaoB57r49rPMuMC9MLwS2FGy/NbSN1X4UM7vFzDrMrKOrq6uI0iM3fGARoJvViYgUEwYZ4ELgXne/ADjI4V1CALi7AxO2U97d73P3dndvb25uLvr9GqrLgGiEICKSZMWEwVZgq7uvDvM/IAqHHWH3D+F1Z1i+DVhUsH1LaBurfdLVVUaPc+jpG5yKjxMRmbbGHQbu/i6wxczOCk3LgNeBlUD+jKAVwGNheiVwUziraCmwL+xOehy40swaw4HjK0PbpKuv1MhARASKf9LZHwDfNbNyYBPwaaKAecTMbgY2A58K6/4YuBboBHrDurj7HjP7IvCLsN4X3H1KTvGpC2GwXyMDEUm4osLA3dcA7aMsWjbKug7cOsb7PAA8UEwt41Ffld9NpJGBiCRboq9AHh4ZHNLIQESSLdFh0BjOJtqxX9caiEiyJToMqsuj3UR3P7FBN6wTkURLdBgAfLitCYDOnQdirkREJD6JD4NbLz8TgO6DAzFXIiISn8SHweyacgD29CoMRCS5Eh8Gs8JBZI0MRCTJEh8GjdXRyGDXAYWBiCRX4sOgLJ1iQUMlG3f2xF2KiEhsEh8GAKfOqdZzDUQk0RQGwJzaCnZrN5GIJJjCAJhdXU63ziYSkQRTGBA916Cnb4joXnoiIsmjMCC6Yd1Qzukb1C0pRCSZFAboiWciIgoDDofBfj3XQEQSSmHA4cdfvrP3UMyViIjEQ2EALJhVBcB9z22KuRIRkXgoDICzTqkjZVCWtrhLERGJhcIguKytmT26WZ2IJJTCIGiqKdfN6kQksRQGwSkNlby7v49cTheeiUjyKAyCxupysjnnwIBOLxWR5FEYBIcvPFMYiEjyKAyCunCtga5CFpEkUhgEjeHxl292HYy5EhGRqacwCNpbZ1NfmeH3vvuyTjEVkcRRGATlmRTXnb8AgB/9clvM1YiITC2FQYE/uW4JAJt3a1eRiCSLwqBAZVmahbOqeOjnm/WgGxFJFIXBES45fTYAL23ujrkSEZGpozA4wm3L2gBYv6Mn5kpERKZO0WFgZmkz+6WZ/d8wv9jMVptZp5l938zKQ3tFmO8My1sL3uPO0L7ezK4qtqZitDRWkzL43x1b4yxDRGRKTcTI4DZgXcH8V4C73f1MoBu4ObTfDHSH9rvDepjZEuAG4L3A1cC3zCw9AXWNSzplXHRaI2u27GVnT19cZYiITKmiwsDMWoCPA98O8wZ8DPhBWOVB4PowvTzME5YvC+svBx529353fxPoBC4upq5ifWH5uQA8uW5nnGWIiEyZYkcGXwf+M5AL83OAve6ev8HPVmBhmF4IbAEIy/eF9YfbR9lmBDO7xcw6zKyjq6uryNLH1jqnBoDuXl18JiLJMO4wMLPrgJ3u/tIE1nNM7n6fu7e7e3tzc/OkfU5lWfRr+fOfrtctrUUkEYoZGVwK/HMzewt4mGj30DeAWWaWCeu0APnLebcBiwDC8gZgd2H7KNvEwsw4r6UBgLf39MZZiojIlBh3GLj7ne7e4u6tRAeAn3L3fwM8DXwirLYCeCxMrwzzhOVPeXRl10rghnC20WKgDXhxvHVNlM98LDrFdMd+HUQWkZlvMq4z+C/AZ82sk+iYwP2h/X5gTmj/LHAHgLuvBR4BXgd+Ctzq7tlJqOukzKuvBODxtTtirkREZPJljr/K8bn7M8AzYXoTo5wN5O59wCfH2P5LwJcmopaJcu7CegB69eQzEUkAXYE8BjPjvQvqtZtIRBJBYXAMTbUVPL1+8k5hFRGZLhQGx1BbEe1F232gP+ZKREQml8LgGG784GkA/P3qt3VLaxGZ0RQGx/C+hQ1kUsZXV23gCd2aQkRmMIXBMdRUZHjis78BwOvv7I+5GhGRyaMwOI7WphrK0kbH5j1xlyIiMmkUBifg9KZant+4i4P9uuZARGYmhcEJWH7BAgDe2Xso5kpERCaHwuAEfKA1ei7y69t13EBEZiaFwQk4v2UWZWljzZa9cZciIjIpFAYnoDyTYsn8er7zT2/FXYqIyKRQGJygM+bWAtDTNxhzJSIiE09hcIKufu8pALpXkYjMSAqDE3TJ6XNoqCrjv/7oVYayueNvICJSQhQGJ6ihqoybL1vM/r4hdh0YiLscEZEJpTA4CfkH3ryzT9cbiMjMojA4Ca1zagD41tOd5HK6i6mIzBwKg5Nw2pwaasrTPLFuJ7d/f03c5YiITBiFwUlIp4yf3bGM3zx/ASt/9Q6PrdkWd0kiIhNCYXCSGqrL+ItPnkdZ2vjhywoDEZkZFAbjUJFJs/z9C3l2Qxe/flf3KxKR0qcwGKePntUMwNVff56t3b0xVyMiUhyFwTh9/H3z+cRFLQA8u0FXJYtIaVMYjJOZ8afXnwvAU3o+soiUOIVBESrL0gAM6poDESlxCoMiffSsZp7b0IW7AkFESpfCoEinza4G4KXN3TFXIiIyfgqDIv27D58OwKO65kBESpjCoEiLZlfTVFvB9158m217dQM7ESlNCoMJ8J3f/gAAl375Kb7xxEayOqAsIiVm3GFgZovM7Gkze93M1prZbaF9tpmtMrON4bUxtJuZ3WNmnWb2ipldWPBeK8L6G81sRfHdmlrva2ngP111FgB3P7GBD3/lKd7erQvRRKR0FDMyGAL+o7svAZYCt5rZEuAO4El3bwOeDPMA1wBt4ecW4F6IwgO4C7gEuBi4Kx8gpeTWy8/kjf9+LVecM4939vXxu//rJd3mWkRKxrjDwN23u/vLYboHWAcsBJYDD4bVHgSuD9PLgYc88gIwy8zmA1cBq9x9j7t3A6uAq8dbV5zSKePbK9q56zeX8Pr2/Vxx97Ps7OmLuywRkeOakGMGZtYKXACsBua5+/aw6F1gXpheCGwp2GxraBurvWTduPQ0blvWxqaug9z56KtxlyMiclxFh4GZ1QKPAre7+4hbeHp0JdaE7Ssxs1vMrMPMOrq6pu/9gDLpFH/4z97DRac18syGLroP6pnJIjK9FRUGZlZGFATfdfcfhuYdYfcP4TV/455twKKCzVtC21jtR3H3+9y93d3bm5ubiyl9Sny4rYlszrngi6toveMf+MeNu+IuSURkVMWcTWTA/cA6d/9awaKVQP6MoBXAYwXtN4WzipYC+8LupMeBK82sMRw4vjK0lbw/+Fgb969o59bLzwDgs4+s0UFlEZmWMkVseylwI/Cqma0JbX8EfBl4xMxuBjYDnwrLfgxcC3QCvcCnAdx9j5l9EfhFWO8L7r6niLqmjXTKWHbOPJadM4+2uXXc/v01/Pi17Vx33oK4SxMRGcFK9QZr7e3t3tHREXcZJ6ynb5CP/o9n2NM7wEff08ztV7yH8xfNirssEUkQM3vJ3dtHW6YrkKdIXWUZj/3+pfzWBS08vb6L5d/8J/YdGoy7LBERQGEwpVoaq/nqp87nlo9EN7e74mvP8mc/Wce67XqOsojES2EQgz+69hy+ccP7KU+n+OtnN3HNN57nur98XrfBFpHY6JhBzDp39vDwi1t4pGML+/uGaJtbyw0Xn8rvXNpKdMKWiMjEONYxA4XBNLG/b5C/X/023129mS17DvHbH2rlT65bQiqlQBCRiaEDyCWgvrKM3/2NM3jmc5dzXksDf/uzt7jkz55kw46euEsTkQRQGEwz6ZRx77+9iA+dMYeunn4+fs/z3Hj/aj6/ci0H+4fiLk9EZijtJprGNnUd4JtPv8HP39jFO/uiu5+e19LAbcvaWHbOvONsLSIyko4ZzAD/55fb+PoTG9h3aJDu3kE+ft58PnflWSxuqom7NBEpEQqDGaR/KMtfPtnJ/3y6E4B59RX8q/ZF3H7Fe3SwWUSOSWEwA73RdYAH/vFNfvzqdrp7Bzl1djVnn1LHmXNrWXr6HD7c1qRTU0VkBIXBDJbLOX/13Bs8u76LrgP9bOo6CMAZzTW8b2ED/+26JdRUZKgsS8dcqYjETWGQIPv7Brnz0Vd5fmMX+/sOn33UflojHzqzicvObKIsbZRnUpxzSr12LYkkiMIgoV7YtJu17+xnZ08fP+vczavb9o1Y3tJYxamzq6mvLMNxLjqtkSXzG0gZ7Ds0SCad4rIzm6gq16hCZCY4VhgU8zwDmeaWnj6HpafPGZ7f2t3L+nd7GMw6XT19PLthF929A2zY0cO+Q4M8vnbHUe9RkUlx5txa6ivLWDCrirn1FSyYVcXZp9Sxq6efnENzXQUtjVU011VQltalKyKlSGGQIC2N1bQ0Vg/P3/jB1uFpd2ftO/s5NJgdfhpb70CW5zfuorPrAPsPDfL8xi52Hxwge4yntb13QT1nnVLHqbOraaqtoLYiw3vm1TG/oZL6qjLS2i0lMi0pDAQAM+PchQ1HtV9+9twR8+7O5t29vLX7IHWVZdRWZHhz10F2Heinc+cB1r/bw8/f2M0PXz76MdYpg/kNVTTVljOrupyGqjLqqzLUVZZRV5mhvrKM05tqaKwpp7o8TcqMirIUlWVpqsrSGnWITCKFgZwUM6O1qYbWgovdzjql7qj1BoZy7O0dYGdP//BuqN0HBtja3cue3kG6ewd4a/dBevqG6OkbZDB7/GNX6ZRRU56mtiJDOm2UpVKkUkYmZVSVR4FRXZ4eDo/q8jSVBe1VZdGy6vIMVeUpKjNpMukU6fAe6ZSRSeenU4fbhl9TpNM2ol2n78pMoTCQSVGeSTG3vpK59ZWjjjgKuTv9Qzl2Hxxg866DbOnuJRX+ke0bytE/mKVvMMuhwSw9fUP0DmQZyuYYyjnuMJDNRcsHsuztHRxet3cgeh0Yyk1aP1NGFBL50EgfER4FYZIPm9GDpqC94D3SZmTSKcrTUfDk3MnmnJw7uRykUpCyaN3876x/KMdQNkdP3xD9Q1kqMmnKMykqMinKM9HoygwMC69Qlk5RU5EZc3nhfJ6ZFSyL5lMGNRWZ8LsxUqEP6RSkUynSBfWmw/JU2M6G+xGtU1WejgK34HPNjvzcwzUStrOjaotWKKzzqO2P6m/BOgkJfIWBxM7MqCxLs3BWFQtnVU34+2dzzqEQFn0FIdE3mCWbc4ZyTjaXYyjrBfMF7fn57BjtI5aP0j7m+0dtI+s4/DqYzYXXaN1szo/4RxRyHl1rkg8JgPJM9I9ofVWGikyagaEc/UPZ8JrDDNwh544TTQ8M5Tg0mJ3w3/1MMlZQjBZM+XYK548ITcYI2/w6w59ZsCxlxuyach79vQ9NeP8UBjLjpVNGbUWG2gr9uR9LPqDyZ5u7g+PhNRrB5cMjWuHo5Vl3evujUMl5PqQYHs3kg65whJMPpmzu8PRg1ukfyjKU9RGfffRnjlKjj6w13z56nwrmfez3ZcT7HbGe+8jPPmI9OLKefD+jDkXvMXpdjPK7r62cnL9j/dchIgBhV9YEXFNy9CEkKQE6PUNERBQGIiKiMBARERQGIiKCwkBERFAYiIgICgMREUFhICIilPDDbcysC9g8zs2bgF0TWE4pUJ9nvqT1F9Tnk3WauzePtqBkw6AYZtYx1tN+Zir1eeZLWn9BfZ5I2k0kIiIKAxERSW4Y3Bd3ATFQn2e+pPUX1OcJk8hjBiIiMlJSRwYiIlJAYSAiIskKAzO72szWm1mnmd0Rdz0TyczeMrNXzWyNmXWEttlmtsrMNobXxtBuZnZP+D28YmYXxlv9iTGzB8xsp5m9VtB20n00sxVh/Y1mtiKOvpyoMfr8eTPbFr7rNWZ2bcGyO0Of15vZVQXtJfO3b2aLzOxpM3vdzNaa2W2hfUZ+18fo79R+z9Gj2mb+D5AG3gBOB8qBXwFL4q5rAvv3FtB0RNufA3eE6TuAr4Tpa4GfED1udSmwOu76T7CPHwEuBF4bbx+B2cCm8NoYphvj7ttJ9vnzwOdGWXdJ+LuuABaHv/d0qf3tA/OBC8N0HbAh9G1GftfH6O+Ufs9JGhlcDHS6+yZ3HwAeBpbHXNNkWw48GKYfBK4vaH/IIy8As8xsfgz1nRR3fw7Yc0TzyfbxKmCVu+9x925gFXD1pBc/TmP0eSzLgYfdvd/d3wQ6if7uS+pv3923u/vLYboHWAcsZIZ+18fo71gm5XtOUhgsBLYUzG/l2L/wUuPA/zOzl8zsltA2z923h+l3gXlheib9Lk62jzOl778fdok8kN9dwgzss5m1AhcAq0nAd31Ef2EKv+ckhcFMd5m7XwhcA9xqZh8pXOjR+HJGn0echD4G9wJnAO8HtgNfjbWaSWJmtcCjwO3uvr9w2Uz8rkfp75R+z0kKg23AooL5ltA2I7j7tvC6E/gR0ZBxR373T3jdGVafSb+Lk+1jyffd3Xe4e9bdc8DfEH3XMIP6bGZlRP8wftfdfxiaZ+x3PVp/p/p7TlIY/AJoM7PFZlYO3ACsjLmmCWFmNWZWl58GrgReI+pf/gyKFcBjYXolcFM4C2MpsK9g+F1qTraPjwNXmlljGHZfGdpKxhHHd/4F0XcNUZ9vMLMKM1sMtAEvUmJ/+2ZmwP3AOnf/WsGiGfldj9XfKf+e4z6SPpU/RGcdbCA64v7Hcdczgf06nejMgV8Ba/N9A+YATwIbgSeA2aHdgG+G38OrQHvcfTjBfn6PaLg8SLQ/9Obx9BH4HaKDbp3Ap+Pu1zj6/HehT6+E/9jnF6z/x6HP64FrCtpL5m8fuIxoF9ArwJrwc+1M/a6P0d8p/Z51OwoREUnUbiIRERmDwkBERBQGIiKiMBARERQGIiKCwkBERFAYiIgI8P8B1M/TYWHrGh0AAAAASUVORK5CYII=\n", 267 | "text/plain": [ 268 | "
" 269 | ] 270 | }, 271 | "metadata": { 272 | "needs_background": "light" 273 | }, 274 | "output_type": "display_data" 275 | } 276 | ], 277 | "source": [ 278 | "from matplotlib import pyplot as plt\n", 279 | "%matplotlib inline\n", 280 | "\n", 281 | "plt.plot(plt_x, plt_y)\n", 282 | "plt.show()" 283 | ] 284 | } 285 | ], 286 | "metadata": { 287 | "kernelspec": { 288 | "display_name": "Python 3 (ipykernel)", 289 | "language": "python", 290 | "name": "python3" 291 | }, 292 | "language_info": { 293 | "codemirror_mode": { 294 | "name": "ipython", 295 | "version": 3 296 | }, 297 | "file_extension": ".py", 298 | "mimetype": "text/x-python", 299 | "name": "python", 300 | "nbconvert_exporter": "python", 301 | "pygments_lexer": "ipython3", 302 | "version": "3.8.11" 303 | } 304 | }, 305 | "nbformat": 4, 306 | "nbformat_minor": 5 307 | } 308 | -------------------------------------------------------------------------------- /6.ada_delta.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0\n", 49 | "\n", 50 | "#初始化S为全0\n", 51 | "S_w = np.zeros(M)\n", 52 | "S_b = 0\n", 53 | "\n", 54 | "#初始化delta为全0\n", 55 | "delta_w = np.zeros(M)\n", 56 | "delta_b = 0" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "id": "92163201", 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "0.6590042695516539" 69 | ] 70 | }, 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "#预测函数\n", 78 | "def predict(x):\n", 79 | " return w.dot(x) + b\n", 80 | "\n", 81 | "\n", 82 | "predict(x[0])" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 4, 88 | "id": "a7bb7a80", 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "0.21258140154187247" 95 | ] 96 | }, 97 | "execution_count": 4, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "#求loss,MSELoss\n", 104 | "def get_loss(x, y):\n", 105 | " pred = predict(x)\n", 106 | " loss = (pred - y)**2\n", 107 | " return loss\n", 108 | "\n", 109 | "\n", 110 | "get_loss(x[0], y[0])" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "id": "8027d213", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 123 | " 0.923131013558981)" 124 | ] 125 | }, 126 | "execution_count": 5, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "def get_gradient(x, y):\n", 133 | " global w\n", 134 | " global b\n", 135 | "\n", 136 | " eps = 1e-3\n", 137 | "\n", 138 | " loss_before = get_loss(x, y)\n", 139 | "\n", 140 | " gradient_w = np.empty(M)\n", 141 | " for i in range(M):\n", 142 | " w[i] += eps\n", 143 | " loss_after = get_loss(x, y)\n", 144 | " w[i] -= eps\n", 145 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 146 | "\n", 147 | " b += eps\n", 148 | " loss_after = get_loss(x, y)\n", 149 | " b -= eps\n", 150 | " gradient_b = (loss_after - loss_before) / eps\n", 151 | "\n", 152 | " return gradient_w, gradient_b\n", 153 | "\n", 154 | "\n", 155 | "get_gradient(x[0], y[0])" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 6, 161 | "id": "f39e0125", 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "data": { 166 | "text/plain": [ 167 | "11073.905141728206" 168 | ] 169 | }, 170 | "execution_count": 6, 171 | "metadata": {}, 172 | "output_type": "execute_result" 173 | } 174 | ], 175 | "source": [ 176 | "def total_loss():\n", 177 | " loss = 0\n", 178 | " for i in range(N):\n", 179 | " loss += get_loss(x[i], y[i])\n", 180 | " return loss\n", 181 | "\n", 182 | "\n", 183 | "total_loss()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 7, 189 | "id": "c371c6a4", 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "0 [9.99999848e-07 9.99999845e-07] 9.999998984693441e-07 11059.213325989294\n", 197 | "500 [2.66451164e-06 2.89372373e-06] 3.595841972186342e-06 7623.402923852171\n", 198 | "1000 [1.49058551e-06 1.77039119e-06] 2.3569819378456833e-06 5116.6673548382705\n", 199 | "1500 [1.54060667e-07 1.52269411e-07] 4.854100489197519e-07 3453.7618700421162\n", 200 | "2000 [4.5399257e-06 3.6700749e-06] 4.182369642302336e-06 2259.06518603289\n", 201 | "2500 [1.65797772e-06 1.05985930e-06] 1.3383495167537935e-06 1438.5906667372406\n", 202 | "3000 [2.96914906e-07 6.56168224e-07] 6.131919613839703e-07 1051.9695231637406\n", 203 | "3500 [1.48431059e-06 1.85664992e-06] 3.2354643896326765e-06 886.751528580343\n", 204 | "4000 [3.06717189e-06 1.53765326e-06] 7.968217077828009e-06 826.4649373166238\n", 205 | "4500 [1.83228209e-06 1.98744616e-06] 3.0032937384269872e-06 794.0081885344441\n", 206 | "5000 [5.05157589e-07 6.08365156e-07] 6.43608532358165e-07 798.5606607279586\n" 207 | ] 208 | } 209 | ], 210 | "source": [ 211 | "plt_x = []\n", 212 | "plt_y = []\n", 213 | "\n", 214 | "for epoch in range(5500):\n", 215 | " i = np.random.randint(N)\n", 216 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 217 | "\n", 218 | " #ada_delta算法不需要设定超参数lr\n", 219 | " #他需要维持两个变量,delta和S\n", 220 | "\n", 221 | " #S的计算和rmsprop完全一致\n", 222 | " S_w = 0.2 * S_w + 0.8 * gradient_w**2\n", 223 | " S_b = 0.2 * S_b + 0.8 * gradient_b**2\n", 224 | "\n", 225 | " #计算lr的公式,这里的1e-6是为了防止除0\n", 226 | " lr = (delta_w + 1e-6) / (S_w + 1e-6)\n", 227 | " gradient_w = lr**0.5 * gradient_w\n", 228 | "\n", 229 | " lr = (delta_b + 1e-6) / (S_b + 1e-6)\n", 230 | " gradient_b = lr**0.5 * gradient_b\n", 231 | "\n", 232 | " #更新参数\n", 233 | " w -= gradient_w\n", 234 | " b -= gradient_b\n", 235 | "\n", 236 | " #更新delta,这里的两个系数和计算S时用的要一样\n", 237 | " delta_w = 0.2 * delta_w + 0.8 * gradient_w**2\n", 238 | " delta_b = 0.2 * delta_b + 0.8 * gradient_b**2\n", 239 | "\n", 240 | " #思考一下,在时刻0,S就是梯度的平方乘以0.8\n", 241 | " #所以在一开始的时候,S是比较大的.但delta还是0\n", 242 | " #所以一开始的时候lr是比较大的.\n", 243 | " #delta更新为变量更新量的平方*0.8\n", 244 | " #所以delta当中差不多相当于存储了变量更新量的历史信息\n", 245 | " #所以最后的lr,应该是取两者的比值\n", 246 | "\n", 247 | " plt_x.append(epoch)\n", 248 | " plt_y.append(total_loss())\n", 249 | "\n", 250 | " if epoch % 500 == 0:\n", 251 | " print(epoch, delta_w[:2], delta_b, total_loss())" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 8, 257 | "id": "0471a70d", 258 | "metadata": { 259 | "scrolled": true 260 | }, 261 | "outputs": [ 262 | { 263 | "data": { 264 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjAklEQVR4nO3deZxU1Z338c+vqnpfaWh2ZBEEcUPoNIgbigoaR5xRJ24jMT6SiSZjYp5MNJmJGSfJM0kmUTNGHYMLZuI2JkZcoiJxV8DGDQG1mx1soKHphqa36qrz/FEH0xIQ6O12VX3fr1e96t5zl/4dLfrbdztlzjlERCS9hYIuQEREgqcwEBERhYGIiCgMREQEhYGIiACRoAvoqH79+rkRI0YEXYaISNJYunTpNudc6b6WJW0YjBgxgoqKiqDLEBFJGma2bn/LdJpIREQUBiIiojAQEREUBiIigsJARERQGIiICAoDEREhDcPgZ89+yKMVG4IuQ0SkV0nah846YmdzlDteWgXAzKMHUpidEXBFIiK9Q1odGRRmZ3DHZRMB+O2b+30QT0Qk7aRVGACceHg/AH7+3EcBVyIi0nukXRgU5f7l1FDV1oYAKxER6T3SLgwA7vSniv5nkU4ViYhAmobB2ccMAuD+N9YGW4iISC+RlmEAMGvCYAA27mgMuBIRkeClbRh89ZTDAVi8ujbgSkREgpe2YTBuYAFFORksXrM96FJERAKXtmEQChnlI0tYvEZHBiIiaRsGAJNHlrBueyPV9U1BlyIiEqi0DoMpo/oCum4gIpLWYXDkoEIKsiO6biAiaS+twyAcMkryMnloyQZicRd0OSIigUnrMACYcdRAAN7bWBdsISIiAUr7MLjyxBEAvLu+LtA6RESClPZhMKgoh1Gledz3xhqao7GgyxERCcQBw8DM7jWzrWb2Qbu2EjNbYGaV/r2Pbzcz+5WZVZnZ+2Y2sd02s/36lWY2u137JDNb5rf5lZlZV3fyQH5w7ng21Dbxw/nLe/pHi4j0CgdzZHA/MHOvthuAhc65McBCPw9wNjDGv+YAd0IiPICbgMlAOXDTngDx61zdbru9f1a3mza2PwDLP9nZ0z9aRKRXOGAYOOdeAfa+EX8WMM9PzwPOb9f+gEtYBBSb2SBgBrDAOVfrnNsBLABm+mWFzrlFzjkHPNBuXz3qkvJhrK9tJK67ikQkDXX0msEA51y1n94MDPDTQ4D23za/0bd9XvvGfbT3uLLhJdQ3RXlh5ZYgfryISKA6fQHZ/0XfI39Om9kcM6sws4qampou3feMoxO3mD6zrPoAa4qIpJ6OhsEWf4oH/77Vt28ChrVbb6hv+7z2ofto3yfn3N3OuTLnXFlpaWkHS9+3/KwI5x47iNdXbae1Ld6l+xYR6e06GgbzgT13BM0GnmjXfoW/q2gKUO9PJz0HnGVmffyF47OA5/yynWY2xd9FdEW7ffW4v5s4hJpdLfzwSd1VJCLp5WBuLX0IeBMYa2Ybzewq4D+AM82sEjjDzwM8A6wGqoDfANcAOOdqgX8H3vKvm30bfp25fptVwJ+6pmuH7vRxA+hfkMUzy6pJnP0SEUkPkQOt4Jy7ZD+Lpu9jXQdcu5/93Avcu4/2CuDoA9XRU647Ywzff/wD1m5vZGS/vKDLERHpEWn/BPLeThlTihk8tGR90KWIiPQYhcFehpXkcuoRpbz8UdferSQi0pspDPbh2KHFfLRlF+9uqAu6FBGRHqEw2IeLJiXudj3/16/TFtNtpiKS+hQG+zCsJJfykSUA3P/G2mCLERHpAQqD/bjvy18AYMvO5oArERHpfgqD/cjLilA+soRXK7cFXYqISLdTGHyOaWNL+XDzLrbu0tGBiKQ2hcHnmD4uMRjr0+9r8DoRSW0Kg88xdmABxw4t4u5XVuuuIhFJaQqDA/iHKcOprm/mO4+9H3QpIiLdRmFwABdMTDxz8J4eQBORFKYwOIBQyLik/DC2NbToVJGIpCyFwUE4eUw/dja3ccdLq4IuRUSkWygMDsLp4/oDcOsLH+s2UxFJSQqDg5CdEebOyyYSd3Dqz14iFtcX34hIalEYHKSzjxlEXmaYpmiMf/njB0GXIyLSpRQGh+DF70wD9MU3IpJ6FAaHoH9BNpeUD8MMtjW0BF2OiEiXURgcoqtOGoVzUPajF1hV0xB0OSIiXUJhcIhG98+nfETiuw6m/+Jl4rqYLCIpQGHQAY/+4wnMmjAYgPc21gVbjIhIF1AYdNANZ48D4F+f0J1FIpL8FAYdNKgoh375mXywaaeGuBaRpKcw6ISHrp4CwE3zdXQgIslNYdAJYwYUMH5QIdsaWjVMhYgkNYVBJ31n5lgAHl6yIeBKREQ6TmHQSdOOKGXCsGJufeFjGlragi5HRKRDFAadZGb888yxxB088paODkQkOSkMusAJo/oyqCibh5esxzk9hCYiyadTYWBm3zKz5Wb2gZk9ZGbZZjbSzBabWZWZPWJmmX7dLD9f5ZePaLefG337R2Y2o5N96nFmxjXTDqdyawPraxuDLkdE5JB1OAzMbAjwT0CZc+5oIAxcDPwUuMU5NxrYAVzlN7kK2OHbb/HrYWbj/XZHATOBO8ws3NG6glLmh6j47ZvrAq5EROTQdfY0UQTIMbMIkAtUA6cDj/nl84Dz/fQsP49fPt3MzLc/7Jxrcc6tAaqA8k7W1ePGDSxg3MAC5r62huZoLOhyREQOSYfDwDm3CfhPYD2JEKgHlgJ1zrk9t9VsBIb46SHABr9tm1+/b/v2fWzzGWY2x8wqzKyipqamo6V3CzPjq6eOAuCOF6sCrkZE5NB05jRRHxJ/1Y8EBgN5JE7zdBvn3N3OuTLnXFlpaWl3/qgOOffYwQwryeHpZRqeQkSSS2dOE50BrHHO1TjnosAfgBOBYn/aCGAosMlPbwKGAfjlRcD29u372CapZIRDXD55OKtqdlNd3xR0OSIiB60zYbAemGJmuf7c/3RgBfAicKFfZzbwhJ+e7+fxy//sEvdhzgcu9ncbjQTGAEs6UVegTh2bOGJ5SE8ki0gS6cw1g8UkLgS/DSzz+7ob+C5wvZlVkbgmcI/f5B6gr2+/HrjB72c58CiJIHkWuNY5l7RXYMcNLOTE0X351cJKrn6gIuhyREQOiiXrQ1JlZWWuoqJ3/rJ9o2obl85dDMADXynnlCN63/UNEUk/ZrbUOVe2r2V6ArkbTB3dj5e/Mw2AK+5dQl1ja7AFiYgcgMKgmwzvm8dFk4YCcN/ra4MtRkTkABQG3eim844C4LaFlTS2akRTEem9FAbdKD8rwt9NTDw/d/8ba4MtRkTkcygMutnPLjg28f7sR0Rj8YCrERHZN4VBN4uEQ9x28QQA/vvlVcEWIyKyHwqDHnDecYMZUJjF7xavp6UtaR+hEJEUpjDoAWbGzy88jur6Zp5+X+MWiUjvozDoISeP6cfAwmyeW7456FJERP6KwqCHmBlnHTWA55Zv4bXKbUGXIyLyGQqDHnTVSSMBuPyexbp2ICK9isKgBw3vm8ctXzoOgFsWVAZcjYjIXygMetj5E4ZQlJPBwpVbgi5FRORTCoMeZmbMnjqCyq0NGqJCRHoNhUEAJo8sAeBhfQGOiPQSCoMATD28L7mZYW5+aoUuJItIr6AwCICZcaEf3nqeBrATkV5AYRCQH/5NYnjrnzzzIZ/UNQVcjYikO4VBQEIh467LJwFw/aPvBluMiKQ9hUGAZh49kJlHDWTR6lreXLU96HJEJI0pDAJ2w9njALjkN4v4cPPOgKsRkXSlMAjYiH55zL2iDICbn1yBcy7gikQkHSkMeoEzxg/gB+eO541V23l+hZ5MFpGepzDoJa44YTjDSnJ44M21QZciImlIYdBLRMIh/vb4obyxajtrt+0OuhwRSTMKg17k8smHETLjR0+vDLoUEUkzCoNepH9hNlMP78trVTU0tWqYChHpOQqDXuaqk0bSHI3zgoa4FpEepDDoZaYe3o+SvEy+8dA71DdGgy5HRNJEp8LAzIrN7DEz+9DMVprZCWZWYmYLzKzSv/fx65qZ/crMqszsfTOb2G4/s/36lWY2u7OdSmaZkRD/7++OAeCuV1YFXI2IpIvOHhncBjzrnBsHHAesBG4AFjrnxgAL/TzA2cAY/5oD3AlgZiXATcBkoBy4aU+ApKsZRw0kZHDnS6s0xLWI9IgOh4GZFQGnAPcAOOdanXN1wCxgnl9tHnC+n54FPOASFgHFZjYImAEscM7VOud2AAuAmR2tK1X8+G8TRwcvrNgacCUikg46c2QwEqgB7jOzd8xsrpnlAQOcc9V+nc3AAD89BGj/1V4bfdv+2tPaRZOGUpAV4doH36a1LR50OSKS4joTBhFgInCnc+54YDd/OSUEgEsMtNNlg+2Y2RwzqzCzipqamq7aba8UCYe44ZzEIHYPLl4XcDUikuo6EwYbgY3OucV+/jES4bDFn/7Bv+85z7EJGNZu+6G+bX/tf8U5d7dzrsw5V1ZaWtqJ0pPDpeWHUT6yhP9+ZTWxuAawE5Hu0+EwcM5tBjaY2VjfNB1YAcwH9twRNBt4wk/PB67wdxVNAer96aTngLPMrI+/cHyWb0t7ZsaVU0dQXd/Ml+9bEnQ5IpLCIp3c/hvA78wsE1gNXEkiYB41s6uAdcDf+3WfAc4BqoBGvy7OuVoz+3fgLb/ezc652k7WlTLOGD+AoX1yeLVyG0++9wl/c9zgoEsSkRRkyTp+fllZmauoqAi6jB5R3xTluH97nqxIiI9+dHbQ5YhIkjKzpc65sn0t0xPISaAoJ4PzjhtMS1ucJWt00CQiXU9hkCS+MyNxaeZbj7yrW01FpMspDJLEsJJcvjbtcDbVNXH8zc/r6zFFpEspDJLIt888glOOKGV3a4yl63YEXY6IpBCFQRKJhEP854XHAnDhXW/SHNW4RSLSNRQGSaZ/YTYnj+kHwMxbXwm4GhFJFQqDJDTvynIA1m5v5FuPvBtsMSKSEhQGSSgUMt644XQAHn9nEz9+ekXAFYlIslMYJKnBxTks+d50AO59fS3rtzcGXJGIJDOFQRLrX5jNk18/iVjcccrPX2TpOj2QJiIdozBIcscMLeL2S48H4II73+T1qm0BVyQiyUhhkALOPXYw879+IgCXzV1MXMNdi8ghUhikiGOHFvON00cD8NZanS4SkUOjMEgh/3jq4QBc8ptFbG9oCbgaEUkmCoMUkpcVYc4po4g7OO/21zV+kYgcNIVBivneOUfy5akj2FTXxG0LK4MuR0SShMIgBd30N+OZNLwPt75QSeWWXUGXIyJJQGGQgsyMG88eB8CX73uLaEzffyAin09hkKLKRpRwSflhbKpr4prfvR10OSLSyykMUthP/vZoCrIiLFixhVc+rgm6HBHpxRQGKczMePW7pwEw57cVCgQR2S+FQYorzs3kT9edTHM0zhX3LuHCO9+gqVVfiiMin6UwSANHDirkj9cmhquoWLeDnz77YcAViUhvozBIExOGFbP2P77IBROH8r8VG3R0ICKfoTBIMxeXD2N3a0xHByLyGQqDNPOFESWcPKYf97+xli/995s6QhARQGGQlubOLmPS8D4sXlPL7HuXaAwjEVEYpKOsSJjff20q508YzJK1tRz5g2dZVdMQdFkiEiCFQRq75UsTmDa2lOZonC/ft4Q2DVshkrYUBmnMzLj/ynK+feYRbKht4puPvEtM35ImkpY6HQZmFjazd8zsKT8/0swWm1mVmT1iZpm+PcvPV/nlI9rt40bf/pGZzehsTXJovjF9DNdMO5yn3q/mD29vDLocEQlAVxwZXAesbDf/U+AW59xoYAdwlW+/Ctjh22/x62Fm44GLgaOAmcAdZhbugrrkEHxnxliOGVLED55YzoOL1wddjoj0sE6FgZkNBb4IzPXzBpwOPOZXmQec76dn+Xn88ul+/VnAw865FufcGqAKKO9MXXLozIzbLz2ekMH3Hl/G61Xbgi5JRHpQZ48MbgX+Gdhz5bEvUOeca/PzG4EhfnoIsAHAL6/363/avo9tPsPM5phZhZlV1NRo0LWuNrxvHk98PTFsxXUPv8vWnc0BVyQiPaXDYWBm5wJbnXNLu7Cez+Wcu9s5V+acKystLe2pH5tWRvcv4KlvnERdYytXzavQMwgiaaIzRwYnAueZ2VrgYRKnh24Dis0s4tcZCmzy05uAYQB+eRGwvX37PraRABw9pIhvnzWWZZvqueLeJcR1h5FIyutwGDjnbnTODXXOjSBxAfjPzrnLgBeBC/1qs4En/PR8P49f/meX+LNzPnCxv9toJDAGWNLRuqRrXH3ySEb1y+PVym08/NaGA28gIkmtO54z+C5wvZlVkbgmcI9vvwfo69uvB24AcM4tBx4FVgDPAtc65zRgTsAi4RALv30qk0eWcPNTy7nntTU6QhBJYZas54TLyspcRUVF0GWkvO0NLVxw5xus3d7ICaP68uDVk0ncBCYiycbMljrnyva1TE8gy+fqm5/Fn647hdPGlvLm6u1M/Y8/s1rjGImkHIWBHFBOZph7Zn+BsuF9qK5v5vRfvMxT738SdFki0oUUBnJQQiHjwauncM4xAwH4+oPv8MGm+oCrEpGuojCQg5YZCXHHZZN4ZM4UAM79r9eoa2wNuCoR6QoKAzlkk0f15Z7ZiWtQ//bkCj2YJpICFAbSIdOPHMC3zjiCx9/ZxE3zlwddjoh0ksJAOuwbp4/m5DH9eODNdTy2VENfiyQzhYF0WChk3DP7C5SPKOH//u97/OaV1UGXJCIdpDCQTsmMhLj90uMB+PEzK/njOxpWSiQZKQyk0/oXZlPxL2cwrCSHbz7yLr9+sSrokkTkECkMpEv0y8/iya+fxMDCbO54sYqqrXpKWSSZKAykyxTnZvL4tVOJxh1n/PJlnl++OeiSROQgKQykSw0qyuHGs8cBMOe3S/nd4nV6DkEkCSgMpMtdeeJIFt04nfGDCvn+4x9w+i9e1tAVIr2cwkC6xcCibP547YnMPGoga7bt5tz/eo1VGu1UpNdSGEi3yYyEuOsfJvH4NVPJz4rwD3MXU7tbYxmJ9EYKA+l2xx/Wh//5P5PZtruVc257lYq1tUGXJCJ7URhIj5gwrJiHrp5MS1uMS3+zmC07m4MuSUTaURhIj5k0vITvnXMkrbE4k3+ykMvmLtIQ2CK9hMJAetRFZcP4wzVTOXJQIa9Xbee821/n4y27gi5LJO0pDKTHTTysD3+67mR+cdFxrK9t5JK7F7Fw5RZicT2PIBIUhYEE5oJJQ3n2mydTmJPBVfMqmHnrK+xuaQu6LJG0pDCQQI0bWMiz3zyZ6888gsqtDZx3+2u8t6Eu6LJE0o7CQAKXFQnzT9PHcN+VX2B3S4xLfrOInzyzkgYdJYj0GIWB9Bqnje3PvK+Uc/SQIu5+ZTXTfv4S33t8Gc3RWNCliaQ8hYH0KmMHFvDoV0/g91+byuGleTy4eD3j/vVZ3lm/I+jSRFKawkB6pUnD+/DwnCn89IJjALh87mKWrtOTyyLdRWEgvZaZ8aUvHMZdl08kEg5xwZ1vcv6vX6e+KRp0aSIpR2Egvd7Mowfx2ndP48tTR/DuhjqunldBS5uuI4h0pQ6HgZkNM7MXzWyFmS03s+t8e4mZLTCzSv/ex7ebmf3KzKrM7H0zm9huX7P9+pVmNrvz3ZJUU5CdwQ/PO4p/PXc8S9bWUv7jhbxRtS3oskRSRmeODNqAbzvnxgNTgGvNbDxwA7DQOTcGWOjnAc4GxvjXHOBOSIQHcBMwGSgHbtoTICJ7u+qkkfziouOob4py6dzFXPvg29Q36rSRSGd1OAycc9XOubf99C5gJTAEmAXM86vNA87307OAB1zCIqDYzAYBM4AFzrla59wOYAEws6N1Seq7YNJQVtw8g386fTTPL9/Mube/yqLV24MuSySpdck1AzMbARwPLAYGOOeq/aLNwAA/PQTY0G6zjb5tf+0i+5WbGeH6s8Yy78pyNtQ2cfHdi/j6g2/T1KprCSId0ekwMLN84PfAN51zO9svc4lvQu+y0cfMbI6ZVZhZRU1NTVftVpLY1NH9WHTjdM4+eiBPvV/Nebe/pi/PEemAToWBmWWQCILfOef+4Ju3+NM/+Petvn0TMKzd5kN92/7a/4pz7m7nXJlzrqy0tLQzpUsKGViUzZ2XT+L+K79AUzTGl+5exH8trCSuUVBFDlpn7iYy4B5gpXPul+0WzQf23BE0G3iiXfsV/q6iKUC9P530HHCWmfXxF47P8m0ih2Ta2P786bqTOfPIAfxiwcec8vMXmfvqatZu2x10aSK9niXO5HRgQ7OTgFeBZUDcN3+PxHWDR4HDgHXA3zvnan143E7i4nAjcKVzrsLv6yt+W4AfO+fuO9DPLysrcxUVFR2qXVJbPO548v1PuGn+cuoao2RGQvz8wmOZNUGXoiS9mdlS51zZPpd1NAyCpjCQA4nFHSurd/LD+cupWLeDAYVZnDa2P3NOGcWo0vygyxPpcQoDSWvRWJwHF6/n6WXVLFmTuLhcNrwP/3nRcYzolxdwdSI9R2Eg4q2uaeDRio389s21APzySxOYcdTAYIsS6SEKA5G9rN22m3/8n6V8uHkXY/rnc8XUEZx33GCKcjKCLk2k2ygMRPahvinK1Q9UfHrqKDMS4ovHDOKyyYdx/GF9CIcs4ApFutbnhUGkp4sR6S2KcjJ49Ksn4Jxj2aZ6Hlu6kceWbuTxdzYxsl8e1595BKeP609elv6ZSOrTkYFIO7uaozy3fAu3LPiYTXVNAJx6RCnHDSvmihOG0y8/K+AKRTpOp4lEDlEs7nhrbS0/enoF1XXNbN/dCsC4gQVMGFbMGUcO4JQjSsmM6CtBJHkoDEQ66Z31O3h+xRbe21DH0nU7aGmLk50R4pppo7mk/DBKC3TEIL2fwkCkC7W2xXn54xoerdjAghVbABjZL48TR/flwknDmDCsONgCRfZDYSDSTZauq+Wlj2p4b2M9r3ycGEl33MACzpswmHOOHsTwvrkkRmIRCZ7CQKQHbG9o4cn3PuGxtzfywabEaO6Di7IZWJTNuEGFHDOkiPKRJYzom6fbViUQCgORHrZ2225e+mgrr1Vtp2ZXM5VbG2j0X7wzqCibw0vzOXJQAaP75zO6fwHjBhboFlbpdgoDkYA55/hw8y5erazhhZVbWV3TwK7mNlraEgP+ZkZCnDa2lPGDihjRL5fhffMY0z9fASFdSmEg0gvF4o5NO5r4eEsiJBas2MIn9c2fWeewklyOGFBAS1uM4txMjhtaxJGDConFHaNK88iMhOhfkB1QDyTZKAxEkkRzNMbHW3ZRXd9M5ZZdrKjeSdXWBnIyI2zd2Uz1XmEBUFqQxfCSXI4YWMBhJbmELNE2pn8BQ4pzKM7N0EVsATQchUjSyM4Ic+zQYo4dyj5HU926q5nln+xkdc1u8rPC1DVGWVXTQNXWBp5ZVk1dY/SvtskIGyV5mTgH4VBiOj8rQlZGmNyMMJA4TdUnN4PD++dTmp9FblaEvnmZFOVk0Dc/k9xM/apIdfo/LJJE+hdk039sNqeN3ffyXc1R4g4+qWtizbbdVNc3U7OrhdrdLYTMiMYcNQ0tNLW2Ud8UpbquiVjc0RZ3bG9oYbe/yL234twMBhRk0xSNkRkJEQkZ4ZBRnJtBYXYGeVkRSguyCJsRd47MSIim1hgZ4RA7m6OEzOibl0lJfiY5GWHCIaOhpe3T9vzsCP3yswgZ5GRGCBnkZkYoysmgLRYnGnNkhBPbtMUdORlhsjPChAzMjNa2OCGDkBlmiUEI4z78IiEjEjYioRD1TVFa2mIUZGcQCRlZkdCn122isTgfb9lFXWOUTXVN1DdGKcrNYEBhNq1tcaKxOKUFWQwszPb/TRJhmRG2/R55xeOOlrY49U1R399E4OdkhMnJDBMyIxZ3NEdjmMHmnc3U7m4lEgrR0hbjk7omdjRGqdzSQEtbjHDIKMrJ4OZZR3f+w7QXhYFICinITgzBXZSTwZGDCg9p23g8ERTbG1rZ1RxlR2OUXc1Rtu5qobq+ic31LeRkhonF44kAiTl2NLaydWcLDS1t1OxqIe4cITPa4olf3tGYozA7gnOwq6XtkPuTnRGiORr/3HXCocQvVDNwDnIywjRF9x1q3SEcMnIzwxRmZ9AWj9Palni1tMVpi3fNafj+BVnkZ0WIO0dxbmaX7HNvCgMRASAUMgYUZjOgsPMXpKOxOBnhEPG4I+SfqWiOxqhvitLYGqMtFic/O0LcQW1DKzubo9TubiXuHI2tsUR4NEfZvruV7EiIrIww0Vj807/om6MxmqNx2uJx4s6RHQnTFnc452hoiTG4OJuMcIhoLP7pkU80FqcoJ4PsjMTptbhznx4VZEVCZIQT/R/aJ5chxTkU5kTY0RilZlcLeZlhIuEQ1XVNbNnVTDgUorYhEYLN0TgNLW3sbI6SGQ6RGQn95d2/CrMzyM+KEIs7WmNxmqMxmqKxT//7ZEfCOBLXevrlZxKPQyRsvo6MHvmeDYWBiHS5jHBiAL9Qu4frsv2pnb0NKc7psboOVW5m5DP1jUzhr0nVkIsiIqIwEBERhYGIiKAwEBERFAYiIoLCQEREUBiIiAgKAxERIYlHLTWzGmBdBzfvB2zrwnJ6E/UteaVy/1K5b5A8/RvunCvd14KkDYPOMLOK/Q3jmuzUt+SVyv1L5b5BavRPp4lERERhICIi6RsGdwddQDdS35JXKvcvlfsGKdC/tLxmICIin5WuRwYiItKOwkBERNIrDMxsppl9ZGZVZnZD0PUcLDO718y2mtkH7dpKzGyBmVX69z6+3czsV76P75vZxHbbzPbrV5rZ7CD6sjczG2ZmL5rZCjNbbmbX+fak75+ZZZvZEjN7z/ft33z7SDNb7PvwiJll+vYsP1/ll49ot68bfftHZjYjoC79FTMLm9k7ZvaUn0+lvq01s2Vm9q6ZVfi2pP9c7pdzLi1eQBhYBYwCMoH3gPFB13WQtZ8CTAQ+aNf2M+AGP30D8FM/fQ7wJ8CAKcBi314CrPbvffx0n17Qt0HARD9dAHwMjE+F/vka8/10BrDY1/wocLFvvwv4mp++BrjLT18MPOKnx/vPaxYw0n+Ow0H/v/O1XQ88CDzl51Opb2uBfnu1Jf3ncn+vdDoyKAeqnHOrnXOtwMPArIBrOijOuVeA2r2aZwHz/PQ84Px27Q+4hEVAsZkNAmYAC5xztc65HcACYGa3F38Azrlq59zbfnoXsBIYQgr0z9fY4Gcz/MsBpwOP+fa9+7anz48B083MfPvDzrkW59waoIrE5zlQZjYU+CIw188bKdK3z5H0n8v9SacwGAJsaDe/0bclqwHOuWo/vRkY4Kf3189e339/6uB4En9Bp0T//GmUd4GtJH4RrALqnHNtfpX2dX7aB7+8HuhLL+0bcCvwz0Dcz/cldfoGieB+3syWmtkc35YSn8t9iQRdgHSec86ZWVLfI2xm+cDvgW8653Ym/mhMSOb+OediwAQzKwYeB8YFW1HXMLNzga3OuaVmNi3gcrrLSc65TWbWH1hgZh+2X5jMn8t9Sacjg03AsHbzQ31bstriD0Px71t9+/762Wv7b2YZJILgd865P/jmlOkfgHOuDngROIHEKYQ9f4i1r/PTPvjlRcB2emffTgTOM7O1JE65ng7cRmr0DQDn3Cb/vpVEkJeTYp/L9tIpDN4Cxvi7HTJJXMSaH3BNnTEf2HNnwmzgiXbtV/i7G6YA9f6w9jngLDPr4++AOMu3BcqfN74HWOmc+2W7RUnfPzMr9UcEmFkOcCaJayIvAhf61fbu254+Xwj82SWuQs4HLvZ35IwExgBLeqQT++Gcu9E5N9Q5N4LEv6U/O+cuIwX6BmBmeWZWsGeaxOfpA1Lgc7lfQV/B7skXiSv+H5M4b/v9oOs5hLofAqqBKIlzjleRON+6EKgEXgBK/LoG/Nr3cRlQ1m4/XyFxga4KuDLofvmaTiJxbvZ94F3/OicV+gccC7zj+/YB8APfPorEL7wq4H+BLN+e7eer/PJR7fb1fd/nj4Czg+7bXv2cxl/uJkqJvvl+vOdfy/f8vkiFz+X+XhqOQkRE0uo0kYiI7IfCQEREFAYiIqIwEBERFAYiIoLCQEREUBiIiAjw/wEZxi+isV885QAAAABJRU5ErkJggg==\n", 265 | "text/plain": [ 266 | "
" 267 | ] 268 | }, 269 | "metadata": { 270 | "needs_background": "light" 271 | }, 272 | "output_type": "display_data" 273 | } 274 | ], 275 | "source": [ 276 | "from matplotlib import pyplot as plt\n", 277 | "%matplotlib inline\n", 278 | "\n", 279 | "plt.plot(plt_x, plt_y)\n", 280 | "plt.show()" 281 | ] 282 | } 283 | ], 284 | "metadata": { 285 | "kernelspec": { 286 | "display_name": "Python 3 (ipykernel)", 287 | "language": "python", 288 | "name": "python3" 289 | }, 290 | "language_info": { 291 | "codemirror_mode": { 292 | "name": "ipython", 293 | "version": 3 294 | }, 295 | "file_extension": ".py", 296 | "mimetype": "text/x-python", 297 | "name": "python", 298 | "nbconvert_exporter": "python", 299 | "pygments_lexer": "ipython3", 300 | "version": "3.8.11" 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 5 305 | } 306 | -------------------------------------------------------------------------------- /7.adam.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4fcd93aa", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "((1503, 5), (1503,))" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import numpy as np\n", 22 | "\n", 23 | "#加载数据\n", 24 | "data = np.loadtxt(fname='./线性数据.csv', delimiter='\\t')\n", 25 | "\n", 26 | "#标准化\n", 27 | "data -= data.mean(axis=0)\n", 28 | "data /= data.std(axis=0)\n", 29 | "\n", 30 | "x = data[:, :-1]\n", 31 | "y = data[:, -1]\n", 32 | "\n", 33 | "x.shape, y.shape" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "id": "cc6c8e3c", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "#常量\n", 44 | "N, M = x.shape\n", 45 | "\n", 46 | "#变量\n", 47 | "w = np.ones(M)\n", 48 | "b = 0\n", 49 | "\n", 50 | "#初始化S为全0\n", 51 | "S_w = np.zeros(M)\n", 52 | "S_b = 0\n", 53 | "\n", 54 | "v_w = np.zeros(M)\n", 55 | "v_b = 0" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "id": "92163201", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "0.6590042695516539" 68 | ] 69 | }, 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "#预测函数\n", 77 | "def predict(x):\n", 78 | " return w.dot(x) + b\n", 79 | "\n", 80 | "\n", 81 | "predict(x[0])" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "id": "a7bb7a80", 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "text/plain": [ 93 | "0.21258140154187247" 94 | ] 95 | }, 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "#求loss,MSELoss\n", 103 | "def get_loss(x, y):\n", 104 | " pred = predict(x)\n", 105 | " loss = (pred - y)**2\n", 106 | " return loss\n", 107 | "\n", 108 | "\n", 109 | "get_loss(x[0], y[0])" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "id": "8027d213", 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "data": { 120 | "text/plain": [ 121 | "(array([-0.61003339, -1.05581946, 1.66242713, 1.21242212, -0.59417855]),\n", 122 | " 0.923131013558981)" 123 | ] 124 | }, 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "def get_gradient(x, y):\n", 132 | " global w\n", 133 | " global b\n", 134 | "\n", 135 | " eps = 1e-3\n", 136 | "\n", 137 | " loss_before = get_loss(x, y)\n", 138 | "\n", 139 | " gradient_w = np.empty(M)\n", 140 | " for i in range(M):\n", 141 | " w[i] += eps\n", 142 | " loss_after = get_loss(x, y)\n", 143 | " w[i] -= eps\n", 144 | " gradient_w[i] = (loss_after - loss_before) / eps\n", 145 | "\n", 146 | " b += eps\n", 147 | " loss_after = get_loss(x, y)\n", 148 | " b -= eps\n", 149 | " gradient_b = (loss_after - loss_before) / eps\n", 150 | "\n", 151 | " return gradient_w, gradient_b\n", 152 | "\n", 153 | "\n", 154 | "get_gradient(x[0], y[0])" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 6, 160 | "id": "f39e0125", 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "data": { 165 | "text/plain": [ 166 | "11073.905141728206" 167 | ] 168 | }, 169 | "execution_count": 6, 170 | "metadata": {}, 171 | "output_type": "execute_result" 172 | } 173 | ], 174 | "source": [ 175 | "def total_loss():\n", 176 | " loss = 0\n", 177 | " for i in range(N):\n", 178 | " loss += get_loss(x[i], y[i])\n", 179 | " return loss\n", 180 | "\n", 181 | "\n", 182 | "total_loss()" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 7, 188 | "id": "c371c6a4", 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "500 [0.08904127 0.71062389] [16.68083241 9.3870837 ] 1261.5118603011992\n", 196 | "1000 [-0.90368761 0.83072632] [8.3751223 4.76494681] 749.9516889645644\n", 197 | "1500 [-0.30210063 0.02406904] [5.98086602 3.792364 ] 745.5930763832396\n", 198 | "2000 [ 0.05851974 -0.52552324] [5.03242601 3.45097686] 734.1425192541457\n", 199 | "2500 [ 0.19785822 -0.16807166] [4.55122551 2.92449296] 763.0359728232618\n", 200 | "3000 [ 0.11518987 -0.08617003] [4.08077263 2.93377002] 746.4590779281906\n", 201 | "3500 [ 0.14053857 -0.11983124] [3.52452827 2.59132747] 783.7272638890388\n", 202 | "4000 [0.03610707 0.36194797] [3.26520532 2.42821712] 766.6407173851062\n", 203 | "4500 [-0.24672945 0.18166294] [3.41205737 3.01430322] 791.864659000636\n", 204 | "5000 [ 0.1558802 -0.19912752] [3.24213769 2.86521236] 788.53176983857\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "plt_x = []\n", 210 | "plt_y = []\n", 211 | "\n", 212 | "for t in range(1, 5500):\n", 213 | " i = np.random.randint(N)\n", 214 | " gradient_w, gradient_b = get_gradient(x[i], y[i])\n", 215 | "\n", 216 | " v_w = 0.9 * v_w + 0.1 * gradient_w\n", 217 | " v_b = 0.9 * v_b + 0.1 * gradient_b\n", 218 | "\n", 219 | " #S的计算和rmsprop完全一致\n", 220 | " S_w = 0.999 * S_w + 0.001 * gradient_w**2\n", 221 | " S_b = 0.999 * S_b + 0.001 * gradient_b**2\n", 222 | "\n", 223 | " #根据以上公式,在时刻0\n", 224 | " #v = [0.1 * gradient_0]\n", 225 | "\n", 226 | " #这可能太过于小,为了消除这个影响,需要做偏差修正,也就是除以系数\n", 227 | " #v = 0.1 * sigma[0.9**(t-i) * gradient_i]\n", 228 | " #S = 0.001 * sigma[0.999**(t-i) * gradient_i**2]\n", 229 | " \n", 230 | " #将梯度的系数部分整理得到\n", 231 | " #0.1 * sigma[0.9**(t-i)] = 1-0.9**t\n", 232 | "\n", 233 | " #偏差修正\n", 234 | " v_hat_w = v_w / (1 - 0.9**t)\n", 235 | " v_hat_b = v_b / (1 - 0.9**t)\n", 236 | " S_hat_w = S_w / (1 - 0.999**t)\n", 237 | " S_hat_b = S_b / (1 - 0.999**t)\n", 238 | "\n", 239 | " #下面是adam参数更新的公式\n", 240 | " #这里的1e-2是超参数lr\n", 241 | " gradient_w = (1e-2 * v_hat_w) / (S_hat_w**0.5 + 1e-6)\n", 242 | " gradient_b = (1e-2 * v_hat_b) / (S_hat_b**0.5 + 1e-6)\n", 243 | "\n", 244 | " #更新参数\n", 245 | " w -= gradient_w\n", 246 | " b -= gradient_b\n", 247 | "\n", 248 | " plt_x.append(t)\n", 249 | " plt_y.append(total_loss())\n", 250 | "\n", 251 | " if t % 500 == 0:\n", 252 | " print(t, v_hat_w[:2], S_hat_w[:2], total_loss())" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 8, 258 | "id": "0471a70d", 259 | "metadata": { 260 | "scrolled": true 261 | }, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgqElEQVR4nO3de3Sc9X3n8fd37pJGd41kWTbxFYOhEIgLhJA2gQQITWKym3bpphtvyymnbbabNnu2JafnhG672dN0t0mT05aUDXRJTrYJJeniNhfiAIGWBBMDxviCbdlgbFmWZN2vM5qZ3/4xPzmDkHzRSB5Jz+d1zpx5nt/zzOj7s8b6zO/3PM+MOecQEZFgC5W7ABERKT+FgYiIKAxERERhICIiKAxERASIlLuAuWpqanJr1qwpdxkiIkvGCy+8cNo5l5pp25INgzVr1rBr165ylyEismSY2bHZtmmaSEREFAYiIqIwEBERFAYiIoLCQEREUBiIiAgKAxERIWBhkMs77v/REXYfHyh3KSIii0qgwmA0k+Xvnn2Nv/jBwXKXIiKyqAQqDGoSUW7a0MTRntFylyIisqgEKgwA1jRV0TEwzngmV+5SREQWjcCFQXN1HIC+sUyZKxERWTwCFwZ1lTEABhQGIiJnBDAMogAMjE2WuRIRkcUjcGFQ70cG/RoZiIicEcAwKIwM+jUyEBE5I3BhcOaYwahGBiIiUwIXBrFIiKpYWCMDEZEigQsDKIwOBsY1MhARmRLIMKiviupsIhGRIoEMg7qKmM4mEhEpEswwqNTIQESkWCDDoL5SIwMRkWIBDYMog+OT5PKu3KWIiCwKgQyDusoYzsHwhKaKRETgPMLAzB4ys24z21vU1mBmO8zssL+v9+1mZl8ys3Yz22Nm1xY9Zpvf/7CZbStqf4eZveIf8yUzs/nu5HT1VboKWUSk2PmMDP4PcPu0tnuBJ5xzG4En/DrAB4CN/nYPcD8UwgO4D7geuA64bypA/D6/WfS46T9r3tXp84lERN7knGHgnHsG6JvWvBV42C8/DNxZ1P5VV/AcUGdmrcBtwA7nXJ9zrh/YAdzut9U4555zzjngq0XPtWBqEhEARiayC/2jRESWhLkeM2hxznX65VNAi19uA44X7XfCt52t/cQM7TMys3vMbJeZ7erp6Zlj6RCPhAGYmNS3nYmIwDwcQPbv6C/KaTnOuQecc1ucc1tSqdScnycRLYTBuMJARASYexh0+Ske/H23b+8AVhftt8q3na191QztCyoRLXQ7PZlf6B8lIrIkzDUMtgNTZwRtAx4rav+4P6voBmDQTyc9DtxqZvX+wPGtwON+25CZ3eDPIvp40XMtmKmRwURWIwMREYDIuXYws78H3gM0mdkJCmcF/RnwiJndDRwDfsXv/l3gDqAdGAN+HcA512dmfwr81O/3J865qYPSv0PhjKUK4Hv+tqDOhIGmiUREgPMIA+fcr86y6ZYZ9nXAJ2Z5noeAh2Zo3wVcea465lMiUhgQTWiaSEQECOgVyJFwiEjINDIQEfECGQZQmCrSyEBEpCDAYRDSAWQRES+wYRCPhDVNJCLiBTYMKmJhXWcgIuIFNgwS0ZCuQBYR8YIbBpomEhE5I7hhEFUYiIhMCXAYhHRqqYiIF9gwiEfDOrVURMQLbBgkIjqbSERkSnDDIBrSMQMRES/AYaADyCIiUwIcBiEmsnkKH7QqIhJswQ2DSJhc3jGZUxiIiAQ3DPRtZyIiZwQ4DKa+4EZhICIS4DAojAx0eqmIiMJAIwMRERQG+kgKERECHQb+mIEOIIuIBDkMNE0kIjIluGEQ0TSRiMiU4IaBnybSt52JiAQ6DDRNJCIyJbBhEPcjg7TCQEQkuGGgU0tFRH4msGFQ4cNgLKORgYhIYMMgGg6RiIYYzWTLXYqISNmVFAZm9vtmts/M9prZ35tZwszWmtlOM2s3s2+aWczvG/fr7X77mqLn+bRvP2hmt5XYp/OWjEcZnlAYiIjMOQzMrA34z8AW59yVQBi4C/gc8AXn3AagH7jbP+RuoN+3f8Hvh5lt9o+7Argd+BszC8+1rgtRnYgwklYYiIiUOk0UASrMLAJUAp3AzcCjfvvDwJ1+eatfx2+/xczMt3/DOZd2zr0GtAPXlVjXeUnGI4xMTF6MHyUisqjNOQyccx3A/wLeoBACg8ALwIBzburt9gmgzS+3Acf9Y7N+/8bi9hkes6CScY0MRESgtGmiegrv6tcCK4EqCtM8C8bM7jGzXWa2q6enp+Tnq05EdMxARITSponeB7zmnOtxzk0C3wbeBdT5aSOAVUCHX+4AVgP47bVAb3H7DI95E+fcA865Lc65LalUqoTSC5IKAxERoLQweAO4wcwq/dz/LcB+4Cngo36fbcBjfnm7X8dvf9I553z7Xf5so7XARuD5Euo6b9WaJhIRAQoHgOfEObfTzB4FXgSywEvAA8B3gG+Y2X/3bQ/6hzwIfM3M2oE+CmcQ4ZzbZ2aPUAiSLPAJ59xFuRIs6c8mcs5RyDMRkWCacxgAOOfuA+6b1nyUGc4Gcs5NAL88y/N8FvhsKbXMRTIeJZd3TEzmqYhdlLNZRUQWpcBegQyFA8gAw2mdXioiwaYwAIbGddxARIIt0GFQVxkDYHA8U+ZKRETKK9hhUBEFoH9U00QiEmyBDoN6PzLoH9PIQESCLdBhUFdVGBkMjmtkICLBFugwqI5HCIdMIwMRCbxAh4GZUVcRpX9MIwMRCbZAhwFAXWWUAY0MRCTgAh8G9ZUxBjQyEJGAC3wY1FVqmkhERGFQGdM0kYgEXuDDoL4yqrOJRCTwAh8GdZUxJibzTExelE/NFhFZlBQGlYULz3QQWUSCLPBhoI+kEBFRGJwZGfSPKgxEJLgCHwZNyTgAvQoDEQmwwIdBY1Vhmqh3JF3mSkREyifwYVBfGSNkcHpEIwMRCa7Ah0EoZDRUxekd1chARIIr8GEA0JSMaWQgIoGmMKBwEPm0jhmISIApDCiMDHo1MhCRAFMYAI0aGYhIwCkMgMZkjLFMjvGMPp9IRIJJYUDRtQY6o0hEAkphADRUFa5C7tNVyCISUAoDoOHMyEBhICLBpDCg+CMpFAYiEkwlhYGZ1ZnZo2b2qpkdMLN3mlmDme0ws8P+vt7va2b2JTNrN7M9ZnZt0fNs8/sfNrNtpXbqQjUmC2HQp2MGIhJQpY4Mvgh83zl3GXA1cAC4F3jCObcReMKvA3wA2Ohv9wD3A5hZA3AfcD1wHXDfVIBcLMl4hFg4pGkiEQmsOYeBmdUCvwA8COCcyzjnBoCtwMN+t4eBO/3yVuCrruA5oM7MWoHbgB3OuT7nXD+wA7h9rnXNhZnRUBWjT9NEIhJQpYwM1gI9wN+Z2Utm9hUzqwJanHOdfp9TQItfbgOOFz3+hG+brf0tzOweM9tlZrt6enpKKP2tGqpiOptIRAKrlDCIANcC9zvnrgFG+dmUEADOOQe4En7GmzjnHnDObXHObUmlUvP1tEDhuMFphYGIBFQpYXACOOGc2+nXH6UQDl1++gd/3+23dwCrix6/yrfN1n5RNSXjnB7WAWQRCaY5h4Fz7hRw3Mw2+aZbgP3AdmDqjKBtwGN+eTvwcX9W0Q3AoJ9Oehy41czq/YHjW33bRdVcHadnJE1hMCMiEiyREh//u8DXzSwGHAV+nULAPGJmdwPHgF/x+34XuANoB8b8vjjn+szsT4Gf+v3+xDnXV2JdFyxVHSeTzTM0nqW2Mnqxf7yISFmVFAbOud3Alhk23TLDvg74xCzP8xDwUCm1lCpVXfhIip6RCYWBiASOrkD2psKgW8cNRCSAFAZe89TIQGEgIgGkMPBS1QlAYSAiwaQw8GoSEWKRkMJARAJJYeCZGalkXMcMRCSQFAZFmmviGhmISCApDIqkkgoDEQkmhUGRlL8KWUQkaBQGRZqrE/SNZshk8+UuRUTkolIYFJm68KxX33gmIgGjMCiS0oVnIhJQCoMiU1chdw8pDEQkWBQGRX72YXUKAxEJFoVBkaakpolEJJgUBkVikRD1lVG6hyfKXYqIyEWlMJgmVa0Lz0QkeBQG0ygMRCSIFAbTNFcn9GF1IhI4CoNppkYGhW/pFBEJBoXBNKlknHQ2z3A6W+5SREQuGoXBNM01uvBMRIJHYTBNStcaiEgAKQymafJXIZ/WVcgiEiAKg2mmrkJWGIhIkCgMpqmriBIJmaaJRCRQFAbThEJGYzKmkYGIBIrCYAYraivoGBgvdxkiIheNwmAGG5uTHOoaKXcZIiIXjcJgBptaqukZTtM/mil3KSIiF0XJYWBmYTN7ycz+2a+vNbOdZtZuZt80s5hvj/v1dr99TdFzfNq3HzSz20qtqVSXtVYDcODUUJkrERG5OOZjZPBJ4EDR+ueALzjnNgD9wN2+/W6g37d/we+HmW0G7gKuAG4H/sbMwvNQ15xd3loDwP6TCgMRCYaSwsDMVgG/BHzFrxtwM/Co3+Vh4E6/vNWv47ff4vffCnzDOZd2zr0GtAPXlVJXqZqScVLVcV49NVzOMkRELppSRwZ/CfwBkPfrjcCAc27qU95OAG1+uQ04DuC3D/r9z7TP8Jg3MbN7zGyXme3q6ekpsfSz25BK0t6tg8giEgxzDgMz+yDQ7Zx7YR7rOSvn3APOuS3OuS2pVGpBf9aG5iRHukf0UdYiEgiREh77LuDDZnYHkABqgC8CdWYW8e/+VwEdfv8OYDVwwswiQC3QW9Q+pfgxZbOhOclwOkv3cJqWmkS5yxERWVBzHhk45z7tnFvlnFtD4QDwk865jwFPAR/1u20DHvPL2/06fvuTrvC2eztwlz/baC2wEXh+rnXNlw3NSQBNFYlIICzEdQZ/CHzKzNopHBN40Lc/CDT69k8B9wI45/YBjwD7ge8Dn3DO5RagrguiMBCRICllmugM59yPgB/55aPMcDaQc24C+OVZHv9Z4LPzUct8aa6OUx2PKAxEJBB0BfIszIz1zTqjSESCQWFwFhuak7T3KAxEZPlTGJzFulQVPcNpRtLZc+8sIrKEKQzOYlV9JQAd/fo4axFZ3hQGZ9FWVwFAx8BYmSsREVlYCoOzWF1fCIMTGhmIyDKnMDiLpmScWDikaSIRWfYUBmcRChlt9RUaGYjIsqcwOIe2ugpO6PuQRWSZUxicw6r6Ck0TiciypzA4h1X1FZweSTMxWfaPSxIRWTAKg3No0xlFIhIACoNzWJ8qfHrpQX0FpogsYwqDc7i8tYZENMSuY33lLkVEZMEoDM4hGg5x9ao6XjzWX+5SREQWjMLgPFy3toG9J4foH82UuxQRkQWhMDgP79/cQi7vePLV7nKXIiKyIBQG5+Hn2mpprU3w/X2nyl2KiMiCUBicBzPjA1e28vTBHgbHJstdjojIvFMYnKePXNNGJpfnn/acLHcpIiLzTmFwnq5sq2Fzaw1f+8kxnHPlLkdEZF4pDM6TmfEfb1zDwa5hXnxjoNzliIjMK4XBBXjvZc0AvPSGrjkQkeVFYXABUtVxVtYm2H18oNyliIjMK4XBBbrmknpe0jSRiCwzCoMLtGVNPR0D4xzvGyt3KSIi80ZhcIHevTEFwL8cPl3mSkRE5o/C4AKtT1WxoibBs0cUBiKyfCgMLpCZceP6Rp470qvrDURk2ZhzGJjZajN7ysz2m9k+M/ukb28wsx1mdtjf1/t2M7MvmVm7me0xs2uLnmub3/+wmW0rvVsL64b1jfSOZjjUNVLuUkRE5kUpI4Ms8F+cc5uBG4BPmNlm4F7gCefcRuAJvw7wAWCjv90D3A+F8ADuA64HrgPumwqQxerG9Y0APNuuqSIRWR7mHAbOuU7n3It+eRg4ALQBW4GH/W4PA3f65a3AV13Bc0CdmbUCtwE7nHN9zrl+YAdw+1zruhhW1VeyqaWa/7e7o9yliIjMi3k5ZmBma4BrgJ1Ai3Ou0286BbT45TbgeNHDTvi22doXtV+9bjV7Tgyyt2Ow3KWIiJSs5DAwsyTwLeD3nHNDxdtc4QjrvB1lNbN7zGyXme3q6emZr6edk49cu4pENMTDP369rHWIiMyHksLAzKIUguDrzrlv++YuP/2Dv5/6erAOYHXRw1f5ttna38I594BzbotzbksqlSql9JLVVkS56+cv4VsvnuBQ13BZaxERKVUpZxMZ8CBwwDn3+aJN24GpM4K2AY8VtX/cn1V0AzDop5MeB241s3p/4PhW37boffKWjSTjET77nQPlLkVEpCSljAzeBfwH4GYz2+1vdwB/BrzfzA4D7/PrAN8FjgLtwP8GfgfAOdcH/CnwU3/7E9+26NVXxfjdmzfy9KEedh7tLXc5IiJzZkv1wqktW7a4Xbt2lbsMJiZzXPfZH/Ley5r54l3XlLscEZFZmdkLzrktM23TFcglSkTDfOSaNr6zp5P9J4fO/QARkUVIYTAPfu99l1JXGeP3v7mbXH5pjrREJNgUBvOgvirGZz60mYNdwzp2ICJLksJgnrxnUwoz2Pnakjj2LSLyJgqDeVKTiHJVWy3feaVTU0UisuQoDObRb/3ietq7R/jy00fKXYqIyAVRGMyj269cwYeuXsnndxzij7fvo280U+6SRETOi8JgHpkZ/+MjV7L16pV8fecxbv3CM+w+PlDuskREzklhMM+qE1E+/+/ezvb/dBOJaIiP3v9j/us/vMxrp0fLXZqIyKwUBgvk8tYavv3bN/JrN7yN7S+f5Ja/+BGfemQ3I+lsuUsTEXkLhcECaq5J8McfvoJ//cOb+c13r+Ox3Sf58F/9K8f7xspdmojImygMLoJUdZxP33E5X7v7OnpHMvzSl/6Fz33/Vfp1gFlEFgmFwUV04/om/uG33slNG5v426eP8LGv7GR4YrLcZYmIKAwutktbqvmbj72DB7f9PAe7htn20PP0DKfLXZaIBJzCoEzee1kzf/3vr2FvxxDv+/zTPLLrOEv148RFZOlTGJTR7Ve28t1PvptLW5L8waN7+NBf/Svfe6VToSCB5JzTa7+MFAZltqE5yTfveSd//m+vYmIyz29//UW2/vWzfO25Y+ztGCSdzZW7RJkHQxOTfGdPJxOTS/f3mc87/ts/7eP+Hx056+syn3cc6hpmYjLHeOb8+nugc4ir/vgH/NqDO8nP8tleswVFLu/Ysb+LX/nyT/jo/T/m1ODErM+xEJxzDIxlGMss7dPG9U1ni0g2l+dbL57gy08fPXORWkU0zLpUFStqEkTCRmMyTlNVjNrKGPWVUcIhYzLnyOXzJKJhqmIRvvrcMV461s/qhkresynFphXVrKqvJBo2DneNUFMRZTSdJRI2WmsraK1NAPDDA12MZ3LcdsUKGpMxEtEwvSMZJiZztNQkiEcK7x0m83k6+scZn8yRyzu6htLknaMpGWc0nfXPHSISMtLZPIPjGVLVceoqY1TGwrTVVVCdiL6l78MTWfZ3DnF6JM1YJsdrp0epSUS4YmUtADduaCQeCZPPOyayOTr6x9nfOcRVq+pY21RF70iaw90jHO4e4Uj3CC+90U/fWIY7rmyltjLK4Ngk772smaZknL0dg2Ryea5f20A27zg5ME40HGJFTYIVtQmefLWbkwPjrKyr4G2NlRzvG2PX64V/0w9dvZJDXcP86GAPsUiIG9Y1sLK2giM9IwBUxSNUxsKsqEnw8olBnjjQxQvH+jncPcK6piruvKaNxmSMfN7xzvWNrKqvpH8sQ99oBsNoSsZIJiL85EgvY5kcPznay/G+MfacGOQ9m1J85oObiYRCDI5P8o8vdXDL5c1csbKG9u4R/vaZo7x+epQNzUl+8dIUN21sOvM7Gklnyeby/Mvh0zx1sJtNK6rZ0JzkHZfUs6E5SXUiSiwSYjSdZdexfh7+8es0V8f58NUr2bKmgYeefY0/+96rAKxPVfHBq1YSi4S4tKWaTS3VRCNGejLPfdv38fShHgDikRDv39zCppZqrl/XyE+O9BZex1UxekczhEPG5tYa/ufjB3mlYxCAP7rjcsygbzRz5v9Bx8A4BzqHuOWyFm6+rJlMLs8bfWOsqEnwxKtdPNv+5o+Ob66O8xs3rWVDKknfaIZI2Li0pZrqRIS2ugo6BycwK1wkmp7M8cMD3cQiITY2J5nM5UkmIlREw5wcmCDvHLm8o38sQywcYjST45lDPYyks1TGwvz09T66htJUxyNc3loDBh+6qpVrLqlndUMlI+ksIxNZRtKTNFbFaaqO++ce54cHunjuaC/hkHH3TeuIR0KsaaqiZzjN0Z4RQiGja3Ci0KeaOCPpHFesrGFdUxWFr6G/MGf7pjOFwSLknOONvjFe6Rhk1+v9HOsd5dRQmmwuz+mRNAPjk5zt11YZC/PBq1p5vXeMF471L8pPUa2tiBIySGfzpLP5GWuMRUJksvkz6/FIiJAZ4zO8u46EjGzRc1TFwrytsYqWmjjPHD5NLu/ess+FioYLwVv8Mx2c97/v21fXkXeOPScGL+jnVkTDrG6ooK4yxvOzfET6VG0V0TBXttVw8NQwQxOzv1O95pI63ugdo3fa6c3xSIi0/zdvrIqRyeYZTmcJGeQdvO/yFq5f28DfPfsaJ/0fqenM4J53ryMUMjoHxnl8X9eMv7NiyXiEz3xwM3/7zBGO9Pzsav3VDRUYRltdBZc0VPK9vZ1n+jX1+4xHQnzmQ5v55Xes5nt7OznQOcwzh3rY3znzNw9O9aUUFdEwlzRUMj6Z422NlVy9qo7j/WMcPDVMx8A4w2f5t5/ODKKhEJlc/tw7U/i/s/sz71cYTFnOYXAuubxjeGKS/rFJcnlHLBwiEjbGMjkGxzNsWlFDMh4BCt/R/EbfGCf6x8hkHetTVYxP5qiMRcjm83QOTHBqaALn4PLWamorovz09T6GxrOMT+aoq4xSFYvQNTzBxGSekEHYjNa6CpLxCCGDlpoEZtA7kjnzjmpiMkc4ZOTyjubqBL2jaQbGJhnNZDnRP05H/zhmEAuHiEdDxCNhKmNhLm2pprU2QSJaGEEc6xujo3+cSf+ONhyCiljhnXdDZYzLW2vYfbyfU0MT1FfGuLSlmo0tSVbUJM78Z+n37wzDIeMH+7rI5R2bV9aQyeY52DVMNGysqKkgl3ecGpqgc2CcVQ0V/MLGFJ2DExzrHcPhuO2KFbx8fIDnX+9jQyrJjRuacM7x3NE+Tg2Os3llLeOZHEMThbDuH8uwuqGS69c2cHJgnLX+3Vz/aIZMLs9YJscP93cxmc9TVxGjoSqKc9AzkqZraIIb1jUSC4e4rLWG2orCSOpA5xBPvtpNRTRMVTzM6vpK9p0conc0Q0tNnK1vb6OhKkY2l+eFY/08/1ofFbEwqeo41YkI8UiYpmScTSuqcc5xon+cl44P0D+aYXhikuGJLFXxCDWJCP/mHauIhIxnDvWw+/ggaxor2fr2NipiYaAwmktn8+w5MciJ/jGyeUfYjMtba/i5VbVves2+0TvGc6/18u6NTdRWROkdKYwWe0czHDg5xPXrGqhORBnP5HilY5C2+gqakjHikfCbnmdiMkfPcBrnCkHR3j1CfVWMpmT8Tftlc3l+sL+LkMH6VJKhiSynR9L0j2Y41jfGJQ2VhIwzf7Tfub6RimiYw90jxCOFUddIOsvaxioi4RBmUF8Z9aNwx+bWGkKhmf8YT/27/uRIL0MTkyTjEaoTUSrjYToHJhhNZxnL5AiH4P2bV9Bal2A0neXZ9l4mc3mGxiepr4qxPpUkl3e01MTPjC6SiQjdw2neu6l5Tn87FAYiInLWMNABZBERURiIiIjCQEREUBiIiAgKAxERQWEgIiIoDEREBIWBiIiwhC86M7Me4NgcH94EnJ7HchYT9W3pWs79W859g6XTv7c551IzbViyYVAKM9s121V4S536tnQt5/4t577B8uifpolERERhICIiwQ2DB8pdwAJS35au5dy/5dw3WAb9C+QxAxERebOgjgxERKSIwkBERIIVBmZ2u5kdNLN2M7u33PWcLzN7yMy6zWxvUVuDme0ws8P+vt63m5l9yfdxj5ldW/SYbX7/w2a2rRx9mc7MVpvZU2a238z2mdknffuS75+ZJczseTN72fftv/n2tWa20/fhm2YW8+1xv97ut68peq5P+/aDZnZbmbr0FmYWNrOXzOyf/fpy6tvrZvaKme02s12+bcm/LmflnAvEDQgDR4B1QAx4Gdhc7rrOs/ZfAK4F9ha1/Tlwr1++F/icX74D+B5gwA3ATt/eABz19/V+uX4R9K0VuNYvVwOHgM3LoX++xqRfjgI7fc2PAHf59i8Dv+2Xfwf4sl++C/imX97sX69xYK1/HYfL/bvztX0K+L/AP/v15dS314GmaW1L/nU52y1II4PrgHbn3FHnXAb4BrC1zDWdF+fcM8D0b0LfCjzslx8G7ixq/6oreA6oM7NW4DZgh3OuzznXD+wAbl/w4s/BOdfpnHvRLw8DB4A2lkH/fI0jfjXqbw64GXjUt0/v21SfHwVuscIXOW8FvuGcSzvnXgPaKbyey8rMVgG/BHzFrxvLpG9nseRfl7MJUhi0AceL1k/4tqWqxTnX6ZdPAS1+ebZ+Lvr++6mDayi8g14W/fPTKLuBbgp/CI4AA865rN+luM4zffDbB4FGFmnfgL8E/gDI+/VGlk/foBDcPzCzF8zsHt+2LF6XM4mUuwApnXPOmdmSPkfYzJLAt4Dfc84NFd40Fizl/jnncsDbzawO+EfgsvJWND/M7INAt3PuBTN7T5nLWSg3Oec6zKwZ2GFmrxZvXMqvy5kEaWTQAawuWl/l25aqLj8Mxd93+/bZ+rlo+29mUQpB8HXn3Ld987LpH4BzbgB4CngnhSmEqTdixXWe6YPfXgv0sjj79i7gw2b2OoUp15uBL7I8+gaAc67D33dTCPLrWGavy2JBCoOfAhv92Q4xCgextpe5plJsB6bOTNgGPFbU/nF/dsMNwKAf1j4O3Gpm9f4MiFt9W1n5eeMHgQPOuc8XbVry/TOzlB8RYGYVwPspHBN5Cvio321636b6/FHgSVc4CrkduMufkbMW2Ag8f1E6MQvn3Kedc6ucc2so/F960jn3MZZB3wDMrMrMqqeWKbye9rIMXpezKvcR7It5o3DE/xCFeds/Knc9F1D33wOdwCSFOce7Kcy3PgEcBn4INPh9Dfhr38dXgC1Fz/MbFA7QtQO/Xu5++ZpuojA3uwfY7W93LIf+AVcBL/m+7QU+49vXUfiD1w78AxD37Qm/3u63ryt6rj/yfT4IfKDcfZvWz/fws7OJlkXffD9e9rd9U38vlsPrcrabPo5CREQCNU0kIiKzUBiIiIjCQEREFAYiIoLCQEREUBiIiAgKAxERAf4/Vqzj2EHm6PMAAAAASUVORK5CYII=\n", 266 | "text/plain": [ 267 | "
" 268 | ] 269 | }, 270 | "metadata": { 271 | "needs_background": "light" 272 | }, 273 | "output_type": "display_data" 274 | } 275 | ], 276 | "source": [ 277 | "from matplotlib import pyplot as plt\n", 278 | "%matplotlib inline\n", 279 | "\n", 280 | "plt.plot(plt_x, plt_y)\n", 281 | "plt.show()" 282 | ] 283 | } 284 | ], 285 | "metadata": { 286 | "kernelspec": { 287 | "display_name": "Python 3 (ipykernel)", 288 | "language": "python", 289 | "name": "python3" 290 | }, 291 | "language_info": { 292 | "codemirror_mode": { 293 | "name": "ipython", 294 | "version": 3 295 | }, 296 | "file_extension": ".py", 297 | "mimetype": "text/x-python", 298 | "name": "python", 299 | "nbconvert_exporter": "python", 300 | "pygments_lexer": "ipython3", 301 | "version": "3.8.11" 302 | } 303 | }, 304 | "nbformat": 4, 305 | "nbformat_minor": 5 306 | } 307 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 视频课程地址:https://www.bilibili.com/video/BV1M64y187qX/ 2 | --------------------------------------------------------------------------------