├── LICENSE ├── data.csv ├── gradient_descent_example.gif ├── gradient_descent_example.py └── readme.md /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Matt Nedrich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data.csv: -------------------------------------------------------------------------------- 1 | 32.502345269453031,31.70700584656992 2 | 53.426804033275019,68.77759598163891 3 | 61.530358025636438,62.562382297945803 4 | 47.475639634786098,71.546632233567777 5 | 59.813207869512318,87.230925133687393 6 | 55.142188413943821,78.211518270799232 7 | 52.211796692214001,79.64197304980874 8 | 39.299566694317065,59.171489321869508 9 | 48.10504169176825,75.331242297063056 10 | 52.550014442733818,71.300879886850353 11 | 45.419730144973755,55.165677145959123 12 | 54.351634881228918,82.478846757497919 13 | 44.164049496773352,62.008923245725825 14 | 58.16847071685779,75.392870425994957 15 | 56.727208057096611,81.43619215887864 16 | 48.955888566093719,60.723602440673965 17 | 44.687196231480904,82.892503731453715 18 | 60.297326851333466,97.379896862166078 19 | 45.618643772955828,48.847153317355072 20 | 38.816817537445637,56.877213186268506 21 | 66.189816606752601,83.878564664602763 22 | 65.41605174513407,118.59121730252249 23 | 47.48120860786787,57.251819462268969 24 | 41.57564261748702,51.391744079832307 25 | 51.84518690563943,75.380651665312357 26 | 59.370822011089523,74.765564032151374 27 | 57.31000343834809,95.455052922574737 28 | 63.615561251453308,95.229366017555307 29 | 46.737619407976972,79.052406169565586 30 | 50.556760148547767,83.432071421323712 31 | 52.223996085553047,63.358790317497878 32 | 35.567830047746632,41.412885303700563 33 | 42.436476944055642,76.617341280074044 34 | 58.16454011019286,96.769566426108199 35 | 57.504447615341789,74.084130116602523 36 | 45.440530725319981,66.588144414228594 37 | 61.89622268029126,77.768482417793024 38 | 33.093831736163963,50.719588912312084 39 | 36.436009511386871,62.124570818071781 40 | 37.675654860850742,60.810246649902211 41 | 44.555608383275356,52.682983366387781 42 | 43.318282631865721,58.569824717692867 43 | 50.073145632289034,82.905981485070512 44 | 43.870612645218372,61.424709804339123 45 | 62.997480747553091,115.24415280079529 46 | 32.669043763467187,45.570588823376085 47 | 40.166899008703702,54.084054796223612 48 | 53.575077531673656,87.994452758110413 49 | 33.864214971778239,52.725494375900425 50 | 64.707138666121296,93.576118692658241 51 | 38.119824026822805,80.166275447370964 52 | 44.502538064645101,65.101711570560326 53 | 40.599538384552318,65.562301260400375 54 | 41.720676356341293,65.280886920822823 55 | 51.088634678336796,73.434641546324301 56 | 55.078095904923202,71.13972785861894 57 | 41.377726534895203,79.102829683549857 58 | 62.494697427269791,86.520538440347153 59 | 49.203887540826003,84.742697807826218 60 | 41.102685187349664,59.358850248624933 61 | 41.182016105169822,61.684037524833627 62 | 50.186389494880601,69.847604158249183 63 | 52.378446219236217,86.098291205774103 64 | 50.135485486286122,59.108839267699643 65 | 33.644706006191782,69.89968164362763 66 | 39.557901222906828,44.862490711164398 67 | 56.130388816875467,85.498067778840223 68 | 57.362052133238237,95.536686846467219 69 | 60.269214393997906,70.251934419771587 70 | 35.678093889410732,52.721734964774988 71 | 31.588116998132829,50.392670135079896 72 | 53.66093226167304,63.642398775657753 73 | 46.682228649471917,72.247251068662365 74 | 43.107820219102464,57.812512976181402 75 | 70.34607561504933,104.25710158543822 76 | 44.492855880854073,86.642020318822006 77 | 57.50453330326841,91.486778000110135 78 | 36.930076609191808,55.231660886212836 79 | 55.805733357942742,79.550436678507609 80 | 38.954769073377065,44.847124242467601 81 | 56.901214702247074,80.207523139682763 82 | 56.868900661384046,83.14274979204346 83 | 34.33312470421609,55.723489260543914 84 | 59.04974121466681,77.634182511677864 85 | 57.788223993230673,99.051414841748269 86 | 54.282328705967409,79.120646274680027 87 | 51.088719898979143,69.588897851118475 88 | 50.282836348230731,69.510503311494389 89 | 44.211741752090113,73.687564318317285 90 | 38.005488008060688,61.366904537240131 91 | 32.940479942618296,67.170655768995118 92 | 53.691639571070056,85.668203145001542 93 | 68.76573426962166,114.85387123391394 94 | 46.230966498310252,90.123572069967423 95 | 68.319360818255362,97.919821035242848 96 | 50.030174340312143,81.536990783015028 97 | 49.239765342753763,72.111832469615663 98 | 50.039575939875988,85.232007342325673 99 | 48.149858891028863,66.224957888054632 100 | 25.128484647772304,53.454394214850524 101 | -------------------------------------------------------------------------------- /gradient_descent_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mattnedrich/GradientDescentExample/a13134d5ba83bc47d33629315071afc8d32fe072/gradient_descent_example.gif -------------------------------------------------------------------------------- /gradient_descent_example.py: -------------------------------------------------------------------------------- 1 | from numpy import * 2 | 3 | # y = mx + b 4 | # m is slope, b is y-intercept 5 | def compute_error_for_line_given_points(b, m, points): 6 | totalError = 0 7 | for i in range(0, len(points)): 8 | x = points[i, 0] 9 | y = points[i, 1] 10 | totalError += (y - (m * x + b)) ** 2 11 | return totalError / float(len(points)) 12 | 13 | def step_gradient(b_current, m_current, points, learningRate): 14 | b_gradient = 0 15 | m_gradient = 0 16 | N = float(len(points)) 17 | for i in range(0, len(points)): 18 | x = points[i, 0] 19 | y = points[i, 1] 20 | b_gradient += -(2/N) * (y - ((m_current * x) + b_current)) 21 | m_gradient += -(2/N) * x * (y - ((m_current * x) + b_current)) 22 | new_b = b_current - (learningRate * b_gradient) 23 | new_m = m_current - (learningRate * m_gradient) 24 | return [new_b, new_m] 25 | 26 | def gradient_descent_runner(points, starting_b, starting_m, learning_rate, num_iterations): 27 | b = starting_b 28 | m = starting_m 29 | for i in range(num_iterations): 30 | b, m = step_gradient(b, m, array(points), learning_rate) 31 | return [b, m] 32 | 33 | def run(): 34 | points = genfromtxt("data.csv", delimiter=",") 35 | learning_rate = 0.0001 36 | initial_b = 0 # initial y-intercept guess 37 | initial_m = 0 # initial slope guess 38 | num_iterations = 1000 39 | print "Starting gradient descent at b = {0}, m = {1}, error = {2}".format(initial_b, initial_m, compute_error_for_line_given_points(initial_b, initial_m, points)) 40 | print "Running..." 41 | [b, m] = gradient_descent_runner(points, initial_b, initial_m, learning_rate, num_iterations) 42 | print "After {0} iterations b = {1}, m = {2}, error = {3}".format(num_iterations, b, m, compute_error_for_line_given_points(b, m, points)) 43 | 44 | if __name__ == '__main__': 45 | run() 46 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Gradient Descent Example for Linear Regression 2 | This example project demonstrates how the [gradient descent](http://en.wikipedia.org/wiki/Gradient_descent) algorithm may be used to solve a [linear regression](http://en.wikipedia.org/wiki/Linear_regression) problem. A more detailed description of this example can be found [here](https://spin.atomicobject.com/2014/06/24/gradient-descent-linear-regression/). 3 | 4 | ### Code Requirements 5 | The example code is in Python ([version 2.6](https://www.python.org/doc/versions/) or higher will work). The only other requirement is [NumPy](http://www.numpy.org/). 6 | 7 | ### Description 8 | This code demonstrates how a gradient descent search may be used to solve the linear regression problem of fitting a line to a set of points. In this problem, we wish to model a set of points using a line. The line model is defined by two parameters - the line's slope `m`, and y-intercept `b`. Gradient descent attemps to find the best values for these parameters, subject to an error function. 9 | 10 | The code contains a main function called `run`. This function defines a set of parameters used in the gradient descent algorithm including an initial guess of the line slope and y-intercept, the learning rate to use, and the number of iterations to run gradient descent for. 11 | 12 | ```python 13 | initial_b = 0 # initial y-intercept guess 14 | initial_m = 0 # initial slope guess 15 | num_iterations = 1000 16 | ``` 17 | 18 | Using these parameters a gradient descent search is executed on a sample data set of 100 ponts. Here is a visualization of the search running for 200 iterations using an initial guess of `m = 0`, `b = 0`, and a learning rate of `0.000005`. 19 | 20 | 21 | 22 | ### Execution 23 | To run the example, simply run the `gradient_descent_example.py` file using Python 24 | 25 | ``` 26 | python gradient_descent_example.py 27 | ``` 28 | 29 | The output will look like this 30 | 31 | ``` 32 | Starting gradient descent at b = 0, m = 0, error = 5565.10783448 33 | Running... 34 | After 1000 iterations b = 0.0889365199374, m = 1.47774408519, error = 112.614810116 35 | ``` 36 | --------------------------------------------------------------------------------