├── .github └── FUNDING.yml ├── favicon.png ├── screenshot.png ├── css └── main.css ├── manifest.json ├── js ├── CannonBall.js ├── Cannon.js ├── main.js ├── Matrix.js └── Dejavu.js ├── LICENSE ├── index.html └── README.md /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | patreon: victorqribeiro 2 | -------------------------------------------------------------------------------- /favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorqribeiro/bangBangML/HEAD/favicon.png -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/victorqribeiro/bangBangML/HEAD/screenshot.png -------------------------------------------------------------------------------- /css/main.css: -------------------------------------------------------------------------------- 1 | html, body{ 2 | margin: 0; 3 | padding: 0; 4 | } 5 | canvas{ 6 | display: block; 7 | } 8 | -------------------------------------------------------------------------------- /manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Bang Bang with a Neural Network", 3 | "short_name": "BangBangML", 4 | "description": "Watch as a Neural Network learns to shoot a target", 5 | "start_url": "index.html", 6 | "display": "standalone", 7 | "orientation": "portrait", 8 | "background_color": "#FFF", 9 | "theme_color": "#FFF", 10 | "icons": [ 11 | { 12 | "src": "favicon.png", 13 | "sizes": "256x256", 14 | "type": "image/png", 15 | "density": 1 16 | }, 17 | { 18 | "src": "favicon.png", 19 | "sizes": "512x512", 20 | "type": "image/png", 21 | "density": 1 22 | } 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /js/CannonBall.js: -------------------------------------------------------------------------------- 1 | class CannonBall { 2 | 3 | constructor(pos, angle, strength){ 4 | this.pos = { 5 | x: pos.x || 0, 6 | y: pos.y || 0 7 | } 8 | this.angle = angle || 0 9 | this.strength = strength || 0 10 | this.acc = { 11 | x: Math.cos( this.angle ) * this.strength, 12 | y: Math.sin( this.angle ) * this.strength 13 | } 14 | this.TWOPI = 2 * Math.PI 15 | } 16 | 17 | isGone(){ 18 | return this.pos.y > h 19 | } 20 | 21 | update(){ 22 | 23 | if( this.isGone() ) 24 | 25 | return 26 | 27 | this.pos.x += this.acc.x 28 | this.pos.y += this.acc.y 29 | 30 | this.acc.x *= 0.975 31 | this.acc.y *= 0.975 32 | 33 | this.acc.y += Math.sin( Math.PI/2 ) 34 | 35 | } 36 | 37 | show(){ 38 | 39 | c.beginPath() 40 | c.arc( this.pos.x, this.pos.y, 5, 0, this.TWOPI ) 41 | c.fillStyle = "black" 42 | c.fill() 43 | 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Victor Ribeiro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /js/Cannon.js: -------------------------------------------------------------------------------- 1 | class Cannon { 2 | 3 | constructor(pos){ 4 | this.pos = { 5 | x: pos.x || 0, 6 | y: pos.y || 0 7 | } 8 | this.angle = 0 9 | this.strength = 10 10 | this.isRotatingUp = false 11 | this.isRotatingDown = false 12 | this.isIncreasingStrength = false 13 | this.isDecreasingStrength = false 14 | } 15 | 16 | update(){ 17 | 18 | if( this.isRotatingUp && this.angle > -Math.PI/2) 19 | this.angle -= 0.1 20 | else if( this.isRotatingDown && this.angle < 0) 21 | this.angle += 0.1 22 | 23 | if( this.isIncreasingStrength && this.strength < 100) 24 | this.strength += 1 25 | else if( this.isDecreasingStrength && this.strength > 0) 26 | this.strength -= 1 27 | 28 | } 29 | 30 | show(){ 31 | c.save() 32 | c.translate(this.pos.x, this.pos.y) 33 | c.fillStyle = "red" 34 | c.strokeStyle = "black" 35 | c.fillRect( -30, -30, 10, -this.strength ) 36 | c.strokeRect( -30, -30, 10, -100) 37 | c.fillStyle = "black" 38 | c.beginPath() 39 | c.arc(0, 0, 30, 0, Math.PI, true) 40 | c.lineTo(-30,10) 41 | c.lineTo(30,10) 42 | c.fill() 43 | c.rotate(this.angle) 44 | c.fillRect(-5,-7,50,14) 45 | c.restore() 46 | 47 | } 48 | 49 | shoot(){ 50 | cannonBall = new CannonBall(this.pos, this.angle, this.strength ) 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | bangBangML 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self Taught Neural Network 2 | 3 | Watch a Neural Network learns how to shoot a target. 4 | 5 | ![screenshot](screenshot.png) 6 | 7 | [Live version](https://victorribeiro.com/bangBangML) 8 | 9 | [Watch the video of the Neural Network learning ~ 20 minutes](https://www.youtube.com/watch?v=3FB92OOYqPI) 10 | 11 | ## About 12 | 13 | Halfway through creating a clone of the classic windows game [Bang Bang](https://archive.org/details/win3_BANGBANG) I realized I need a interesting Artificial Intelligence to play against the player. So I thought about having the opponent cannon be controlled by a Neural Network and learn how to shoot during run time. 14 | 15 | 16 | ## Training 17 | 18 | I came up with this algorithm to train the Neural Network. 19 | 20 | 1 - Shoot it at random 21 | 22 | 2 - If the shot was to the left of the target, adjust the weights to the right and vice and versa. On this step I'm not saving the training data, as I don't care for a miss shot. 23 | 24 | 3 - After a hit, collect the data and use it to further train the Neural Network. This data is stored as it is a good example of how to shoot a target.   25 | 26 | ## Application 27 | 28 | I'm thinking about using different iterations of the trained Neural Network to use as a difficult level. Neural Networks trained for a short while would be the first opponents while the Neural Network that ware trained for a long period of time would come last, as a more hard opponent. 29 | -------------------------------------------------------------------------------- /js/main.js: -------------------------------------------------------------------------------- 1 | 2 | let canvas, c, w, h, cannon, cannonBall, nn, target, u, choice, x, y, totalShots, hits, _x, _y 3 | 4 | const $ = _ => document.querySelector(_) 5 | 6 | const $c = _ => document.createElement(_) 7 | 8 | const init = () => { 9 | 10 | totalShots = hits = accuracy = 0 11 | 12 | x = [] 13 | y = [] 14 | 15 | canvas = $c('canvas') 16 | canvas.width = w = innerWidth 17 | canvas.height = h = innerHeight 18 | c = canvas.getContext('2d') 19 | 20 | c.font = "30px Arial"; 21 | 22 | $('body').appendChild( canvas ) 23 | 24 | cannonBall = null 25 | 26 | cannon = new Cannon({x:40 , y:h}) 27 | 28 | nn = new Dejavu([4,2,2],0.01,10) 29 | 30 | target = { 31 | x: Math.random() * (w-w/2) + w/2, 32 | y: h 33 | } 34 | 35 | mainLoop() 36 | 37 | } 38 | 39 | const mainLoop = () => { 40 | 41 | if( !cannonBall ){ 42 | _x = [target.x/w, target.y/h, cannon.angle / (-Math.PI / 2), cannon.strength / 100] 43 | choice = nn.predict( _x ).data 44 | cannon.angle = choice[0] * (-Math.PI/2) 45 | cannon.strength = choice[1] * 100 46 | cannon.shoot() 47 | totalShots += 1 48 | } 49 | 50 | if( cannonBall && cannonBall.isGone() ){ 51 | if( target.x < cannonBall.pos.x ) 52 | _y = [1,0] 53 | else 54 | _y = [0,1] 55 | nn.fit( [_x], [_y] ) 56 | cannonBall = null 57 | } 58 | 59 | if( cannonBall && dist(target.x, cannonBall.pos.x, target.y, cannonBall.pos.y) < 60 ){ 60 | hits += 1 61 | 62 | x.push( _x ) 63 | y.push( [choice[0], choice[1]] ) 64 | nn.shuffle( x, y ) 65 | nn.fit( x, y ) 66 | 67 | target = { 68 | x: Math.random() * (w-w/2) + w/2, 69 | y: h 70 | } 71 | cannonBall = null 72 | } 73 | 74 | 75 | cannon.update() 76 | if( cannonBall ) 77 | cannonBall.update() 78 | draw() 79 | u = requestAnimationFrame( mainLoop ) 80 | 81 | } 82 | 83 | const dist = (x1,x2,y1,y2) => { 84 | return Math.sqrt( (x1-x2) ** 2 + (y1-y2) ** 2 ) 85 | } 86 | 87 | const draw = () => { 88 | c.clearRect(0,0,w,h) 89 | c.fillStyle = "black" 90 | c.fillText("Accuracy: "+parseFloat(hits/totalShots).toFixed(4), 20, 30 ) 91 | cannon.show() 92 | if( cannonBall ) 93 | cannonBall.show() 94 | if( target ){ 95 | c.beginPath() 96 | c.arc(target.x, target.y, 30, 0, Math.PI * 2 ) 97 | c.fillStyle = "red" 98 | c.fill() 99 | } 100 | } 101 | 102 | init() 103 | -------------------------------------------------------------------------------- /js/Matrix.js: -------------------------------------------------------------------------------- 1 | class Matrix { 2 | 3 | constructor(rows, cols, values = 0){ 4 | this.rows = rows || 0; 5 | this.cols = cols || 0; 6 | if( values instanceof Array ){ 7 | this.data = values.slice(); 8 | }else if( values == "RANDOM" ){ 9 | this.data = Array( this.rows * this.cols ).fill().map( _ => Math.random() * 2 - 1 ); 10 | }else{ 11 | this.data = Array( this.rows * this.cols ).fill( values ); 12 | } 13 | } 14 | 15 | multiply(b){ 16 | 17 | if( b.rows !== this.cols ){ 18 | throw new Error('Cols from Matrix A different from Rows of Matrix B'); 19 | return; 20 | } 21 | 22 | let result = new Matrix( this.rows, b.cols ); 23 | 24 | for(let i = 0; i < this.rows; i++){ 25 | for(let j = 0; j < b.cols; j++){ 26 | let s = 0; 27 | for(let k = 0; k < this.cols; k++){ 28 | s += this.data[ i * this.cols + k ] * b.data[ k * b.cols + j ]; 29 | } 30 | result.data[ i * result.cols + j ] = s; 31 | } 32 | } 33 | return result; 34 | } 35 | 36 | transpose(){ 37 | for(let i = 0; i < this.rows; i++){ 38 | for(let j = 0; j < this.cols; j++){ 39 | let temp = this.data[ i * this.cols + j ]; 40 | this.data[ i * this.cols + j ] = this.data[ j * this.rows + i ]; 41 | this.data[ j * this.rows + i ] = temp; 42 | } 43 | } 44 | let temp = this.cols; 45 | this.cols = this.rows; 46 | this.rows = temp; 47 | } 48 | 49 | add(a){ 50 | if( this.rows != a.rows || this.cols != a.cols ){ 51 | throw new Error('Cant add Matrix of different sizes!'); 52 | return; 53 | } 54 | for(let i = 0; i < this.data.length; i++){ 55 | this.data[i] += a.data[i]; 56 | } 57 | } 58 | 59 | subtract(a){ 60 | if( this.rows != a.rows || this.cols != a.cols ){ 61 | throw new Error('Cant subtract Matrix of different sizes!'); 62 | return; 63 | } 64 | for(let i = 0; i < this.data.length; i++){ 65 | this.data[i] -= a.data[i]; 66 | } 67 | } 68 | 69 | scalar(a){ 70 | for(let i = 0; i < this.data.length; i++){ 71 | this.data[i] *= a; 72 | } 73 | } 74 | 75 | hadamard(a){ 76 | if( this.rows != a.rows || this.cols != a.cols ){ 77 | throw new Error('Cant multiply Matrix of different sizes!'); 78 | return; 79 | } 80 | for(let i = 0; i < this.data.length; i++){ 81 | this.data[i] *= a.data[i]; 82 | } 83 | } 84 | 85 | copy(){ 86 | return new Matrix( this.rows, this.cols, this.data ); 87 | } 88 | 89 | foreach( func ){ 90 | for(let i = 0; i < this.data.length; i++){ 91 | this.data[i] = func( this.data[i] ); 92 | } 93 | } 94 | 95 | } 96 | -------------------------------------------------------------------------------- /js/Dejavu.js: -------------------------------------------------------------------------------- 1 | class Dejavu { 2 | 3 | constructor( nn=[0] , learningRate = 0.1, iterations = 100){ 4 | 5 | this.layers = { 'length': 0 }; 6 | 7 | for(let i = 0; i < nn.length-1; i++){ 8 | 9 | this.layers[i] = {}; 10 | this.layers[i]['weights'] = new Matrix( nn[i+1], nn[i], "RANDOM" ) ; 11 | this.layers[i]['bias'] = new Matrix( nn[i+1], 1, "RANDOM" ); 12 | this.layers[i]['activation'] = 'sigmoid'; 13 | this.layers['length'] += 1; 14 | 15 | } 16 | 17 | this.lr = learningRate; 18 | 19 | this.it = iterations; 20 | 21 | this.activations = { 22 | 'sigmoid': { 23 | 'func': x => 1 / (1 + Math.exp(-x)), 24 | 'dfunc': x => x * (1 - x) 25 | }, 26 | 'relu': { 27 | 'func': x => x > 0 ? x : 0, 28 | 'dfunc': x => x > 0 ? 1 : 0 29 | }, 30 | 'tanh': { 31 | 'func': x => Math.tanh(x), 32 | 'dfunc': x => 1 - (x * x) 33 | } 34 | }; 35 | } 36 | 37 | predict(input){ 38 | 39 | let output = ( input instanceof Matrix ) ? input : new Matrix(input.length, 1, input); 40 | 41 | for(let i = 0; i < this.layers.length; i++){ 42 | 43 | output = this.layers[i]['weights'].multiply( output ) 44 | this.layers[i]['output'] = output; 45 | this.layers[i]['output'].add( this.layers[i]['bias'] ); 46 | this.layers[i]['output'].foreach( this.activations[ this.layers[i]['activation'] ]['func'] ); 47 | 48 | } 49 | 50 | return this.layers[this.layers.length-1]['output']; 51 | 52 | } 53 | 54 | fit(inputs, labels){ 55 | 56 | let it = 0; 57 | 58 | while( it < this.it ){ 59 | 60 | let s = 0; 61 | 62 | //this.shuffle( inputs, labels ); 63 | 64 | for(let i = 0; i < inputs.length; i++){ 65 | 66 | const input = new Matrix(inputs[i].length, 1, inputs[i]); 67 | 68 | this.predict( input ); 69 | 70 | let output_error = new Matrix(labels[0].length, 1, labels[i]); 71 | 72 | output_error.subtract( this.layers[this.layers.length-1]['output'] ); 73 | 74 | let sum = 0 75 | for(let i = 0; i < output_error.data.length; i++){ 76 | sum += output_error.data[i] ** 2; 77 | } 78 | s += sum/this.layers[this.layers.length-1]['output'].rows; 79 | 80 | for(let i = this.layers.length-1; i >= 0; i--){ 81 | 82 | let gradient = this.layers[i]['output'].copy(); 83 | gradient.foreach( this.activations[ this.layers[i]['activation'] ]['dfunc'] ); 84 | gradient.hadamard( output_error ); 85 | gradient.scalar( this.lr ); 86 | 87 | let layer = ( i ) ? this.layers[i-1]['output'].copy() : input.copy(); 88 | layer.transpose(); 89 | let delta = gradient.multiply( layer ); 90 | 91 | this.layers[i]['weights'].add( delta ); 92 | this.layers[i]['bias'].add( gradient ); 93 | 94 | let error = this.layers[i]['weights'].copy() 95 | error.transpose(); 96 | output_error = error.multiply( output_error ); 97 | 98 | } 99 | 100 | 101 | } 102 | 103 | it++; 104 | 105 | //if( !(it % 1) ) console.log( it, s ); 106 | 107 | } 108 | 109 | } 110 | 111 | shuffle(x,y){ 112 | 113 | for(let i = 0; i < y.length; i++){ 114 | const pos = Math.floor( Math.random() * y.length ); 115 | const tmpy = y[i]; 116 | const tmpx = x[i]; 117 | y[i] = y[pos]; 118 | x[i] = x[pos]; 119 | y[pos] = tmpy; 120 | x[pos] = tmpx; 121 | } 122 | 123 | } 124 | 125 | save(filename){ 126 | 127 | const nn = { 128 | 'layers': this.layers, 129 | 'lr': this.lr, 130 | 'it': this.it 131 | }; 132 | const blob = new Blob([JSON.stringify(nn)], {type: 'text/json'}); 133 | const link = document.createElement('a'); 134 | link.href = window.URL.createObjectURL(blob); 135 | link.download = filename; 136 | link.click(); 137 | 138 | } 139 | 140 | load(nn){ 141 | 142 | this.lr = nn.lr; 143 | 144 | this.it = nn.it; 145 | 146 | for(let i = 0; i < nn.layers.length; i++){ 147 | 148 | const layer = nn.layers[i]; 149 | 150 | this.layers[i] = {}; 151 | this.layers[i]['weights'] = new Matrix( layer['weights'].rows, layer['weights'].cols, layer['weights'].data ) ; 152 | this.layers[i]['bias'] = new Matrix( layer['bias'].rows, layer['bias'].cols, layer['bias'].data ); 153 | this.layers[i]['output'] = new Matrix( layer['output'].rows, layer['output'].cols, layer['output'].data ); 154 | this.layers[i]['activation'] = layer['activation']; 155 | this.layers['length'] += 1; 156 | 157 | } 158 | 159 | console.log( 'loaded' ); 160 | 161 | } 162 | 163 | } 164 | --------------------------------------------------------------------------------