├── .gitignore ├── LICENSE ├── README.md ├── decision-tree-demo ├── controls.js ├── demo.html ├── demo_2d.png └── style.css ├── decision-tree-min.js ├── decision-tree.js └── random-forest-demo ├── controls.js ├── demo.html ├── demo_2d.png └── style.css /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2013 Yurii Lahodiuk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | decision-tree-js 2 | ================ 3 | 4 | Small JavaScript implementation of algorithm for training [Decision Tree](http://en.wikipedia.org/wiki/Decision_tree) and [Random Forest](http://en.wikipedia.org/wiki/Random_forest) classifiers. 5 | 6 | ### Random forest demo ### 7 | 8 | Online demo: http://fiddle.jshell.net/7WsMf/show/light/ 9 | 10 | ![Random forest demo](https://raw.github.com/lagodiuk/decision-tree-js/master/random-forest-demo/demo_2d.png) 11 | 12 | ### Decision tree demo ### 13 | 14 | Online demo: http://fiddle.jshell.net/92Jxj/show/light/ 15 | 16 | ![Decision tree demo](https://raw.github.com/lagodiuk/decision-tree-js/master/decision-tree-demo/demo_2d.png) 17 | 18 | ### Toy example of usage ### 19 | Predicting sex of characters from 'The Simpsons' cartoon, using such features as weight, hair length and age 20 | 21 | Online demo: http://jsfiddle.net/xur98/ 22 | ```javascript 23 | // Training set 24 | var data = 25 | [{person: 'Homer', hairLength: 0, weight: 250, age: 36, sex: 'male'}, 26 | {person: 'Marge', hairLength: 10, weight: 150, age: 34, sex: 'female'}, 27 | {person: 'Bart', hairLength: 2, weight: 90, age: 10, sex: 'male'}, 28 | {person: 'Lisa', hairLength: 6, weight: 78, age: 8, sex: 'female'}, 29 | {person: 'Maggie', hairLength: 4, weight: 20, age: 1, sex: 'female'}, 30 | {person: 'Abe', hairLength: 1, weight: 170, age: 70, sex: 'male'}, 31 | {person: 'Selma', hairLength: 8, weight: 160, age: 41, sex: 'female'}, 32 | {person: 'Otto', hairLength: 10, weight: 180, age: 38, sex: 'male'}, 33 | {person: 'Krusty', hairLength: 6, weight: 200, age: 45, sex: 'male'}]; 34 | 35 | // Configuration 36 | var config = { 37 | trainingSet: data, 38 | categoryAttr: 'sex', 39 | ignoredAttributes: ['person'] 40 | }; 41 | 42 | // Building Decision Tree 43 | var decisionTree = new dt.DecisionTree(config); 44 | 45 | // Building Random Forest 46 | var numberOfTrees = 3; 47 | var randomForest = new dt.RandomForest(config, numberOfTrees); 48 | 49 | // Testing Decision Tree and Random Forest 50 | var comic = {person: 'Comic guy', hairLength: 8, weight: 290, age: 38}; 51 | 52 | var decisionTreePrediction = decisionTree.predict(comic); 53 | var randomForestPrediction = randomForest.predict(comic); 54 | ``` 55 | Data taken from presentation: http://www.cs.sjsu.edu/faculty/lee/cs157b/ID3-AllanNeymark.ppt 56 | -------------------------------------------------------------------------------- /decision-tree-demo/controls.js: -------------------------------------------------------------------------------- 1 | function init() { 2 | 3 | var canv = document.getElementById('myCanvas'); 4 | var clearBtn = document.getElementById('clearBtn'); 5 | var context = canv.getContext('2d'); 6 | var displayTreeDiv = document.getElementById('displayTree'); 7 | 8 | var NOT_SELECTED_COLOR_STYLE = '2px solid white'; 9 | var SELECTED_COLOR_STYLE = '2px solid black'; 10 | var colorSelectElements = document.getElementsByClassName('color-select'); 11 | for (var i = 0; i < colorSelectElements.length; i++) { 12 | colorSelectElements[i].style.backgroundColor = colorSelectElements[i].getAttribute('label'); 13 | colorSelectElements[i].style.border = NOT_SELECTED_COLOR_STYLE; 14 | } 15 | 16 | var color = colorSelectElements[0].getAttribute('label'); 17 | var POINT_RADIUS = 3; 18 | var points = []; 19 | var tree = null; 20 | var MAX_ALPHA = 128; 21 | var addingPoints = false; 22 | 23 | colorSelectElements[0].style.border = SELECTED_COLOR_STYLE; 24 | 25 | canv.addEventListener('mousedown', enableAddingPointsListener, false); 26 | 27 | canv.addEventListener('mouseup', rebuildForestListener, false); 28 | 29 | canv.addEventListener('mouseout', rebuildForestListener, false); 30 | 31 | canv.addEventListener('mousemove', addPointsListener, false); 32 | 33 | 34 | for (var i = 0; i < colorSelectElements.length; i++) { 35 | colorSelectElements[i].addEventListener('click', selectColorListener, false); 36 | } 37 | 38 | clearBtn.addEventListener('click', clearCanvasListener, false); 39 | 40 | function enableAddingPointsListener(e) { 41 | e.preventDefault(); 42 | addingPoints = true; 43 | } 44 | 45 | function addPointsListener(e) { 46 | if (addingPoints) { 47 | var x = e.offsetX ? e.offsetX : (e.layerX - canv.offsetLeft); 48 | var y = e.offsetY ? e.offsetY : (e.layerY - canv.offsetTop); 49 | 50 | drawCircle(context, x, y, POINT_RADIUS, color); 51 | points.push({ 52 | x: x, 53 | y: y, 54 | color: color 55 | }); 56 | } 57 | } 58 | 59 | function rebuildForestListener(e) { 60 | 61 | if (!addingPoints) return; 62 | 63 | if (points.length == 0) return; 64 | 65 | addingPoints = false; 66 | 67 | 68 | var threshold = Math.floor(points.length / 100); 69 | threshold = (threshold > 1) ? threshold : 1; 70 | 71 | tree = new dt.DecisionTree({ 72 | trainingSet: points, 73 | categoryAttr: 'color', 74 | minItemsCount: threshold 75 | }); 76 | 77 | displayTreePredictions(); 78 | displayPoints(); 79 | 80 | displayTreeDiv.innerHTML = treeToHtml(tree.root); 81 | } 82 | 83 | function displayTreePredictions() { 84 | context.clearRect(0, 0, canv.width, canv.height); 85 | var imageData = context.getImageData(0, 0, canv.width, canv.height); 86 | 87 | for (var x = 0; x < canv.width; x++) { 88 | for (var y = 0; y < canv.height; y++) { 89 | var predictedHexColor = tree.predict({ 90 | x: x, 91 | y: y 92 | }); 93 | putPixel(imageData, canv.width, x, y, predictedHexColor, MAX_ALPHA); 94 | } 95 | } 96 | 97 | context.putImageData(imageData, 0, 0); 98 | } 99 | 100 | function displayPoints() { 101 | for (var p in points) { 102 | drawCircle(context, points[p].x, points[p].y, POINT_RADIUS, points[p].color); 103 | } 104 | } 105 | 106 | function drawCircle(context, x, y, radius, hexColor) { 107 | context.beginPath(); 108 | context.arc(x, y, radius, 0, 2 * Math.PI, false); 109 | 110 | var c = hexToRgb(hexColor) 111 | context.fillStyle = 'rgb(' + c.r + ',' + c.g + ',' + c.b + ')'; 112 | 113 | context.fill(); 114 | context.closePath(); 115 | context.stroke(); 116 | } 117 | 118 | function putPixel(imageData, width, x, y, hexColor, alpha) { 119 | var c = hexToRgb(hexColor); 120 | var indx = (y * width + x) * 4; 121 | 122 | var currAlpha = imageData.data[indx + 3]; 123 | 124 | imageData.data[indx + 0] = (c.r * alpha + imageData.data[indx + 0] * currAlpha) / (alpha + currAlpha); 125 | imageData.data[indx + 1] = (c.g * alpha + imageData.data[indx + 1] * currAlpha) / (alpha + currAlpha); 126 | imageData.data[indx + 2] = (c.b * alpha + imageData.data[indx + 2] * currAlpha) / (alpha + currAlpha); 127 | imageData.data[indx + 3] = alpha + currAlpha; 128 | } 129 | 130 | function selectColorListener(event) { 131 | color = this.getAttribute('label'); 132 | 133 | for (var i = 0; i < colorSelectElements.length; i++) { 134 | colorSelectElements[i].style.border = NOT_SELECTED_COLOR_STYLE; 135 | } 136 | 137 | this.style.border = SELECTED_COLOR_STYLE; 138 | } 139 | 140 | function clearCanvasListener(event) { 141 | context.clearRect(0, 0, canv.width, canv.height); 142 | points = []; 143 | displayTreeDiv.innerHTML = ''; 144 | } 145 | 146 | /** 147 | * Taken from: http://stackoverflow.com/a/5624139/653511 148 | */ 149 | function hexToRgb(hex) { 150 | // Expand shorthand form (e.g. "03F") to full form (e.g. "0033FF") 151 | var shorthandRegex = /^#?([a-f\d])([a-f\d])([a-f\d])$/i; 152 | hex = hex.replace(shorthandRegex, function (m, r, g, b) { 153 | return r + r + g + g + b + b; 154 | }); 155 | 156 | var result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex); 157 | return result ? { 158 | r: parseInt(result[1], 16), 159 | g: parseInt(result[2], 16), 160 | b: parseInt(result[3], 16) 161 | } : null; 162 | } 163 | 164 | // Repeating of string taken from: http://stackoverflow.com/a/202627/653511 165 | var EMPTY_STRING = new Array(26).join(' '); 166 | 167 | // Recursively traversing decision tree (DFS) 168 | function treeToHtml(tree) { 169 | 170 | if (tree.category) { 171 | return [''].join(''); 176 | } 177 | 178 | return [''].join(''); 193 | } 194 | } -------------------------------------------------------------------------------- /decision-tree-demo/demo.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
 
12 |
 
13 |
 
14 |
 
15 | 16 | 17 |
18 | 19 | 20 |
21 | 22 |
23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /decision-tree-demo/demo_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lagodiuk/decision-tree-js/91977780b71775802aacbc6e1a415bf6e7234965/decision-tree-demo/demo_2d.png -------------------------------------------------------------------------------- /decision-tree-demo/style.css: -------------------------------------------------------------------------------- 1 | canvas { 2 | border: 1px solid #ccc; 3 | margin: 30px; 4 | margin-top: 5px; 5 | } 6 | 7 | .color-select { 8 | margin: 1px; 9 | margin-top: 10px; 10 | width : 30px; 11 | height : 30px; 12 | display:inline-block; 13 | } 14 | 15 | .color-select:first-child { 16 | margin-left: 30px; 17 | } 18 | 19 | .clearBtn { 20 | margin-top: 12px; 21 | margin-left: 15px; 22 | -moz-border-radius: 5px; 23 | -webkit-border-radius: 5px; 24 | border-radius:5px; 25 | color: green; 26 | background-color: #EDEDED; 27 | font-size: 15px; 28 | text-decoration: none; 29 | cursor: poiner; 30 | height: 30px; 31 | position: absolute; 32 | } 33 | 34 | /* 35 | Transforming nested lists to pretty tree 36 | 37 |
38 | 46 |
47 | 48 | Source: http://thecodeplayer.com/walkthrough/css3-family-tree 49 | 50 | Some other advices about displaying trees: http://stackoverflow.com/questions/1695115/how-do-i-draw-the-lines-of-a-family-tree-using-html-css 51 | */ 52 | 53 | * { 54 | margin: 0; 55 | padding: 0; 56 | } 57 | 58 | .tree ul { 59 | padding-top: 20px; 60 | position: relative; 61 | 62 | transition: all 0.5s; 63 | -webkit-transition: all 0.5s; 64 | -moz-transition: all 0.5s; 65 | } 66 | 67 | .tree li { 68 | white-space: nowrap; 69 | float: left; 70 | text-align: center; 71 | list-style-type: none; 72 | position: relative; 73 | padding: 20px 5px 0 5px; 74 | 75 | transition: all 0.5s; 76 | -webkit-transition: all 0.5s; 77 | -moz-transition: all 0.5s; 78 | } 79 | 80 | /*We will use ::before and ::after to draw the connectors*/ 81 | 82 | .tree li::before, .tree li::after{ 83 | content: ''; 84 | position: absolute; 85 | top: 0; 86 | right: 50%; 87 | border-top: 1px solid #ccc; 88 | width: 50%; 89 | height: 20px; 90 | } 91 | .tree li::after{ 92 | right: auto; 93 | left: 50%; 94 | border-left: 1px solid #ccc; 95 | } 96 | 97 | /*We need to remove left-right connectors from elements without 98 | any siblings*/ 99 | .tree li:only-child::after, .tree li:only-child::before { 100 | display: none; 101 | } 102 | 103 | /*Remove space from the top of single children*/ 104 | .tree li:only-child{ 105 | padding-top: 0; 106 | } 107 | 108 | /*Remove left connector from first child and 109 | right connector from last child*/ 110 | .tree li:first-child::before, .tree li:last-child::after{ 111 | border: 0 none; 112 | } 113 | /*Adding back the vertical connector to the last nodes*/ 114 | .tree li:last-child::before{ 115 | border-right: 1px solid #ccc; 116 | border-radius: 0 5px 0 0; 117 | -webkit-border-radius: 0 5px 0 0; 118 | -moz-border-radius: 0 5px 0 0; 119 | } 120 | .tree li:first-child::after{ 121 | border-radius: 5px 0 0 0; 122 | -webkit-border-radius: 5px 0 0 0; 123 | -moz-border-radius: 5px 0 0 0; 124 | } 125 | 126 | /*Time to add downward connectors from parents*/ 127 | .tree ul ul::before{ 128 | content: ''; 129 | position: absolute; 130 | top: 0; 131 | left: 50%; 132 | border-left: 1px solid #ccc; 133 | width: 0; 134 | height: 20px; 135 | } 136 | 137 | .tree li a{ 138 | border: 1px solid #ccc; 139 | padding: 5px 10px; 140 | text-decoration: none; 141 | color: #666; 142 | font-family: arial, verdana, tahoma; 143 | font-size: 11px; 144 | display: inline-block; 145 | 146 | border-radius: 5px; 147 | -webkit-border-radius: 5px; 148 | -moz-border-radius: 5px; 149 | 150 | transition: all 0.5s; 151 | -webkit-transition: all 0.5s; 152 | -moz-transition: all 0.5s; 153 | } 154 | 155 | /*Time for some hover effects*/ 156 | /*We will apply the hover effect the the lineage of the element also*/ 157 | .tree li a:hover, .tree li a:hover+ul li a { 158 | background: #c8e4f8; 159 | color: #000; 160 | border: 1px solid #94a0b4; 161 | } 162 | /*Connector styles on hover*/ 163 | .tree li a:hover+ul li::after, 164 | .tree li a:hover+ul li::before, 165 | .tree li a:hover+ul::before, 166 | .tree li a:hover+ul ul::before{ 167 | border-color: #94a0b4; 168 | } 169 | 170 | /*Thats all. I hope you enjoyed it. 171 | Thanks :)*/ -------------------------------------------------------------------------------- /decision-tree-min.js: -------------------------------------------------------------------------------- 1 | var dt=function(){function n(b){var c=v,e=b.trainingSet,d=b.ignoredAttributes,a={};if(d)for(var f in d)a[d[f]]=!0;this.root=c({trainingSet:e,ignoredAttributes:a,categoryAttr:b.categoryAttr||"category",minItemsCount:b.minItemsCount||1,entropyThrehold:b.entropyThrehold||0.01,maxTreeDepth:b.maxTreeDepth||70})}function p(b,c){for(var e=b.trainingSet,d=[],a=0;ad&&(d=e[f],a=f);return a}function v(b){var c=b.trainingSet,e=b.minItemsCount,d=b.categoryAttr,a=b.entropyThrehold,f=b.maxTreeDepth,n=b.ignoredAttributes;if(0==f||c.length<=e)return{category:x(c,d)};e=w(c,d);if(e<=a)return{category:x(c,d)};for(var m={},a={gain:0}, 3 | y=c.length-1;0<=y;y--){var p=c[y],k;for(k in p)if(k!=d&&!n[k]){var s=p[k],t;t="number"==typeof s?">=":"==";var r=k+t+s;if(!m[r]){m[r]=!0;var r=D[t],g;g=c;for(var l=k,z=r,h=s,q=[],B=[],u=void 0,C=void 0,A=g.length-1;0<=A;A--)u=g[A],C=u[l],z(C,h)?q.push(u):B.push(u);g={match:q,notMatch:B};l=w(g.match,d);z=w(g.notMatch,d);h=0;h+=l*g.match.length;h+=z*g.notMatch.length;h/=c.length;l=e-h;l>a.gain&&(a=g,a.predicateName=t,a.predicate=r,a.attribute=k,a.pivot=s,a.gain=l)}}}if(!a.gain)return{category:x(c,d)}; 4 | b.maxTreeDepth=f-1;b.trainingSet=a.match;c=v(b);b.trainingSet=a.notMatch;b=v(b);return{attribute:a.attribute,predicate:a.predicate,predicateName:a.predicateName,pivot:a.pivot,match:c,notMatch:b,matchedCount:a.match.length,notMatchedCount:a.notMatch.length}}n.prototype.predict=function(b){a:{for(var c=this.root,e,d,a;;){if(c.category){b=c.category;break a}e=c.attribute;e=b[e];d=c.predicate;a=c.pivot;c=d(e,a)?c.match:c.notMatch}b=void 0}return b};p.prototype.predict=function(b){var c=this.trees,e={}, 5 | d;for(d in c){var a=c[d].predict(b);e[a]=e[a]?e[a]+1:1}return e};var D={"==":function(b,c){return b==c},">=":function(b,c){return b>=c}},m={};m.DecisionTree=n;m.RandomForest=p;return m}(); -------------------------------------------------------------------------------- /decision-tree.js: -------------------------------------------------------------------------------- 1 | var dt = (function () { 2 | 3 | /** 4 | * Creates an instance of DecisionTree 5 | * 6 | * @constructor 7 | * @param builder - contains training set and 8 | * some configuration parameters 9 | */ 10 | function DecisionTree(builder) { 11 | this.root = buildDecisionTree({ 12 | trainingSet: builder.trainingSet, 13 | ignoredAttributes: arrayToHashSet(builder.ignoredAttributes), 14 | categoryAttr: builder.categoryAttr || 'category', 15 | minItemsCount: builder.minItemsCount || 1, 16 | entropyThrehold: builder.entropyThrehold || 0.01, 17 | maxTreeDepth: builder.maxTreeDepth || 70 18 | }); 19 | } 20 | 21 | DecisionTree.prototype.predict = function (item) { 22 | return predict(this.root, item); 23 | } 24 | 25 | /** 26 | * Creates an instance of RandomForest 27 | * with specific number of trees 28 | * 29 | * @constructor 30 | * @param builder - contains training set and some 31 | * configuration parameters for 32 | * building decision trees 33 | */ 34 | function RandomForest(builder, treesNumber) { 35 | this.trees = buildRandomForest(builder, treesNumber); 36 | } 37 | 38 | RandomForest.prototype.predict = function (item) { 39 | return predictRandomForest(this.trees, item); 40 | } 41 | 42 | /** 43 | * Transforming array to object with such attributes 44 | * as elements of array (afterwards it can be used as HashSet) 45 | */ 46 | function arrayToHashSet(array) { 47 | var hashSet = {}; 48 | if (array) { 49 | for(var i in array) { 50 | var attr = array[i]; 51 | hashSet[attr] = true; 52 | } 53 | } 54 | return hashSet; 55 | } 56 | 57 | /** 58 | * Calculating how many objects have the same 59 | * values of specific attribute. 60 | * 61 | * @param items - array of objects 62 | * 63 | * @param attr - variable with name of attribute, 64 | * which embedded in each object 65 | */ 66 | function countUniqueValues(items, attr) { 67 | var counter = {}; 68 | 69 | // detecting different values of attribute 70 | for (var i = items.length - 1; i >= 0; i--) { 71 | // items[i][attr] - value of attribute 72 | counter[items[i][attr]] = 0; 73 | } 74 | 75 | // counting number of occurrences of each of values 76 | // of attribute 77 | for (var i = items.length - 1; i >= 0; i--) { 78 | counter[items[i][attr]] += 1; 79 | } 80 | 81 | return counter; 82 | } 83 | 84 | /** 85 | * Calculating entropy of array of objects 86 | * by specific attribute. 87 | * 88 | * @param items - array of objects 89 | * 90 | * @param attr - variable with name of attribute, 91 | * which embedded in each object 92 | */ 93 | function entropy(items, attr) { 94 | // counting number of occurrences of each of values 95 | // of attribute 96 | var counter = countUniqueValues(items, attr); 97 | 98 | var entropy = 0; 99 | var p; 100 | for (var i in counter) { 101 | p = counter[i] / items.length; 102 | entropy += -p * Math.log(p); 103 | } 104 | 105 | return entropy; 106 | } 107 | 108 | /** 109 | * Splitting array of objects by value of specific attribute, 110 | * using specific predicate and pivot. 111 | * 112 | * Items which matched by predicate will be copied to 113 | * the new array called 'match', and the rest of the items 114 | * will be copied to array with name 'notMatch' 115 | * 116 | * @param items - array of objects 117 | * 118 | * @param attr - variable with name of attribute, 119 | * which embedded in each object 120 | * 121 | * @param predicate - function(x, y) 122 | * which returns 'true' or 'false' 123 | * 124 | * @param pivot - used as the second argument when 125 | * calling predicate function: 126 | * e.g. predicate(item[attr], pivot) 127 | */ 128 | function split(items, attr, predicate, pivot) { 129 | var match = []; 130 | var notMatch = []; 131 | 132 | var item, 133 | attrValue; 134 | 135 | for (var i = items.length - 1; i >= 0; i--) { 136 | item = items[i]; 137 | attrValue = item[attr]; 138 | 139 | if (predicate(attrValue, pivot)) { 140 | match.push(item); 141 | } else { 142 | notMatch.push(item); 143 | } 144 | }; 145 | 146 | return { 147 | match: match, 148 | notMatch: notMatch 149 | }; 150 | } 151 | 152 | /** 153 | * Finding value of specific attribute which is most frequent 154 | * in given array of objects. 155 | * 156 | * @param items - array of objects 157 | * 158 | * @param attr - variable with name of attribute, 159 | * which embedded in each object 160 | */ 161 | function mostFrequentValue(items, attr) { 162 | // counting number of occurrences of each of values 163 | // of attribute 164 | var counter = countUniqueValues(items, attr); 165 | 166 | var mostFrequentCount = 0; 167 | var mostFrequentValue; 168 | 169 | for (var value in counter) { 170 | if (counter[value] > mostFrequentCount) { 171 | mostFrequentCount = counter[value]; 172 | mostFrequentValue = value; 173 | } 174 | }; 175 | 176 | return mostFrequentValue; 177 | } 178 | 179 | var predicates = { 180 | '==': function (a, b) { return a == b }, 181 | '>=': function (a, b) { return a >= b } 182 | }; 183 | 184 | /** 185 | * Function for building decision tree 186 | */ 187 | function buildDecisionTree(builder) { 188 | 189 | var trainingSet = builder.trainingSet; 190 | var minItemsCount = builder.minItemsCount; 191 | var categoryAttr = builder.categoryAttr; 192 | var entropyThrehold = builder.entropyThrehold; 193 | var maxTreeDepth = builder.maxTreeDepth; 194 | var ignoredAttributes = builder.ignoredAttributes; 195 | 196 | if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) { 197 | // restriction by maximal depth of tree 198 | // or size of training set is to small 199 | // so we have to terminate process of building tree 200 | return { 201 | category: mostFrequentValue(trainingSet, categoryAttr) 202 | }; 203 | } 204 | 205 | var initialEntropy = entropy(trainingSet, categoryAttr); 206 | 207 | if (initialEntropy <= entropyThrehold) { 208 | // entropy of training set too small 209 | // (it means that training set is almost homogeneous), 210 | // so we have to terminate process of building tree 211 | return { 212 | category: mostFrequentValue(trainingSet, categoryAttr) 213 | }; 214 | } 215 | 216 | // used as hash-set for avoiding the checking of split by rules 217 | // with the same 'attribute-predicate-pivot' more than once 218 | var alreadyChecked = {}; 219 | 220 | // this variable expected to contain rule, which splits training set 221 | // into subsets with smaller values of entropy (produces informational gain) 222 | var bestSplit = {gain: 0}; 223 | 224 | for (var i = trainingSet.length - 1; i >= 0; i--) { 225 | var item = trainingSet[i]; 226 | 227 | // iterating over all attributes of item 228 | for (var attr in item) { 229 | if ((attr == categoryAttr) || ignoredAttributes[attr]) { 230 | continue; 231 | } 232 | 233 | // let the value of current attribute be the pivot 234 | var pivot = item[attr]; 235 | 236 | // pick the predicate 237 | // depending on the type of the attribute value 238 | var predicateName; 239 | if (typeof pivot == 'number') { 240 | predicateName = '>='; 241 | } else { 242 | // there is no sense to compare non-numeric attributes 243 | // so we will check only equality of such attributes 244 | predicateName = '=='; 245 | } 246 | 247 | var attrPredPivot = attr + predicateName + pivot; 248 | if (alreadyChecked[attrPredPivot]) { 249 | // skip such pairs of 'attribute-predicate-pivot', 250 | // which been already checked 251 | continue; 252 | } 253 | alreadyChecked[attrPredPivot] = true; 254 | 255 | var predicate = predicates[predicateName]; 256 | 257 | // splitting training set by given 'attribute-predicate-value' 258 | var currSplit = split(trainingSet, attr, predicate, pivot); 259 | 260 | // calculating entropy of subsets 261 | var matchEntropy = entropy(currSplit.match, categoryAttr); 262 | var notMatchEntropy = entropy(currSplit.notMatch, categoryAttr); 263 | 264 | // calculating informational gain 265 | var newEntropy = 0; 266 | newEntropy += matchEntropy * currSplit.match.length; 267 | newEntropy += notMatchEntropy * currSplit.notMatch.length; 268 | newEntropy /= trainingSet.length; 269 | var currGain = initialEntropy - newEntropy; 270 | 271 | if (currGain > bestSplit.gain) { 272 | // remember pairs 'attribute-predicate-value' 273 | // which provides informational gain 274 | bestSplit = currSplit; 275 | bestSplit.predicateName = predicateName; 276 | bestSplit.predicate = predicate; 277 | bestSplit.attribute = attr; 278 | bestSplit.pivot = pivot; 279 | bestSplit.gain = currGain; 280 | } 281 | } 282 | } 283 | 284 | if (!bestSplit.gain) { 285 | // can't find optimal split 286 | return { category: mostFrequentValue(trainingSet, categoryAttr) }; 287 | } 288 | 289 | // building subtrees 290 | 291 | builder.maxTreeDepth = maxTreeDepth - 1; 292 | 293 | builder.trainingSet = bestSplit.match; 294 | var matchSubTree = buildDecisionTree(builder); 295 | 296 | builder.trainingSet = bestSplit.notMatch; 297 | var notMatchSubTree = buildDecisionTree(builder); 298 | 299 | return { 300 | attribute: bestSplit.attribute, 301 | predicate: bestSplit.predicate, 302 | predicateName: bestSplit.predicateName, 303 | pivot: bestSplit.pivot, 304 | match: matchSubTree, 305 | notMatch: notMatchSubTree, 306 | matchedCount: bestSplit.match.length, 307 | notMatchedCount: bestSplit.notMatch.length 308 | }; 309 | } 310 | 311 | /** 312 | * Classifying item, using decision tree 313 | */ 314 | function predict(tree, item) { 315 | var attr, 316 | value, 317 | predicate, 318 | pivot; 319 | 320 | // Traversing tree from the root to leaf 321 | while(true) { 322 | 323 | if (tree.category) { 324 | // only leafs contains predicted category 325 | return tree.category; 326 | } 327 | 328 | attr = tree.attribute; 329 | value = item[attr]; 330 | 331 | predicate = tree.predicate; 332 | pivot = tree.pivot; 333 | 334 | // move to one of subtrees 335 | if (predicate(value, pivot)) { 336 | tree = tree.match; 337 | } else { 338 | tree = tree.notMatch; 339 | } 340 | } 341 | } 342 | 343 | /** 344 | * Building array of decision trees 345 | */ 346 | function buildRandomForest(builder, treesNumber) { 347 | var items = builder.trainingSet; 348 | 349 | // creating training sets for each tree 350 | var trainingSets = []; 351 | for (var t = 0; t < treesNumber; t++) { 352 | trainingSets[t] = []; 353 | } 354 | for (var i = items.length - 1; i >= 0 ; i--) { 355 | // assigning items to training sets of each tree 356 | // using 'round-robin' strategy 357 | var correspondingTree = i % treesNumber; 358 | trainingSets[correspondingTree].push(items[i]); 359 | } 360 | 361 | // building decision trees 362 | var forest = []; 363 | for (var t = 0; t < treesNumber; t++) { 364 | builder.trainingSet = trainingSets[t]; 365 | 366 | var tree = new DecisionTree(builder); 367 | forest.push(tree); 368 | } 369 | return forest; 370 | } 371 | 372 | /** 373 | * Each of decision tree classifying item 374 | * ('voting' that item corresponds to some class). 375 | * 376 | * This function returns hash, which contains 377 | * all classifying results, and number of votes 378 | * which were given for each of classifying results 379 | */ 380 | function predictRandomForest(forest, item) { 381 | var result = {}; 382 | for (var i in forest) { 383 | var tree = forest[i]; 384 | var prediction = tree.predict(item); 385 | result[prediction] = result[prediction] ? result[prediction] + 1 : 1; 386 | } 387 | return result; 388 | } 389 | 390 | var exports = {}; 391 | exports.DecisionTree = DecisionTree; 392 | exports.RandomForest = RandomForest; 393 | return exports; 394 | })(); -------------------------------------------------------------------------------- /random-forest-demo/controls.js: -------------------------------------------------------------------------------- 1 | function init() { 2 | 3 | var canv = document.getElementById('myCanvas'); 4 | var clearBtn = document.getElementById('clearBtn'); 5 | var context = canv.getContext('2d'); 6 | 7 | var NOT_SELECTED_COLOR_STYLE = '2px solid white'; 8 | var SELECTED_COLOR_STYLE = '2px solid black'; 9 | var colorSelectElements = document.getElementsByClassName('color-select'); 10 | for (var i = 0; i < colorSelectElements.length; i++) { 11 | colorSelectElements[i].style.backgroundColor = colorSelectElements[i].getAttribute('label'); 12 | colorSelectElements[i].style.border = NOT_SELECTED_COLOR_STYLE; 13 | } 14 | 15 | var color = colorSelectElements[0].getAttribute('label'); 16 | var POINT_RADIUS = 3; 17 | var points = []; 18 | var forest = null; 19 | var TREES_NUMBER = 7; 20 | var MAX_ALPHA = 128; 21 | var addingPoints = false; 22 | 23 | colorSelectElements[0].style.border = SELECTED_COLOR_STYLE; 24 | 25 | canv.addEventListener('mousedown', enableAddingPointsListener, false); 26 | 27 | canv.addEventListener('mouseup', rebuildForestListener, false); 28 | 29 | canv.addEventListener('mouseout', rebuildForestListener, false); 30 | 31 | canv.addEventListener('mousemove', addPointsListener, false); 32 | 33 | 34 | for (var i = 0; i < colorSelectElements.length; i++) { 35 | colorSelectElements[i].addEventListener('click', selectColorListener, false); 36 | } 37 | 38 | clearBtn.addEventListener('click', clearCanvasListener, false); 39 | 40 | function enableAddingPointsListener(e) { 41 | e.preventDefault(); 42 | addingPoints = true; 43 | } 44 | 45 | function addPointsListener(e) { 46 | if (addingPoints) { 47 | var x = e.offsetX ? e.offsetX : (e.layerX - canv.offsetLeft); 48 | var y = e.offsetY ? e.offsetY : (e.layerY - canv.offsetTop); 49 | 50 | drawCircle(context, x, y, POINT_RADIUS, color); 51 | points.push({ 52 | x: x, 53 | y: y, 54 | color: color 55 | }); 56 | } 57 | } 58 | 59 | function rebuildForestListener(e) { 60 | 61 | if (!addingPoints) return; 62 | 63 | if (points.length == 0) return; 64 | 65 | addingPoints = false; 66 | 67 | 68 | var threshold = Math.floor(points.length / 100); 69 | threshold = (threshold > 1) ? threshold : 1; 70 | 71 | forest = new dt.RandomForest({ 72 | trainingSet: points, 73 | categoryAttr: 'color', 74 | minItemsCount: threshold 75 | }, TREES_NUMBER); 76 | 77 | displayTreePredictions(); 78 | displayPoints(); 79 | } 80 | 81 | function displayTreePredictions() { 82 | context.clearRect(0, 0, canv.width, canv.height); 83 | var imageData = context.getImageData(0, 0, canv.width, canv.height); 84 | 85 | for (var x = 0; x < canv.width; x++) { 86 | for (var y = 0; y < canv.height; y++) { 87 | 88 | var prediction = forest.predict({ 89 | x: x, 90 | y: y 91 | }); 92 | 93 | var sum = 0; 94 | for (var predictedHexColor in prediction) { 95 | sum += prediction[predictedHexColor]; 96 | } 97 | 98 | for (var predictedHexColor in prediction) { 99 | var numberOfVotes = prediction[predictedHexColor]; 100 | putPixel(imageData, canv.width, x, y, predictedHexColor, numberOfVotes * MAX_ALPHA / sum); 101 | } 102 | } 103 | } 104 | 105 | context.putImageData(imageData, 0, 0); 106 | } 107 | 108 | function displayPoints() { 109 | for (var p in points) { 110 | drawCircle(context, points[p].x, points[p].y, POINT_RADIUS, points[p].color); 111 | } 112 | } 113 | 114 | function drawCircle(context, x, y, radius, hexColor) { 115 | context.beginPath(); 116 | context.arc(x, y, radius, 0, 2 * Math.PI, false); 117 | 118 | var c = hexToRgb(hexColor) 119 | context.fillStyle = 'rgb(' + c.r + ',' + c.g + ',' + c.b + ')'; 120 | 121 | context.fill(); 122 | context.closePath(); 123 | context.stroke(); 124 | } 125 | 126 | function putPixel(imageData, width, x, y, hexColor, alpha) { 127 | var c = hexToRgb(hexColor); 128 | var indx = (y * width + x) * 4; 129 | 130 | var currAlpha = imageData.data[indx + 3]; 131 | 132 | imageData.data[indx + 0] = (c.r * alpha + imageData.data[indx + 0] * currAlpha) / (alpha + currAlpha); 133 | imageData.data[indx + 1] = (c.g * alpha + imageData.data[indx + 1] * currAlpha) / (alpha + currAlpha); 134 | imageData.data[indx + 2] = (c.b * alpha + imageData.data[indx + 2] * currAlpha) / (alpha + currAlpha); 135 | imageData.data[indx + 3] = alpha + currAlpha; 136 | } 137 | 138 | function selectColorListener(event) { 139 | color = this.getAttribute('label'); 140 | 141 | for (var i = 0; i < colorSelectElements.length; i++) { 142 | colorSelectElements[i].style.border = NOT_SELECTED_COLOR_STYLE; 143 | } 144 | 145 | this.style.border = SELECTED_COLOR_STYLE; 146 | } 147 | 148 | function clearCanvasListener(event) { 149 | context.clearRect(0, 0, canv.width, canv.height); 150 | points = []; 151 | } 152 | 153 | /** 154 | * Taken from: http://stackoverflow.com/a/5624139/653511 155 | */ 156 | function hexToRgb(hex) { 157 | // Expand shorthand form (e.g. "03F") to full form (e.g. "0033FF") 158 | var shorthandRegex = /^#?([a-f\d])([a-f\d])([a-f\d])$/i; 159 | hex = hex.replace(shorthandRegex, function (m, r, g, b) { 160 | return r + r + g + g + b + b; 161 | }); 162 | 163 | var result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex); 164 | return result ? { 165 | r: parseInt(result[1], 16), 166 | g: parseInt(result[2], 16), 167 | b: parseInt(result[3], 16) 168 | } : null; 169 | } 170 | } -------------------------------------------------------------------------------- /random-forest-demo/demo.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
 
10 |
 
11 |
 
12 |
 
13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /random-forest-demo/demo_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lagodiuk/decision-tree-js/91977780b71775802aacbc6e1a415bf6e7234965/random-forest-demo/demo_2d.png -------------------------------------------------------------------------------- /random-forest-demo/style.css: -------------------------------------------------------------------------------- 1 | canvas { 2 | border: 1px solid #ccc; 3 | margin: 30px; 4 | margin-top: 5px; 5 | } 6 | 7 | .color-select { 8 | margin: 1px; 9 | width : 30px; 10 | height : 30px; 11 | display:inline-block; 12 | } 13 | 14 | .color-select:first-child { 15 | margin-left: 30px; 16 | } 17 | 18 | .clearBtn { 19 | margin-top: 3px; 20 | margin-left: 15px; 21 | -moz-border-radius: 5px; 22 | -webkit-border-radius: 5px; 23 | border-radius:5px; 24 | color: green; 25 | background-color: #EDEDED; 26 | font-size: 15px; 27 | text-decoration: none; 28 | cursor: poiner; 29 | height: 30px; 30 | position: absolute; 31 | } --------------------------------------------------------------------------------