├── .gitignore ├── .travis.yml ├── README.md ├── package-lock.json ├── package.json ├── public ├── favicon.ico ├── index.html └── manifest.json └── src ├── App.css ├── App.js ├── App.test.js ├── index.css ├── index.js └── registerServiceWorker.js /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/ignore-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | 6 | # testing 7 | /coverage 8 | 9 | # production 10 | /build 11 | 12 | # misc 13 | .DS_Store 14 | .env.local 15 | .env.development.local 16 | .env.test.local 17 | .env.production.local 18 | 19 | npm-debug.log* 20 | yarn-debug.log* 21 | yarn-error.log* 22 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: node_js 2 | 3 | node_js: 4 | - stable 5 | 6 | install: 7 | - npm install 8 | 9 | script: 10 | - npm test -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Linear Regression with Gradient Descent 2 | 3 | [![Build Status](https://travis-ci.org/javascript-machine-learning/linear-regression-gradient-descent.svg?branch=master)](https://travis-ci.org/javascript-machine-learning/linear-regression-gradient-descent) 4 | 5 | This example project demonstrates how the [gradient descent](http://en.wikipedia.org/wiki/Gradient_descent) algorithm may be used to solve a [linear regression](http://en.wikipedia.org/wiki/Linear_regression) problem. 6 | 7 | [Read more about it here.](https://www.robinwieruch.de/linear-regression-gradient-descent-javascript/) 8 | 9 | ![linear-regression-gradient-descent](https://user-images.githubusercontent.com/2479967/31553725-9a24097a-b065-11e7-9b89-771a614f464e.gif) 10 | 11 | ## Installation 12 | 13 | * `git clone git@github.com:javascript-machine-learning/linear-regression-gradient-descent.git` 14 | * `cd linear-regression-gradient-descent` 15 | * `npm install` 16 | * `npm start` 17 | * visit http://localhost:3000/ 18 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "react-linear-regression-gradient-descent", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "react": "^16.0.0", 7 | "react-dom": "^16.0.0", 8 | "react-scripts": "1.0.14" 9 | }, 10 | "scripts": { 11 | "start": "react-scripts start", 12 | "build": "react-scripts build", 13 | "test": "react-scripts test --env=jsdom", 14 | "eject": "react-scripts eject" 15 | } 16 | } -------------------------------------------------------------------------------- /public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/javascript-machine-learning/linear-regression-gradient-descent/bb243b997b3fd31ac536e7bbf2e996cdf67764ae/public/favicon.ico -------------------------------------------------------------------------------- /public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | 12 | 13 | 22 | React App 23 | 24 | 25 | 28 |
29 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "192x192", 8 | "type": "image/png" 9 | } 10 | ], 11 | "start_url": "./index.html", 12 | "display": "standalone", 13 | "theme_color": "#000000", 14 | "background_color": "#ffffff" 15 | } 16 | -------------------------------------------------------------------------------- /src/App.css: -------------------------------------------------------------------------------- 1 | svg { 2 | background: #F3F3F3; 3 | } 4 | 5 | circle { 6 | fill: #FFFFFF; 7 | stroke: #FF0000; 8 | stroke-width: 3; 9 | r: 5; 10 | } 11 | 12 | line { 13 | stroke-width: 3; 14 | stroke: #000000; 15 | } -------------------------------------------------------------------------------- /src/App.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import './App.css'; 3 | 4 | // adjust training set size 5 | 6 | const M = 10; 7 | 8 | // generate random training set 9 | 10 | const DATA = []; 11 | 12 | const getRandomIntFromInterval = (min, max) => 13 | Math.floor(Math.random() * (max - min + 1) + min); 14 | 15 | const createRandomPortlandHouse = () => ({ 16 | squareMeter: getRandomIntFromInterval(0, 100), 17 | price: getRandomIntFromInterval(0, 100), 18 | }); 19 | 20 | for (let i = 0; i < M; i++) { 21 | DATA.push(createRandomPortlandHouse()); 22 | } 23 | 24 | const x = DATA.map(date => date.squareMeter); 25 | const y = DATA.map(date => date.price); 26 | 27 | // linear regression and gradient descent 28 | 29 | const LEARNING_RATE = 0.0003; 30 | 31 | let thetaOne = 0; 32 | let thetaZero = 0; 33 | 34 | const hypothesis = x => thetaZero + thetaOne * x; 35 | 36 | const learn = (alpha) => { 37 | let thetaZeroSum = 0; 38 | let thetaOneSum = 0; 39 | 40 | for (let i = 0; i < M; i++) { 41 | thetaZeroSum += hypothesis(x[i]) - y[i]; 42 | thetaOneSum += (hypothesis(x[i]) - y[i]) * x[i]; 43 | } 44 | 45 | thetaZero = thetaZero - (alpha / M) * thetaZeroSum; 46 | thetaOne = thetaOne - (alpha / M) * thetaOneSum; 47 | } 48 | 49 | const cost = () => { 50 | let sum = 0; 51 | 52 | for (let i = 0; i < M; i++) { 53 | sum += Math.pow(hypothesis(x[i]) - y[i], 2); 54 | } 55 | 56 | return sum / (2 * M); 57 | } 58 | 59 | // count iterations 60 | 61 | let iteration = 0; 62 | 63 | // view 64 | 65 | class App extends React.Component { 66 | componentDidMount() { 67 | this.interval = setInterval(this.onLearn, 1); 68 | } 69 | 70 | componentWillUnmount() { 71 | clearInterval(this.interval); 72 | } 73 | 74 | onLearn = () => { 75 | learn(LEARNING_RATE); 76 | 77 | iteration++; 78 | 79 | this.forceUpdate(); 80 | } 81 | 82 | render() { 83 | return ( 84 |
85 | 89 | 90 |
91 | 92 | 93 | 94 |
95 |
96 | ); 97 | } 98 | } 99 | 100 | const Plot = ({ x, y }) => 101 | 102 | 108 | 109 | {DATA.map((date, key) => 110 | 115 | )} 116 | 117 | 118 | const Iteration = ({ iteration }) => 119 |

