├── RandomForest.h ├── .gitignore ├── Makefile ├── MnistPreProcess.h ├── README.md ├── MnistPreProcess.cpp ├── Sample.h ├── Sample.cpp ├── main.cpp ├── Tree.h ├── Node.h ├── Tree.cpp ├── RandomForest.cpp └── Node.cpp /RandomForest.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/handspeaker/RandomForests/HEAD/RandomForest.h -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.obj 6 | 7 | # Precompiled Headers 8 | *.gch 9 | *.pch 10 | 11 | # Compiled Dynamic libraries 12 | *.so 13 | *.dylib 14 | *.dll 15 | 16 | # Fortran module files 17 | *.mod 18 | *.smod 19 | 20 | # Compiled Static libraries 21 | *.lai 22 | *.la 23 | *.a 24 | *.lib 25 | 26 | # Executables 27 | *.exe 28 | *.out 29 | *.app 30 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | objects = MnistPreProcess.o Node.o RandomForest.o Sample.o Tree.o main.o 2 | 3 | INCLUDE_DIRS := /usr/local/include 4 | LIBRARY_DIRS := /usr/local/lib 5 | #LIBRARY_DIRS := 6 | 7 | LDFLAGS := $(foreach librarydir, $(LIBRARY_DIRS),-L$(librarydir)) $(PKG_CONFIG) \ 8 | $(foreach library,$(LIBRARIES),-l$(library)) 9 | COMMON_FLAGS := $(foreach includedir, $(INCLUDE_DIRS), -I$(includedir)) 10 | CXXFLAGS := -g $(COMMON_FLAGS) 11 | 12 | RandomForests : $(objects) 13 | g++ -o RandomForests $(objects) $(CXXFLAGS) $(LDFLAGS) 14 | 15 | .PHONY : clean 16 | clean : 17 | -rm RandomForests $(objects) 18 | -------------------------------------------------------------------------------- /MnistPreProcess.h: -------------------------------------------------------------------------------- 1 | /************************************************ 2 | *Random Forest Program 3 | *Function: read mnist dataset and do preprocess 4 | *Author: handspeaker@163.com 5 | *CreateTime: 2014.7.10 6 | *Version: V0.1 7 | *************************************************/ 8 | #ifndef MNISTPREPROCESS_H 9 | #define MNISTPREPROCESS_H 10 | #include 11 | 12 | inline void revertInt(int&x) 13 | { 14 | x=((x&0x000000ff)<<24)|((x&0x0000ff00)<<8)|((x&0x00ff0000)>>8)|((x&0xff000000)>>24); 15 | }; 16 | void readData(float** dataset,float*labels,const char* dataPath,const char*labelPath); 17 | #endif//MNISTPREPROCESS_H 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RandomForests 2 | A 1500 lines simple C++ implementation of RandomForests with detailed comments. No dependency, support regression and classification. 3 | 4 | # quick start 5 | 1. download MNIST dataset from http://yann.lecun.com/exdb/mnist/ and unzip train&test data in some dir 6 | 2. git clone https://github.com/handspeaker/RandomForests.git 7 | 3. cd RandomForests and make 8 | 4. run the following command:
9 | ./RandomForests
10 | mnist_train_image_file_path
11 | mnist_train_label_file_path
12 | mnist_test_image_file_path
13 | mnist_test_label_file_path
14 | 15 | # more 16 | Open main.cpp file and modify as you want. In RandomForest.h there are some parameters you can change. input samples is a matrix, every row is a sample. 17 | -------------------------------------------------------------------------------- /MnistPreProcess.cpp: -------------------------------------------------------------------------------- 1 | #include"MnistPreProcess.h" 2 | 3 | void readData(float** dataset,float*labels,const char* dataPath,const char*labelPath) 4 | { 5 | FILE* dataFile=fopen(dataPath,"rb"); 6 | FILE* labelFile=fopen(labelPath,"rb"); 7 | int mbs=0,number=0,col=0,row=0; 8 | fread(&mbs,4,1,dataFile); 9 | fread(&number,4,1,dataFile); 10 | fread(&row,4,1,dataFile); 11 | fread(&col,4,1,dataFile); 12 | revertInt(mbs); 13 | revertInt(number); 14 | revertInt(row); 15 | revertInt(col); 16 | fread(&mbs,4,1,labelFile); 17 | fread(&number,4,1,labelFile); 18 | revertInt(mbs); 19 | revertInt(number); 20 | unsigned char temp; 21 | for(int i=0;i(temp); 27 | } 28 | fread(&temp,1,1,labelFile); 29 | labels[i]=static_cast(temp); 30 | } 31 | fclose(dataFile); 32 | fclose(labelFile); 33 | }; 34 | -------------------------------------------------------------------------------- /Sample.h: -------------------------------------------------------------------------------- 1 | /************************************************ 2 | *Random Forest Program 3 | *Function: this class is used to store the selected index 4 | of the samples and features. 5 | *Author: handspeaker@163.com 6 | *CreateTime: 2014.7.10 7 | *Version: V0.1 8 | *************************************************/ 9 | #ifndef SAMPLE_H 10 | #define SAMPLE_H 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | class Sample 17 | { 18 | public: 19 | static const int SAMPLESELECTION=1; 20 | static const int FEATURESELECTION=2; 21 | 22 | //create a empty samples 23 | Sample(float**dataset,float*labels,int classNum,int sampleNum,int featureNum); 24 | //copy infomation from samples 25 | Sample(Sample* samples); 26 | //copy a part[start,end] from samples 27 | Sample(Sample* samples,int start,int end); 28 | ~Sample(); 29 | //random select samples with replacement 30 | void randomSelectSample(int*sampleIndex,int SampleNum,int selectedSampleNum); 31 | //random select features without replacement 32 | void randomSelectFeature(int*featureIndex,int featureNum,int selectedFeatureNum); 33 | 34 | inline int getClassNum(){return _classNum;}; 35 | 36 | inline int getSampleNum(){return _sampleNum;}; 37 | inline int getFeatureNum(){return _featureNum;}; 38 | 39 | inline int getSelectedSampleNum(){return _selectedSampleNum;}; 40 | inline int getSelectedFeatureNum(){return _selectedFeatureNum;}; 41 | 42 | inline int*getSampleIndex(){return _sampleIndex;}; 43 | inline int*getFeatureIndex(){return _featureIndex;}; 44 | 45 | inline void releaseSampleIndex() 46 | { 47 | if(_sampleIndex!=NULL) 48 | { 49 | delete[] _sampleIndex; 50 | _sampleIndex=NULL; 51 | } 52 | }; 53 | 54 | 55 | float**_dataset; //pointer to the input dataset 56 | float*_labels; //pointer to the input labels 57 | 58 | private: 59 | int*_sampleIndex; //all sample index 60 | int*_featureIndex; //all feature index 61 | int _classNum; //class number 62 | int _featureNum; //all feature dimension 63 | int _sampleNum; //all sample number 64 | int _selectedSampleNum; //selected sample number 65 | int _selectedFeatureNum;//selected feature number 66 | }; 67 | #endif//SAMPLE_H 68 | -------------------------------------------------------------------------------- /Sample.cpp: -------------------------------------------------------------------------------- 1 | #include"Sample.h" 2 | 3 | Sample::Sample(float**dataset,float*labels,int classNum,int sampleNum,int featureNum) 4 | { 5 | _dataset=dataset; 6 | _labels=labels; 7 | _sampleNum=sampleNum; 8 | _featureNum=featureNum; 9 | _classNum=classNum; 10 | _selectedSampleNum=sampleNum; 11 | _selectedFeatureNum=featureNum; 12 | _sampleIndex=NULL; 13 | _featureIndex=NULL; 14 | } 15 | 16 | Sample::Sample(Sample* samples) 17 | { 18 | _dataset=samples->_dataset; 19 | _labels=samples->_labels; 20 | _classNum=samples->getClassNum(); 21 | _featureNum=samples->getFeatureNum(); 22 | _sampleNum=samples->getSampleNum(); 23 | _selectedSampleNum=samples->getSelectedSampleNum(); 24 | _selectedFeatureNum=samples->getSelectedFeatureNum(); 25 | _sampleIndex=samples->getSampleIndex(); 26 | _featureIndex=samples->getFeatureIndex(); 27 | } 28 | 29 | Sample::Sample(Sample* samples,int start,int end) 30 | { 31 | _dataset=samples->_dataset; 32 | _labels=samples->_labels; 33 | _classNum=samples->getClassNum(); 34 | _sampleNum=samples->getSampleNum(); 35 | _selectedSampleNum=end-start+1; 36 | _featureNum=samples->getFeatureNum(); 37 | _selectedFeatureNum=samples->getSelectedFeatureNum(); 38 | _sampleIndex=new int[_selectedSampleNum]; 39 | memcpy(_sampleIndex,samples->getSampleIndex()+start,sizeof(float)*_selectedSampleNum); 40 | } 41 | 42 | Sample::~Sample() 43 | { 44 | _sampleIndex=NULL; 45 | _featureIndex=NULL; 46 | } 47 | 48 | void Sample::randomSelectSample(int*sampleIndex,int SampleNum,int selectedSampleNum) 49 | { 50 | _sampleNum=SampleNum; 51 | _selectedSampleNum=selectedSampleNum; 52 | if(_sampleIndex!=NULL) 53 | {delete[] _sampleIndex;} 54 | _sampleIndex=sampleIndex; 55 | int i=0,index=0; 56 | //sampling trainset with replacement 57 | for(i=0;i 16 | #include 17 | #include"Sample.h" 18 | #include"Node.h" 19 | 20 | class Tree 21 | { 22 | public: 23 | /************************************************************* 24 | *MaxDepth:the max Depth of one single tree 25 | *trainFeatureNumPerNode:the feature number used in every node while training 26 | *minLeafSample:terminate criterion,the min samples in a leaf 27 | *minInfoGain:terminate criterion,the min information gain in 28 | *a node if it can be splitted 29 | *isRegression:if the problem is regression(true) or classification(false) 30 | **************************************************************/ 31 | Tree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression); 32 | virtual ~Tree(); 33 | virtual void train(Sample*Sample)=0; 34 | Result predict(float*data); 35 | inline Node**getTreeArray(){return _cartreeArray;}; 36 | virtual void createNode(int id,int featureIndex,float threshold)=0; 37 | protected: 38 | bool _isRegression; //the type of this tree 39 | int _MaxDepth; 40 | int _nodeNum; //the number of node,=2^_MaxDepth-1 41 | int _minLeafSample; 42 | int _trainFeatureNumPerNode; 43 | float _minInfoGain; 44 | // Sample*_samples;//all samples used while training the tree 45 | Node** _cartreeArray; //utilize a node array to store the tree, 46 | //every node is a split or leaf node 47 | }; 48 | //Classification Tree 49 | class ClasTree:public Tree 50 | { 51 | public: 52 | ClasTree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression); 53 | ~ClasTree(); 54 | void train(Sample*Sample); 55 | void createNode(int id,int featureIndex,float threshold); 56 | void createLeaf(int id,float clas,float prob); 57 | }; 58 | //Regression Tree 59 | class RegrTree:public Tree 60 | { 61 | public: 62 | RegrTree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression); 63 | ~RegrTree(); 64 | void train(Sample*Sample); 65 | void createNode(int id,int featureIndex,float threshold); 66 | void createLeaf(int id,float value); 67 | }; 68 | #endif//CARTREE_H 69 | -------------------------------------------------------------------------------- /Node.h: -------------------------------------------------------------------------------- 1 | /************************************************ 2 | *Random Forest Program 3 | *Function: implementation of two kinds of node 4 | for classification and regression 5 | *Author: handspeaker@163.com 6 | *CreateTime: 2014.7.10 7 | *Version: V0.1 8 | *************************************************/ 9 | #ifndef NODE_H 10 | #define NODE_H 11 | 12 | #include"Sample.h" 13 | #include 14 | 15 | struct Result 16 | { 17 | float label; //label or value 18 | float prob; //prob or 1 19 | }; 20 | 21 | struct Pair 22 | { 23 | float feature; 24 | int id; 25 | }; 26 | 27 | int compare_pair( const void* a, const void* b ); 28 | 29 | class Node 30 | { 31 | public: 32 | Node(); 33 | virtual ~Node(); 34 | //sort the selected samples in the ascending order based on featureId 35 | void sortIndex(int featureId); 36 | Sample*_samples;//the samples hold by this node 37 | //set this node as leaf node 38 | inline void setLeaf(bool flag){_isLeaf=flag;}; 39 | inline bool isLeaf(){return _isLeaf;}; 40 | //calculate the information gain 41 | virtual void calculateInfoGain(Node**_cartreeArray,int id,float minInfoGain)=0; 42 | virtual void calculateParams()=0; 43 | //create a leaf node 44 | virtual void createLeaf()=0; 45 | //predict the data 46 | virtual int predict(float*data,int id)=0; 47 | virtual void getResult(Result&r)=0; 48 | 49 | inline int getFeatureIndex(){return _featureIndex;}; 50 | inline void setFeatureIndex(int featureIndex){_featureIndex=featureIndex;}; 51 | inline float getThreshold(){return _threshold;}; 52 | inline void setThreshold(float threshold){_threshold=threshold;}; 53 | protected: 54 | bool _isLeaf; 55 | int _featureIndex; 56 | float _threshold; 57 | }; 58 | 59 | class ClasNode:public Node 60 | { 61 | public: 62 | ClasNode(); 63 | ~ClasNode(); 64 | void calculateInfoGain(Node**_cartreeArray,int id,float minInfoGain); 65 | void calculateParams(); 66 | void createLeaf(); 67 | int predict(float*data,int id); 68 | void getResult(Result&r); 69 | 70 | inline float getClass(){return _class;}; 71 | inline float getProb(){return _prob;}; 72 | 73 | inline void setClass(float clas){_class=clas;}; 74 | inline void setProb(float prob){_prob=prob;}; 75 | //parameters for training 76 | float _gini; 77 | float*_probs; 78 | private: 79 | float _class; //the class 80 | float _prob; //the probablity 81 | }; 82 | 83 | class RegrNode:public Node 84 | { 85 | public: 86 | RegrNode(); 87 | ~RegrNode(); 88 | void calculateInfoGain(Node**_cartreeArray,int id,float minInfoGain); 89 | void calculateParams(); 90 | void createLeaf(); 91 | int predict(float*data,int id); 92 | void getResult(Result&r); 93 | 94 | inline float getValue(){return _value;}; 95 | inline void setValue(float value){_value=value;}; 96 | //parameters for training 97 | float _mean; 98 | float _variance; 99 | private: 100 | float _value; //the regression value 101 | }; 102 | 103 | #endif//NODE_H 104 | -------------------------------------------------------------------------------- /Tree.cpp: -------------------------------------------------------------------------------- 1 | #include"Tree.h" 2 | 3 | Tree::Tree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression) 4 | { 5 | _MaxDepth=MaxDepth; 6 | _trainFeatureNumPerNode=trainFeatureNumPerNode; 7 | _minLeafSample=minLeafSample; 8 | _minInfoGain=minInfoGain; 9 | _nodeNum=static_cast(pow(2.0,_MaxDepth)-1); 10 | _cartreeArray=new Node*[_nodeNum]; 11 | _isRegression=isRegression; 12 | for(int i=0;i<_nodeNum;++i) 13 | {_cartreeArray[i]=NULL;} 14 | } 15 | 16 | Tree::~Tree() 17 | { 18 | if(_cartreeArray!=NULL) 19 | { 20 | for(int i=0;i<_nodeNum;++i) 21 | { 22 | if(_cartreeArray[i]!=NULL) 23 | { 24 | delete _cartreeArray[i]; 25 | _cartreeArray[i]=NULL; 26 | } 27 | } 28 | delete[] _cartreeArray; 29 | _cartreeArray=NULL; 30 | } 31 | } 32 | 33 | Result Tree::predict(float*data) 34 | { 35 | int position=0; 36 | Node*head=_cartreeArray[position]; 37 | while(!head->isLeaf()) 38 | { 39 | position=head->predict(data,position); 40 | head=_cartreeArray[position]; 41 | } 42 | Result r; 43 | head->getResult(r); 44 | return r; 45 | } 46 | /************************************************/ 47 | //Classification Tree 48 | ClasTree::ClasTree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression) 49 | :Tree(MaxDepth,trainFeatureNumPerNode,minLeafSample,minInfoGain,isRegression) 50 | {} 51 | ClasTree::~ClasTree() 52 | {} 53 | void ClasTree::train(Sample*sample) 54 | { 55 | //initialize root node 56 | //random generate feature index 57 | int*_featureIndex=new int[_trainFeatureNumPerNode]; 58 | Sample*nodeSample=new Sample(sample,0,sample->getSelectedSampleNum()-1); 59 | _cartreeArray[0]=new ClasNode(); 60 | _cartreeArray[0]->_samples=nodeSample; 61 | //calculate the probablity and gini 62 | _cartreeArray[0]->calculateParams(); 63 | for(int i=0;i<_nodeNum;++i) 64 | { 65 | int parentId=(i-1)/2; 66 | //if current node's parent node is NULL,continue 67 | if(_cartreeArray[parentId]==NULL) 68 | {continue;} 69 | //if the current node's parent node is a leaf,continue 70 | if(i>0&&_cartreeArray[parentId]->isLeaf()) 71 | {continue;} 72 | //if it reach the max depth 73 | //set current node as a leaf and continue 74 | if(i*2+1>=_nodeNum) //if left child node is out of range 75 | { 76 | _cartreeArray[i]->createLeaf(); 77 | continue; 78 | } 79 | //if current samples in this node is less than the threshold 80 | //set current node as a leaf and continue 81 | if(_cartreeArray[i]->_samples->getSelectedSampleNum()<=_minLeafSample) 82 | { 83 | _cartreeArray[i]->createLeaf(); 84 | continue; 85 | } 86 | _cartreeArray[i]->_samples->randomSelectFeature 87 | (_featureIndex,sample->getFeatureNum(),_trainFeatureNumPerNode); 88 | //else calculate the information gain 89 | _cartreeArray[i]->calculateInfoGain(_cartreeArray,i,_minInfoGain); 90 | _cartreeArray[i]->_samples->releaseSampleIndex(); 91 | } 92 | delete[] _featureIndex; 93 | _featureIndex=NULL; 94 | delete nodeSample; 95 | } 96 | 97 | void ClasTree::createNode(int id,int featureIndex,float threshold) 98 | { 99 | _cartreeArray[id]=new ClasNode(); 100 | _cartreeArray[id]->setLeaf(false); 101 | _cartreeArray[id]->setFeatureIndex(featureIndex); 102 | _cartreeArray[id]->setThreshold(threshold); 103 | } 104 | 105 | void ClasTree::createLeaf(int id,float clas,float prob) 106 | { 107 | _cartreeArray[id]=new ClasNode(); 108 | _cartreeArray[id]->setLeaf(true); 109 | ((ClasNode*)_cartreeArray[id])->setClass(clas); 110 | ((ClasNode*)_cartreeArray[id])->setProb(prob); 111 | } 112 | /************************************************/ 113 | //Regression Tree 114 | RegrTree::RegrTree(int MaxDepth,int trainFeatureNumPerNode,int minLeafSample,float minInfoGain,bool isRegression) 115 | :Tree(MaxDepth,trainFeatureNumPerNode,minLeafSample,minInfoGain,isRegression) 116 | {} 117 | RegrTree::~RegrTree() 118 | {} 119 | void RegrTree::train(Sample*sample) 120 | { 121 | //initialize root node 122 | //random generate feature index 123 | int*_featureIndex=new int[_trainFeatureNumPerNode]; 124 | Sample*nodeSample=new Sample(sample,0,sample->getSelectedSampleNum()-1); 125 | _cartreeArray[0]=new RegrNode(); 126 | _cartreeArray[0]->_samples=nodeSample; 127 | //calculate the mean and variance 128 | _cartreeArray[0]->calculateParams(); 129 | for(int i=0;i<_nodeNum;++i) 130 | { 131 | int parentId=(i-1)/2; 132 | //if current node's parent node is NULL,continue 133 | if(_cartreeArray[parentId]==NULL) 134 | {continue;} 135 | //if the current node's parent node is a leaf,continue 136 | if(i>0&&_cartreeArray[parentId]->isLeaf()) 137 | {continue;} 138 | //if it reach the max depth 139 | //set current node as a leaf and continue 140 | if(i*2+1>=_nodeNum) //if left child node is out of range 141 | { 142 | _cartreeArray[i]->createLeaf(); 143 | continue; 144 | } 145 | //if current samples in this node is less than the threshold 146 | //set current node as a leaf and continue 147 | if(_cartreeArray[i]->_samples->getSelectedSampleNum()<=_minLeafSample) 148 | { 149 | _cartreeArray[i]->createLeaf(); 150 | continue; 151 | } 152 | _cartreeArray[i]->_samples->randomSelectFeature 153 | (_featureIndex,sample->getFeatureNum(),_trainFeatureNumPerNode); 154 | //else calculate the information gain 155 | _cartreeArray[i]->calculateInfoGain(_cartreeArray,i,_minInfoGain); 156 | _cartreeArray[i]->_samples->releaseSampleIndex(); 157 | } 158 | delete[] _featureIndex; 159 | _featureIndex=NULL; 160 | delete nodeSample; 161 | } 162 | 163 | void RegrTree::createNode(int id,int featureIndex,float threshold) 164 | { 165 | _cartreeArray[id]=new RegrNode(); 166 | _cartreeArray[id]->setLeaf(false); 167 | _cartreeArray[id]->setFeatureIndex(featureIndex); 168 | _cartreeArray[id]->setThreshold(threshold); 169 | } 170 | 171 | void RegrTree::createLeaf(int id,float value) 172 | { 173 | _cartreeArray[id]=new RegrNode(); 174 | _cartreeArray[id]->setLeaf(true); 175 | ((RegrNode*)_cartreeArray[id])->setValue(value); 176 | } 177 | -------------------------------------------------------------------------------- /RandomForest.cpp: -------------------------------------------------------------------------------- 1 | #include"RandomForest.h" 2 | 3 | RandomForest::RandomForest(int treeNum,int maxDepth,int minLeafSample,float minInfoGain) 4 | { 5 | _treeNum=treeNum; 6 | _maxDepth=maxDepth; 7 | _minLeafSample=minLeafSample; 8 | _minInfoGain=minInfoGain; 9 | _trainSample=NULL; 10 | printf("total tree number:%d\n",_treeNum); 11 | printf("max depth of a single tree:%d\n",_maxDepth); 12 | printf("the minimum samples in a leaf:%d\n",_minLeafSample); 13 | printf("the minimum information gain:%f\n",_minInfoGain); 14 | 15 | _forest=new Tree*[_treeNum]; 16 | for(int i=0;i<_treeNum;++i) 17 | {_forest[i]=NULL;} 18 | } 19 | 20 | RandomForest::RandomForest(const char*modelPath) 21 | { 22 | readModel(modelPath); 23 | } 24 | 25 | RandomForest::~RandomForest() 26 | { 27 | //printf("destroy RandomForest...\n"); 28 | if(_forest!=NULL) 29 | { 30 | for(int i=0;i<_treeNum;++i) 31 | { 32 | if(_forest[i]!=NULL) 33 | { 34 | delete _forest[i]; 35 | _forest[i]=NULL; 36 | } 37 | } 38 | delete[] _forest; 39 | _forest=NULL; 40 | } 41 | if(_trainSample!=NULL) 42 | { 43 | delete _trainSample; 44 | _trainSample=NULL; 45 | } 46 | } 47 | 48 | void RandomForest::train(float**trainset,float*labels,int SampleNum,int featureNum, 49 | int classNum,bool isRegression) 50 | { 51 | int trainFeatureNumPerNode=static_cast(sqrt(static_cast(featureNum))); 52 | train(trainset,labels,SampleNum,featureNum,classNum,isRegression,trainFeatureNumPerNode); 53 | } 54 | 55 | void RandomForest::train(float**trainset,float*labels,int SampleNum,int featureNum, 56 | int classNum,bool isRegression,int trainFeatureNumPerNode) 57 | { 58 | if(_treeNum<1) 59 | { 60 | printf("total tree number must bigger than 0!\n"); 61 | printf("training failed\n"); 62 | return; 63 | } 64 | if(_maxDepth<1) 65 | { 66 | printf("the max depth must bigger than 0!\n"); 67 | printf("training failed\n"); 68 | return; 69 | } 70 | if(_minLeafSample<2) 71 | { 72 | printf("the minimum samples in a leaf must bigger than 1!\n"); 73 | printf("training failed\n"); 74 | return; 75 | } 76 | _trainSampleNum=SampleNum; 77 | _featureNum=featureNum; 78 | _classNum=classNum; 79 | _trainFeatureNumPerNode=trainFeatureNumPerNode; 80 | _isRegression=isRegression; 81 | //initialize every tree 82 | if(_isRegression) 83 | { 84 | _classNum=1; 85 | for(int i=0;i<_treeNum;++i) 86 | { 87 | _forest[i]=new RegrTree(_maxDepth,_trainFeatureNumPerNode, 88 | _minLeafSample,_minInfoGain,_isRegression); 89 | } 90 | } 91 | else 92 | { 93 | for(int i=0;i<_treeNum;++i) 94 | { 95 | _forest[i]=new ClasTree(_maxDepth,_trainFeatureNumPerNode, 96 | _minLeafSample,_minInfoGain,_isRegression); 97 | } 98 | } 99 | //this object hold the whole trainset&labels 100 | _trainSample=new Sample(trainset,labels,_classNum,_trainSampleNum,_featureNum); 101 | srand(static_cast(time(NULL))); 102 | int*_sampleIndex=new int[_trainSampleNum]; 103 | //start to train every tree in the forest 104 | for(int i=0;i<_treeNum;++i) 105 | { 106 | printf("train the %d th tree...\n",i); 107 | //random sampling from trainset 108 | Sample*sample=new Sample(_trainSample); 109 | sample->randomSelectSample(_sampleIndex,_trainSampleNum,_trainSampleNum); 110 | _forest[i]->train(sample); 111 | delete sample; 112 | } 113 | delete[] _sampleIndex; 114 | _sampleIndex=NULL; 115 | } 116 | 117 | void RandomForest::predict(float*data,float&response) 118 | { 119 | //get the predict from every tree 120 | //if regression,_classNum=1 121 | float*result=new float[_classNum]; 122 | int i=0; 123 | for(i=0;i<_classNum;++i) 124 | {result[i]=0;} 125 | for(i=0;i<_treeNum;++i)//_treeNum 126 | { 127 | Result r; 128 | r.label=0; 129 | r.prob=0;//Result 130 | r=_forest[i]->predict(data); 131 | result[static_cast(r.label)]+=r.prob; 132 | } 133 | if(_isRegression) 134 | {response=result[0]/_treeNum;} 135 | else 136 | { 137 | float maxProbLabel=0; 138 | float maxProb=result[0]; 139 | for(i=1;i<_classNum;++i) 140 | { 141 | if(result[i]>maxProb) 142 | { 143 | maxProbLabel=i; 144 | maxProb=result[i]; 145 | } 146 | } 147 | response=maxProbLabel; 148 | } 149 | delete[] result; 150 | } 151 | 152 | void RandomForest::predict(float**testset,int SampleNum,float*responses) 153 | { 154 | //get the predict from every tree 155 | for(int i=0;i(pow(2.0,_maxDepth)-1); 169 | int isLeaf=0; 170 | for(int i=0;i<_treeNum;++i) 171 | { 172 | Node**arr=_forest[i]->getTreeArray(); 173 | isLeaf=0; 174 | for(int j=0;jisLeaf()) 179 | { 180 | isLeaf=1; 181 | fwrite(&isLeaf,sizeof(int),1,saveFile); 182 | if(_isRegression) 183 | { 184 | float value=((RegrNode*)arr[j])->getValue(); 185 | fwrite(&value,sizeof(float),1,saveFile); 186 | } 187 | else 188 | { 189 | float clas=((ClasNode*)arr[j])->getClass(); 190 | float prob=((ClasNode*)arr[j])->getProb(); 191 | fwrite(&clas,sizeof(float),1,saveFile); 192 | fwrite(&prob,sizeof(float),1,saveFile); 193 | } 194 | } 195 | else 196 | { 197 | isLeaf=0; 198 | fwrite(&isLeaf,sizeof(int),1,saveFile); 199 | int featureIndex=arr[j]->getFeatureIndex(); 200 | float threshold=arr[j]->getThreshold(); 201 | fwrite(&featureIndex,sizeof(int),1,saveFile); 202 | fwrite(&threshold,sizeof(float),1,saveFile); 203 | } 204 | } 205 | } 206 | ////write an numb node to denote the tree end 207 | //isLeaf=-1; 208 | //fwrite(&isLeaf,sizeof(int),1,saveFile); 209 | } 210 | fclose(saveFile); 211 | } 212 | 213 | void RandomForest::readModel(const char*path) 214 | { 215 | _minLeafSample=0; 216 | _minInfoGain=0; 217 | _trainFeatureNumPerNode=0; 218 | FILE* modelFile=fopen(path,"rb"); 219 | fread(&_treeNum,sizeof(int),1,modelFile); 220 | fread(&_maxDepth,sizeof(int),1,modelFile); 221 | fread(&_classNum,sizeof(int),1,modelFile); 222 | fread(&_isRegression,sizeof(bool),1,modelFile); 223 | int nodeNum=static_cast(pow(2.0,_maxDepth)-1); 224 | _trainSample=NULL; 225 | printf("total tree number:%d\n",_treeNum); 226 | printf("max depth of a single tree:%d\n",_maxDepth); 227 | printf("_classNum:%d\n",_classNum); 228 | printf("_isRegression:%d\n",_isRegression); 229 | _forest=new Tree*[_treeNum]; 230 | //initialize every tree 231 | if(_isRegression) 232 | { 233 | for(int i=0;i<_treeNum;++i) 234 | { 235 | _forest[i]=new RegrTree(_maxDepth,_trainFeatureNumPerNode, 236 | _minLeafSample,_minInfoGain,_isRegression); 237 | } 238 | } 239 | else 240 | { 241 | for(int i=0;i<_treeNum;++i) 242 | { 243 | _forest[i]=new ClasTree(_maxDepth,_trainFeatureNumPerNode, 244 | _minLeafSample,_minInfoGain,_isRegression); 245 | } 246 | } 247 | int*nodeTable=new int[nodeNum]; 248 | int isLeaf=-1; 249 | int featureIndex=0; 250 | float threshold=0; 251 | float value=0; 252 | float clas=0; 253 | float prob=0; 254 | for(int i=0;i<_treeNum;++i) 255 | { 256 | memset(nodeTable,0,sizeof(int)*nodeNum); 257 | nodeTable[0]=1; 258 | for(int j=0;jcreateNode(j,featureIndex,threshold); 271 | } 272 | else if(isLeaf==1) //leaf 273 | { 274 | if(_isRegression) 275 | { 276 | fread(&value,sizeof(float),1,modelFile); 277 | ((RegrTree*)_forest[i])->createLeaf(j,value); 278 | } 279 | else 280 | { 281 | fread(&clas,sizeof(float),1,modelFile); 282 | fread(&prob,sizeof(float),1,modelFile); 283 | ((ClasTree*)_forest[i])->createLeaf(j,clas,prob); 284 | } 285 | } 286 | } 287 | //fread(&isLeaf,sizeof(int),1,modelFile); 288 | } 289 | fclose(modelFile); 290 | delete[] nodeTable; 291 | } 292 | -------------------------------------------------------------------------------- /Node.cpp: -------------------------------------------------------------------------------- 1 | #include"Node.h" 2 | /***************************************************************/ 3 | //Node 4 | Node::Node() 5 | { 6 | _isLeaf=false; 7 | _featureIndex=-1; 8 | _threshold=0; 9 | _samples=NULL; 10 | } 11 | 12 | Node::~Node() 13 | { 14 | } 15 | 16 | void Node::sortIndex(int featureId) 17 | { 18 | float**data=_samples->_dataset; 19 | int*sampleId=_samples->getSampleIndex(); 20 | Pair*pairs=new Pair[_samples->getSelectedSampleNum()]; 21 | for(int i=0;i<_samples->getSelectedSampleNum();++i) 22 | { 23 | pairs[i].id=sampleId[i]; 24 | pairs[i].feature=data[sampleId[i]][featureId]; 25 | } 26 | qsort(pairs,_samples->getSelectedSampleNum(),sizeof(Pair),compare_pair); 27 | for(int i=0;i<_samples->getSelectedSampleNum();++i) 28 | {sampleId[i]=pairs[i].id;} 29 | delete[] pairs; 30 | } 31 | 32 | int compare_pair( const void* a, const void* b ) 33 | { 34 | Pair* arg1 = (Pair*) a; 35 | Pair* arg2 = (Pair*) b; 36 | if( arg1->feature < arg2->feature ) return -1; 37 | else if( arg1->feature == arg2->feature ) return 0; 38 | else return 1; 39 | } 40 | /***************************************************************/ 41 | //ClasNode 42 | ClasNode::ClasNode() 43 | :Node() 44 | { 45 | _class=-1; 46 | _prob=0; 47 | } 48 | 49 | ClasNode::~ClasNode() 50 | { 51 | if(_probs!=NULL) 52 | { 53 | delete[] _probs; 54 | _probs=NULL; 55 | } 56 | } 57 | 58 | void ClasNode::calculateParams() 59 | { 60 | int i=0; 61 | int*sampleId=_samples->getSampleIndex(); 62 | int sampleNum=_samples->getSelectedSampleNum(); 63 | int classNum=_samples->getClassNum(); 64 | float gini=0; 65 | _probs=new float[classNum]; 66 | for(i=0;i(_samples->_labels[sampleId[i]])]++;} 70 | for(i=0;igetSampleIndex(); 83 | int*featureId=_samples->getFeatureIndex(); 84 | float**data=_samples->_dataset; 85 | float*labels=_samples->_labels; 86 | int featureNum=_samples->getSelectedFeatureNum(); 87 | int sampleNum=_samples->getSelectedSampleNum(); 88 | int classNum=_samples->getClassNum(); 89 | //the final params need to store 90 | float maxInfoGain=0; 91 | int maxFeatureId=0; 92 | float maxThreshold=0; 93 | float maxGiniLeft=0; 94 | float maxGiniRight=0; 95 | int maxSamplesOnLeft=0; 96 | float*maxProbsLeft=new float[classNum]; 97 | float*maxProbsRight=new float[classNum]; 98 | for(i=0;i(labels[sampleId[j]])]++; 147 | probsRight[static_cast(labels[sampleId[j]])]--; 148 | //do not do calculation if the nearby samples' feature are too similar(<0.000001) 149 | if((data[sampleId[j+1]][featureId[i]]-data[sampleId[j]][featureId[i]])<0.000001) 150 | {continue;} 151 | for(k=0;kfMaxinfoGain) 167 | { 168 | fMaxinfoGain=infoGain; 169 | fMaxGiniLeft=giniLeft; 170 | fMaxGiniRight=giniRight; 171 | fMaxThreshold=(data[sampleId[j]][featureId[i]]+data[sampleId[j+1]][featureId[i]])/2; 172 | fMaxSamplesOnLeft=j; 173 | memcpy(fMaxProbsLeft,probsLeft,sizeof(float)*classNum); 174 | memcpy(fMaxProbsRight,probsRight,sizeof(float)*classNum); 175 | } 176 | } 177 | if(fMaxinfoGain>maxInfoGain) 178 | { 179 | maxInfoGain=fMaxinfoGain; 180 | maxGiniLeft=fMaxGiniLeft; 181 | maxGiniRight=fMaxGiniRight; 182 | maxFeatureId=fMaxFeatureId; 183 | maxThreshold=fMaxThreshold; 184 | maxSamplesOnLeft=fMaxSamplesOnLeft; 185 | memcpy(maxProbsLeft,fMaxProbsLeft,sizeof(float)*classNum); 186 | memcpy(maxProbsRight,fMaxProbsRight,sizeof(float)*classNum); 187 | } 188 | } 189 | sortIndex(maxFeatureId); 190 | if(maxInfoGain_gini=maxGiniLeft; 199 | ((ClasNode*)nodeArray[id*2+1])->_probs=maxProbsLeft; 200 | ((ClasNode*)nodeArray[id*2+2])->_gini=maxGiniRight; 201 | ((ClasNode*)nodeArray[id*2+2])->_probs=maxProbsRight; 202 | //assign samples to left and right 203 | Sample*leftSamples=new Sample(_samples,0,maxSamplesOnLeft); 204 | Sample*rightSamples=new Sample(_samples,maxSamplesOnLeft+1,sampleNum-1); 205 | nodeArray[id*2+1]->_samples=leftSamples; 206 | nodeArray[id*2+2]->_samples=rightSamples; 207 | } 208 | delete[] _probs; 209 | _probs=NULL; 210 | delete[] fMaxProbsLeft; 211 | delete[] fMaxProbsRight; 212 | delete[] probsLeft; 213 | delete[] probsRight; 214 | } 215 | 216 | void ClasNode::createLeaf() 217 | { 218 | _class=0; 219 | _prob=_probs[0]; 220 | for(int i=1;i<_samples->getClassNum();++i) 221 | { 222 | if(_probs[i]>_prob) 223 | { 224 | _class=i; 225 | _prob=_probs[i]; 226 | } 227 | } 228 | _prob/=_samples->getSelectedSampleNum(); 229 | _isLeaf=true; 230 | } 231 | 232 | int ClasNode::predict(float*data,int id) 233 | { 234 | if(data[_featureIndex]<_threshold) 235 | {return id*2+1;} 236 | else 237 | {return id*2+2;} 238 | } 239 | 240 | void ClasNode::getResult(Result&r) 241 | { 242 | r.label=_class; 243 | r.prob=_prob; 244 | } 245 | /***************************************************************/ 246 | //RegrNode 247 | RegrNode::RegrNode() 248 | :Node() 249 | { 250 | _value=0; 251 | } 252 | 253 | RegrNode::~RegrNode() 254 | {} 255 | 256 | void RegrNode::calculateParams() 257 | { 258 | int i=0; 259 | int*labelId=_samples->getSampleIndex(); 260 | int sampleNum=_samples->getSelectedSampleNum(); 261 | double mean=0,variance=0; 262 | for(i=0;i_labels[labelId[i]];} 264 | mean/=sampleNum; 265 | for(i=0;i_labels[labelId[i]]-mean; 268 | variance+=diff*diff; 269 | } 270 | _mean=mean; 271 | _variance=variance/sampleNum; 272 | } 273 | 274 | void RegrNode::calculateInfoGain(Node**nodeArray,int id,float minInfoGain) 275 | { 276 | //some used variables 277 | int i=0,j=0,k=0; 278 | int*sampleId=_samples->getSampleIndex(); 279 | int*featureId=_samples->getFeatureIndex(); 280 | float**data=_samples->_dataset; 281 | float*labels=_samples->_labels; 282 | int featureNum=_samples->getSelectedFeatureNum(); 283 | int sampleNum=_samples->getSelectedSampleNum(); 284 | //the final params need to store 285 | float maxInfoGain=0; 286 | int maxFeatureId=0; 287 | float maxThreshold=0; 288 | float maxVarLeft=0; 289 | float maxVarRight=0; 290 | int maxSamplesOnLeft=0; 291 | float maxMeanLeft=0; 292 | float maxMeanRight=0; 293 | //the params need to store in first loop 294 | float fMaxinfoGain=0; 295 | int fMaxFeatureId=0; 296 | float fMaxThreshold=0; 297 | float fMaxVarLeft=0; 298 | float fMaxVarRight=0; 299 | int fMaxSamplesOnLeft=0; 300 | float fMaxMeanLeft=0; 301 | float fMaxMeanRight=0; 302 | //the temp params in inner loop 303 | float infoGain=0; 304 | float varLeft=0,varRight=0; 305 | float meanLeft=0,meanRight=0; 306 | for(i=0;ifMaxinfoGain) 352 | { 353 | fMaxinfoGain=infoGain; 354 | fMaxVarLeft=varLeft; 355 | fMaxVarRight=varRight; 356 | fMaxThreshold=(data[sampleId[j]][featureId[i]]+data[sampleId[j+1]][featureId[i]])/2; 357 | fMaxSamplesOnLeft=j; 358 | fMaxMeanLeft=meanLeft; 359 | fMaxMeanRight=meanRight; 360 | } 361 | } 362 | if(fMaxinfoGain>maxInfoGain) 363 | { 364 | maxInfoGain=fMaxinfoGain; 365 | maxVarLeft=fMaxVarLeft; 366 | maxVarRight=fMaxVarRight; 367 | maxFeatureId=fMaxFeatureId; 368 | maxThreshold=fMaxThreshold; 369 | maxSamplesOnLeft=fMaxSamplesOnLeft; 370 | maxMeanLeft=fMaxMeanLeft; 371 | maxMeanRight=fMaxMeanRight; 372 | } 373 | } 374 | 375 | if(maxInfoGain_variance=maxVarLeft; 388 | ((RegrNode*)nodeArray[id*2+1])->_mean=maxMeanLeft; 389 | ((RegrNode*)nodeArray[id*2+2])->_variance=maxVarRight; 390 | ((RegrNode*)nodeArray[id*2+2])->_mean=maxMeanRight; 391 | //assign samples to left and right 392 | Sample*leftSamples=new Sample(_samples,0,maxSamplesOnLeft); 393 | Sample*rightSamples=new Sample(_samples,maxSamplesOnLeft+1,sampleNum-1); 394 | nodeArray[id*2+1]->_samples=leftSamples; 395 | nodeArray[id*2+2]->_samples=rightSamples; 396 | } 397 | } 398 | 399 | void RegrNode::createLeaf() 400 | { 401 | _value=_mean; 402 | _isLeaf=true; 403 | } 404 | 405 | int RegrNode::predict(float*data,int id) 406 | { 407 | if(data[_featureIndex]<_threshold) 408 | {return id*2+1;} 409 | else 410 | {return id*2+2;} 411 | } 412 | 413 | void RegrNode::getResult(Result&r) 414 | { 415 | r.label=0; 416 | r.prob=_value; 417 | } 418 | --------------------------------------------------------------------------------