├── main.cpp ├── input.txt ├── README.md ├── .gitattributes ├── .gitignore └── DecisionTree.h /main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include"DecisionTree.h" 3 | using namespace std; 4 | int main(){ 5 | Tree* tree=new Tree; 6 | cout<<"Trained..........."<GetOutPut(); 10 | 11 | return 0; 12 | } 13 | 14 | -------------------------------------------------------------------------------- /input.txt: -------------------------------------------------------------------------------- 1 | 3 0 2 | 3 0 3 | 2 0 4 | 2 1 5 | 0 0 0 0 0 0 6 | 0 0 0 1 0 0 7 | 1 0 0 0 1 0 8 | 2 1 0 0 1 0 9 | 2 2 1 0 1 0 10 | 2 2 1 1 0 0 11 | 1 2 1 1 1 0 12 | 0 1 0 0 0 0 13 | 0 2 1 0 1 0 14 | 2 1 1 0 1 0 15 | 0 1 1 1 1 0 16 | 1 1 0 1 1 0 17 | 1 0 1 0 1 0 18 | 2 1 0 1 0 1 19 | 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DecisionTree 2 | # The realization of decision tree learning algorithm-----C++ Version,决策树学习算法,C++实现 3 | 4 | # 代码中有详细注释 5 | # 此决策树采用最普通的ID3算法 6 | 7 | 8 | 输入格式 9 | 先输入分类属性的数量 10 | 再输入训练样例,比如有如下属性和训练样例 11 | 12 | 年龄:幼儿,儿童,少年,青年,中年,老年 13 | 压力:小,中,大 14 | 是否幸福:不幸福,普通,很幸福 15 | 16 | 样例: 17 | 1.幼儿 小 很幸福 18 | 2.儿童 中 普通 19 | 3.少年 中 不幸福 20 | 21 | 则对应的输入为 22 | 6 0 23 | 3 0 24 | 3 1 25 | 0 0 2 0 26 | 1 1 1 0 27 | 2 1 0 1 28 | 29 | 项目中有一个简单的训练样例 30 | 31 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Windows image file caches 2 | Thumbs.db 3 | ehthumbs.db 4 | 5 | # Folder config file 6 | Desktop.ini 7 | 8 | # Recycle Bin used on file shares 9 | $RECYCLE.BIN/ 10 | 11 | # Windows Installer files 12 | *.cab 13 | *.msi 14 | *.msm 15 | *.msp 16 | 17 | # Windows shortcuts 18 | *.lnk 19 | 20 | # ========================= 21 | # Operating System Files 22 | # ========================= 23 | 24 | # OSX 25 | # ========================= 26 | 27 | .DS_Store 28 | .AppleDouble 29 | .LSOverride 30 | 31 | # Thumbnails 32 | ._* 33 | 34 | # Files that might appear on external disk 35 | .Spotlight-V100 36 | .Trashes 37 | 38 | # Directories potentially created on remote AFP share 39 | .AppleDB 40 | .AppleDesktop 41 | Network Trash Folder 42 | Temporary Items 43 | .apdisk 44 | -------------------------------------------------------------------------------- /DecisionTree.h: -------------------------------------------------------------------------------- 1 | #ifndef DECISIONTREE 2 | #define DECISIONTREE 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | //训练用数据 8 | class TrainData{ 9 | public: 10 | vector > Input;//一行一条数据 11 | vector OutPut;//允许离散的输出值 12 | void InSertData(vector data,int out){//输入一行数据和目标输出 13 | Input.push_back(data); 14 | OutPut.push_back(out); 15 | } 16 | }; 17 | 18 | class Node{ 19 | public: 20 | int Attribute;//属性序号 21 | bool IsLeaf;//是否是叶节点 22 | vector Num;//子女节点 23 | Node(int ID,bool a):Attribute(ID),IsLeaf(a){} 24 | 25 | }; 26 | 27 | //决策树 28 | class Tree{ 29 | private: 30 | Node* Root;//根节点 31 | 32 | vector > AttrData;//属性列表 33 | 34 | Node* CreateTree(TrainData data,vector usedAttr);//ID3算法生成树 35 | int MostNormalOutPut(TrainData data);//将最普遍的输出作为节点值 36 | int Best(TrainData data,vector usedAttr);//计算信息增益最高的属性 37 | double Entropy(TrainData data);//计算信息熵 38 | public: 39 | Tree(); 40 | void GetOutPut();//输入一个案例,获得输出 41 | }; 42 | 43 | Tree::Tree(){ 44 | /*输入属性列表,输入每个属性的分类属性个数即可*/ 45 | int stop=0,num=0; 46 | while(!stop){ 47 | vector temp; 48 | cout<<"Attribute"<<"["<>aa; 51 | for(int i=0;i>stop; 56 | num++; 57 | } 58 | 59 | /*输入训练数据,直接按顺序输入分类属性序号*/ 60 | TrainData data; 61 | stop=0; 62 | while(!stop){ 63 | vector train; 64 | cout<<"TrainData:"; 65 | int aa=0; 66 | for(unsigned int i=0;i>aa; 68 | train.push_back(aa); 69 | } 70 | cout<<"OutPut:"; 71 | int aaa; 72 | cin>>aaa; 73 | data.InSertData(train,aaa); 74 | cout<<"Stop?"<>stop; 76 | } 77 | 78 | vector temp2; 79 | Root=CreateTree(data,temp2); 80 | cout<<"Training........."< usedAttr){ 84 | 85 | Node* root=new Node(0,0);//创建根节点 86 | 87 | /*如果输出都一样,则创建一个节点,值为该输出,且为叶节点*/ 88 | int stop=1; 89 | for(unsigned int i=1;iAttribute=A; 104 | 105 | /*递归的在每一个分类属性下新建一棵树*/ 106 | for(unsigned int i=0;iNum.push_back(new Node(MostNormalOutPut(data),1)); 115 | } 116 | else{ 117 | root->Num.push_back(CreateTree(tempExample,usedAttr)); 118 | } 119 | } 120 | 121 | return root; 122 | } 123 | 124 | int Tree::MostNormalOutPut(TrainData data){ 125 | vector out;//记录输出的种类 126 | vector count;//记录种类的数量 127 | for(unsigned int i=0;imax){ 149 | maxi=i; 150 | max=count[i]; 151 | } 152 | } 153 | return out[maxi]; 154 | } 155 | 156 | double Tree::Entropy(TrainData data){ 157 | /*计算输出种类和数量*/ 158 | vector out; 159 | vector count; 160 | for(unsigned int i=0;i usedAttr){ 192 | vector Gain;//记录每一个属性的信息增益 193 | 194 | bool used; 195 | /*将使用过的属性的信息增益设置为0*/ 196 | for(unsigned int i=0;imax){ 228 | maxi=i; 229 | max=Gain[i]; 230 | } 231 | } 232 | return maxi; 233 | } 234 | 235 | void Tree::GetOutPut(){ 236 | vector data; 237 | cout<<"TestData:"; 238 | int aa=0; 239 | for(int i=0;i>aa; 241 | data.push_back(aa); 242 | } 243 | if(Root->IsLeaf){ 244 | cout<<"OutPut:"<Attribute<Num[data[Root->Attribute]]; 248 | while(!current->IsLeaf) 249 | current=current->Num[data[current->Attribute]]; 250 | cout<<"OutPut:"<Attribute<