└── GradingDesent ├── LinearRegressionTest.py └── data.txt /GradingDesent/LinearRegressionTest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from matplotlib import pyplot as plt 4 | 5 | path = 'C:\\Users\\Administrator\\Desktop\\data.txt' 6 | data = pd.read_csv(path, header=None) 7 | plt.scatter(data[:][0], data[:][1], marker='+') 8 | data = np.array(data) 9 | m = data.shape[0] 10 | theta = np.array([0, 0]) 11 | data = np.hstack([np.ones([m, 1]), data]) 12 | y = data[:, 2] 13 | data = data[:, :2] 14 | 15 | 16 | def cost_function(data, theta, y): 17 | cost = np.sum((data.dot(theta) - y) ** 2) 18 | return cost / (2 * m) 19 | 20 | 21 | def gradient(data, theta, y): 22 | grad = np.empty(len(theta)) 23 | grad[0] = np.sum(data.dot(theta) - y) 24 | for i in range(1, len(theta)): 25 | grad[i] = (data.dot(theta) - y).dot(data[:, i]) 26 | return grad 27 | 28 | 29 | def gradient_descent(data, theta, y, eta): 30 | while True: 31 | last_theta = theta 32 | grad = gradient(data, theta, y) 33 | theta = theta - eta * grad 34 | print(theta) 35 | if abs(cost_function(data, last_theta, y) - cost_function(data, theta, y)) < 1e-15: 36 | break 37 | return theta 38 | 39 | 40 | res = gradient_descent(data, theta, y, 0.0001) 41 | X = np.arange(3, 25) 42 | Y = res[0] + res[1] * X 43 | plt.plot(X, Y, color='r') 44 | plt.show() 45 | -------------------------------------------------------------------------------- /GradingDesent/data.txt: -------------------------------------------------------------------------------- 1 | 6.1101,17.592 2 | 5.5277,9.1302 3 | 8.5186,13.662 4 | 7.0032,11.854 5 | 5.8598,6.8233 6 | 8.3829,11.886 7 | 7.4764,4.3483 8 | 8.5781,12 9 | 6.4862,6.5987 10 | 5.0546,3.8166 11 | 5.7107,3.2522 12 | 14.164,15.505 13 | 5.734,3.1551 14 | 8.4084,7.2258 15 | 5.6407,0.71618 16 | 5.3794,3.5129 17 | 6.3654,5.3048 18 | 5.1301,0.56077 19 | 6.4296,3.6518 20 | 7.0708,5.3893 21 | 6.1891,3.1386 22 | 20.27,21.767 23 | 5.4901,4.263 24 | 6.3261,5.1875 25 | 5.5649,3.0825 26 | 18.945,22.638 27 | 12.828,13.501 28 | 10.957,7.0467 29 | 13.176,14.692 30 | 22.203,24.147 31 | 5.2524,-1.22 32 | 6.5894,5.9966 33 | 9.2482,12.134 34 | 5.8918,1.8495 35 | 8.2111,6.5426 36 | 7.9334,4.5623 37 | 8.0959,4.1164 38 | 5.6063,3.3928 39 | 12.836,10.117 40 | 6.3534,5.4974 41 | 5.4069,0.55657 42 | 6.8825,3.9115 43 | 11.708,5.3854 44 | 5.7737,2.4406 45 | 7.8247,6.7318 46 | 7.0931,1.0463 47 | 5.0702,5.1337 48 | 5.8014,1.844 49 | 11.7,8.0043 50 | 5.5416,1.0179 51 | 7.5402,6.7504 52 | 5.3077,1.8396 53 | 7.4239,4.2885 54 | 7.6031,4.9981 55 | 6.3328,1.4233 56 | 6.3589,-1.4211 57 | 6.2742,2.4756 58 | 5.6397,4.6042 59 | 9.3102,3.9624 60 | 9.4536,5.4141 61 | 8.8254,5.1694 62 | 5.1793,-0.74279 63 | 21.279,17.929 64 | 14.908,12.054 65 | 18.959,17.054 66 | 7.2182,4.8852 67 | 8.2951,5.7442 68 | 10.236,7.7754 69 | 5.4994,1.0173 70 | 20.341,20.992 71 | 10.136,6.6799 72 | 7.3345,4.0259 73 | 6.0062,1.2784 74 | 7.2259,3.3411 75 | 5.0269,-2.6807 76 | 6.5479,0.29678 77 | 7.5386,3.8845 78 | 5.0365,5.7014 79 | 10.274,6.7526 80 | 5.1077,2.0576 81 | 5.7292,0.47953 82 | 5.1884,0.20421 83 | 6.3557,0.67861 84 | 9.7687,7.5435 85 | 6.5159,5.3436 86 | 8.5172,4.2415 87 | 9.1802,6.7981 88 | 6.002,0.92695 89 | 5.5204,0.152 90 | 5.0594,2.8214 91 | 5.7077,1.8451 92 | 7.6366,4.2959 93 | 5.8707,7.2029 94 | 5.3054,1.9869 95 | 8.2934,0.14454 96 | 13.394,9.0551 97 | 5.4369,0.61705 98 | --------------------------------------------------------------------------------