├── .gitignore ├── .travis.yml ├── MIT-LICENSE.txt ├── README.md ├── binding.gyp ├── index.js ├── irf ├── .gitignore ├── MurmurHash3.cpp ├── MurmurHash3.h ├── irfmodule.cpp ├── node.cpp ├── randomForest.cpp ├── randomForest.h └── setup.py ├── libsparsehash.pc ├── package.json ├── tests ├── .gitignore ├── mushrooms ├── mushrooms.js ├── mushrooms.py └── simple.py └── wscript /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *~ 3 | *#*# 4 | .lock-wscript 5 | build 6 | irf.node 7 | node_modules 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: node_js 2 | node_js: 3 | - 0.6 4 | - 0.8 5 | before_install: 6 | - sudo apt-get update 7 | - sudo apt-get install libsparsehash-dev 8 | - sudo cp libsparsehash.pc /usr/lib/pkgconfig/ 9 | -------------------------------------------------------------------------------- /MIT-LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010-2012 Carlos Guerreiro, http://perceptiveconstructs.com 2 | Portions Copyright (c) 2012 Igalia S.L., http://igalia.com 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining 5 | a copy of this software and associated documentation files (the 6 | "Software"), to deal in the Software without restriction, including 7 | without limitation the rights to use, copy, modify, merge, publish, 8 | distribute, sublicense, and/or sell copies of the Software, and to 9 | permit persons to whom the Software is furnished to do so, subject to 10 | the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 19 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 20 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 21 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Incremental Random Forest 2 | ========================= 3 | 4 | [![Build Status](https://secure.travis-ci.org/pconstr/irf.png)](http://travis-ci.org/pconstr/irf) 5 | 6 | An implementation in C++ (with [node.js](http://nodejs.org) and Python bindings) of a variant of [Leo Breiman's Random Forests](http://stat-www.berkeley.edu/users/breiman/RandomForests/cc_home.htm) 7 | 8 | The forest is maintained incrementally as samples are added or removed - rather than fully rebuilt from scratch every time - to save effort. 9 | 10 | It is not a streaming implementation, all the samples are stored and will be reseen when required to recursively rebuild invalidated subtrees. The effort to update each individual tree can vary substantially but the overall effort to update the forest is averaged across the trees so tends not to vary so much. 11 | 12 | IRF is licensed under the MIT license. 13 | 14 | Features and limitations 15 | ------------------------ 16 | 17 | * Sparse feature vectors 18 | * Samples can be added, removed and changed 19 | * Learning can be performed lazily or initiated explicitly 20 | * The forest can be serialized to JSON for transmission/storage 21 | * The forest needs to fit fully in RAM, performance suffers dramatically when swapping 22 | * Currently only binary classification - 0 or 1. The classifier estimates the probability of belonging to class 1, as a float from 0 to 1 23 | * Currently only binary features: y >= 0.5 is considered 1, otherwise 0 24 | 25 | Node.js setup 26 | ----- 27 | `npm install irf` 28 | 29 | Node.js usage 30 | ------------- 31 | 32 | ```javascript 33 | var irf = require('irf'); 34 | 35 | var f = new irf.IRF(99); // create forest of 99 trees 36 | 37 | f.add('1', {1:1, 3:1, 5:1}, 0); // add a sample identified as '1' with the given feature values, classified as 0 38 | f.add('2', {1:0, 3:0, 4:1}, 0); // features are stored sparsely, when a value is not given it will be taken as 0 39 | f.add('3', {2:0, 3:0, 5:0}, 0); // but 0s can also be given explicitly 40 | // ... 41 | 42 | var y = f.classify({1:1, 3:1, 5:1}); // classify feature vector 43 | // the forest will be lazily updated before classification 44 | f.commit(); // but you can force an update at any time 45 | // you get a probability estimate from 0 to 1 for belong to class 1 46 | var c = Math.round(y); // round to nearest to get class (0 or 1) 47 | 48 | f.remove('8'); // remove a sample 49 | f.add('8', {1:0, 2:0, 3:0, 4:0, 5:1}, 0); // and add it again with new values 50 | 51 | console.log(f.asJSON()); // serialize to json (for classification, not suitable for incremental update) 52 | 53 | f.each(function(suid, features, y) { 54 | // ... 55 | }); 56 | 57 | var b = f.toBuffer(); // serialize (complete) to buffer 58 | var f2 = new irf.IRF(b); // construct from buffer contents 59 | ``` 60 | 61 | Python setup 62 | ----- 63 | cd irf 64 | python setup.py install 65 | 66 | Python usage 67 | ------------ 68 | 69 | ```python 70 | import irf 71 | 72 | f = irf.IRF(99) # create forest of 99 trees 73 | 74 | f.add('1', {1:1, 3:1, 5:1}, 0) # add a sample identified as '1' with the given feature values, classified as 0 75 | f.add('2', {1:0, 3:0, 4:1}, 0) # features are stored sparsely, when a value is not given it will be taken as 0 76 | f.add('3', {2:0, 3:0, 5:0}, 0) # but 0s can also be given explicitly 77 | # ... 78 | 79 | y = f.classify({1:1, 2:1, 5:1}); print y, int(round(y)) # classify feature vector, round to nearest to get class 80 | 81 | f.save('simple.rf') # save forest to file 82 | 83 | f = irf.load('simple.rf') # load forest from file 84 | 85 | f.remove('8') # remove a sample 86 | f.add('8', {1:0, 2:0, 3:0, 4:0, 5:1}, 0) # and add it again with new values 87 | 88 | y = f.classify({1:1, 2:1, 5:1}); print y, int(round(y)) # the forest will be lazily updated before classification 89 | # f.commit() # but you can force it 90 | 91 | for (sId, x, y) in f.samples(): # iterate through samples in the forest, in lexicographic ID order 92 | print sId, x, y # and print them 93 | ``` 94 | 95 | C++ usage 96 | --------- 97 | 98 | _to be written_ 99 | 100 | Dependencies 101 | ------------ 102 | 103 | System: 104 | 105 | * STL 106 | 107 | Included: 108 | 109 | * MurmurHash3 (from [smhasher](http://code.google.com/p/smhasher/)) 110 | 111 | External: 112 | 113 | * [google sparse hash](http://goog-sparsehash.sourceforge.net/) 114 | 115 | 116 | Tests 117 | ----- 118 | 119 | * simple.py - trivial made up data to illustrate how to use the API 120 | * mushrooms.js, mushrooms.py - using the [mushrooms dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#mushrooms) collected by [LIBSVM](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/) from the [UCI Machine Learning Repository](http://archive.ics.uci.edu/ml/) 121 | -------------------------------------------------------------------------------- /binding.gyp: -------------------------------------------------------------------------------- 1 | { 2 | "targets": [ 3 | { 4 | "target_name": "irf", 5 | "sources": [ 6 | "irf/node.cpp", 7 | "irf/randomForest.h", 8 | "irf/randomForest.cpp", 9 | "irf/MurmurHash3.h", 10 | "irf/MurmurHash3.cpp" 11 | ], 12 | 'cflags': [ ' 22 | 23 | #define ROTL32(x,y) _rotl(x,y) 24 | #define ROTL64(x,y) _rotl64(x,y) 25 | 26 | #define BIG_CONSTANT(x) (x) 27 | 28 | // Other compilers 29 | 30 | #else // defined(_MSC_VER) 31 | 32 | #define FORCE_INLINE __attribute__((always_inline)) 33 | 34 | inline uint32_t rotl32 ( uint32_t x, int8_t r ) 35 | { 36 | return (x << r) | (x >> (32 - r)); 37 | } 38 | 39 | inline uint64_t rotl64 ( uint64_t x, int8_t r ) 40 | { 41 | return (x << r) | (x >> (64 - r)); 42 | } 43 | 44 | #define ROTL32(x,y) rotl32(x,y) 45 | #define ROTL64(x,y) rotl64(x,y) 46 | 47 | #define BIG_CONSTANT(x) (x##LLU) 48 | 49 | #endif // !defined(_MSC_VER) 50 | 51 | //----------------------------------------------------------------------------- 52 | // Block read - if your platform needs to do endian-swapping or can only 53 | // handle aligned reads, do the conversion here 54 | 55 | FORCE_INLINE uint32_t getblock ( const uint32_t * p, int i ) 56 | { 57 | return p[i]; 58 | } 59 | 60 | FORCE_INLINE uint64_t getblock ( const uint64_t * p, int i ) 61 | { 62 | return p[i]; 63 | } 64 | 65 | //----------------------------------------------------------------------------- 66 | // Finalization mix - force all bits of a hash block to avalanche 67 | 68 | FORCE_INLINE uint32_t fmix ( uint32_t h ) 69 | { 70 | h ^= h >> 16; 71 | h *= 0x85ebca6b; 72 | h ^= h >> 13; 73 | h *= 0xc2b2ae35; 74 | h ^= h >> 16; 75 | 76 | return h; 77 | } 78 | 79 | //---------- 80 | 81 | FORCE_INLINE uint64_t fmix ( uint64_t k ) 82 | { 83 | k ^= k >> 33; 84 | k *= BIG_CONSTANT(0xff51afd7ed558ccd); 85 | k ^= k >> 33; 86 | k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); 87 | k ^= k >> 33; 88 | 89 | return k; 90 | } 91 | 92 | //----------------------------------------------------------------------------- 93 | 94 | void MurmurHash3_x86_32 ( const void * key, int len, 95 | uint32_t seed, void * out ) 96 | { 97 | const uint8_t * data = (const uint8_t*)key; 98 | const int nblocks = len / 4; 99 | 100 | uint32_t h1 = seed; 101 | 102 | uint32_t c1 = 0xcc9e2d51; 103 | uint32_t c2 = 0x1b873593; 104 | 105 | //---------- 106 | // body 107 | 108 | const uint32_t * blocks = (const uint32_t *)(data + nblocks*4); 109 | 110 | for(int i = -nblocks; i; i++) 111 | { 112 | uint32_t k1 = getblock(blocks,i); 113 | 114 | k1 *= c1; 115 | k1 = ROTL32(k1,15); 116 | k1 *= c2; 117 | 118 | h1 ^= k1; 119 | h1 = ROTL32(h1,13); 120 | h1 = h1*5+0xe6546b64; 121 | } 122 | 123 | //---------- 124 | // tail 125 | 126 | const uint8_t * tail = (const uint8_t*)(data + nblocks*4); 127 | 128 | uint32_t k1 = 0; 129 | 130 | switch(len & 3) 131 | { 132 | case 3: k1 ^= tail[2] << 16; 133 | case 2: k1 ^= tail[1] << 8; 134 | case 1: k1 ^= tail[0]; 135 | k1 *= c1; k1 = ROTL32(k1,16); k1 *= c2; h1 ^= k1; 136 | }; 137 | 138 | //---------- 139 | // finalization 140 | 141 | h1 ^= len; 142 | 143 | h1 = fmix(h1); 144 | 145 | *(uint32_t*)out = h1; 146 | } 147 | 148 | //----------------------------------------------------------------------------- 149 | 150 | void MurmurHash3_x86_128 ( const void * key, const int len, 151 | uint32_t seed, void * out ) 152 | { 153 | const uint8_t * data = (const uint8_t*)key; 154 | const int nblocks = len / 16; 155 | 156 | uint32_t h1 = seed; 157 | uint32_t h2 = seed; 158 | uint32_t h3 = seed; 159 | uint32_t h4 = seed; 160 | 161 | uint32_t c1 = 0x239b961b; 162 | uint32_t c2 = 0xab0e9789; 163 | uint32_t c3 = 0x38b34ae5; 164 | uint32_t c4 = 0xa1e38b93; 165 | 166 | //---------- 167 | // body 168 | 169 | const uint32_t * blocks = (const uint32_t *)(data + nblocks*16); 170 | 171 | for(int i = -nblocks; i; i++) 172 | { 173 | uint32_t k1 = getblock(blocks,i*4+0); 174 | uint32_t k2 = getblock(blocks,i*4+1); 175 | uint32_t k3 = getblock(blocks,i*4+2); 176 | uint32_t k4 = getblock(blocks,i*4+3); 177 | 178 | k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; 179 | 180 | h1 = ROTL32(h1,19); h1 += h2; h1 = h1*5+0x561ccd1b; 181 | 182 | k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; 183 | 184 | h2 = ROTL32(h2,17); h2 += h3; h2 = h2*5+0x0bcaa747; 185 | 186 | k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; 187 | 188 | h3 = ROTL32(h3,15); h3 += h4; h3 = h3*5+0x96cd1c35; 189 | 190 | k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; 191 | 192 | h4 = ROTL32(h4,13); h4 += h1; h4 = h4*5+0x32ac3b17; 193 | } 194 | 195 | //---------- 196 | // tail 197 | 198 | const uint8_t * tail = (const uint8_t*)(data + nblocks*16); 199 | 200 | uint32_t k1 = 0; 201 | uint32_t k2 = 0; 202 | uint32_t k3 = 0; 203 | uint32_t k4 = 0; 204 | 205 | switch(len & 15) 206 | { 207 | case 15: k4 ^= tail[14] << 16; 208 | case 14: k4 ^= tail[13] << 8; 209 | case 13: k4 ^= tail[12] << 0; 210 | k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; 211 | 212 | case 12: k3 ^= tail[11] << 24; 213 | case 11: k3 ^= tail[10] << 16; 214 | case 10: k3 ^= tail[ 9] << 8; 215 | case 9: k3 ^= tail[ 8] << 0; 216 | k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; 217 | 218 | case 8: k2 ^= tail[ 7] << 24; 219 | case 7: k2 ^= tail[ 6] << 16; 220 | case 6: k2 ^= tail[ 5] << 8; 221 | case 5: k2 ^= tail[ 4] << 0; 222 | k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; 223 | 224 | case 4: k1 ^= tail[ 3] << 24; 225 | case 3: k1 ^= tail[ 2] << 16; 226 | case 2: k1 ^= tail[ 1] << 8; 227 | case 1: k1 ^= tail[ 0] << 0; 228 | k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; 229 | }; 230 | 231 | //---------- 232 | // finalization 233 | 234 | h1 ^= len; h2 ^= len; h3 ^= len; h4 ^= len; 235 | 236 | h1 += h2; h1 += h3; h1 += h4; 237 | h2 += h1; h3 += h1; h4 += h1; 238 | 239 | h1 = fmix(h1); 240 | h2 = fmix(h2); 241 | h3 = fmix(h3); 242 | h4 = fmix(h4); 243 | 244 | h1 += h2; h1 += h3; h1 += h4; 245 | h2 += h1; h3 += h1; h4 += h1; 246 | 247 | ((uint32_t*)out)[0] = h1; 248 | ((uint32_t*)out)[1] = h2; 249 | ((uint32_t*)out)[2] = h3; 250 | ((uint32_t*)out)[3] = h4; 251 | } 252 | 253 | //----------------------------------------------------------------------------- 254 | 255 | void MurmurHash3_x64_128 ( const void * key, const int len, 256 | const uint32_t seed, void * out ) 257 | { 258 | const uint8_t * data = (const uint8_t*)key; 259 | const int nblocks = len / 16; 260 | 261 | uint64_t h1 = seed; 262 | uint64_t h2 = seed; 263 | 264 | uint64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); 265 | uint64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); 266 | 267 | //---------- 268 | // body 269 | 270 | const uint64_t * blocks = (const uint64_t *)(data); 271 | 272 | for(int i = 0; i < nblocks; i++) 273 | { 274 | uint64_t k1 = getblock(blocks,i*2+0); 275 | uint64_t k2 = getblock(blocks,i*2+1); 276 | 277 | k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; 278 | 279 | h1 = ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; 280 | 281 | k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; 282 | 283 | h2 = ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; 284 | } 285 | 286 | //---------- 287 | // tail 288 | 289 | const uint8_t * tail = (const uint8_t*)(data + nblocks*16); 290 | 291 | uint64_t k1 = 0; 292 | uint64_t k2 = 0; 293 | 294 | switch(len & 15) 295 | { 296 | case 15: k2 ^= uint64_t(tail[14]) << 48; 297 | case 14: k2 ^= uint64_t(tail[13]) << 40; 298 | case 13: k2 ^= uint64_t(tail[12]) << 32; 299 | case 12: k2 ^= uint64_t(tail[11]) << 24; 300 | case 11: k2 ^= uint64_t(tail[10]) << 16; 301 | case 10: k2 ^= uint64_t(tail[ 9]) << 8; 302 | case 9: k2 ^= uint64_t(tail[ 8]) << 0; 303 | k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; 304 | 305 | case 8: k1 ^= uint64_t(tail[ 7]) << 56; 306 | case 7: k1 ^= uint64_t(tail[ 6]) << 48; 307 | case 6: k1 ^= uint64_t(tail[ 5]) << 40; 308 | case 5: k1 ^= uint64_t(tail[ 4]) << 32; 309 | case 4: k1 ^= uint64_t(tail[ 3]) << 24; 310 | case 3: k1 ^= uint64_t(tail[ 2]) << 16; 311 | case 2: k1 ^= uint64_t(tail[ 1]) << 8; 312 | case 1: k1 ^= uint64_t(tail[ 0]) << 0; 313 | k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; 314 | }; 315 | 316 | //---------- 317 | // finalization 318 | 319 | h1 ^= len; h2 ^= len; 320 | 321 | h1 += h2; 322 | h2 += h1; 323 | 324 | h1 = fmix(h1); 325 | h2 = fmix(h2); 326 | 327 | h1 += h2; 328 | h2 += h1; 329 | 330 | ((uint64_t*)out)[0] = h1; 331 | ((uint64_t*)out)[1] = h2; 332 | } 333 | 334 | //----------------------------------------------------------------------------- 335 | 336 | -------------------------------------------------------------------------------- /irf/MurmurHash3.h: -------------------------------------------------------------------------------- 1 | //----------------------------------------------------------------------------- 2 | // MurmurHash3 was written by Austin Appleby, and is placed in the public 3 | // domain. The author hereby disclaims copyright to this source code. 4 | 5 | #ifndef _MURMURHASH3_H_ 6 | #define _MURMURHASH3_H_ 7 | 8 | //----------------------------------------------------------------------------- 9 | // Platform-specific functions and macros 10 | 11 | // Microsoft Visual Studio 12 | 13 | #if defined(_MSC_VER) 14 | 15 | typedef unsigned char uint8_t; 16 | typedef unsigned long uint32_t; 17 | typedef unsigned __int64 uint64_t; 18 | 19 | // Other compilers 20 | 21 | #else // defined(_MSC_VER) 22 | 23 | #include 24 | 25 | #endif // !defined(_MSC_VER) 26 | 27 | //----------------------------------------------------------------------------- 28 | 29 | void MurmurHash3_x86_32 ( const void * key, int len, uint32_t seed, void * out ); 30 | 31 | void MurmurHash3_x86_128 ( const void * key, int len, uint32_t seed, void * out ); 32 | 33 | void MurmurHash3_x64_128 ( const void * key, int len, uint32_t seed, void * out ); 34 | 35 | //----------------------------------------------------------------------------- 36 | 37 | #endif // _MURMURHASH3_H_ 38 | -------------------------------------------------------------------------------- /irf/irfmodule.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2010-2011 Carlos Guerreiro 2 | * Licensed under the MIT license */ 3 | 4 | #include 5 | #include "structmember.h" 6 | 7 | #include "randomForest.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include "MurmurHash3.h" 18 | 19 | using namespace std; 20 | using namespace IncrementalRandomForest; 21 | 22 | struct IRF { 23 | PyObject_HEAD 24 | Forest* forest; 25 | 26 | IRF(void) { 27 | forest = 0; 28 | } 29 | ~IRF(void) { 30 | if(forest) 31 | destroy(forest); 32 | } 33 | }; 34 | 35 | static void IRF_dealloc(IRF* self) { 36 | self->~IRF(); 37 | self->ob_type->tp_free((PyObject*)self); 38 | } 39 | 40 | static PyObject* IRF_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { 41 | IRF *self; 42 | 43 | PyObject* firstArg = PyTuple_GetItem(args, 0); 44 | 45 | bool fromFile = firstArg && PyString_Check(firstArg); 46 | 47 | char* fname; 48 | 49 | if(fromFile) { 50 | if(!PyArg_ParseTuple(args, "s", 51 | &fname)) 52 | return 0; 53 | 54 | ifstream inF(fname); 55 | if(!inF.is_open()) { 56 | // FIXME: raise exception? 57 | return 0; 58 | } 59 | 60 | self = new (type->tp_alloc(type, 0)) IRF(); 61 | if(self) { 62 | self->forest = load(inF); 63 | } 64 | } else { 65 | int nTrees; 66 | if(!PyArg_ParseTuple(args, "i", &nTrees)) 67 | return 0; 68 | 69 | self = new (type->tp_alloc(type, 0)) IRF(); 70 | if(self) { 71 | self->forest = create(nTrees); 72 | } 73 | } 74 | 75 | return (PyObject *)self; 76 | } 77 | 78 | static int IRF_init(IRF *self, PyObject *args, PyObject *kwds) { 79 | return 0; 80 | } 81 | 82 | static PyMemberDef IRF_members[] = { 83 | {NULL} /* Sentinel */ 84 | }; 85 | 86 | static PyObject* IRF_commit(IRF* self) { 87 | commit(self->forest); 88 | return Py_BuildValue(""); 89 | } 90 | 91 | static PyObject* IRF_validate(IRF* self) { 92 | return PyBool_FromLong(validate(self->forest)); 93 | } 94 | 95 | static PyObject* IRF_asJSON(IRF* self) { 96 | stringstream ss; 97 | asJSON(self->forest, ss); 98 | ss.flush(); 99 | return Py_BuildValue("s", ss.str().c_str()); 100 | } 101 | 102 | static PyObject* IRF_statsJSON(IRF* self) { 103 | stringstream ss; 104 | statsJSON(self->forest, ss); 105 | ss.flush(); 106 | return Py_BuildValue("s", ss.str().c_str()); 107 | } 108 | 109 | static PyObject* IRF_save(IRF* self, PyObject* args) { 110 | char* fname; 111 | 112 | if(!PyArg_ParseTuple(args, "s", 113 | &fname)) { 114 | return 0; 115 | } 116 | 117 | ofstream outS(fname); 118 | if(!outS.is_open()) 119 | return PyBool_FromLong(false); 120 | 121 | return PyBool_FromLong(save(self->forest, outS)); 122 | } 123 | 124 | static PyObject* packFeatures(Sample* s) { 125 | PyObject* d = PyDict_New(); 126 | for(map::const_iterator it = s->xCodes.begin(); it != s->xCodes.end(); ++it) { 127 | PyObject* k = Py_BuildValue("i", it->first); 128 | PyObject* v = Py_BuildValue("f", it->second); 129 | PyDict_SetItem(d, k, v); 130 | Py_DECREF(k); 131 | Py_DECREF(v); 132 | } 133 | return d; 134 | } 135 | 136 | static bool extractFeatures(PyObject* features, Sample* s) { 137 | PyObject *key, *value; 138 | Py_ssize_t pos = 0; 139 | while (PyDict_Next(features, &pos, &key, &value)) { 140 | long k = PyInt_AsLong(key); 141 | if(k == -1 && PyErr_Occurred() != 0) { 142 | return false; 143 | } 144 | double v = PyFloat_AsDouble(value); 145 | if(v == -1 && PyErr_Occurred() != 0) { 146 | return false; 147 | } 148 | s->xCodes[k] = v; 149 | } 150 | return true; 151 | } 152 | 153 | static PyObject* IRF_classify(IRF* self, PyObject* args) { 154 | PyObject* features; 155 | if(!PyArg_ParseTuple(args, "O", 156 | &features 157 | )) 158 | return 0; 159 | 160 | Sample s; 161 | extractFeatures(features, &s); 162 | 163 | return Py_BuildValue("f", classify(self->forest, &s)); 164 | } 165 | 166 | static PyObject* IRF_classifyPartial(IRF* self, PyObject* args) { 167 | PyObject* features; 168 | int nTrees; 169 | if(!PyArg_ParseTuple(args, "Oi", 170 | &features, 171 | &nTrees)) 172 | return 0; 173 | 174 | Sample s; 175 | extractFeatures(features, &s); 176 | 177 | return Py_BuildValue("f", classifyPartial(self->forest, &s, nTrees)); 178 | } 179 | 180 | static PyObject* IRF_remove(IRF* self, PyObject* args) { 181 | char* sampleId; 182 | if(!PyArg_ParseTuple(args, "s", 183 | &sampleId)) 184 | return 0; 185 | return PyBool_FromLong(remove(self->forest, sampleId)); 186 | } 187 | 188 | static PyObject* IRF_add(IRF* self, PyObject* args) { 189 | char* sampleId; 190 | PyObject* features; 191 | float target; 192 | 193 | if(!PyArg_ParseTuple(args, "sOf", 194 | &sampleId, 195 | &features, 196 | &target)) { 197 | return 0; 198 | } 199 | 200 | Sample* s = new Sample(); 201 | 202 | s->suid = sampleId; 203 | s->y = target; 204 | 205 | if(!extractFeatures(features, s)) { 206 | cerr << "failed to extract features!" << endl; 207 | delete s; 208 | return 0; 209 | } 210 | 211 | return PyBool_FromLong(add(self->forest, s)); 212 | } 213 | 214 | static PyObject* IRF_samples(IRF* self, PyObject* args); 215 | 216 | static PyMethodDef IRF_methods[] = { 217 | {"commit", (PyCFunction)IRF_commit, METH_NOARGS, 218 | "Commit pending changes" 219 | }, 220 | {"asJSON", (PyCFunction)IRF_asJSON, METH_NOARGS, 221 | "Encode as JSON" 222 | }, 223 | {"statsJSON", (PyCFunction)IRF_statsJSON, METH_NOARGS, 224 | "Encode stats as JSON" 225 | }, 226 | {"save", (PyCFunction)IRF_save, METH_VARARGS, 227 | "Save forest to file" 228 | }, 229 | {"validate", (PyCFunction)IRF_validate, METH_NOARGS, 230 | "Validate forest" 231 | }, 232 | {"classify", (PyCFunction)IRF_classify, METH_VARARGS, 233 | "Classify according to features" 234 | }, 235 | {"classifyPartial", (PyCFunction)IRF_classifyPartial, METH_VARARGS, 236 | "Classify according to features, using only N trees" 237 | }, 238 | {"add", (PyCFunction)IRF_add, METH_VARARGS, 239 | "Add a sample" 240 | }, 241 | {"remove", (PyCFunction)IRF_remove, METH_VARARGS, 242 | "Remove a sample" 243 | }, 244 | {"samples", (PyCFunction)IRF_samples, METH_NOARGS, 245 | "Get stored samples" 246 | }, 247 | {NULL} /* Sentinel */ 248 | }; 249 | 250 | static PyTypeObject IRFType = { 251 | PyObject_HEAD_INIT(NULL) 252 | 0, /*ob_size*/ 253 | "irf.IRF", /*tp_name*/ 254 | sizeof(IRF), /*tp_basicsize*/ 255 | 0, /*tp_itemsize*/ 256 | (destructor)IRF_dealloc, /*tp_dealloc*/ 257 | 0, /*tp_print*/ 258 | 0, /*tp_getattr*/ 259 | 0, /*tp_setattr*/ 260 | 0, /*tp_compare*/ 261 | 0, /*tp_repr*/ 262 | 0, /*tp_as_number*/ 263 | 0, /*tp_as_sequence*/ 264 | 0, /*tp_as_mapping*/ 265 | 0, /*tp_hash */ 266 | 0, /*tp_call*/ 267 | 0, /*tp_str*/ 268 | 0, /*tp_getattro*/ 269 | 0, /*tp_setattro*/ 270 | 0, /*tp_as_buffer*/ 271 | Py_TPFLAGS_DEFAULT, /*tp_flags*/ 272 | "IRF objects", /* tp_doc */ 273 | 0, /* tp_traverse */ 274 | 0, /* tp_clear */ 275 | 0, /* tp_richcompare */ 276 | 0, /* tp_weaklistoffset */ 277 | 0, /* tp_iter */ 278 | 0, /* tp_iternext */ 279 | IRF_methods, /* tp_methods */ 280 | IRF_members, /* tp_members */ 281 | 0, /* tp_getset */ 282 | 0, /* tp_base */ 283 | 0, /* tp_dict */ 284 | 0, /* tp_descr_get */ 285 | 0, /* tp_descr_set */ 286 | 0, /* tp_dictoffset */ 287 | (initproc)IRF_init, /* tp_init */ 288 | 0, /* tp_alloc */ 289 | IRF_new, /* tp_new */ 290 | }; 291 | 292 | struct SampleIter { 293 | PyObject_HEAD 294 | SampleWalker* walker; 295 | 296 | void setRange(SampleWalker* w) { 297 | delete walker; 298 | walker = w; 299 | } 300 | 301 | SampleIter(SampleWalker* w) { 302 | walker = w; 303 | } 304 | 305 | SampleIter(void) { 306 | walker = 0; 307 | } 308 | 309 | ~SampleIter(void) { 310 | delete walker; 311 | } 312 | }; 313 | 314 | static void SampleIter_dealloc(SampleIter* self) { 315 | self->~SampleIter(); 316 | self->ob_type->tp_free((PyObject*)self); 317 | } 318 | 319 | static PyObject* SampleIter_new(PyTypeObject *type, PyObject *ars, PyObject *kwds) { 320 | SampleIter* self; 321 | 322 | self = new (type->tp_alloc(type, 0)) SampleIter(); 323 | if( self != NULL) { 324 | } 325 | return (PyObject*)self; 326 | } 327 | 328 | static int SampleIter_init(SampleIter* self, PyObject* args, PyObject* kwds) { 329 | return 0; 330 | } 331 | 332 | PyObject* SampleIter_iter(PyObject *self) { 333 | Py_INCREF(self); 334 | return self; 335 | } 336 | 337 | PyObject* SampleIter_iternext(PyObject *self) { 338 | SampleIter *p = (SampleIter*)self; 339 | if(p->walker->stillSome()) { 340 | Sample* s = p->walker->get(); 341 | return Py_BuildValue("(sNf)", s->suid.c_str(), packFeatures(s), s->y); 342 | } else { 343 | /* Raising of standard StopIteration exception with empty value. */ 344 | PyErr_SetNone(PyExc_StopIteration); 345 | return NULL; 346 | } 347 | } 348 | 349 | static PyTypeObject SampleIterType = { 350 | PyObject_HEAD_INIT(NULL) 351 | 0, /*ob_size*/ 352 | "irf.SampleIter", /*tp_name*/ 353 | sizeof(SampleIter), /*tp_basicsize*/ 354 | 0, /*tp_itemsize*/ 355 | (destructor)SampleIter_dealloc, /*tp_dealloc*/ 356 | 0, /*tp_print*/ 357 | 0, /*tp_getattr*/ 358 | 0, /*tp_setattr*/ 359 | 0, /*tp_compare*/ 360 | 0, /*tp_repr*/ 361 | 0, /*tp_as_number*/ 362 | 0, /*tp_as_sequence*/ 363 | 0, /*tp_as_mapping*/ 364 | 0, /*tp_hash */ 365 | 0, /*tp_call*/ 366 | 0, /*tp_str*/ 367 | 0, /*tp_getattro*/ 368 | 0, /*tp_setattro*/ 369 | 0, /*tp_as_buffer*/ 370 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_ITER, 371 | /* tp_flags: Py_TPFLAGS_HAVE_ITER tells python to 372 | use tp_iter and tp_iternext fields. */ 373 | "Internal Sample iterator objects", /* tp_doc */ 374 | 0, /* tp_traverse */ 375 | 0, /* tp_clear */ 376 | 0, /* tp_richcompare */ 377 | 0, /* tp_weaklistoffset */ 378 | SampleIter_iter, /* tp_iter: __iter__() method */ 379 | SampleIter_iternext, /* tp_iternext: next() method */ 380 | 0, /* tp_methods */ 381 | 0, /* tp_members */ 382 | 0, /* tp_getset */ 383 | 0, /* tp_base */ 384 | 0, /* tp_dict */ 385 | 0, /* tp_descr_get */ 386 | 0, /* tp_descr_set */ 387 | 0, /* tp_dictoffset */ 388 | (initproc)SampleIter_init, /* tp_init */ 389 | 0, /* tp_alloc */ 390 | SampleIter_new, /* tp_new */ 391 | }; 392 | 393 | static PyObject* IRF_load(PyObject* self, PyObject* args) { 394 | IRF* p; 395 | 396 | char* fname; 397 | if(!PyArg_ParseTuple(args, "s", 398 | &fname)) 399 | return 0; 400 | 401 | p = (IRF*) PyObject_CallObject((PyObject*) &IRFType, args); 402 | 403 | if (!p) return NULL; 404 | 405 | return (PyObject*) p; 406 | } 407 | 408 | static PyMethodDef module_methods[] = { 409 | {"load", (PyCFunction)IRF_load, METH_VARARGS, 410 | "load random forest from file" 411 | }, 412 | {NULL} /* Sentinel */ 413 | }; 414 | 415 | static PyObject* IRF_samples(IRF* self, PyObject* args) { 416 | SampleIter *p; 417 | 418 | PyObject *argList = Py_BuildValue("()"); 419 | p = (SampleIter*) PyObject_CallObject((PyObject*) &SampleIterType, argList); 420 | Py_DECREF(argList); 421 | 422 | if (!p) return NULL; 423 | 424 | /* I'm not sure if it's strictly necessary. */ 425 | if (!PyObject_Init((PyObject *)p, &SampleIterType)) { 426 | Py_DECREF(p); 427 | return NULL; 428 | } 429 | 430 | p->setRange(getSamples(self->forest)); 431 | 432 | return (PyObject *)p; 433 | } 434 | 435 | #ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ 436 | #define PyMODINIT_FUNC void 437 | #endif 438 | PyMODINIT_FUNC initirf(void) { 439 | PyObject* m; 440 | 441 | if (PyType_Ready(&IRFType) < 0) 442 | return; 443 | if (PyType_Ready(&SampleIterType) < 0) 444 | return; 445 | 446 | m = Py_InitModule3("irf", module_methods, 447 | "Incremental Random Forest."); 448 | 449 | Py_INCREF(&IRFType); 450 | PyModule_AddObject(m, "IRF", (PyObject *)&IRFType); 451 | PyModule_AddObject(m, "SampleIter", (PyObject *)&SampleIterType); 452 | } 453 | -------------------------------------------------------------------------------- /irf/node.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2012, Igalia S.L. 2 | * Author Carlos Guerreiro cguerreiro@igalia.com 3 | * Licensed under the MIT license */ 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "randomForest.h" 13 | 14 | using namespace v8; 15 | using namespace node; 16 | using namespace std; 17 | using namespace IncrementalRandomForest; 18 | 19 | class IRF: ObjectWrap { 20 | private: 21 | Forest* f; 22 | 23 | static void setFeatures(Sample* s, Local& features) { 24 | Local featureNames = features->GetOwnPropertyNames(); 25 | int featureCount = featureNames->Length(); 26 | int i; 27 | for(i = 0; i < featureCount; ++i) { 28 | // FIXME: verify that this is an integer 29 | Local n = featureNames->Get(i)->ToInteger(); 30 | // FIXME: verify that this is a number 31 | Local v = features->Get(n->Value())->ToNumber(); 32 | s->xCodes[n->Value()] = v->Value(); 33 | } 34 | } 35 | 36 | static void getFeatures(Sample* s, Local& features) { 37 | map::const_iterator it; 38 | char key[16]; 39 | for(it = s->xCodes.begin(); it != s->xCodes.end(); ++it) { 40 | sprintf(key, "%d", it->first); 41 | features->Set(String::New(key), Number::New(it->second)); 42 | } 43 | } 44 | 45 | public: 46 | IRF(uint32_t count) : ObjectWrap() { 47 | f = create(count); 48 | } 49 | 50 | IRF(Forest* withF) : ObjectWrap(), f(withF) { 51 | } 52 | 53 | ~IRF() { 54 | destroy(f); 55 | } 56 | 57 | static void init(Handle < Object > target, Handle (*func)(const Arguments&), Persistent& ct, const char* name) { 58 | Local t = FunctionTemplate::New(func); 59 | ct = Persistent::New(t); 60 | ct->InstanceTemplate()->SetInternalFieldCount(1); 61 | Local nameSymbol = String::NewSymbol(name); 62 | ct->SetClassName(nameSymbol); 63 | NODE_SET_PROTOTYPE_METHOD(ct, "add", add); 64 | NODE_SET_PROTOTYPE_METHOD(ct, "remove", remove); 65 | NODE_SET_PROTOTYPE_METHOD(ct, "classify", classify); 66 | NODE_SET_PROTOTYPE_METHOD(ct, "classifyPartial", classifyPartial); 67 | NODE_SET_PROTOTYPE_METHOD(ct, "asJSON", asJSON); 68 | NODE_SET_PROTOTYPE_METHOD(ct, "statsJSON", statsJSON); 69 | NODE_SET_PROTOTYPE_METHOD(ct, "each", each); 70 | NODE_SET_PROTOTYPE_METHOD(ct, "commit", commit); 71 | NODE_SET_PROTOTYPE_METHOD(ct, "toBuffer", toBuffer); 72 | target->Set(nameSymbol, ct->GetFunction()); 73 | } 74 | 75 | static Handle fromBuffer(const Arguments& args) { 76 | if(args.Length() != 1) { 77 | return ThrowException(Exception::Error(String::New("add takes 3 arguments"))); 78 | } 79 | 80 | if(!Buffer::HasInstance(args[0])) 81 | return ThrowException(Exception::Error(String::New("argument must be a Buffer"))); 82 | 83 | Local o = args[0]->ToObject(); 84 | 85 | cerr << Buffer::Length(o) << endl; 86 | 87 | stringstream ss(Buffer::Data(o)); 88 | 89 | IRF* ih = new IRF(load(ss)); 90 | ih->Wrap(args.This()); 91 | return args.This(); 92 | } 93 | 94 | static Handle New(const Arguments& args) { 95 | HandleScope scope; 96 | 97 | if (!args.IsConstructCall()) { 98 | return ThrowException(Exception::TypeError(String::New("Use the new operator to create instances of this object."))); 99 | } 100 | 101 | IRF* ih; 102 | if(args.Length() >= 1) { 103 | if(args[0]->IsNumber()) { 104 | uint32_t count = args[0]->ToInteger()->Value(); 105 | ih = new IRF(count); 106 | } else if(Buffer::HasInstance(args[0])) { 107 | Local o = args[0]->ToObject(); 108 | stringstream ss(Buffer::Data(o)); 109 | ih = new IRF(load(ss)); 110 | } else { 111 | return ThrowException(Exception::Error(String::New("argument 1 must be a number (number of trees) or a Buffer (to create from)"))); 112 | } 113 | } else 114 | ih = new IRF(1); 115 | 116 | ih->Wrap(args.This()); 117 | return args.This(); 118 | } 119 | 120 | static Handle add(const Arguments& args) { 121 | HandleScope scope; 122 | 123 | if(args.Length() != 3) { 124 | return ThrowException(Exception::Error(String::New("add takes 3 arguments"))); 125 | } 126 | 127 | Local suid = *args[0]->ToString(); 128 | if(suid.IsEmpty()) 129 | return ThrowException(Exception::Error(String::New("argument 1 must be a string"))); 130 | 131 | if(!args[1]->IsObject()) 132 | return ThrowException(Exception::Error(String::New("argument 2 must be a object"))); 133 | Local features = *args[1]->ToObject(); 134 | 135 | if(!args[2]->IsNumber()) 136 | return ThrowException(Exception::Error(String::New("argument 3 must be a number"))); 137 | Local y = *args[2]->ToNumber(); 138 | 139 | IRF* ih = ObjectWrap::Unwrap(args.This()); 140 | Sample* s = new Sample(); 141 | s->suid = *String::AsciiValue(suid); 142 | s->y = y->Value(); 143 | setFeatures(s, features); 144 | 145 | return scope.Close(Boolean::New(IncrementalRandomForest::add(ih->f, s))); 146 | } 147 | 148 | static Handle remove(const Arguments& args) { 149 | HandleScope scope; 150 | 151 | if(args.Length() != 1) { 152 | return ThrowException(Exception::Error(String::New("remove takes 1 argument"))); 153 | } 154 | 155 | Local suid = *args[0]->ToString(); 156 | if(suid.IsEmpty()) 157 | return ThrowException(Exception::Error(String::New("argument 1 must be a string"))); 158 | 159 | IRF* ih = ObjectWrap::Unwrap(args.This()); 160 | 161 | return scope.Close(Boolean::New(IncrementalRandomForest::remove(ih->f, *String::AsciiValue(suid)))); 162 | } 163 | 164 | static Handle classify(const Arguments& args) { 165 | HandleScope scope; 166 | 167 | if(args.Length() != 1) { 168 | return ThrowException(Exception::Error(String::New("classify takes 1 argument"))); 169 | } 170 | 171 | if(!args[0]->IsObject()) 172 | return ThrowException(Exception::Error(String::New("argument 1 must be a object"))); 173 | Local features = *args[0]->ToObject(); 174 | 175 | IRF* ih = ObjectWrap::Unwrap(args.This()); 176 | 177 | IncrementalRandomForest::Sample s; 178 | setFeatures(&s, features); 179 | 180 | return scope.Close(Number::New(IncrementalRandomForest::classify(ih->f, &s))); 181 | } 182 | 183 | static Handle classifyPartial(const Arguments& args) { 184 | HandleScope scope; 185 | 186 | if(args.Length() != 2) { 187 | return ThrowException(Exception::Error(String::New("classifyPartial takes 2 argument"))); 188 | } 189 | 190 | if(!args[0]->IsObject()) 191 | return ThrowException(Exception::Error(String::New("argument 1 must be a object"))); 192 | Local features = *args[0]->ToObject(); 193 | 194 | 195 | if(!args[1]->IsNumber()) 196 | return ThrowException(Exception::Error(String::New("argument 2 must be a number"))); 197 | Local nTrees = *args[1]->ToNumber(); 198 | 199 | IRF* ih = ObjectWrap::Unwrap(args.This()); 200 | 201 | IncrementalRandomForest::Sample s; 202 | setFeatures(&s, features); 203 | 204 | return scope.Close(Number::New(IncrementalRandomForest::classifyPartial(ih->f, &s, nTrees->Value()))); 205 | } 206 | 207 | static Handle asJSON(const Arguments& args) { 208 | HandleScope scope; 209 | 210 | if(args.Length() != 0) { 211 | return ThrowException(Exception::Error(String::New("toJSON takes 0 arguments"))); 212 | } 213 | 214 | IRF* ih = ObjectWrap::Unwrap(args.This()); 215 | 216 | stringstream ss; 217 | IncrementalRandomForest::asJSON(ih->f, ss); 218 | ss.flush(); 219 | 220 | return scope.Close(String::New(ss.str().c_str())); 221 | } 222 | 223 | static Handle statsJSON(const Arguments& args) { 224 | HandleScope scope; 225 | 226 | if(args.Length() != 0) { 227 | return ThrowException(Exception::Error(String::New("statsJSON takes 0 arguments"))); 228 | } 229 | 230 | IRF* ih = ObjectWrap::Unwrap(args.This()); 231 | 232 | stringstream ss; 233 | IncrementalRandomForest::statsJSON(ih->f, ss); 234 | ss.flush(); 235 | 236 | return scope.Close(String::New(ss.str().c_str())); 237 | } 238 | 239 | static Handle each(const Arguments& args) { 240 | HandleScope scope; 241 | 242 | if(args.Length() != 1) { 243 | return ThrowException(Exception::Error(String::New("each takes 1 argument"))); 244 | } 245 | if (!args[0]->IsFunction()) { 246 | return ThrowException(Exception::TypeError(String::New("argument must be a callback function"))); 247 | } 248 | // There's no ToFunction(), use a Cast instead. 249 | Local callback = Local::Cast(args[0]); 250 | 251 | Local k = Local::New(Undefined()); 252 | Local v = Local::New(Undefined()); 253 | 254 | const unsigned argc = 3; 255 | Local argv[argc] = { v }; 256 | 257 | IRF* ih = ObjectWrap::Unwrap(args.This()); 258 | 259 | SampleWalker* walker = getSamples(ih->f); 260 | 261 | Local globalObj = Context::GetCurrent()->Global(); 262 | Local objectConstructor = Local::Cast(globalObj->Get(String::New("Object"))); 263 | 264 | while(walker->stillSome()) { 265 | Sample* s = walker->get(); 266 | argv[0] = Local::New(String::New(s->suid.c_str())); 267 | Local features = Object::New(); 268 | getFeatures(s, features); 269 | argv[1] = features; 270 | argv[2] = Local::New(Number::New(s->y)); 271 | TryCatch tc; 272 | Local ret = callback->Call(Context::GetCurrent()->Global(), argc, argv); 273 | if(ret.IsEmpty() || ret->IsFalse()) 274 | break; 275 | } 276 | 277 | delete walker; 278 | 279 | return Undefined(); 280 | } 281 | 282 | static Handle commit(const Arguments& args) { 283 | HandleScope scope; 284 | 285 | if(args.Length() != 0) { 286 | return ThrowException(Exception::Error(String::New("commit takes 0 arguments"))); 287 | } 288 | 289 | IRF* ih = ObjectWrap::Unwrap(args.This()); 290 | 291 | IncrementalRandomForest::commit(ih->f); 292 | 293 | return scope.Close(Undefined()); 294 | } 295 | 296 | static Handle toBuffer(const Arguments& args) { 297 | HandleScope scope; 298 | 299 | if(args.Length() != 0) { 300 | return ThrowException(Exception::Error(String::New("save takes 0 arguments"))); 301 | } 302 | 303 | IRF* ih = ObjectWrap::Unwrap(args.This()); 304 | stringstream ss(stringstream::out | stringstream::binary); 305 | save(ih->f, ss); 306 | ss.flush(); 307 | 308 | Buffer* out = Buffer::New(const_cast(ss.str().c_str()), ss.tellp()); 309 | 310 | return scope.Close(out->handle_); 311 | } 312 | }; 313 | 314 | static Persistent irf_ct; 315 | 316 | void RegisterModule(Handle target) { 317 | IRF::init(target, IRF::New, irf_ct, "IRF"); 318 | } 319 | 320 | NODE_MODULE(irf, RegisterModule); 321 | -------------------------------------------------------------------------------- /irf/randomForest.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2010-2011 Carlos Guerreiro 2 | * Licensed under the MIT license */ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "randomForest.h" 20 | #include "MurmurHash3.h" 21 | 22 | #include 23 | 24 | #include 25 | 26 | using namespace std; 27 | using google::sparse_hash_map; 28 | 29 | 30 | namespace IncrementalRandomForest { 31 | 32 | static const unsigned int maxCodesToConsider = 30; 33 | static const unsigned int maxCodesToKeep = 40; 34 | static const unsigned int minEvidence = 2; 35 | static const unsigned int maxUnsplit = 30; 36 | static const unsigned int minBalanceSplit = 10; 37 | static const float minProbDiff = 0; 38 | static const float minEntropyGain = 0.01; 39 | 40 | template 41 | static inline string to_string (const T& t) 42 | { 43 | stringstream ss; 44 | ss << t; 45 | return ss.str(); 46 | } 47 | 48 | static void printSample(ostream& out, Sample* s) { 49 | out << s->y; 50 | out << " " << s->xCodes.size(); 51 | map::const_iterator itCodes; 52 | for(itCodes = s->xCodes.begin(); itCodes != s->xCodes.end(); ++itCodes) { 53 | out << " " << itCodes->first << " " << itCodes->second; 54 | } 55 | out << endl; 56 | } 57 | 58 | static float entropyBinary(int c0, int c1) { 59 | float h = 0; 60 | const int n = c0 + c1; 61 | if(c0 > 0) { 62 | float p0 = (float) c0 / n; 63 | h -= p0 * log(p0); 64 | } 65 | if(c1 > 0) { 66 | float p1 = (float) c1 / n; 67 | h -= p1 * log(p1); 68 | } 69 | return h; 70 | } 71 | 72 | static void findUsedCodes(SampleWalker& sw, set& uc) { 73 | map::const_iterator itC; 74 | while(sw.stillSome()) { 75 | Sample* s = sw.get(); 76 | for(itC = s->xCodes.begin(); itC != s->xCodes.end(); ++itC) 77 | uc.insert(itC->first); 78 | } 79 | } 80 | 81 | static void splitListByTarget(const vector& sl, vector& sl0, vector& sl1) { 82 | vector::const_iterator it; 83 | for(it = sl.begin(); it != sl.end(); ++it) { 84 | Sample* s = *it; 85 | if(s->y > 0.5) 86 | sl1.push_back(s); 87 | else 88 | sl0.push_back(s); 89 | } 90 | } 91 | 92 | class VectorSampleWalker : public SampleWalker { 93 | private: 94 | vector::const_iterator itCurr; 95 | vector::const_iterator itEnd; 96 | public: 97 | VectorSampleWalker(const vector& sv) : 98 | itCurr(sv.begin()), itEnd(sv.end()) { 99 | } 100 | virtual bool stillSome(void) const { 101 | return itCurr != itEnd; 102 | } 103 | virtual Sample* get(void) { 104 | return *itCurr++; 105 | } 106 | }; 107 | 108 | static void splitListAgainstCode(SampleWalker& sw, int c, vector& sl0, vector& sl1) { 109 | while(sw.stillSome()) { 110 | Sample* s = sw.get(); 111 | map::const_iterator itCode = s->xCodes.find(c); 112 | bool cc; 113 | if(itCode != s->xCodes.end()) { 114 | cc = itCode->second > 0.5; 115 | } else 116 | cc = false; 117 | 118 | if(cc) 119 | sl1.push_back(s); 120 | else 121 | sl0.push_back(s); 122 | } 123 | } 124 | 125 | typedef unsigned int CodeRankType; 126 | 127 | struct DecisionCounts { 128 | unsigned int c0p; 129 | unsigned int c1p; 130 | CodeRankType rank; 131 | 132 | DecisionCounts(void) { 133 | c0p = 0; 134 | c1p = 0; 135 | rank = 0; // needs to be set afterwards 136 | } 137 | 138 | bool enoughEvidence(DecisionTreeNode* dt) const; 139 | 140 | bool isZeroFor(DecisionTreeNode* dt) const; 141 | 142 | void print(ostream& outs, DecisionTreeNode* dt) const; 143 | 144 | float entropy(DecisionTreeNode* dt) const; 145 | }; 146 | 147 | static bool operator == (const DecisionCounts& dc1, const DecisionCounts& dc2) { 148 | return (dc1.c0p == dc2.c0p) && (dc1.c1p == dc2.c1p); 149 | } 150 | 151 | static bool operator != (const DecisionCounts& dc1, const DecisionCounts& dc2) { 152 | return !(dc1 == dc2); 153 | } 154 | 155 | struct DecisionTreeInternal; 156 | struct DecisionTreeLeaf; 157 | 158 | struct DecisionTreeNode { 159 | int code; // iff code == -1 it's a leaf node 160 | unsigned int c0; 161 | unsigned int c1; 162 | sparse_hash_map decisionCountMap; 163 | unsigned long id; 164 | pair minValidRank; 165 | DecisionTreeNode() : decisionCountMap() { 166 | decisionCountMap.set_deleted_key(-1); 167 | minValidRank = make_pair(0U, 0); 168 | } 169 | DecisionTreeInternal* checkInternal(void); 170 | DecisionTreeLeaf* checkLeaf(void); 171 | bool checkType(DecisionTreeInternal**, DecisionTreeLeaf**); 172 | }; 173 | 174 | struct DecisionTreeInternal : public DecisionTreeNode { 175 | DecisionTreeNode* negative; 176 | DecisionTreeNode* positive; 177 | }; 178 | 179 | struct DecisionTreeLeaf : public DecisionTreeNode { 180 | float value; 181 | vector samples; 182 | }; 183 | 184 | DecisionTreeInternal* DecisionTreeNode::checkInternal(void) { 185 | if(code != -1) 186 | return static_cast(this); 187 | else 188 | return 0; 189 | } 190 | 191 | DecisionTreeLeaf* DecisionTreeNode::checkLeaf(void) { 192 | if(code == -1) 193 | return static_cast(this); 194 | else 195 | return 0; 196 | } 197 | 198 | bool DecisionTreeNode::checkType(DecisionTreeInternal** ni, DecisionTreeLeaf** nl) { 199 | if(code == -1) { 200 | *ni = 0; 201 | *nl = static_cast(this); 202 | return true; 203 | } else { 204 | *ni = static_cast(this); 205 | *nl = 0; 206 | return false; 207 | } 208 | } 209 | 210 | bool DecisionCounts::enoughEvidence(DecisionTreeNode* dt) const { 211 | const unsigned int c0n = dt->c0 - c0p; 212 | const unsigned int c1n = dt->c1 - c1p; 213 | return ((c0n + c1n) >= minEvidence) && ((c0p + c1p) >= minEvidence); 214 | } 215 | 216 | void DecisionCounts::print(ostream& outs, DecisionTreeNode* dt) const { 217 | const unsigned int c0n = dt->c0 - c0p; 218 | const unsigned int c1n = dt->c1 - c1p; 219 | outs << " c0n = " << c0n << endl; 220 | outs << " c1n = " << c1n << endl; 221 | outs << " c0p = " << c0p << endl; 222 | outs << " c1p = " << c1p << endl; 223 | outs << " rank = " << rank << endl; 224 | } 225 | 226 | float DecisionCounts::entropy(DecisionTreeNode* dt) const { 227 | // FIXME: redundant computation of c0n and c1n all over the place 228 | const unsigned int c0n = dt->c0 - c0p; 229 | const unsigned int c1n = dt->c1 - c1p; 230 | float hn = entropyBinary(c0n, c1n); 231 | float hp = entropyBinary(c0p, c1p); 232 | int cp = c0p + c1p; 233 | int cn = c0n + c1n; 234 | return (hn * cn + hp * cp) / (cn + cp); 235 | } 236 | 237 | static DecisionTreeLeaf* makeLeaf(TreeState& ts, float v) { 238 | DecisionTreeLeaf* n = new DecisionTreeLeaf(); 239 | n->code = -1; 240 | n->value = v; 241 | n->c0 = 0; 242 | n->c1 = 0; 243 | n->id = rand_r(&ts.seed); 244 | return n; 245 | } 246 | 247 | static DecisionTreeInternal* makeInternal(TreeState& ts, int c, DecisionTreeNode* n0, DecisionTreeNode* n1) { 248 | DecisionTreeInternal* n = new DecisionTreeInternal(); 249 | n->code = c; 250 | n->negative = n0; 251 | n->positive = n1; 252 | n->id = rand_r(&ts.seed); 253 | return n; 254 | } 255 | 256 | bool DecisionCounts::isZeroFor(DecisionTreeNode* dt) const { 257 | const unsigned int c0n = dt->c0 - c0p; 258 | const unsigned int c1n = dt->c1 - c1p; 259 | return (c0n == dt->c0) && (c1n == dt->c1) && (c0p == 0) && (c1p == 0); 260 | } 261 | 262 | static void destroyDecisionTreeNode(DecisionTreeNode* dt) { 263 | DecisionTreeInternal* ni; 264 | DecisionTreeLeaf* nl; 265 | if(!dt->checkType(&ni, &nl)) { 266 | destroyDecisionTreeNode(ni->negative); 267 | destroyDecisionTreeNode(ni->positive); 268 | delete ni; 269 | } else 270 | delete nl; 271 | } 272 | 273 | static DecisionTreeNode* emptyDecisionTree(TreeState& ts) { 274 | return makeLeaf(ts, 0); 275 | } 276 | 277 | static CodeRankType codeRankInNode(int code, unsigned long nodeId) { 278 | char s[64]; 279 | // FIXME: this is possibly quite slow 280 | sprintf(s, "%d%lu", code, nodeId); 281 | uint32_t out; 282 | MurmurHash3_x86_32(s, strlen(s), 42, &out); 283 | return out; 284 | } 285 | 286 | static void updateValue(DecisionTreeLeaf* l) { 287 | int n = l->c0 + l->c1; 288 | if(n == 0) 289 | l->value = 1; // assume positive, FIXME: does this make sense? 290 | else 291 | l->value = (float)l->c1 / n; 292 | } 293 | 294 | class TreeSampleWalker; 295 | 296 | static void computeDecisionCounters(DecisionTreeNode* dt, 297 | const TreeSampleWalker& origSW, 298 | sparse_hash_map& decisionCountMap, 299 | unsigned int& outC0, 300 | unsigned int& outC1, 301 | pair& minValidRank); 302 | 303 | static pair findMinRankToConsider(const sparse_hash_map& dcMap) { 304 | sparse_hash_map::const_iterator mapIt; 305 | pair minRankToConsider; 306 | minRankToConsider.first = 0; minRankToConsider.second = 0; // FIXME: is this necessary? 307 | if(dcMap.size() > maxCodesToConsider) { 308 | set > ranks; 309 | for(mapIt = dcMap.begin(); mapIt != dcMap.end(); ++mapIt) { 310 | ranks.insert(make_pair(mapIt->second.rank, mapIt->first)); 311 | if(ranks.size() > maxCodesToConsider) 312 | ranks.erase(ranks.begin()); 313 | } 314 | minRankToConsider = *(ranks.begin()); 315 | } 316 | return minRankToConsider; 317 | } 318 | 319 | static int findMinEntropyCode(float currentEntropy, DecisionTreeNode* dt) { 320 | float minEntropy = 10; 321 | int minEntropyCode = -1; 322 | 323 | sparse_hash_map::const_iterator mapIt; 324 | 325 | // FIXME: this is possibly too inneficient. keep DCs sorted by rank in a vector instead? 326 | pair minRankToConsider = findMinRankToConsider(dt->decisionCountMap); 327 | 328 | for(mapIt = dt->decisionCountMap.begin(); mapIt != dt->decisionCountMap.end(); ++mapIt) { 329 | const DecisionCounts& dc = mapIt->second; 330 | 331 | if(make_pair(dc.rank, mapIt->first) >= minRankToConsider && 332 | dc.enoughEvidence(dt)) { 333 | 334 | float ah = dc.entropy(dt); 335 | 336 | if(ah < minEntropy) { 337 | minEntropy = ah; 338 | minEntropyCode = mapIt->first; 339 | } 340 | } 341 | } 342 | 343 | if(minEntropyCode == -1) 344 | return -1; 345 | 346 | if(minEntropy < currentEntropy) 347 | return minEntropyCode; 348 | 349 | return -1; 350 | } 351 | 352 | static void splitNode(TreeState& ts, DecisionTreeInternal* dt, int minEntropyCode, SampleWalker& sw); 353 | 354 | class TreeSampleWalker : public SampleWalker { 355 | private: 356 | stack st; 357 | vector::iterator itCurr, itEnd; 358 | bool hasSome; 359 | 360 | bool stackDown(DecisionTreeNode* n) { 361 | DecisionTreeInternal* ni; 362 | DecisionTreeLeaf* nl; 363 | while(!n->checkType(&ni, &nl)) { 364 | st.push(ni); 365 | n = ni->negative; 366 | } 367 | itCurr = nl->samples.begin(); 368 | itEnd = nl->samples.end(); 369 | if(itCurr != itEnd) { 370 | st.push(nl); 371 | return true; 372 | } else { 373 | return false; 374 | } 375 | } 376 | private: 377 | void advanceToNext(void) { 378 | while(!st.empty()) { 379 | DecisionTreeInternal* ni = st.top()->checkInternal(); 380 | st.pop(); 381 | if(stackDown(ni->positive)) 382 | break; 383 | } 384 | } 385 | public: 386 | TreeSampleWalker(DecisionTreeNode* n): st() { 387 | if(!stackDown(n)) { 388 | advanceToNext(); 389 | } 390 | } 391 | virtual bool stillSome(void) const { 392 | return !st.empty(); 393 | } 394 | virtual Sample* get(void) { 395 | Sample* s = *itCurr; 396 | 397 | ++itCurr; 398 | if(itCurr == itEnd) { 399 | st.pop(); // pop leaf 400 | 401 | advanceToNext(); 402 | } 403 | 404 | return s; 405 | } 406 | }; 407 | 408 | // samples 409 | static void setupLeafFromSamples(DecisionTreeLeaf* dt) { 410 | // FIXME: probably done already as we can only call this on a leaf 411 | dt->code = -1; 412 | 413 | computeDecisionCounters(dt, 414 | TreeSampleWalker(dt), 415 | dt->decisionCountMap, 416 | dt->c0, 417 | dt->c1, 418 | dt->minValidRank); 419 | updateValue(dt); 420 | } 421 | 422 | static void computeDecisionCounters(DecisionTreeNode* dt, 423 | const TreeSampleWalker& origSW, 424 | sparse_hash_map& decisionCountMap, 425 | unsigned int& outC0, 426 | unsigned int& outC1, 427 | pair& minValidRank) { 428 | 429 | minValidRank = make_pair(0U, 0); 430 | 431 | set usedCodes; 432 | TreeSampleWalker sw2 = origSW; 433 | findUsedCodes(sw2, usedCodes); 434 | 435 | set::const_iterator ucIt; 436 | 437 | int c0 = 0; 438 | int c1 = 0; 439 | 440 | vector::const_iterator sIt; 441 | 442 | decisionCountMap.clear(); 443 | 444 | for(TreeSampleWalker sw = origSW; sw.stillSome();) { 445 | Sample* s = sw.get(); 446 | bool classIn = s->y >= 0.5; 447 | if(classIn) 448 | ++c1; 449 | else 450 | ++c0; 451 | } 452 | outC0 = c0; 453 | outC1 = c1; 454 | 455 | set > ranks; 456 | for(ucIt = usedCodes.begin(); ucIt != usedCodes.end(); ++ucIt) { 457 | const int code = *ucIt; 458 | CodeRankType rank = codeRankInNode(code, dt->id); 459 | ranks.insert(make_pair(rank, code)); 460 | if(ranks.size() > maxCodesToKeep) { 461 | minValidRank = max(minValidRank, make_pair(ranks.begin()->first, code + 1)); 462 | ranks.erase(ranks.begin()); 463 | } 464 | } 465 | 466 | set >::const_iterator rIt; 467 | for(rIt = ranks.begin(); rIt != ranks.end(); ++rIt) { 468 | const int code = rIt->second; 469 | DecisionCounts& dc = decisionCountMap[code]; 470 | dc.rank = codeRankInNode(code, dt->id); 471 | for(TreeSampleWalker sw=origSW; sw.stillSome();) { 472 | Sample* s = sw.get(); 473 | bool classIn = s->y > 0.5; 474 | 475 | map::const_iterator itCodeInS = s->xCodes.find(code); 476 | bool hasCode = itCodeInS != s->xCodes.end() && itCodeInS->second > 0.5; 477 | 478 | if(classIn) { 479 | if(hasCode) { 480 | ++(dc.c1p); 481 | } 482 | } else { 483 | if(hasCode) { 484 | ++(dc.c0p); 485 | } 486 | } 487 | } 488 | } 489 | } 490 | 491 | static DecisionTreeNode* splitLeafIfPossible(TreeState& ts, DecisionTreeNode* dt) { 492 | float currentEntropy = entropyBinary(dt->c0, dt->c1); 493 | int minEntropyCode = findMinEntropyCode(currentEntropy, dt); 494 | 495 | bool shouldBeSplit = minEntropyCode != -1; 496 | 497 | if(shouldBeSplit) { 498 | TreeSampleWalker sw(dt); 499 | DecisionTreeInternal* newInternal = makeInternal(ts, minEntropyCode, 0, 0); 500 | newInternal->c0 = dt->c0; 501 | newInternal->c1 = dt->c1; 502 | newInternal->minValidRank = dt->minValidRank; 503 | newInternal->decisionCountMap = dt->decisionCountMap; 504 | newInternal->id = dt->id; 505 | splitNode(ts, newInternal, minEntropyCode, sw); 506 | destroyDecisionTreeNode(dt); 507 | return newInternal; 508 | } else 509 | return dt; 510 | } 511 | 512 | static void splitNode(TreeState& ts, DecisionTreeInternal* dt, int minEntropyCode, SampleWalker& sw) { 513 | DecisionTreeLeaf* dtn = makeLeaf(ts, 0); 514 | DecisionTreeLeaf* dtp = makeLeaf(ts, 0); 515 | 516 | if(dt->decisionCountMap.find(minEntropyCode) == dt->decisionCountMap.end()) { 517 | cerr << " code " << minEntropyCode << " not found in decisionCountMap!" << endl; 518 | exit(1); 519 | } 520 | 521 | splitListAgainstCode(sw, minEntropyCode, dtn->samples, dtp->samples); 522 | 523 | setupLeafFromSamples(dtn); 524 | setupLeafFromSamples(dtp); 525 | 526 | if(dt->negative) { 527 | // resplit 528 | destroyDecisionTreeNode(dt->negative); 529 | } 530 | dt->negative = dtn; 531 | if(dt->positive) { 532 | // resplit 533 | destroyDecisionTreeNode(dt->positive); 534 | } 535 | dt->positive = dtp; 536 | 537 | dt->code = minEntropyCode; 538 | 539 | dt->negative = splitLeafIfPossible(ts, dt->negative); 540 | dt->positive = splitLeafIfPossible(ts, dt->positive); 541 | } 542 | 543 | static void updateDecisionCounters(DecisionTreeNode* dt, Sample* s, int addedBefore0, int addedBefore1, int direction = 1) { 544 | sparse_hash_map::iterator dcIt; 545 | for(dcIt = dt->decisionCountMap.begin(); dcIt != dt->decisionCountMap.end();) { 546 | const int code = dcIt->first; 547 | DecisionCounts& dc = dcIt->second; 548 | map::const_iterator codeIt = s->xCodes.find(code); 549 | if(codeIt == s->xCodes.end()) { 550 | // code not used in sample 551 | } else { 552 | // code used in sample 553 | if(s->y >= 0.5) { 554 | if(codeIt->second >= 0.5) { 555 | (dc.c1p) += direction; 556 | } 557 | } else { 558 | if(codeIt->second >= 0.5) { 559 | (dc.c0p) += direction; 560 | } 561 | } 562 | } 563 | 564 | if(direction < 0 && dc.c0p == 0 && dc.c1p == 0) { 565 | sparse_hash_map::iterator toErase = dcIt; 566 | ++dcIt; 567 | dt->decisionCountMap.erase(toErase); 568 | } else 569 | ++dcIt; 570 | } 571 | 572 | if(direction < 0) 573 | return; 574 | 575 | // FIXME: unnecessary if it's clear no new codes will be added 576 | set > ranks; 577 | for(dcIt = dt->decisionCountMap.begin(); dcIt != dt->decisionCountMap.end(); ++dcIt) 578 | ranks.insert(make_pair(dcIt->second.rank, dcIt->first)); 579 | 580 | map::const_iterator codeIt; 581 | for(codeIt = s->xCodes.begin(); codeIt != s->xCodes.end(); ++codeIt) { 582 | sparse_hash_map::iterator dcIt = dt->decisionCountMap.find(codeIt->first); 583 | if(dcIt == dt->decisionCountMap.end()) { 584 | 585 | CodeRankType newRank = codeRankInNode(codeIt->first, dt->id); 586 | 587 | bool doInsert = make_pair(newRank, codeIt->first) >= dt->minValidRank; 588 | 589 | if(doInsert) { 590 | dcIt = dt->decisionCountMap.insert(make_pair(codeIt->first, DecisionCounts())).first; 591 | DecisionCounts& dc = dcIt->second; 592 | dc.rank = newRank; 593 | 594 | if(s->y >= 0.5) { 595 | if(codeIt->second >= 0.5) { 596 | (dc.c1p) += direction; 597 | } 598 | } else { 599 | if(codeIt->second >= 0.5) { 600 | (dc.c0p) += direction; 601 | } 602 | } 603 | 604 | ranks.insert(make_pair(dc.rank, codeIt->first)); 605 | 606 | if(ranks.size() > maxCodesToKeep) { 607 | int toDrop = ranks.begin()->second; 608 | dt->minValidRank = max(dt->minValidRank, make_pair(ranks.begin()->first, ranks.begin()->second + 1)); 609 | ranks.erase(ranks.begin()); 610 | dt->decisionCountMap.erase(toDrop); 611 | } 612 | } 613 | } 614 | } 615 | } 616 | 617 | static void printDCs(const sparse_hash_map& dc, DecisionTreeNode* dt) { 618 | set > ranks; 619 | sparse_hash_map::const_iterator it; 620 | for(it = dc.begin(); it != dc.end(); ++it) { 621 | ranks.insert(make_pair(it->second.rank, it->first)); 622 | } 623 | set >::const_reverse_iterator rIt; 624 | for(rIt = ranks.rbegin(); rIt != ranks.rend(); ++rIt) 625 | cerr << " " << rIt->first << "," << rIt->second; 626 | cerr << endl; 627 | } 628 | 629 | static void printNodeSamples(DecisionTreeNode* dt) { 630 | TreeSampleWalker sw(dt); 631 | while(sw.stillSome()) { 632 | printSample(cerr, sw.get()); 633 | } 634 | } 635 | 636 | static bool compareDCsDir(const sparse_hash_map& dcM1, 637 | const sparse_hash_map& dcM2, 638 | DecisionTreeNode* dt, 639 | const char* tag1, 640 | const char* tag2) { 641 | bool valid = true; 642 | 643 | pair minR1 = findMinRankToConsider(dcM1); 644 | pair minR2 = findMinRankToConsider(dcM2); 645 | 646 | sparse_hash_map::const_iterator itDC; 647 | 648 | int countIn = 0; 649 | for(itDC = dcM1.begin(); itDC != dcM1.end(); ++itDC) { 650 | int code = itDC->first; 651 | const DecisionCounts& dc = itDC->second; 652 | 653 | if(make_pair(dc.rank, code) >= minR1) { 654 | ++countIn; 655 | sparse_hash_map::const_iterator itDC2 = dcM2.find(code); 656 | if(itDC2 == dcM2.end()) { 657 | if(!dc.isZeroFor(dt)) { 658 | cerr << "ERROR: non-zero DC for code " << code << " not found: (" << tag1 << " in " << tag2 << ") in " << (long)dt << " : " << endl; 659 | cerr << "minValidRank = " << dt->minValidRank.first << " , " << dt->minValidRank.second << endl; 660 | cerr << "minR1 = (" << minR1.first << "," << minR1.second << ") minR2 = (" << minR2.first << "," << minR2.second << ")" << endl; 661 | dc.print(cerr, dt); 662 | valid = false; 663 | } 664 | } else { 665 | const DecisionCounts& dc2 = itDC2->second; 666 | if(make_pair(dc2.rank, code) < minR2) { 667 | cerr << "ERROR: non-zero DC for code " << code << " found: (" << tag1 << " in " << tag2 << ") but not in top N in " << (long) dt << endl; 668 | valid = false; 669 | } 670 | if(dc != dc2) { 671 | cerr << "ERROR: DCs for code " << code << " don't match: (" << tag1 << " in " << tag2 << ") in " << (long) dt << " : " << endl; 672 | dc.print(cerr, dt); 673 | dc2.print(cerr, dt); 674 | valid = false; 675 | } 676 | } 677 | } 678 | } 679 | 680 | if(!valid) { 681 | cerr << "there were " << countIn << " codes over minimum out of a total of " << dcM1.size() << endl; 682 | cerr << tag1 << " : " << endl; 683 | printDCs(dcM1, dt); 684 | cerr << tag2 << " : " << endl; 685 | printDCs(dcM2, dt); 686 | printNodeSamples(dt); 687 | } 688 | 689 | return valid; 690 | } 691 | 692 | static bool compareDCs(const sparse_hash_map& dcM1, 693 | const sparse_hash_map& dcM2, 694 | DecisionTreeNode* dt, 695 | const char* tag1, 696 | const char* tag2) { 697 | bool valid = true; 698 | 699 | if(!compareDCsDir(dcM1, dcM2, dt, tag1, tag2)) 700 | valid = false; 701 | if(!compareDCsDir(dcM2, dcM1, dt, tag2, tag1)) 702 | valid = false; 703 | 704 | return valid; 705 | } 706 | 707 | static void insertLeafSamples(vector& v, SampleWalker& sw) { 708 | while(sw.stillSome()) 709 | v.push_back(sw.get()); 710 | } 711 | 712 | static void collectRecursive(DecisionTreeNode* dt, vector& v) { 713 | DecisionTreeInternal *ni; 714 | DecisionTreeLeaf* nl; 715 | if(dt->checkType(&ni, &nl)) { 716 | copy(nl->samples.begin(), nl->samples.end(), back_inserter(v)); 717 | } else { 718 | collectRecursive(ni->negative, v); 719 | collectRecursive(ni->positive, v); 720 | } 721 | } 722 | 723 | static bool validateWalker(DecisionTreeNode* dt) { 724 | vector vRec; 725 | collectRecursive(dt, vRec); 726 | vector vWalker; 727 | TreeSampleWalker sw(dt); 728 | while(sw.stillSome()) 729 | vWalker.push_back(sw.get()); 730 | if(vWalker != vRec) { 731 | cerr << "ERROR: vWalker != vRec" << endl; 732 | return false; 733 | } else { 734 | return true; 735 | } 736 | } 737 | 738 | static bool validateDecisionTree(TreeState& ts, DecisionTreeNode* dt) { 739 | 740 | bool valid = true; 741 | 742 | DecisionTreeInternal* ni; 743 | DecisionTreeLeaf* nl; 744 | dt->checkType(&ni, &nl); 745 | 746 | // make sure there are no multiple versions of the same post 747 | 748 | if(!validateWalker(dt)) 749 | valid = false; 750 | 751 | if(nl) { 752 | vector::const_iterator itS; 753 | map suidMap; 754 | for(itS = nl->samples.begin(); itS != nl->samples.end(); ++itS) { 755 | Sample* s = *itS; 756 | const string& suid = s->suid; 757 | 758 | map::const_iterator itSS = suidMap.find(suid); 759 | if(itSS != suidMap.end()) { 760 | cerr << "ERROR: multiple occurances of post " << suid << endl; 761 | valid = false; 762 | } 763 | suidMap[s->suid] = s; 764 | } 765 | } 766 | 767 | // validate counters; 768 | 769 | if(nl) { 770 | if(nl->samples.size() != (nl->c0 + nl->c1)) { 771 | cerr << "ERROR: c0 + c1 != #samples" << endl; 772 | valid = false; 773 | } 774 | } 775 | 776 | sparse_hash_map::const_iterator itDC; 777 | for(itDC = dt->decisionCountMap.begin(); itDC != dt->decisionCountMap.end(); ++itDC) { 778 | const DecisionCounts& dc = itDC->second; 779 | 780 | const unsigned int c0n = dt->c0 - dc.c0p; 781 | const unsigned int c1n = dt->c1 - dc.c1p; 782 | 783 | if(c0n < 0) { 784 | cerr << "ERROR: c0n < 0: " << c0n << endl; 785 | valid = false; 786 | } 787 | if(c1n < 0) { 788 | cerr << "ERROR: c1n < 0: " << c1n << endl; 789 | valid = false; 790 | } 791 | if(dc.c0p < 0) { 792 | cerr << "ERROR: c0p < 0: " << dc.c0p << endl; 793 | valid = false; 794 | } 795 | if(dc.c1p < 0) { 796 | cerr << "ERROR: c1p < 0: " << dc.c1p << endl; 797 | valid = false; 798 | } 799 | 800 | if((dc.c0p + c0n) != dt->c0) { 801 | cerr << "ERROR: c0p + c0n != c0 : " << dc.c0p << " + " << c0n << " != " << dt->c0 << endl; 802 | valid = false; 803 | } 804 | 805 | if((dc.c1p + c1n) != dt->c1) { 806 | cerr << "ERROR: c1p + c1n != c1 : " << dc.c1p << " + " << c1n << " != " << dt->c1 << endl; 807 | valid = false; 808 | } 809 | } 810 | 811 | // validate counters against samples 812 | 813 | { 814 | sparse_hash_map computedDCs; 815 | unsigned int computedC0, computedC1; 816 | pair computedMinValidRank; 817 | 818 | // FIXME: validate minValidRank 819 | 820 | TreeSampleWalker sw(dt); 821 | computeDecisionCounters(dt, sw, computedDCs, computedC0, computedC1 ,computedMinValidRank); 822 | if(computedC0 != dt->c0) { 823 | cerr << "ERROR: c0 != computedC0 : " << dt->c0 << " != " << computedC0 << endl; 824 | valid = false; 825 | } 826 | if(computedC1 != dt->c1) { 827 | cerr << "ERROR: c1 != computedC1 : " << dt->c1 << " != " << computedC1 << endl; 828 | valid = false; 829 | } 830 | if(!compareDCs(dt->decisionCountMap, computedDCs, dt, "stored", "computed")) { 831 | cerr << "bang bang bang" << endl; 832 | cerr << "dt = " << (long) dt << endl; 833 | cerr << "minValidRank = " << dt->minValidRank.first << " , " << dt->minValidRank.second << endl; 834 | cerr << dt->decisionCountMap.size() << " DCs in stored" << endl; 835 | cerr << computedDCs.size() << " DCs in computed" << endl; 836 | cerr << "split code was " << dt->code << endl; 837 | valid = false; 838 | } 839 | } 840 | 841 | if(ni) { 842 | 843 | sparse_hash_map::const_iterator itDC; 844 | itDC = dt->decisionCountMap.find(dt->code); 845 | if(itDC != dt->decisionCountMap.end()) { 846 | const DecisionCounts& dc = itDC->second; 847 | 848 | const unsigned int c0n = dt->c0 - dc.c0p; 849 | const unsigned int c1n = dt->c1 - dc.c1p; 850 | 851 | if(ni->negative->c0 != c0n) { 852 | cerr << "ERROR: negative->c0 != c0n : " << ni->negative->c0 << " != " << c0n << endl; 853 | valid = false; 854 | } 855 | if(ni->negative->c1 != c1n) { 856 | cerr << "ERROR: negative->c1 != c1n : " << ni->negative->c1 << " != " << c1n << endl; 857 | valid = false; 858 | } 859 | if(ni->positive->c0 != dc.c0p) { 860 | cerr << "ERROR: positive->c0 != c0p : " << ni->positive->c0 << " != " << dc.c0p << endl; 861 | valid = false; 862 | } 863 | if(ni->positive->c1 != dc.c1p) { 864 | cerr << "ERROR: positive->c1 != c1p : " << ni->positive->c1 << " != " << dc.c1p << endl; 865 | valid = false; 866 | } 867 | 868 | } else { 869 | cerr << "ERROR: code upon node is split not found in decisionCountMap" << endl; 870 | valid = false; 871 | } 872 | 873 | if(!valid) 874 | cerr << "validating negative branch now" << endl; 875 | if(!validateDecisionTree(ts, ni->negative)) 876 | valid = false; 877 | if(!valid) 878 | cerr << "validating positive branch now" << endl; 879 | if(!validateDecisionTree(ts, ni->positive)) 880 | valid = false; 881 | 882 | if((ni->negative->c0 + ni->positive->c0) != ni 883 | ->c0) { 884 | cerr << "ERROR: negative->c0 + positive->c0 != c0 : " << ni->negative->c0 << " + " << ni->positive->c0 << " != " << ni->c0 << endl; 885 | valid = false; 886 | } 887 | if((ni->negative->c1 + ni->positive->c1) != dt->c1) { 888 | cerr << "ERROR: negative->c1 + positive->c1 != c1 : " << ni->negative->c1 << " + " << ni->positive->c1 << " != " << ni->c1 << endl; 889 | valid = false; 890 | } 891 | } 892 | 893 | return valid; 894 | } 895 | 896 | static void updateDecisionTreeSamples(DecisionTreeNode* dt, const vector& batchAdd, const vector& batchRemove) { 897 | DecisionTreeInternal* ni; 898 | DecisionTreeLeaf* nl; 899 | 900 | dt->checkType(&ni, &nl); 901 | 902 | if(nl) { 903 | vector::const_iterator bIt; 904 | for(bIt = batchRemove.begin(); bIt != batchRemove.end(); ++bIt) { 905 | Sample* s = *bIt; 906 | vector::iterator sIt = find(nl->samples.begin(), nl->samples.end(), s); 907 | if(sIt != nl->samples.end()) 908 | nl->samples.erase(sIt); 909 | else 910 | cerr << "ERROR: could not find sample to remove!" << endl; 911 | } 912 | nl->samples.insert(nl->samples.end(), batchAdd.begin(), batchAdd.end()); 913 | } 914 | 915 | // FIXME: these splits will have to be done again in when walking the tree to update (split/unsplit) nodes 916 | // how to avoid the duplicate work? 917 | if(ni) { 918 | vector aN, aP; 919 | VectorSampleWalker swAdd(batchAdd); 920 | splitListAgainstCode(swAdd, ni->code, aN, aP); 921 | 922 | vector rN, rP; 923 | VectorSampleWalker swRemove(batchRemove); 924 | splitListAgainstCode(swRemove, ni->code, rN, rP); 925 | 926 | if(aN.size() > 0 || rN.size() > 0) 927 | updateDecisionTreeSamples(ni->negative, aN, rN); 928 | 929 | if(aP.size() > 0 || rP.size() > 0) 930 | updateDecisionTreeSamples(ni->positive, aP, rP); 931 | 932 | } 933 | } 934 | 935 | static DecisionTreeNode* updateDecisionTreeNode(TreeState& ts, DecisionTreeNode* dt, const vector& batchAdd, const vector& batchRemove) { 936 | DecisionTreeInternal* ni; 937 | DecisionTreeLeaf* nl; 938 | 939 | dt->checkType(&ni, &nl); 940 | 941 | { 942 | // removals 943 | 944 | vector::const_iterator bIt; 945 | int addedBefore0 = 0; 946 | int addedBefore1 = 0; 947 | for(bIt = batchRemove.begin(); bIt != batchRemove.end(); ++bIt) { 948 | updateDecisionCounters(dt, *bIt, addedBefore0, addedBefore1, -1); 949 | if((*bIt)->y >= 0.5) 950 | ++addedBefore1; 951 | else 952 | ++addedBefore0; 953 | } 954 | 955 | vector b0, b1; 956 | 957 | // FIXME: more efficient just to count them! 958 | splitListByTarget(batchRemove, b0, b1); 959 | 960 | (dt->c1) += -1 * b1.size(); 961 | (dt->c0) += -1 * b0.size(); 962 | 963 | } 964 | 965 | { 966 | // additions 967 | 968 | vector::const_iterator bIt; 969 | int addedBefore0 = 0; 970 | int addedBefore1 = 0; 971 | for(bIt = batchAdd.begin(); bIt != batchAdd.end(); ++bIt) { 972 | updateDecisionCounters(dt, *bIt, addedBefore0, addedBefore1); 973 | if((*bIt)->y >= 0.5) 974 | ++addedBefore1; 975 | else 976 | ++addedBefore0; 977 | } 978 | 979 | vector b0, b1; 980 | 981 | // FIXME: more efficient just to count them! 982 | splitListByTarget(batchAdd, b0, b1); 983 | 984 | (dt->c1) += b1.size(); 985 | (dt->c0) += b0.size(); 986 | } 987 | 988 | if((dt->decisionCountMap.size() < maxCodesToConsider) 989 | && ((dt->minValidRank.first != 0) || (dt->minValidRank.second != 0))) { 990 | computeDecisionCounters(dt, 991 | TreeSampleWalker(dt), 992 | dt->decisionCountMap, 993 | dt->c0, 994 | dt->c1, 995 | dt->minValidRank); 996 | } 997 | 998 | float currentEntropy = entropyBinary(dt->c0, dt->c1); 999 | int minEntropyCode = findMinEntropyCode(currentEntropy, dt); 1000 | 1001 | bool shouldBeSplit = minEntropyCode != -1; 1002 | 1003 | if(nl) { 1004 | // update leaf node 1005 | 1006 | if(shouldBeSplit) { 1007 | 1008 | // time to split 1009 | 1010 | DecisionTreeInternal* newInternal = makeInternal(ts, minEntropyCode, 0, 0); 1011 | newInternal->c0 = dt->c0; 1012 | newInternal->c1 = dt->c1; 1013 | newInternal->minValidRank = dt->minValidRank; 1014 | newInternal->decisionCountMap = dt->decisionCountMap; 1015 | newInternal->id = dt->id; 1016 | VectorSampleWalker sw(nl->samples); 1017 | splitNode(ts, newInternal, minEntropyCode, sw); 1018 | 1019 | destroyDecisionTreeNode(nl); 1020 | 1021 | return newInternal; 1022 | } else { 1023 | // staying a leaf 1024 | 1025 | updateValue(nl); 1026 | 1027 | return nl; 1028 | } 1029 | } else { 1030 | // update internal node 1031 | 1032 | if(!shouldBeSplit) { 1033 | DecisionTreeLeaf* newLeaf = makeLeaf(ts, 0); 1034 | newLeaf->id = dt->id; 1035 | 1036 | TreeSampleWalker sw(dt); 1037 | 1038 | insertLeafSamples(newLeaf->samples, sw); 1039 | 1040 | setupLeafFromSamples(newLeaf); 1041 | 1042 | destroyDecisionTreeNode(dt); 1043 | 1044 | return newLeaf; 1045 | } else { 1046 | 1047 | if(minEntropyCode != ni->code) { 1048 | TreeSampleWalker sw(dt); 1049 | splitNode(ts, ni, minEntropyCode, sw); 1050 | } else { 1051 | vector aN, aP; 1052 | VectorSampleWalker swAdd(batchAdd); 1053 | 1054 | splitListAgainstCode(swAdd, ni->code, aN, aP); 1055 | 1056 | vector rN, rP; 1057 | VectorSampleWalker swRemove(batchRemove); 1058 | 1059 | splitListAgainstCode(swRemove, ni->code, rN, rP); 1060 | 1061 | if(aN.size() > 0 || rN.size() > 0) { 1062 | ni->negative = updateDecisionTreeNode(ts, ni->negative, aN, rN); 1063 | } 1064 | if(aP.size() > 0 || rP.size() > 0) { 1065 | ni->positive = updateDecisionTreeNode(ts, ni->positive, aP, rP); 1066 | } 1067 | } 1068 | 1069 | return dt; 1070 | } 1071 | } 1072 | } 1073 | 1074 | static DecisionTreeNode* updateDecisionTree(TreeState& ts, DecisionTreeNode* dt, const vector& batchAdd, const vector& batchRemove) { 1075 | 1076 | for(vector::const_iterator it1 = batchAdd.begin(); it1 != batchAdd.end(); ++it1) { 1077 | vector::const_iterator it2 = find(batchRemove.begin(), batchRemove.end(), *it1); 1078 | if(it2 != batchRemove.end()) { 1079 | cerr << "sample in batchAdd also in batchRemove!!!" << endl; 1080 | exit(1); 1081 | } 1082 | } 1083 | for(vector::const_iterator it1 = batchRemove.begin(); it1 != batchRemove.end(); ++it1) { 1084 | vector::const_iterator it2 = find(batchAdd.begin(), batchAdd.end(), *it1); 1085 | if(it2 != batchAdd.end()) { 1086 | cerr << "sample in batchRemove also in batchAdd!!!" << endl; 1087 | exit(1); 1088 | } 1089 | } 1090 | 1091 | updateDecisionTreeSamples(dt, batchAdd, batchRemove); 1092 | DecisionTreeNode* n = updateDecisionTreeNode(ts, dt, batchAdd, batchRemove); 1093 | return n; 1094 | } 1095 | 1096 | static DecisionTreeNode* loadDecisionTreeNodeForForest(TreeState& ts, istream& forestS, map& sampleMap) { 1097 | int nodeCode; 1098 | forestS >> nodeCode; 1099 | 1100 | DecisionTreeNode* n = 0; 1101 | DecisionTreeLeaf* nl = 0; 1102 | DecisionTreeInternal* ni = 0; 1103 | 1104 | if(nodeCode == -1) { 1105 | nl = makeLeaf(ts, 0); 1106 | n = nl; 1107 | } else { 1108 | ni = makeInternal(ts, nodeCode, 0, 0); 1109 | n = ni; 1110 | } 1111 | 1112 | forestS >> n->id; 1113 | forestS >> n->minValidRank.first >> n->minValidRank.second; 1114 | forestS >> n->c0 >> n->c1; 1115 | 1116 | int countDC; 1117 | forestS >> countDC; 1118 | 1119 | n->decisionCountMap.resize(countDC); 1120 | 1121 | for(int i = 0; i < countDC; ++i) { 1122 | int code; 1123 | forestS >> code; 1124 | DecisionCounts dc; 1125 | // FIXME: no need to be backwards compatible after complete deployment 1126 | unsigned int dummy; 1127 | forestS >> dummy >> dummy >> dc.c0p >> dc.c1p >> dc.rank; 1128 | // not loading empty DC 1129 | if(!(dc.c0p == 0 && dc.c1p == 0)) 1130 | n->decisionCountMap[code] = dc; 1131 | } 1132 | 1133 | if(nl) { 1134 | int countSamples; 1135 | forestS >> countSamples; 1136 | nl->samples.resize(countSamples); 1137 | for(int i = 0; i< countSamples; ++i) { 1138 | long sampleId; 1139 | forestS >> sampleId; 1140 | if(sampleMap.find(sampleId) == sampleMap.end()) { 1141 | cerr << "unknown sample!" << endl; 1142 | exit(1); 1143 | } 1144 | nl->samples[i] = sampleMap[sampleId]; 1145 | } 1146 | 1147 | forestS >> nl->value; 1148 | } else { 1149 | ni->negative = loadDecisionTreeNodeForForest(ts, forestS, sampleMap); 1150 | ni->positive = loadDecisionTreeNodeForForest(ts, forestS, sampleMap); 1151 | } 1152 | return n; 1153 | } 1154 | 1155 | static void saveDecisionTreeNodeInForest(DecisionTreeNode* dt, ostream& forestS) { 1156 | DecisionTreeInternal* ni; 1157 | DecisionTreeLeaf* nl; 1158 | 1159 | forestS << dt->code << endl; 1160 | forestS << dt->id << endl; 1161 | forestS << dt->minValidRank.first << " " << dt->minValidRank.second << endl; 1162 | forestS << dt->c0 << " " << dt->c1 << endl; 1163 | forestS << dt->decisionCountMap.size() << endl; 1164 | sparse_hash_map::const_iterator dcIt; 1165 | for(dcIt = dt->decisionCountMap.begin(); dcIt != dt->decisionCountMap.end(); ++dcIt) { 1166 | forestS << dcIt->first << endl; 1167 | const DecisionCounts& dc = dcIt->second; 1168 | forestS << 0 << " " << 0 << " " << dc.c0p << " " << dc.c1p << " " << dc.rank << endl; 1169 | } 1170 | 1171 | dt->checkType(&ni, &nl); 1172 | 1173 | if(nl) { 1174 | forestS << nl->samples.size() << endl; 1175 | vector::const_iterator sIt; 1176 | for(sIt = nl->samples.begin(); sIt != nl->samples.end(); ++sIt) 1177 | forestS << (long)(*sIt) << endl; 1178 | } 1179 | 1180 | if(nl) { 1181 | forestS << nl->value << endl; 1182 | } else { 1183 | saveDecisionTreeNodeInForest(ni->negative, forestS); 1184 | saveDecisionTreeNodeInForest(ni->positive, forestS); 1185 | } 1186 | } 1187 | 1188 | static void saveDecisionTreeInForest(TreeState& ts, DecisionTreeNode* dt, ostream& forestS) { 1189 | saveDecisionTreeNodeInForest(dt, forestS); 1190 | } 1191 | 1192 | void outputDecisionTree(TreeState& ts, DecisionTreeNode* dt, ostream& outS) { 1193 | DecisionTreeInternal* ni; 1194 | DecisionTreeLeaf* nl; 1195 | 1196 | if(dt->checkType(&ni, &nl)) { 1197 | outS << nl->value; 1198 | } else { 1199 | outS << "["; 1200 | outS << ni->code << ","; 1201 | outputDecisionTree(ts, ni->negative, outS); 1202 | outS << ","; 1203 | outputDecisionTree(ts, ni->positive, outS); 1204 | outS << "]"; 1205 | } 1206 | } 1207 | 1208 | void outputDecisionTreeWithStats(TreeState& ts, DecisionTreeNode* dt, ostream& outS) { 1209 | DecisionTreeInternal* ni; 1210 | DecisionTreeLeaf* nl; 1211 | 1212 | outS << "{"; 1213 | 1214 | 1215 | if(dt->checkType(&ni, &nl)) { 1216 | outS << "\"value\":" << nl->value; 1217 | } else { 1218 | outS << "\"split\":" << dt->code; 1219 | outS << ",\"neg\":"; 1220 | outputDecisionTreeWithStats(ts, ni->negative, outS); 1221 | outS << ",\"pos\":"; 1222 | outputDecisionTreeWithStats(ts, ni->positive, outS); 1223 | } 1224 | 1225 | outS << ",\"c0\":" << dt->c0; 1226 | outS << ",\"c1\":" << dt->c1; 1227 | 1228 | outS << ",\"counts\":{"; 1229 | sparse_hash_map::const_iterator mapIt; 1230 | for(mapIt = dt->decisionCountMap.begin(); 1231 | mapIt != dt->decisionCountMap.end(); 1232 | ++mapIt) { 1233 | const int code = mapIt->first; 1234 | if(mapIt != dt->decisionCountMap.begin()) 1235 | outS << ","; 1236 | outS << "\"" << code << "\":{"; 1237 | const DecisionCounts& dc = mapIt->second; 1238 | outS << "\"c0p\":" << dc.c0p; 1239 | outS << ",\"c1p\":" << dc.c1p; 1240 | outS << ",\"rank\":" << dc.rank; 1241 | outS << "}"; 1242 | } 1243 | outS << "}"; 1244 | 1245 | outS << "}"; 1246 | } 1247 | 1248 | static void loadRandomForest(TreeState& ts, istream& forestS, vector& forest, map& samples) { 1249 | forestS >> ts.seed; 1250 | int nTrees; 1251 | forestS >> nTrees; 1252 | int nSamples; 1253 | forestS >> nSamples; 1254 | map sampleMap; 1255 | for(int i = 0; i < nSamples; ++i) { 1256 | long sampleId; 1257 | forestS >> sampleId; 1258 | Sample* s = new Sample(); 1259 | forestS >> s->suid; 1260 | forestS >> s->y; 1261 | int countSampleCodes; 1262 | forestS >> countSampleCodes; 1263 | for(int j = 0; j < countSampleCodes; ++j) { 1264 | int code; 1265 | float value; 1266 | forestS >> code >> value; 1267 | s->xCodes[code] = value; 1268 | } 1269 | sampleMap[sampleId] = s; 1270 | samples[s->suid] = s; 1271 | } 1272 | for(int i = 0; i < nTrees; ++i) { 1273 | forest.push_back(loadDecisionTreeNodeForForest(ts, forestS, sampleMap)); 1274 | } 1275 | } 1276 | 1277 | static float evaluateSampleAgainstDecisionTree(TreeState& ts, Sample* s, DecisionTreeNode* dt) { 1278 | DecisionTreeNode* dtn = dt; 1279 | 1280 | DecisionTreeInternal* ni; 1281 | DecisionTreeLeaf* nl; 1282 | 1283 | while(!dtn->checkType(&ni, &nl)) { 1284 | float y; 1285 | 1286 | map::const_iterator xCodeIt = s->xCodes.find(dtn->code); 1287 | if(xCodeIt != s->xCodes.end()) 1288 | y = xCodeIt->second; 1289 | else 1290 | y = 0; 1291 | 1292 | if(y >= 0.5) { 1293 | dtn = ni->positive; 1294 | } else { 1295 | dtn = ni->negative; 1296 | } 1297 | } 1298 | 1299 | return nl->value; 1300 | } 1301 | 1302 | static bool sampleInTree(const Sample* sp, int t) { 1303 | char s[64]; 1304 | sprintf(s, "%d%s", t, sp->suid.c_str()); 1305 | uint32_t out; 1306 | MurmurHash3_x86_32(s, strlen(s), 42, &out); 1307 | return (out % 3) < 2; // 2 in 3 chance 1308 | } 1309 | 1310 | class MapSampleWalker : public SampleWalker { 1311 | private: 1312 | map::const_iterator itCurr; 1313 | map::const_iterator itEnd; 1314 | public: 1315 | MapSampleWalker(const map& sm) : 1316 | itCurr(sm.begin()), itEnd(sm.end()) { 1317 | } 1318 | virtual bool stillSome(void) const { 1319 | return itCurr != itEnd; 1320 | } 1321 | virtual Sample* get(void) { 1322 | Sample* ret = itCurr->second; 1323 | ++itCurr; 1324 | return ret; 1325 | } 1326 | }; 1327 | 1328 | class Forest { 1329 | private: 1330 | map samples; 1331 | map toAdd; 1332 | map toRemove; 1333 | vector forest; 1334 | bool changesToCommit; 1335 | TreeState ts; 1336 | public: 1337 | Forest(istream& forestS) { 1338 | map sampleMap; 1339 | loadRandomForest(ts, forestS, forest, samples); 1340 | changesToCommit = false; 1341 | } 1342 | 1343 | Forest(int nTrees) { 1344 | for(int i=0; i < nTrees; ++i) 1345 | forest.push_back(emptyDecisionTree(ts)); 1346 | changesToCommit = false; 1347 | } 1348 | 1349 | ~Forest(void) { 1350 | for(vector::iterator itTree = forest.begin(); 1351 | itTree != forest.end(); 1352 | ++itTree) { 1353 | destroyDecisionTreeNode(*itTree); 1354 | } 1355 | map::iterator itAdd; 1356 | for(itAdd = toAdd.begin(); itAdd != toAdd.end(); ++itAdd) { 1357 | delete itAdd->second; 1358 | } 1359 | map::iterator itMap; 1360 | for(itMap = samples.begin(); itMap != samples.end(); ++itMap) { 1361 | delete itMap->second; 1362 | } 1363 | } 1364 | 1365 | bool add(Sample* s) { 1366 | changesToCommit = true; 1367 | map::iterator itAdd = toAdd.find(s->suid); 1368 | 1369 | bool added = false; 1370 | if(itAdd != toAdd.end()) { 1371 | delete itAdd->second; 1372 | } else { 1373 | added = true; 1374 | map::iterator itRemove = toRemove.find(s->suid); 1375 | map::iterator itMap = samples.find(s->suid); 1376 | 1377 | if(itRemove == toRemove.end()) { 1378 | // no remove record 1379 | if(itMap != samples.end()) { 1380 | toRemove[itMap->first] = itMap->second; 1381 | } 1382 | } 1383 | } 1384 | 1385 | toAdd[s->suid] = s; 1386 | return added; 1387 | } 1388 | 1389 | bool remove(const char* sId) { 1390 | map::iterator itAdd = toAdd.find(sId); 1391 | if(itAdd != toAdd.end()) { 1392 | delete itAdd->second; 1393 | toAdd.erase(itAdd); 1394 | changesToCommit = true; 1395 | return true; 1396 | } 1397 | 1398 | map::iterator itRemove = toRemove.find(sId); 1399 | if(itRemove != toRemove.end()) 1400 | return false; 1401 | 1402 | map::iterator itMap = samples.find(sId); 1403 | if(itMap == samples.end()) 1404 | return false; 1405 | 1406 | changesToCommit = true; 1407 | 1408 | toRemove[sId] = itMap->second; 1409 | 1410 | return true; 1411 | } 1412 | 1413 | void commit(void) { 1414 | if(!changesToCommit) 1415 | return; 1416 | 1417 | map::iterator sIt; 1418 | 1419 | int treeId = 0; 1420 | for(vector::iterator itTree = forest.begin(); 1421 | itTree != forest.end(); 1422 | ++itTree, ++treeId) { 1423 | vector treeAdd, treeRemove; 1424 | for(sIt = toRemove.begin(); sIt != toRemove.end(); ++sIt) { 1425 | if(sampleInTree(sIt->second, treeId)) 1426 | treeRemove.push_back(sIt->second); 1427 | } 1428 | for(sIt = toAdd.begin(); sIt != toAdd.end(); ++sIt) { 1429 | if(sampleInTree(sIt->second, treeId)) 1430 | treeAdd.push_back(sIt->second); 1431 | } 1432 | *itTree = updateDecisionTree(ts, *itTree, treeAdd, treeRemove); 1433 | } 1434 | 1435 | for(sIt = toRemove.begin(); sIt != toRemove.end(); ++sIt) { 1436 | delete sIt->second; 1437 | samples.erase(sIt->first); 1438 | } 1439 | for(sIt = toAdd.begin(); sIt != toAdd.end(); ++sIt) 1440 | samples[sIt->first] = sIt->second; 1441 | 1442 | toAdd.clear(); 1443 | toRemove.clear(); 1444 | changesToCommit = false; 1445 | } 1446 | 1447 | void asJSON(ostream& outS) { 1448 | commit(); 1449 | outS << "["; 1450 | for(vector::iterator itTree = forest.begin(); 1451 | itTree != forest.end(); 1452 | ++itTree) { 1453 | if(itTree != forest.begin()) 1454 | outS << ","; 1455 | outputDecisionTree(ts, *itTree, outS); 1456 | } 1457 | outS << "]"; 1458 | } 1459 | 1460 | void statsJSON(ostream& outS) { 1461 | commit(); 1462 | outS << "["; 1463 | for(vector::iterator itTree = forest.begin(); 1464 | itTree != forest.end(); 1465 | ++itTree) { 1466 | if(itTree != forest.begin()) 1467 | outS << ","; 1468 | outputDecisionTreeWithStats(ts, *itTree, outS); 1469 | } 1470 | outS << "]"; 1471 | } 1472 | 1473 | bool save(ostream& outS) { 1474 | commit(); 1475 | outS << ts.seed << endl; 1476 | outS << forest.size() << endl; 1477 | 1478 | outS << samples.size() << endl; 1479 | map::const_iterator sIt; 1480 | for(sIt = samples.begin(); sIt != samples.end(); ++sIt) { 1481 | const Sample* s = sIt->second; 1482 | outS << (long)s << endl; 1483 | outS << s->suid << endl; 1484 | outS << s->y << endl; 1485 | map::const_iterator codeIt; 1486 | outS << s->xCodes.size() << endl; 1487 | for(codeIt = s->xCodes.begin(); codeIt != s->xCodes.end(); ++codeIt) 1488 | outS << codeIt->first << " " << codeIt->second << endl; 1489 | } 1490 | 1491 | for(vector::iterator itTree = forest.begin(); 1492 | itTree != forest.end(); 1493 | ++itTree) { 1494 | saveDecisionTreeInForest(ts, *itTree, outS); 1495 | } 1496 | return true; 1497 | } 1498 | 1499 | float classify(Sample* s) { 1500 | commit(); 1501 | double v = 0; 1502 | for(vector::iterator itTree = forest.begin(); 1503 | itTree != forest.end(); 1504 | ++itTree) { 1505 | double dv = evaluateSampleAgainstDecisionTree(ts, s, *itTree); 1506 | v += dv; 1507 | } 1508 | return v / forest.size(); 1509 | } 1510 | 1511 | float classifyPartial(Sample* s, int n) { 1512 | commit(); 1513 | double v = 0; 1514 | vector::iterator itStop = forest.begin() + n; 1515 | for(vector::iterator itTree = forest.begin(); 1516 | itTree != itStop; 1517 | ++itTree) { 1518 | double dv = evaluateSampleAgainstDecisionTree(ts, s, *itTree); 1519 | v += dv; 1520 | } 1521 | return v / n; 1522 | } 1523 | 1524 | bool validate(void) { 1525 | for(vector::iterator itTree = forest.begin(); 1526 | itTree != forest.end(); 1527 | ++itTree) { 1528 | if(!validateDecisionTree(ts, *itTree)) 1529 | return false; 1530 | } 1531 | return true; 1532 | } 1533 | 1534 | SampleWalker* getSamples(void) { 1535 | commit(); 1536 | return new MapSampleWalker(samples); 1537 | } 1538 | }; 1539 | 1540 | /* visible outside module */ 1541 | 1542 | Forest* create(int nTrees) { 1543 | return new Forest(nTrees); 1544 | } 1545 | 1546 | void destroy(Forest* rf) { 1547 | delete rf; 1548 | } 1549 | 1550 | Forest* load(istream& forestS) { 1551 | return new Forest(forestS); 1552 | } 1553 | 1554 | bool save(Forest* rf, ostream& outS) { 1555 | return rf->save(outS); 1556 | } 1557 | 1558 | void asJSON(Forest* rf, ostream& outS) { 1559 | rf->asJSON(outS); 1560 | } 1561 | 1562 | void statsJSON(Forest* rf, ostream& outS) { 1563 | rf->statsJSON(outS); 1564 | } 1565 | 1566 | bool add(Forest* rf, Sample* s) { 1567 | return rf->add(s); 1568 | } 1569 | 1570 | bool remove(Forest* rf, const char* sId) { 1571 | return rf->remove(sId); 1572 | } 1573 | 1574 | void commit(Forest* rf) { 1575 | rf->commit(); 1576 | } 1577 | 1578 | float classify(Forest* rf, Sample* s) { 1579 | return rf->classify(s); 1580 | } 1581 | 1582 | float classifyPartial(Forest* rf, Sample* s, int n) { 1583 | return rf->classifyPartial(s, n); 1584 | } 1585 | 1586 | bool validate(Forest* rf) { 1587 | return rf->validate(); 1588 | } 1589 | 1590 | SampleWalker* getSamples(Forest* rf) { 1591 | return rf->getSamples(); 1592 | } 1593 | } 1594 | -------------------------------------------------------------------------------- /irf/randomForest.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2010-2011 Carlos Guerreiro 2 | * Licensed under the MIT license */ 3 | 4 | #ifndef PCONSTR_RANDOMFOREST_H 5 | #define PCONSTR_RANDOMFOREST_H 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | namespace IncrementalRandomForest { 12 | 13 | struct DecisionTreeNode; 14 | 15 | struct Sample { 16 | std::string suid; 17 | float y; 18 | std::map xCodes; 19 | }; 20 | 21 | // FIXME: should be opaque 22 | struct TreeState { 23 | unsigned int seed; 24 | TreeState(void) : seed(1) { 25 | } 26 | }; 27 | 28 | class SampleWalker { 29 | private: 30 | public: 31 | virtual ~SampleWalker(void) { 32 | } 33 | virtual bool stillSome(void) const = 0; 34 | virtual Sample* get(void) = 0; 35 | }; 36 | 37 | class Forest; 38 | 39 | Forest* create(int nTrees); 40 | void destroy(Forest* rf); 41 | Forest* load(std::istream& forestS); 42 | bool save(Forest* rf, std::ostream& outS); 43 | void asJSON(Forest* rf, std::ostream& outS); 44 | void statsJSON(Forest* rf, std::ostream& outS); 45 | bool add(Forest* rf, Sample* s); 46 | bool remove(Forest* rf, const char* sId); 47 | void commit(Forest* rf); 48 | float classify(Forest* rf, Sample* s); 49 | float classifyPartial(Forest* rf, Sample* s, int n); 50 | bool validate(Forest* rf); 51 | SampleWalker* getSamples(Forest* rf); 52 | } 53 | 54 | #endif 55 | -------------------------------------------------------------------------------- /irf/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | 3 | module1 = Extension('irf', 4 | sources = ['irfmodule.cpp','randomForest.cpp','MurmurHash3.cpp']) 5 | 6 | setup (name = 'irf', 7 | version = '0.1', 8 | description = 'Incremental Random Forest', 9 | ext_modules = [module1]) 10 | -------------------------------------------------------------------------------- /libsparsehash.pc: -------------------------------------------------------------------------------- 1 | prefix=/usr 2 | exec_prefix=${prefix} 3 | libdir=${exec_prefix}/lib 4 | includedir=${prefix}/include 5 | 6 | Name: sparsehash 7 | Version: 1.10 8 | Description: hash_map and hash_set classes with minimal space overhead 9 | URL: http://code.google.com/p/google-sparsehash 10 | Requires: 11 | Libs: 12 | Cflags: -I${includedir} 13 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "irf", 3 | "version": "0.1.4", 4 | "keywords": ["incremental", "random", "forest", "machine", "learning", "c++", "native", "sparse", "classifier", "ensemble", "supervised"], 5 | "homepage": "https://github.com/pconstr/irf", 6 | "description": "incremental random forest ensemble classifier (native)", 7 | "main": "./index.js", 8 | "repository": { 9 | "type": "git", 10 | "url": "https://github.com/pconstr/irf" 11 | }, 12 | "engines": { 13 | "node": ">=0.6.0 <0.9.0" 14 | }, 15 | "author": "Carlos Guerreiro = 0.5 ? 1 : 0; 23 | counts[y][instance[2]]++ 24 | }); 25 | return counts; 26 | } 27 | 28 | function printCounts(counts) { 29 | var total = counts[0][0]+ counts[0][1]+ counts[1][0]+ counts[1][1] 30 | console.log("total =", total); 31 | console.log(" correct negatives:", counts[0][0]); 32 | console.log(" false negatives:", counts[1][0]); 33 | console.log(" correct positives: ", counts[1][1]); 34 | console.log(" false positives: ", counts[0][1]); 35 | } 36 | 37 | c.on('line', function(line) { 38 | var data = line.split(' '); 39 | 40 | var y = classValues[data[0]]; 41 | 42 | var features = {}; 43 | data.slice(1).forEach(function(t) { 44 | if(t !== '') { 45 | var elems = t.split(':'); 46 | features[elems[0]] = elems[1]; 47 | } 48 | }); 49 | 50 | var iid = String(instanceID); 51 | if(instanceID % 5 <= 3) { 52 | testing.push([iid, features, y]); 53 | } else { 54 | rf.add(iid, features, y); 55 | } 56 | 57 | instanceID++; 58 | }); 59 | 60 | c.on('end', function() { 61 | console.log('learning...'); 62 | rf.commit(); 63 | console.log('classifying...'); 64 | var counts = test(rf, testing); 65 | printCounts(counts); 66 | console.log('saving...'); 67 | fs.writeFileSync('mushrooms.rf', rf.toBuffer()); 68 | console.log('loading...'); 69 | rf = new irf.IRF(fs.readFileSync('mushrooms.rf')); 70 | console.log('classifying...'); 71 | var counts = test(rf, testing); 72 | printCounts(counts); 73 | console.log('.'); 74 | }); 75 | -------------------------------------------------------------------------------- /tests/mushrooms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import irf 4 | 5 | def printCounts(counts): 6 | total = counts[0][0]+ counts[0][1]+ counts[1][0]+ counts[1][1] 7 | print "total = ", total 8 | print " correct negatives: ", counts[0][0] 9 | print " false negatives: ", counts[1][0] 10 | print " correct positives: ", counts[1][1] 11 | print " false positives: ", counts[0][1] 12 | 13 | def test(rf, testing): 14 | counts = [[0,0],[0,0]] 15 | for instance in testing: 16 | c = int(rf.classify(instance[1]) >= 0.5) 17 | counts[c][instance[2]] = counts[c][instance[2]] + 1 18 | return counts 19 | 20 | def main(): 21 | rf = irf.IRF(99) 22 | 23 | print 'reading...' 24 | f = open('mushrooms') 25 | instanceID = 0 26 | testing = [] 27 | classValues = {'1':0, '2':1} 28 | for rawL in f.readlines(): 29 | l = rawL.strip() 30 | values = l.split(' ') 31 | c = classValues[values[0]] 32 | features = {} 33 | for kCv in values[1:]: 34 | k, v = kCv.split(':') 35 | features[int(k)] = int(v) 36 | instance = (str(instanceID), features, c) 37 | if instanceID % 5 <= 3: 38 | testing.append(instance) 39 | else: 40 | rf.add(*instance) 41 | instanceID = instanceID + 1 42 | 43 | print 'learning...' 44 | rf.commit() 45 | 46 | print 'classifying...' 47 | counts = test(rf, testing) 48 | printCounts(counts) 49 | 50 | print 'saving...' 51 | rf.save('mushrooms.rf') 52 | 53 | print 'loading...' 54 | rf = irf.load('mushrooms.rf') 55 | 56 | print 'classifying...' 57 | counts = test(rf, testing) 58 | printCounts(counts) 59 | 60 | print '.' 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /tests/simple.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | import irf 4 | 5 | f = irf.IRF(99) # create forest of 99 trees 6 | 7 | f.add('1', {1:1, 3:1, 5:1}, 0) # add a sample identified as '1' with the given feature values, classified as 0 8 | f.add('2', {1:0, 3:0, 4:1}, 0) # features are stored sparsely, when a value is not given it will be taken as 0 9 | f.add('3', {2:0, 3:0, 5:0}, 0) # but 0s can also be given explicitly 10 | f.add('4', {1:0, 2:0, 3:0, 5:0}, 0) 11 | f.add('5', {1:1, 2:1, 3:1, 4:1, 5:1}, 0) 12 | f.add('6', {2:1, 3:1, 4:0}, 1) 13 | f.add('7', {1:1, 2:1, 3:0, 4:1}, 1) 14 | f.add('8', {1:0, 2:0, 3:0, 4:1, 5:1}, 1) 15 | f.add('9', {1:0, 2:0, 3:1, 4:0, 5:1}, 1) 16 | f.add('10', {1:0, 3:1, 4:1, 5:0, 6:1}, 1) 17 | f.add('11', {1:1, 3:0, 5:1, 7:1}, 1) 18 | f.add('12', {1:0, 3:1, 5:1, 8:1}, 1) 19 | f.add('13', {1:0, 3:0, 4:1, 7:1}, 0) 20 | f.add('14', {1:1, 3:1, 5:1, 8:1}, 0) 21 | f.add('15', {1:1, 4:1, 8:1}, 1) 22 | 23 | y = f.classify({1:1, 2:1, 5:1}); print y, int(round(y)) # classify feature vector, round to nearest to get class 24 | y = f.classify({3:1, 2:1, 5:1}); print y, int(round(y)) 25 | y = f.classify({1:1, 3:1, 5:1}); print y, int(round(y)) 26 | y = f.classify({2:1, 5:1}); print y, int(round(y)) 27 | 28 | f.save('simple.rf') # save forest to file 29 | 30 | f = irf.load('simple.rf') # load forest from file 31 | 32 | f.remove('8') # remove a sample 33 | f.add('8', {1:0, 2:0, 3:0, 4:0, 5:1}, 0) # and add it again with new values 34 | 35 | y = f.classify({1:1, 2:1, 5:1}); print y, int(round(y)) # the forest will be lazily updated before classification 36 | y = f.classify({3:1, 2:1, 5:1}); print y, int(round(y)) # it is also possible for force it with commit() 37 | y = f.classify({1:1, 3:1, 5:1}); print y, int(round(y)) 38 | y = f.classify({2:1, 5:1}); print y, int(round(y)) 39 | 40 | for (sId, x, y) in f.samples(): # iterate through samples in the forest, in lexicographic ID order 41 | print sId, x, y # and print them 42 | -------------------------------------------------------------------------------- /wscript: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import Options 4 | 5 | from os import unlink, symlink 6 | from os.path import exists, lexists 7 | 8 | srcdir = "." 9 | blddir = "build" 10 | VERSION = "0.1.3" 11 | 12 | def set_options(opt): 13 | opt.tool_options("compiler_cxx") 14 | 15 | def configure(conf): 16 | conf.check_tool("compiler_cxx") 17 | conf.check_tool("node_addon") 18 | conf.check_cfg(package='libsparsehash', mandatory=1, args='--cflags --libs') 19 | conf.env.append_value('CXXFLAGS', ['-O2']) 20 | 21 | def build(bld): 22 | obj = bld.new_task_gen("cxx", "shlib", "node_addon") 23 | 24 | obj.target = "irf" 25 | 26 | obj.source = ['irf/MurmurHash3.cpp', 'irf/randomForest.cpp', 'irf/node.cpp'] 27 | --------------------------------------------------------------------------------