├── .gitignore ├── README.md ├── RandomForest.sln ├── RandomForest.vcxproj ├── RandomForest.vcxproj.filters ├── doc ├── kaggle.png ├── multi.png ├── report.tex └── single.png ├── premake5.lua └── src ├── Config.cpp ├── Config.h ├── DecisionTree.cpp ├── DecisionTree.h ├── RandomForest.cpp ├── RandomForest.h ├── RandomForest.vcxproj ├── RandomForest.vcxproj.filters ├── RandomForest.vcxproj.user └── main.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.opensdf 3 | *.sdf 4 | *.suo 5 | *.exe 6 | *.ilk 7 | *.pdb 8 | obj 9 | data 10 | *.make 11 | Makefile 12 | 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parallel Random Forest 2 | 3 | Parallel Random Forest Implementation in C++(with OpenMP). 4 | 5 | ## Dependencies 6 | 7 | * C++11 support(smart pointers, range-based for loops, lambda, etc.). 8 | * OpenMP(available in most compilers including VC++/g++/clang++) 9 | 10 | Tested under the following environments: 11 | 12 | 1. Windows 8.1 + VS2013 13 | 2. Windows 8.1 + g++ (GCC) 4.8.1(mingw32) 14 | 3. Ubuntu 14.04 + g++ (GCC) 4.8.2 15 | 4. Ubuntu 14.04 + clang++ 3.5.2 16 | 17 | ## Predefined parameters 18 | 19 | * Number of features: 617 20 | * Number of labels: 26 21 | * Minimum node size: 2 22 | 23 | Defined in [Config.h](src/Config.h). 24 | 25 | ##Directory structure 26 | 27 | ``` 28 | . 29 | ├─ premake5.lua (premake build scripts) 30 | ├─ RandomForest.vcxproj.filters, RandomForest.vcxproj, RandomForest.sln(VS2013 project files) 31 | ├─ RandomForest.exe (executable built with VS2013) 32 | ├─ README.md (you are reading this) 33 | ├─ doc 34 | │ └── report.pdf 35 | ├─ data (dataset and output) 36 | │ ├── train.csv (the training set) 37 | │ ├── 1000.csv (subset of the training set for validation) 38 | │ ├── test.csv (the test set) 39 | │ └── submit.csv (the output) 40 | └─ src (source code) 41 | ├── Config.h, Config.cpp (configurations, common headers) 42 | ├── DecisionTree.h, DecisionTree.cpp (decision tree implementation) 43 | ├── RandomForest.h, RandomForest.cpp (random forest implementation) 44 | └── main.cpp (then entry file) 45 | ``` 46 | 47 | ## About the executable 48 | 49 | The executable is built for Windows with VS2013, so it needs some .dlls that come with VS2013. If you want to run the program under other environments, you need to build it from source. 50 | 51 | ### Generate output for Kaggle submission 52 | 53 | 1. Put the training set in `data/train.csv`, and test set in `data/test.csv`(the header line will be ignored) 54 | 2. Run `RandomForest`(`./RandomForest` if you are under Linux). You can pass in an optional number of trees, e.g. `RandomForest 1000` will generate 1000 trees. 55 | 3. The results will be saved in `data/submit.csv` 56 | 57 | ### Validate against the training set 58 | 59 | Note: you need to uncomment the `VALIDATE` flag in `src/Config.h` and build the executable again. 60 | 61 | 1. Put the training set in `data/train.csv`, and the validation set in `data/1000.csv`(the header line will be ignored) 62 | 2. Run `RandomForest`(`./RandomForest` if you are under Linux). You can pass in an optional number of trees, e.g. `RandomForest 1000` will generate 1000 trees. 63 | 3. The results will be saved in `data/submit.csv` 64 | 65 | ## Build 66 | 67 | On Windows, you can build it with VS2013(or maybe a lower version of VS), or GNU Make(Win32 port) and MinGW. Note that this program needs C++11 support, so an old compiler might not be able to build it. 68 | 69 | On Linux, you can build it with Make and g++ or clang++. 70 | 71 | ### Windows with VS2013 72 | 73 | Open the `RandomForest.sln` with VS and build the `RandomForest` target. Make sure that it's in Release mode or the build could be slowed down. The executable will appear under the project directory, named `RandomForest.exe`. 74 | 75 | ### Windows with Make+MinGW / older VS 76 | 77 | Download premake5 from [here](https://premake.github.io/download.html#v5), extract the executable in the archive(e.g. `premake5.exe`), and put the path to the executable in your `PATH` environment variables. Then open cmd and run `premake5 --help` to see what project files you can generate. I've written the premake script `premake5.lua` to generate the proper project files. 78 | 79 | For example, to generate the project files for VS2012, simply run `premake5 vs2012` under the project directory, then open `RandomForest.sln` with your VS and build the `RandomForest` target. The executable will appear under the project directory, named `RandomForest.exe`. 80 | 81 | WARNING: make sure your compiler has enough C++11 support or the build could fail. 82 | 83 | ### Linux with Make and g++/clang++ 84 | 85 | Download premake5 from [here](https://premake.github.io/download.html#v5), extract the executable in the archive(e.g. `premake5`), and put the path to the executable in your `PATH` environment variables(e.g. extract the file to `/usr/local/bin` with root permission so you don't have to touch `PATH`). To generate the project files for make, simply run `premake5 gmake`. If you want to use g++, just run `make` to build it. If you want to use clang, run `make config=clang`(but first you need to make sure you have a symlink `clang++` to your clang++ executable). The executable will appear under the project directory, named `RandomForest`. 86 | 87 | WARNING: make sure your compiler has enough C++11 support or the build could fail. 88 | 89 | ##About 90 | 91 | * [Github repository](https://github.com/joyeecheung/parallel-random-forest) 92 | * Author: Qiuyi Zhang 93 | * Jul. 2015 94 | -------------------------------------------------------------------------------- /RandomForest.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 2013 4 | VisualStudioVersion = 12.0.31101.0 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "RandomForest", "RandomForest.vcxproj", "{7B24B04F-0456-46EF-BC9A-5E9B2203A71F}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|Win32 = Debug|Win32 11 | Release|Win32 = Release|Win32 12 | EndGlobalSection 13 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 14 | {7B24B04F-0456-46EF-BC9A-5E9B2203A71F}.Debug|Win32.ActiveCfg = Debug|Win32 15 | {7B24B04F-0456-46EF-BC9A-5E9B2203A71F}.Debug|Win32.Build.0 = Debug|Win32 16 | {7B24B04F-0456-46EF-BC9A-5E9B2203A71F}.Release|Win32.ActiveCfg = Release|Win32 17 | {7B24B04F-0456-46EF-BC9A-5E9B2203A71F}.Release|Win32.Build.0 = Release|Win32 18 | EndGlobalSection 19 | GlobalSection(SolutionProperties) = preSolution 20 | HideSolutionNode = FALSE 21 | EndGlobalSection 22 | EndGlobal 23 | -------------------------------------------------------------------------------- /RandomForest.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | 14 | {7B24B04F-0456-46EF-BC9A-5E9B2203A71F} 15 | Win32Proj 16 | 17 | 18 | 19 | Application 20 | true 21 | v120 22 | 23 | 24 | Application 25 | false 26 | v120 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | true 40 | $(SolutionDir)\ 41 | $(SolutionDir)\obj\ 42 | 43 | 44 | true 45 | $(SolutionDir)\ 46 | $(SolutionDir)\obj\ 47 | 48 | 49 | 50 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) 51 | MultiThreadedDebugDLL 52 | Level3 53 | ProgramDatabase 54 | Disabled 55 | true 56 | 57 | 58 | MachineX86 59 | true 60 | Console 61 | 62 | 63 | 64 | 65 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 66 | MultiThreadedDLL 67 | Level3 68 | ProgramDatabase 69 | true 70 | 71 | 72 | MachineX86 73 | true 74 | Console 75 | true 76 | true 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /RandomForest.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav 15 | 16 | 17 | 18 | 19 | Source Files 20 | 21 | 22 | Source Files 23 | 24 | 25 | Source Files 26 | 27 | 28 | Source Files 29 | 30 | 31 | 32 | 33 | Header Files 34 | 35 | 36 | Header Files 37 | 38 | 39 | Header Files 40 | 41 | 42 | 43 | 44 | Resource Files 45 | 46 | 47 | Resource Files 48 | 49 | 50 | -------------------------------------------------------------------------------- /doc/kaggle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joyeecheung/parallel-random-forest/7c3eb937b24904ea3835ea321e3cd045e280979d/doc/kaggle.png -------------------------------------------------------------------------------- /doc/multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joyeecheung/parallel-random-forest/7c3eb937b24904ea3835ea321e3cd045e280979d/doc/multi.png -------------------------------------------------------------------------------- /doc/report.tex: -------------------------------------------------------------------------------- 1 | \documentclass{article} 2 | %\usepackage[a4paper,top=0.75in, bottom=0.75in, left=1in, right=1in,footskip=0.2in]{geometry} 3 | \usepackage{fullpage} 4 | %-----------------Hyperlink Packages-------------------- 5 | \usepackage{hyperref} 6 | \hypersetup{ 7 | colorlinks = true, 8 | citecolor = black, 9 | linkcolor = black, 10 | urlcolor = black 11 | } 12 | %-----------------Figure Packages-------------------- 13 | \usepackage{graphicx} % For figures 14 | %\usepackage{epsfig} % for postscript graphics files 15 | %------------------Math Packages------------------------ 16 | \usepackage{amssymb,amsmath} 17 | \usepackage{textcomp} 18 | \usepackage{mdwmath} 19 | \usepackage{mdwtab} 20 | \usepackage{eqparbox} 21 | %------------------Table Packages----------------------- 22 | \usepackage{rotating} % Used to rotate tables 23 | \usepackage{array} % Fixed column widths for tables 24 | %-----------------Algorithm Packages-------------------- 25 | \usepackage{listings} % Source code 26 | \usepackage{algorithm} % Pseudo Code 27 | \usepackage{algpseudocode} 28 | %--------------------------------------------------------- 29 | 30 | %opening 31 | 32 | \begin{document} 33 | 34 | \title{ 35 | Data Mining Course Project \\ 36 | Parallel Random Forest 37 | } 38 | \author { 39 | Computer Application Class 2, 12330402\\ 40 | Qiuyi Zhang (\href{mailto:joyeec9h3@gmail.com}{joyeec9h3@gmail.com}) 41 | } 42 | \date{\today} 43 | 44 | \maketitle 45 | \tableofcontents 46 | 47 | 48 | \section{Problem Description} 49 | 50 | The dataset contains 6238 records for training and 1559 records for testing. Each record has 617 features and a label ranging in $[0, 25]$. Given this dataset, the goal is to train a random forest to predict the labels for the test set as precisely as possible. 51 | 52 | \section{Algorithms} 53 | 54 | For this project I implemented a random forest(with parallelism support) consisting of C4.5 decision trees built with Gini impurity. 55 | 56 | \subsection{Decision Tree} 57 | 58 | \subsubsection{Building Decision Trees} 59 | 60 | The algorithm for building the decision tree is described in Algorithm~\ref{alg:dt}. 61 | 62 | For simplicity, the trees are built recursively. 63 | 64 | \begin{algorithm}[H] 65 | \centering 66 | \caption{Decision tree} 67 | \label{alg:dt} 68 | \begin{algorithmic}[1] 69 | \Function{DecisionTree.fit}{$data$, $attributes$} 70 | \If{$data$ is empty} 71 | \State \Return an empty node 72 | \EndIf 73 | 74 | \State $gain_{best}$ = 0.0, $attr_{best}$ = 0, $value_{best}$ = 0.0 75 | \State $set1_{best}$ = $\{\}$, $set2_{best}$ = $\{\}$ 76 | 77 | \If{$data.size > MIN\_NODE\_SIZE$} 78 | \For{each attribute $a$ in $attributes$} 79 | \State Sort $data$ by their values of $a$, $threshold$ = the midpoint of the sorted $data$ 80 | \State Divide $data$ by $threshold$ into $set1$, $set2$ 81 | \State Calculate the information gain of spliting $data$ into $set1$ and $set2$ 82 | \If{The new information gain $> gain_{best}$} 83 | \State Update $gain_{best}$, $attr_{best}$, $value_{best}$, $set1_{best}$, $set2_{best}$ 84 | \EndIf 85 | \EndFor 86 | \EndIf 87 | \If{$attributes$ is not empty and $gain_{best} \neq 0.0$} 88 | \State remove $attr_{best}$ from $attributes$ 89 | \State $this.left =$ \Call{DecisionTree.fit}{$set1_{best}$, $attributes$} 90 | \State $this.right =$ \Call{DecisionTree.fit}{$set1_{best}$, $attributes$} 91 | \State $this.attr = attr_{best}$ 92 | \State $this.threshold = value_{best}$ 93 | \State $this.cound = data.size$ 94 | \State \Return $this$ 95 | \Else 96 | \State Create a counter for labels in $data$ 97 | \State $this.leaf = counter$ 98 | \State \Return $this$ 99 | \EndIf 100 | \EndFunction 101 | \end{algorithmic} 102 | \end{algorithm} 103 | 104 | \begin{description} 105 | \item Remark 1. \hfill \\ 106 | Suppose $i \in \{1, 2, ..., m\}$, and let $f_i$ be the fraction of items labeled with value $i$ in the set, the \textit{Gini impurity} for a set of items is defined as: 107 | 108 | $$I_{G}(f) = \sum_{i=1}^{m} f_i (1-f_i) = \sum_{i=1}^{m} (f_i - {f_i}^2) = \sum_{i=1}^m f_i - \sum_{i=1}^{m} {f_i}^2 = 1 - \sum^{m}_{i=1} {f_i}^{2}$$ 109 | 110 | 111 | \item Remark 2.\hfill \\ 112 | When a data set $D$ with goal attribute $V$ is split into subsets $D_k$ using attribute $A$, each with $n_k$ records, the \textit{information gain} is defined as: 113 | 114 | $$ 115 | Gain(A) = H(V, D) - Remainder(A) 116 | $$ 117 | 118 | where 119 | 120 | $$ 121 | Remainder(A) = \sum_k P(n_k) H(V, D_k) 122 | $$ 123 | 124 | $P(n_k)$ is the proportion of $E_k$ in $D$. 125 | \end{description} 126 | 127 | \subsubsection{Use Decision Trees for Prediction} 128 | 129 | The algorithm for predicting a label with a set of observations for the features are shown in Algorithm~\ref{alg:predictdt}. It's also recursive, therefore rather intuitive. 130 | 131 | \begin{algorithm}[H] 132 | \centering 133 | \caption{Prediction using Decision Trees} 134 | \label{alg:predictdt} 135 | \begin{algorithmic}[1] 136 | \Function{DecisionTree.predict}{$this$, $observation$} 137 | \If{$this.leaves$ is a counter} 138 | \State\Return The most frequent item in $tree.leaves$ 139 | \EndIf 140 | 141 | \State $value$ = $observation[this.attr]$ 142 | \If{$value < this.threshold$} 143 | \State \Return \Call{DecisionTree.predict}{$this.left$, $observation$} 144 | \Else 145 | \State \Return \Call{DecisionTree.predict}{$tree.right$, $observation$} 146 | \EndIf 147 | \EndFunction 148 | \end{algorithmic} 149 | \end{algorithm} 150 | 151 | \subsection{Random Forest} 152 | 153 | \subsubsection{Building Random Forest} 154 | 155 | The algorithm for building a random forest out of a set of decision trees is described in Algorithm~\ref{alg:rf}. Notice that the sampling of records are without replacement, and the sampling of features are with replacement. 156 | 157 | \begin{algorithm}[H] 158 | \centering 159 | \caption{Building a Random Forest} 160 | \label{alg:rf} 161 | \begin{algorithmic}[1] 162 | \Function{RandomForest.fit}{$data$, $numTrees$, $numFeatures$, $numSampleCoeff$} 163 | \State Sample $numSampleCoeff \times data.size$ records(without replacement) from $data$ as the $bootstrap$ 164 | \For{each $tree$ in this forest} 165 | \State Sample $numFeatures$ attributes(with replacement) from the attribute set as $attributes$ 166 | \State $tree$ = \Call{$DecisionTree.fit$}{$bootstrap$, $attributes$} 167 | \EndFor 168 | \EndFunction 169 | \end{algorithmic} 170 | \end{algorithm} 171 | 172 | \subsection{Use the Random Forest for Prediction} 173 | 174 | The algorithm for classifying an observation is described in Algorithm~\ref{alg:predictrf}. It uses the most popular result from the predictions produced by the trees. 175 | 176 | \begin{algorithm}[H] 177 | \centering 178 | \caption{Prediction using a Random Forest} 179 | \label{alg:predictrf} 180 | \begin{algorithmic}[1] 181 | \Function{RandomForest.predict}{$observations$} 182 | \State $result = []$ 183 | 184 | \For{each $row$ with index $i$ in $observations$} 185 | \State Create a new $counter$ 186 | \For {each $tree$ in the forest} 187 | \State $counter.add$\Call{$tree$.predict}{$row$} 188 | \EndFor 189 | \State $result[i]$ = the most frequent item in $counter$ 190 | \EndFor 191 | 192 | \State \Return $result$ 193 | \EndFunction 194 | \end{algorithmic} 195 | \end{algorithm} 196 | 197 | \section{Implementation} 198 | 199 | \subsection{Parallelism} 200 | 201 | The random forest can be easily parallelized since the training and classifying done in each tree can be completely separated. To build a parallel implementation, there are many options available: multi-threading, MPI, Map-reduce(Hadoop), Spark, etc. Because the computational resources I have are limited, I chose threads to parallelize the implementation. Using OpenMP, which is built in most compilers nowadays, it's fairly trivial to parallelize the for loops in the random forest(a few simple \texttt{\#pragma} directives will suffice). 202 | 203 | On a machine with Intel i7-4710HQ@2.5GHz(which has 4 cores and 8 hyperthreads) and 8G RAM, it takes the program(built with VS2013) about 170s to build a random forest with 1000 trees when multi-threading is not enabled(see Figure~\ref{fig:single}), and about 40s when multi-threading is enabled(see Figure~\ref{fig:multi}). In the first case, only 25\% of the CPU can be utilized while in the second case, 100\% of the CPU is utilized. 204 | 205 | \begin{figure}[] 206 | \centering 207 | \includegraphics[width=\linewidth]{single.png} 208 | \caption{Performance analysis of the implementation without multi-threading} 209 | \label{fig:single} 210 | \end{figure} 211 | 212 | \begin{figure}[] 213 | \centering 214 | \includegraphics[width=\linewidth]{multi.png} 215 | \caption{Performance analysis of the implementation with multi-threading} 216 | \label{fig:multi} 217 | \end{figure} 218 | 219 | \subsection{Cross-platform} 220 | 221 | This implementation was originally developed with Visual Studio 2013. By using Premake, it can also be built with MinGW under Windows, or with g++/clang++ under Linux. See README.md for instructions on how to build it on other platforms. 222 | 223 | \section{Experiment Results} 224 | 225 | For applications of random forests, there are some critical parameters that needs to be tuned: 226 | 227 | \begin{enumerate} 228 | \item \textbf{Number of trees in the forest}. Generally speaking, the more trees in the forest, the better the predictions will be. 229 | \item \textbf{Number of sampled features}. A rule of thumb for it is $\log_2n$ where $n$ is the number of features in the dataset. 230 | \item \textbf{Minimum size of tree nodes}. The minimum number of samples required to split a node. 231 | \item \textbf{Size of the bootstrap}. Generally speaking, if the size of the bootstrap is the same as the size of the training set(which means some rows could be chosen multiple times while some would be left out), the random forest will typically provide near-optimal performance. 232 | \item And more... 233 | \end{enumerate} 234 | 235 | For this project, I used 1000 trees in the forest, $\log_2n$ features, set the minimum size of tree nodes to be 2, and use a bootstrap the same size as the training set. This configuration gives a $96.144\%$ precision when tested on the Kaggle platform, see Figure~\ref{fig:kaggle}. 236 | 237 | \begin{figure}[] 238 | \centering 239 | \includegraphics[width=\linewidth]{kaggle.png} 240 | \caption{Score on Kaggle} 241 | \label{fig:kaggle} 242 | \end{figure} 243 | 244 | \end{document} 245 | -------------------------------------------------------------------------------- /doc/single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joyeecheung/parallel-random-forest/7c3eb937b24904ea3835ea321e3cd045e280979d/doc/single.png -------------------------------------------------------------------------------- /premake5.lua: -------------------------------------------------------------------------------- 1 | -- premake5.lua 2 | solution "RandomForest" 3 | configurations { "Debug", "Release", "clang" } 4 | 5 | -- A project defines one build target 6 | project "RandomForest" 7 | kind "ConsoleApp" 8 | language "C++" 9 | files { "src/*.cpp" } 10 | includedirs { "src" } 11 | 12 | configuration { "gmake", "-std=c++11" } 13 | buildoptions { "-fopenmp" } 14 | links { "gomp" } 15 | 16 | configuration { "vs*" } 17 | buildoptions { "/openmp" } 18 | 19 | configuration "Debug" 20 | defines { "DEBUG" } -- -DDEBUG 21 | flags { "Symbols" } 22 | 23 | configuration "Release" 24 | defines { "NDEBUG" } -- -NDEBUG 25 | flags { "Optimize" } 26 | 27 | configuration "clang" 28 | toolset "clang" 29 | -------------------------------------------------------------------------------- /src/Config.cpp: -------------------------------------------------------------------------------- 1 | #include "Config.h" 2 | -------------------------------------------------------------------------------- /src/Config.h: -------------------------------------------------------------------------------- 1 | #ifndef __CONFIG__ 2 | #define __CONFIG__ 3 | 4 | #define JDEBUG 5 | //#define VALIDATE 6 | //#define DEBUG_TREE 7 | #define DEBUG_FOREST 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using std::shared_ptr; 19 | using std::vector; 20 | using std::array; 21 | using std::set; 22 | using std::map; 23 | using std::string; 24 | using std::pair; 25 | 26 | typedef vector MutLabels; 27 | typedef const vector Labels; 28 | typedef vector Indices; 29 | typedef set IndicesSet; 30 | 31 | #define FEATURE_NUM 617 32 | #define LABLE_NUM 26 33 | #define MIN_NODE_SIZE 2 34 | 35 | typedef array MutRow; 36 | typedef const array Row; 37 | typedef vector > MutValues; 38 | typedef const vector > Values; 39 | 40 | extern void printRow(Row row); 41 | #endif -------------------------------------------------------------------------------- /src/DecisionTree.cpp: -------------------------------------------------------------------------------- 1 | #include "DecisionTree.h" 2 | 3 | // ids: ids to avalable data rows 4 | // features: ids to sampled features 5 | void DecisionTree::fit(Values &X, Labels &y, 6 | const Indices &ids, const IndicesSet &features) { 7 | if (ids.size() == 0) { 8 | return; // leaves 9 | } 10 | 11 | // get the best feature to split 12 | double best_score = 0.0, best_value, score; 13 | size_t best_attr; 14 | Indices best_set1, best_set2; 15 | 16 | if (ids.size() > MIN_NODE_SIZE){ 17 | double initial = gini(y, ids); 18 | #ifdef DEBUG_TREE 19 | //printf("======================================\n"); 20 | //printf("Initial gini: %f\n", initial); 21 | #endif 22 | // Note: if features.size() == 0, best_score = 0.0 23 | for (auto &attr : features) { 24 | // choose the threshold 25 | const Indices sorted_idx = argsort(X, ids, attr); 26 | size_t id_count = sorted_idx.size(); 27 | double threshold = X[sorted_idx[id_count / 2]][attr]; 28 | // divide the data set into two sets 29 | Indices set1, set2; 30 | bool missed = split(X, sorted_idx, set1, set2, attr); 31 | //bool missed = split(X, sorted_idx, id_count, set1, set2); 32 | 33 | // get the score of this attribute 34 | if (missed || set1.size() == 0 || set2.size() == 0) { 35 | score = 0.0; 36 | } else { 37 | score = gain(X, y, ids, set1, set2, initial); 38 | } 39 | 40 | // update best score 41 | if (score > best_score) { 42 | { 43 | best_score = score; 44 | best_attr = attr; 45 | best_value = threshold; 46 | best_set1 = set1; 47 | best_set2 = set2; 48 | } 49 | 50 | } 51 | } 52 | } 53 | 54 | if (best_score > 0.0) { // more attributes to test 55 | IndicesSet new_attr = features; 56 | new_attr.erase(best_attr); 57 | this->left = shared_ptr(new DecisionTree); 58 | this->right = shared_ptr(new DecisionTree); 59 | this->left->fit(X, y, best_set1, new_attr); 60 | this->right->fit(X, y, best_set2, new_attr); 61 | 62 | this->attr = best_attr; 63 | this->threshold = best_value; 64 | this->count = ids.size(); 65 | #ifdef DEBUG_TREE 66 | printf("Select %d = %f as the split point for %d samples\n", best_attr, best_value, count); 67 | #endif 68 | } else { // all attributes tested 69 | this->leaf = shared_ptr(new Counter(y, ids)); 70 | #ifdef DEBUG_TREE 71 | printf("This is the leave\n"); 72 | leaf->print(); 73 | #endif 74 | this->count = ids.size(); 75 | } 76 | } 77 | 78 | 79 | MutLabels DecisionTree::predict(Values &X) { 80 | int total = X.size(); 81 | MutLabels y(total); 82 | // TODO: parallel 83 | for (int i = 0; i < total; ++i) { 84 | y[i] = predict(X[i]); 85 | } 86 | return y; 87 | } 88 | 89 | int DecisionTree::predict(Row &x) { 90 | if (leaf != nullptr) { // leaf 91 | return leaf->getMostFrequent(); 92 | } 93 | 94 | double value = x[attr]; 95 | if (value < threshold) { 96 | return left->predict(x); 97 | } else { 98 | return right->predict(x); 99 | } 100 | } 101 | 102 | 103 | double DecisionTree::gini(Labels &y, const Indices &ids) { 104 | size_t total = ids.size(); 105 | Counter counter(y, ids); 106 | double imp = 0.0; 107 | const map &freqs = counter.data; 108 | double normalized_freq; 109 | for (auto &freq : freqs) { 110 | normalized_freq = (double)freq.second / total; 111 | imp -= normalized_freq * normalized_freq; 112 | } 113 | return imp; 114 | } 115 | 116 | // TODO: infomation gain 117 | double DecisionTree::gain(Values &X, Labels &y, const Indices &ids, 118 | const Indices &set1, const Indices &set2, double initial) { 119 | double p = (double)set1.size() / ids.size(); 120 | double remainder = p * gini(y, set1) + (1 - p) * gini(y, set2); 121 | return initial - remainder; 122 | } 123 | 124 | // sort ids by values 125 | Indices DecisionTree::argsort(Values &X, const Indices &ids, size_t attr) { 126 | // initialize original index locations 127 | Indices idx(ids.begin(), ids.end()); 128 | 129 | // sort indexes based on comparing values in v 130 | sort(idx.begin(), idx.end(), 131 | [&X, &attr](size_t i1, size_t i2) { 132 | return X[i1][attr] < X[i2][attr]; 133 | }); 134 | 135 | return idx; 136 | } 137 | 138 | bool DecisionTree::split(Values &X, const Indices &sorted_idx, 139 | Indices &set1, Indices &set2, size_t attr) { 140 | // check if out of range 141 | int id_count = sorted_idx.size(); 142 | double threshold = X[sorted_idx[id_count / 2]][attr]; 143 | if (X[sorted_idx[id_count - 1]][attr] < threshold || X[sorted_idx[0]][attr] > threshold) { 144 | return true; 145 | } 146 | 147 | set1 = Indices(sorted_idx.begin(), sorted_idx.begin() + id_count / 2); 148 | 149 | if (id_count > 1) { 150 | set2 = Indices(sorted_idx.begin() + id_count / 2, sorted_idx.end()); 151 | } else { 152 | set2 = Indices(); 153 | } 154 | return false; 155 | } 156 | 157 | void DecisionTree::print(int indent) { 158 | if (leaf != nullptr) { 159 | leaf->print(0); 160 | } else { 161 | printf("%d < %f ? (%d)\n", attr, threshold, count); 162 | for (int i = 0; i < indent + 3; ++i) printf(" "); 163 | printf("T->"); 164 | left->print(indent + 3); 165 | for (int i = 0; i < indent + 3; ++i) printf(" "); 166 | printf("F->"); 167 | right->print(indent + 3); 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /src/DecisionTree.h: -------------------------------------------------------------------------------- 1 | #ifndef DECISION_TREE_H 2 | #define DECISION_TREE_H 3 | 4 | #include "Config.h" 5 | 6 | #include 7 | #include 8 | class Counter { 9 | public: 10 | Counter(Labels &y, const Indices &ids) { 11 | for (auto &id : ids) { 12 | if (data.find(y[id]) != data.end()) { 13 | data[y[id]] += 1; 14 | } else { 15 | data[y[id]] = 1; 16 | } 17 | } 18 | } 19 | 20 | Counter(Labels &y) { 21 | for (auto &label : y) { 22 | if (data.find(label) != data.end()) { 23 | data[label] += 1; 24 | } else { 25 | data[label] = 1; 26 | } 27 | } 28 | } 29 | 30 | int getMostFrequent() const { 31 | std::vector> pairs(data.begin(), data.end()); 32 | std::sort(pairs.begin(), pairs.end(), [=](const pair& a, const pair& b) { 33 | return a.second > b.second; 34 | } 35 | ); 36 | return pairs.begin()->first; 37 | } 38 | 39 | void print(int indent = 0) { 40 | for (int i = 0; i < indent; ++i) printf(" "); 41 | for (auto &kv : data) { 42 | printf("%d: %d, ", kv.first, kv.second); 43 | } 44 | printf("\n"); 45 | } 46 | 47 | map data; 48 | }; 49 | 50 | class DecisionTree { 51 | 52 | public: 53 | DecisionTree() : left(nullptr), right(nullptr), leaf(nullptr) {} 54 | // rules of three 55 | 56 | DecisionTree(const DecisionTree &other) { 57 | this->left = other.left; 58 | this->right = other.right; 59 | this->leaf = other.leaf; 60 | this->attr = other.attr; 61 | this->threshold = other.threshold; 62 | this->count = other.count; 63 | } 64 | 65 | void swap(DecisionTree &other) { 66 | std::swap(this->left, other.left); 67 | std::swap(this->right, other.right); 68 | std::swap(this->leaf, other.leaf); 69 | std::swap(this->attr, other.attr); 70 | std::swap(this->threshold, other.threshold); 71 | std::swap(this->count, other.count); 72 | } 73 | 74 | DecisionTree &operator=(const DecisionTree &other) { 75 | DecisionTree temp(other); 76 | temp.swap(*this); 77 | return *this; 78 | } 79 | 80 | ~DecisionTree() { 81 | // all automatically recycled 82 | } 83 | 84 | void print(int indent = 2); 85 | 86 | void fit(Values &X, Labels &y, const Indices &ids, 87 | const IndicesSet &features); 88 | 89 | MutLabels predict(Values &X); 90 | int predict(Row &x); 91 | private: 92 | double gini(Labels &y, const Indices &ids); 93 | double gain(Values &X, Labels &y, const Indices &ids, 94 | const Indices &set1, const Indices &set2, double initial); 95 | Indices argsort(Values &X, const Indices &ids, size_t attr); 96 | bool split(Values &X, const Indices &sorted_idx, 97 | Indices &set1, Indices &set2, size_t attr); 98 | 99 | std::shared_ptr left; 100 | std::shared_ptr right; 101 | 102 | size_t attr; 103 | double threshold; 104 | size_t count; 105 | 106 | std::shared_ptr leaf; 107 | }; 108 | 109 | #endif 110 | -------------------------------------------------------------------------------- /src/RandomForest.cpp: -------------------------------------------------------------------------------- 1 | #include "RandomForest.h" 2 | #include 3 | #include 4 | 5 | RandomForest::RandomForest(size_t numOfTrees, size_t maxValues, size_t numLabels, 6 | double sampleCoeff) { 7 | this->numOfTrees = numOfTrees; 8 | this->maxValues = maxValues; 9 | this->numLabels = numLabels; 10 | this->sampleCoeff = sampleCoeff; 11 | this->forest.resize(numOfTrees); 12 | }; 13 | 14 | 15 | void RandomForest::fit(Values &X, Labels &y, const Indices &ids) { 16 | bootstrap = sample(ids); 17 | #ifdef DEBUG_FOREST 18 | printf("Bootstraping done. Sample size = %d\n", bootstrap.size()); 19 | printf("Start training...\nTrees completed:\n"); 20 | #endif 21 | // draw a random sample from X and y 22 | 23 | //TODO: parallel 24 | #pragma omp parallel for 25 | for (int i = 0; i < numOfTrees; ++i) { 26 | 27 | const IndicesSet features = chooseFeaturess(FEATURE_NUM, maxValues); 28 | #ifdef DEBUG_TREE 29 | printf("Chosen features: "); 30 | for (auto &f : features) { 31 | printf("%d ", f); 32 | } 33 | printf("\n"); 34 | #endif 35 | DecisionTree tree; 36 | // train a tree with the sample 37 | tree.fit(X, y, bootstrap, features); 38 | // put it into the forest 39 | forest[i] = tree; 40 | #ifdef DEBUG_TREE 41 | tree.print(); 42 | #endif 43 | #ifdef DEBUG_FOREST 44 | printf("================ %d ==============\n", i); 45 | #endif 46 | } 47 | 48 | #ifdef DEBUG_FOREST 49 | printf("\nTraining done.\n"); 50 | #endif 51 | } 52 | 53 | Indices RandomForest::sample(const Indices &ids) { 54 | size_t data_size = ids.size(); 55 | size_t sample_size = (int)(sampleCoeff * data_size); 56 | Indices idx; 57 | 58 | for (int i = 0; i < sample_size; ++i) { 59 | size_t next = rand() % data_size; // with replacement 60 | idx.push_back(next); 61 | } 62 | return idx; 63 | } 64 | 65 | 66 | IndicesSet RandomForest::chooseFeaturess(size_t numValues, size_t maxValues) { 67 | // randomly choose maxValues numbers from [0, numValues - 1] 68 | Indices idx(numValues); 69 | unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); 70 | 71 | for (size_t i = 0; i < numValues; ++i) idx[i] = i; 72 | std::shuffle(idx.begin(), idx.end(), std::default_random_engine(seed)); 73 | 74 | return IndicesSet(idx.begin(), idx.begin() + maxValues); 75 | } 76 | 77 | 78 | MutLabels RandomForest::predict(Values &X) { 79 | int total = X.size(); 80 | MutLabels y(total); 81 | #ifdef DEBUG_FOREST 82 | printf("\nStart prediction, data size %d...\n", X.size()); 83 | #endif 84 | 85 | // parallel 86 | #pragma omp parallel for 87 | for (int i = 0; i < total; ++i) { 88 | y[i] = predict(X[i]); 89 | } 90 | #ifdef DEBUG_FOREST 91 | printf("Prediction completed.\n"); 92 | #endif 93 | return y; 94 | } 95 | 96 | int RandomForest::predict(Row &x) { 97 | // get the prediction from all tress 98 | MutLabels results(numOfTrees); 99 | 100 | for (int i = 0; i < numOfTrees; ++i) { 101 | results[i] = forest[i].predict(x); 102 | } 103 | 104 | Counter counter(results); 105 | 106 | // average 107 | int result = counter.getMostFrequent(); 108 | return result; 109 | } 110 | 111 | bool csv2data(const char* filename, MutValues &X, MutLabels &y, 112 | Indices &ids, int idIdx, int labelIdx) { 113 | std::ifstream file(filename); 114 | 115 | if (!file) { 116 | #ifdef JDEBUG 117 | printf("Failed to open the file.\n"); 118 | #endif 119 | return false; 120 | } 121 | 122 | #ifdef JDEBUG 123 | printf("Skipping the first line...\n"); 124 | #endif 125 | string line, value; 126 | std::getline(file, line); // first line 127 | 128 | #ifdef JDEBUG 129 | printf("Start loading the data...\n"); 130 | #endif 131 | // for each line 132 | size_t row = 0; 133 | while (std::getline(file, line)) { 134 | MutRow temp; 135 | std::istringstream ss(line); 136 | 137 | // split by commas 138 | // Note: here I use strtol and strtod instead of std::stoi 139 | // and std::stod because mingw32 has a bug and doesn't 140 | // have these. g++ under linux works find though. 141 | size_t col = 0; 142 | size_t row_size = labelIdx == -1 ? FEATURE_NUM + 1 : FEATURE_NUM + 2; 143 | for (size_t i = 0; i < row_size; ++i) { 144 | std::getline(ss, value, ','); 145 | if (i == idIdx) { // id 146 | if (labelIdx == -1) { 147 | ids.push_back(strtol(value.c_str(), nullptr, 10)); // real id 148 | } else { 149 | ids.push_back(row); // idx 150 | } 151 | } else if (i == labelIdx) { // label 152 | y.push_back(strtol(value.c_str(), nullptr, 10)); 153 | } else { // value 154 | temp[col] = strtod(value.c_str(), nullptr); 155 | col++; 156 | } 157 | } 158 | X.push_back(temp); 159 | row++; 160 | } 161 | #ifdef JDEBUG 162 | printf("Data loading completed.\nThe first line is:\n"); 163 | printRow(X[0]); 164 | printf("The last line is::\n"); 165 | printRow(X[X.size() - 1]); 166 | #endif 167 | return true; 168 | } 169 | 170 | void printRow(Row row) { 171 | printf("[\t"); 172 | size_t row_size = row.size(); 173 | if (row_size < 11) { 174 | for (auto &value : row) { 175 | printf("%f\t", value); 176 | } 177 | } else { 178 | for (size_t i = 0; i < 5; ++i) { 179 | printf("%f\t", row[i]); 180 | } 181 | printf("\t...\t"); 182 | for (size_t i = row_size - 5; i < row_size; ++i) { 183 | printf("%f\t", row[i]); 184 | } 185 | } 186 | printf("]\n"); 187 | } -------------------------------------------------------------------------------- /src/RandomForest.h: -------------------------------------------------------------------------------- 1 | #ifndef RANDOM_FOREST_H 2 | #define RANDOM_FOREST_H 3 | #include "Config.h" 4 | #include "DecisionTree.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | bool csv2data(const char* filename, MutValues &X, MutLabels &y, 11 | Indices &ids, int idIdx, int labelIdx = -1); 12 | 13 | class RandomForest { 14 | public: 15 | RandomForest(size_t numOfTrees, size_t maxValues, size_t numLabels, 16 | double sampleCoeff); 17 | void fit(Values &X, Labels &y, const Indices &ids); 18 | MutLabels predict(Values &X); 19 | int predict(Row &x); 20 | bool loadDataSet(const char* filename, size_t idIdx, size_t labelIdx); 21 | private: 22 | IndicesSet chooseFeaturess(size_t numValues, size_t maxValues); 23 | Indices sample(const Indices &ids); 24 | 25 | vector forest; 26 | MutValues X; 27 | MutLabels y; 28 | Indices ids; 29 | size_t numOfTrees; 30 | size_t maxValues; // usually sqrt(FEATURE_NUM) 31 | size_t numLabels; 32 | double sampleCoeff; // 1 would be good 33 | Indices bootstrap; // bootstrap sample 34 | }; 35 | 36 | #endif -------------------------------------------------------------------------------- /src/RandomForest.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | 14 | {CE156241-9F55-45E4-BE0D-E9C4CEE42FB1} 15 | Win32Proj 16 | RandomForest 17 | main 18 | 19 | 20 | 21 | Application 22 | true 23 | v120 24 | Unicode 25 | 26 | 27 | Application 28 | false 29 | v120 30 | true 31 | Unicode 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | true 45 | $(SolutionDir) 46 | $(SolutionDir)\obj\ 47 | 48 | 49 | false 50 | 51 | 52 | 53 | 54 | 55 | Level3 56 | Disabled 57 | WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 58 | true 59 | 60 | 61 | Console 62 | true 63 | 64 | 65 | 66 | 67 | Level3 68 | 69 | 70 | MaxSpeed 71 | true 72 | true 73 | WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) 74 | 75 | 76 | Console 77 | true 78 | true 79 | true 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /src/RandomForest.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;hm;inl;inc;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | 18 | 19 | Source Files 20 | 21 | 22 | Source Files 23 | 24 | 25 | Source Files 26 | 27 | 28 | Source Files 29 | 30 | 31 | Source Files 32 | 33 | 34 | 35 | 36 | Source Files 37 | 38 | 39 | Source Files 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/RandomForest.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | true 5 | 6 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "Config.h" 7 | #include "DecisionTree.h" 8 | #include "RandomForest.h" 9 | 10 | const int MAX_FEATURE = (int)round(sqrt(FEATURE_NUM)); 11 | const double SAMPLE_COEFF = 1.0; 12 | 13 | int main(int argc, char *argv[]) { 14 | int tree_num = 1000; 15 | if (argc > 1) { 16 | tree_num = atoi(argv[1]); 17 | } 18 | #ifdef JDEBUG 19 | printf("=======================================\n"); 20 | printf("Initializing forest:\nSize of forest\t%d\n", tree_num); 21 | printf("number of sampled feature\t%d\n", MAX_FEATURE); 22 | printf("number of lables\t%d\n", LABLE_NUM); 23 | #endif 24 | RandomForest rf(tree_num, MAX_FEATURE, LABLE_NUM, SAMPLE_COEFF); 25 | #ifdef JDEBUG 26 | printf("=======================================\n"); 27 | #endif 28 | #ifdef JDEBUG 29 | printf("=======================================\n"); 30 | #endif 31 | MutValues X, test_X; 32 | MutLabels y, dummy; 33 | Indices ids, test_ids; 34 | 35 | #ifdef VALIDATE 36 | csv2data("data/train.csv", X, y, ids, 0, FEATURE_NUM + 1); 37 | rf.fit(X, y, ids); 38 | csv2data("data/1000.csv", test_X, dummy, test_ids, 0); 39 | #else 40 | csv2data("data/train.csv", X, y, ids, 0, FEATURE_NUM + 1); 41 | rf.fit(X, y, ids); 42 | csv2data("data/test.csv", test_X, dummy, test_ids, 0); 43 | #endif 44 | 45 | Labels yhat = rf.predict(test_X); 46 | size_t count = yhat.size(); 47 | 48 | #ifdef VALIDATE 49 | int errors = 0; 50 | for (size_t i = 0; i < count; ++i) { 51 | if (yhat[i] != y[i]) errors++; 52 | printf("(%d, %d); ", yhat[i], y[i]); 53 | } 54 | printf("\nErrors: %d, %f\n", errors, (double)errors/count); 55 | #else 56 | std::ofstream out("data/submit.csv"); 57 | out << "id,label\n"; 58 | for (size_t i = 0; i < count; ++i) { 59 | out << test_ids[i] << ',' << yhat[i] << '\n'; 60 | } 61 | #endif 62 | 63 | return 0; 64 | } --------------------------------------------------------------------------------