120 | Iteration: {iteration} 121 |

122 | 123 | const Hypothesis = () => 124 |

125 | Hypothesis: f(x) = {thetaZero.toFixed(2)} + {thetaOne.toFixed(2)}x 126 |

127 | 128 | const Cost = () => 129 |

130 | Cost: {cost().toFixed(2)} 131 |

132 | 133 | export default App; 134 | -------------------------------------------------------------------------------- /src/App.test.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import App from './App'; 4 | 5 | it('renders without crashing', () => { 6 | const div = document.createElement('div'); 7 | ReactDOM.render(, div); 8 | }); 9 | -------------------------------------------------------------------------------- /src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | padding: 0; 4 | font-family: sans-serif; 5 | } 6 | -------------------------------------------------------------------------------- /src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import './index.css'; 4 | import App from './App'; 5 | import registerServiceWorker from './registerServiceWorker'; 6 | 7 | ReactDOM.render(, document.getElementById('root')); 8 | registerServiceWorker(); 9 | -------------------------------------------------------------------------------- /src/registerServiceWorker.js: -------------------------------------------------------------------------------- 1 | // In production, we register a service worker to serve assets from local cache. 2 | 3 | // This lets the app load faster on subsequent visits in production, and gives 4 | // it offline capabilities. However, it also means that developers (and users) 5 | // will only see deployed updates on the "N+1" visit to a page, since previously 6 | // cached resources are updated in the background. 7 | 8 | // To learn more about the benefits of this model, read https://goo.gl/KwvDNy. 9 | // This link also includes instructions on opting out of this behavior. 10 | 11 | const isLocalhost = Boolean( 12 | window.location.hostname === 'localhost' || 13 | // [::1] is the IPv6 localhost address. 14 | window.location.hostname === '[::1]' || 15 | // 127.0.0.1/8 is considered localhost for IPv4. 16 | window.location.hostname.match( 17 | /^127(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}$/ 18 | ) 19 | ); 20 | 21 | export default function register() { 22 | if (process.env.NODE_ENV === 'production' && 'serviceWorker' in navigator) { 23 | // The URL constructor is available in all browsers that support SW. 24 | const publicUrl = new URL(process.env.PUBLIC_URL, window.location); 25 | if (publicUrl.origin !== window.location.origin) { 26 | // Our service worker won't work if PUBLIC_URL is on a different origin 27 | // from what our page is served on. This might happen if a CDN is used to 28 | // serve assets; see https://github.com/facebookincubator/create-react-app/issues/2374 29 | return; 30 | } 31 | 32 | window.addEventListener('load', () => { 33 | const swUrl = `${process.env.PUBLIC_URL}/service-worker.js`; 34 | 35 | if (!isLocalhost) { 36 | // Is not local host. Just register service worker 37 | registerValidSW(swUrl); 38 | } else { 39 | // This is running on localhost. Lets check if a service worker still exists or not. 40 | checkValidServiceWorker(swUrl); 41 | } 42 | }); 43 | } 44 | } 45 | 46 | function registerValidSW(swUrl) { 47 | navigator.serviceWorker 48 | .register(swUrl) 49 | .then(registration => { 50 | registration.onupdatefound = () => { 51 | const installingWorker = registration.installing; 52 | installingWorker.onstatechange = () => { 53 | if (installingWorker.state === 'installed') { 54 | if (navigator.serviceWorker.controller) { 55 | // At this point, the old content will have been purged and 56 | // the fresh content will have been added to the cache. 57 | // It's the perfect time to display a "New content is 58 | // available; please refresh." message in your web app. 59 | console.log('New content is available; please refresh.'); 60 | } else { 61 | // At this point, everything has been precached. 62 | // It's the perfect time to display a 63 | // "Content is cached for offline use." message. 64 | console.log('Content is cached for offline use.'); 65 | } 66 | } 67 | }; 68 | }; 69 | }) 70 | .catch(error => { 71 | console.error('Error during service worker registration:', error); 72 | }); 73 | } 74 | 75 | function checkValidServiceWorker(swUrl) { 76 | // Check if the service worker can be found. If it can't reload the page. 77 | fetch(swUrl) 78 | .then(response => { 79 | // Ensure service worker exists, and that we really are getting a JS file. 80 | if ( 81 | response.status === 404 || 82 | response.headers.get('content-type').indexOf('javascript') === -1 83 | ) { 84 | // No service worker found. Probably a different app. Reload the page. 85 | navigator.serviceWorker.ready.then(registration => { 86 | registration.unregister().then(() => { 87 | window.location.reload(); 88 | }); 89 | }); 90 | } else { 91 | // Service worker found. Proceed as normal. 92 | registerValidSW(swUrl); 93 | } 94 | }) 95 | .catch(() => { 96 | console.log( 97 | 'No internet connection found. App is running in offline mode.' 98 | ); 99 | }); 100 | } 101 | 102 | export function unregister() { 103 | if ('serviceWorker' in navigator) { 104 | navigator.serviceWorker.ready.then(registration => { 105 | registration.unregister(); 106 | }); 107 | } 108 | } 109 | --------------------------------------------------------------------------------