├── .gitignore ├── Gemfile ├── README.md ├── Rakefile ├── decisiontree.gemspec ├── examples ├── continuous-id3.rb ├── data │ ├── continuous-test.txt │ ├── continuous-training.txt │ ├── discrete-test.txt │ └── discrete-training.txt ├── discrete-id3.rb └── simple.rb ├── lib ├── .DS_Store ├── core_extensions │ ├── array.rb │ └── object.rb ├── decisiontree.rb └── decisiontree │ └── id3_tree.rb └── spec ├── id3_spec.rb └── spec_helper.rb /.gitignore: -------------------------------------------------------------------------------- 1 | *.gem 2 | *.rbc 3 | .bundle 4 | .config 5 | .yardoc 6 | Gemfile.lock 7 | InstalledFiles 8 | _yardoc 9 | coverage 10 | doc/ 11 | lib/bundler/man 12 | pkg 13 | rdoc 14 | spec/reports 15 | test/tmp 16 | test/version_tmp 17 | tmp 18 | *.png -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source 'https://rubygems.org' 2 | 3 | # Specify your gem's dependencies in ..gemspec 4 | gemspec 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Decision Tree 2 | 3 | A Ruby library which implements [ID3 (information gain)](https://en.wikipedia.org/wiki/ID3_algorithm) algorithm for decision tree learning. Currently, continuous and discrete datasets can be learned. 4 | 5 | - Discrete model assumes unique labels & can be graphed and converted into a png for visual analysis 6 | - Continuous looks at all possible values for a variable and iteratively chooses the best threshold between all possible assignments. This results in a binary tree which is partitioned by the threshold at every step. (e.g. temperate > 20C) 7 | 8 | ## Features 9 | - ID3 algorithms for continuous and discrete cases, with support for inconsistent datasets. 10 | - [Graphviz component](http://rockit.sourceforge.net/subprojects/graphr/) to visualize the learned tree 11 | - Support for multiple, and symbolic outputs and graphing of continuous trees. 12 | - Returns default value when no branches are suitable for input 13 | 14 | ## Implementation 15 | 16 | - Ruleset is a class that trains an ID3Tree with 2/3 of the training data, converts it into set of rules and prunes the rules with the remaining 1/3 of the training data (in a [C4.5](https://en.wikipedia.org/wiki/C4.5_algorithm) way). 17 | - Bagging is a bagging-based trainer (quite obvious), which trains 10 Ruleset trainers and when predicting chooses the best output based on voting. 18 | 19 | [Blog post with explanation & examples](http://www.igvita.com/2007/04/16/decision-tree-learning-in-ruby/) 20 | 21 | ## Example 22 | 23 | ```ruby 24 | require 'decisiontree' 25 | 26 | attributes = ['Temperature'] 27 | training = [ 28 | [36.6, 'healthy'], 29 | [37, 'sick'], 30 | [38, 'sick'], 31 | [36.7, 'healthy'], 32 | [40, 'sick'], 33 | [50, 'really sick'], 34 | ] 35 | 36 | # Instantiate the tree, and train it based on the data (set default to '1') 37 | dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous) 38 | dec_tree.train 39 | 40 | test = [37, 'sick'] 41 | decision = dec_tree.predict(test) 42 | puts "Predicted: #{decision} ... True decision: #{test.last}" 43 | 44 | # => Predicted: sick ... True decision: sick 45 | 46 | # Specify type ("discrete" or "continuous") in the training data 47 | labels = ["hunger", "color"] 48 | training = [ 49 | [8, "red", "angry"], 50 | [6, "red", "angry"], 51 | [7, "red", "angry"], 52 | [7, "blue", "not angry"], 53 | [2, "red", "not angry"], 54 | [3, "blue", "not angry"], 55 | [2, "blue", "not angry"], 56 | [1, "red", "not angry"] 57 | ] 58 | 59 | dec_tree = DecisionTree::ID3Tree.new(labels, training, "not angry", color: :discrete, hunger: :continuous) 60 | dec_tree.train 61 | 62 | test = [7, "red", "angry"] 63 | decision = dec_tree.predict(test) 64 | puts "Predicted: #{decision} ... True decision: #{test.last}" 65 | 66 | # => Predicted: angry ... True decision: angry 67 | ``` 68 | 69 | ## License 70 | 71 | The [MIT License](https://opensource.org/licenses/MIT) - Copyright (c) 2006 Ilya Grigorik 72 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require 'bundler' 2 | Bundler::GemHelper.install_tasks 3 | 4 | require 'rspec/core/rake_task' 5 | RSpec::Core::RakeTask.new 6 | 7 | task :default => :spec 8 | -------------------------------------------------------------------------------- /decisiontree.gemspec: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | $:.push File.expand_path("../lib", __FILE__) 3 | 4 | Gem::Specification.new do |s| 5 | s.name = "decisiontree" 6 | s.version = "0.5.0" 7 | s.platform = Gem::Platform::RUBY 8 | s.authors = ["Ilya Grigorik"] 9 | s.email = ["ilya@igvita.com"] 10 | s.homepage = "https://github.com/igrigorik/decisiontree" 11 | s.summary = %q{ID3-based implementation of the M.L. Decision Tree algorithm} 12 | s.description = s.summary 13 | s.license = "MIT" 14 | 15 | s.rubyforge_project = "decisiontree" 16 | 17 | s.add_development_dependency "graphr" 18 | s.add_development_dependency "rspec" 19 | s.add_development_dependency "rspec-given" 20 | s.add_development_dependency "pry" 21 | 22 | s.files = `git ls-files`.split("\n") 23 | s.test_files = `git ls-files -- {test,spec,features}/*`.split("\n") 24 | s.executables = `git ls-files -- bin/*`.split("\n").map{ |f| File.basename(f) } 25 | s.require_paths = ["lib"] 26 | end 27 | -------------------------------------------------------------------------------- /examples/continuous-id3.rb: -------------------------------------------------------------------------------- 1 | require 'rubygems' 2 | require 'decisiontree' 3 | include DecisionTree 4 | 5 | # ---Continuous--- 6 | 7 | # Read in the training data 8 | training = [] 9 | attributes = nil 10 | 11 | File.open('data/continuous-training.txt', 'r').each_line do |line| 12 | data = line.strip.chomp('.').split(',') 13 | attributes ||= data 14 | training_data = data.collect do |v| 15 | case v 16 | when 'healthy' 17 | 1 18 | when 'colic' 19 | 0 20 | else 21 | v.to_f 22 | end 23 | end 24 | training.push(training_data) 25 | end 26 | 27 | # Remove the attribute row from the training data 28 | training.shift 29 | 30 | # Instantiate the tree, and train it based on the data (set default to '1') 31 | dec_tree = ID3Tree.new(attributes, training, 1, :continuous) 32 | dec_tree.train 33 | 34 | # ---Test the tree--- 35 | 36 | # Read in the test cases 37 | # Note: omit the attribute line (first line), we know the labels from the training data 38 | test = [] 39 | File.open('data/continuous-test.txt', 'r').each_line do |line| 40 | data = line.strip.chomp('.').split(',') 41 | test_data = data.collect do |v| 42 | if v == 'healthy' || v == 'colic' 43 | v == 'healthy' ? 1 : 0 44 | else 45 | v.to_f 46 | end 47 | end 48 | test.push(test_data) 49 | end 50 | 51 | # Let the tree predict the output and compare it to the true specified value 52 | test.each do |t| 53 | predict = dec_tree.predict(t) 54 | puts "Predict: #{predict} ... True: #{t.last}" 55 | end 56 | -------------------------------------------------------------------------------- /examples/data/continuous-test.txt: -------------------------------------------------------------------------------- 1 | 4.60000,139.00000,101.00000,28.80000,7.64000,13.80000,265.06000,1.50000,0.60000,60.00000,12.00000,40.00000,40.00000,3.52393,0.20000,17.61965,healthy. 2 | 4.30000,139.00000,101.00000,26.20000,3.61000,16.10000,518.74103,1.90000,0.01000,68.00000,12.00000,38.00000,36.00000,5.70834,0.20000,28.54170,healthy. 3 | 4.20000,139.00000,101.00000,29.20000,4.96000,13.00000,265.06000,2.10000,0.50000,62.00000,12.00000,39.00000,44.00000,3.44906,0.20000,17.24530,healthy. 4 | 4.40000,141.00000,103.00000,28.30000,12.65000,14.10000,197.60699,2.20000,0.10000,66.00000,12.00000,32.00000,44.00000,3.30135,0.20000,16.50675,healthy. 5 | 4.50000,136.00000,101.00000,26.10000,3.27000,13.40000,300.61499,1.40000,0.01000,68.00000,16.00000,33.00000,50.00000,6.94524,0.70000,9.92177,healthy. 6 | 4.30000,151.00000,112.00000,21.90000,42.66000,21.40000,613.52301,11.50000,172.89999,68.00000,26.00000,63.00000,92.00000,2.69917,0.50000,5.39834,colic. 7 | 3.00000,145.00000,103.00000,22.30000,83.93000,22.70000,476.97101,43.40000,139.50000,86.00000,60.00000,67.00000,68.00000,2.73668,0.20000,13.68340,colic. 8 | 3.40000,134.00000,98.00000,25.90000,90.15000,13.50000,265.06000,2.10000,1.30000,66.00000,20.00000,40.00000,52.00000,3.13565,0.50000,6.27130,colic. 9 | 2.90000,136.00000,92.00000,34.70000,5.81000,12.20000,243.71800,4.20000,22.80000,61.00000,20.00000,41.00000,48.00000,3.20928,0.20000,16.04640,colic. 10 | 3.80000,140.00000,99.00000,28.20000,88.92000,16.60000,695.82800,7.00000,2.60000,60.00000,28.00000,49.00000,80.00000,1.67106,0.50000,3.34212,colic. 11 | 3.70000,143.00000,105.00000,21.60000,93.67000,20.10000,265.06000,4.60000,38.80000,68.00000,16.00000,43.00000,48.00000,3.51757,0.50000,7.03514,colic. 12 | 3.70000,142.00000,103.00000,27.00000,100.24000,15.70000,386.71301,2.30000,0.01000,85.00000,40.00000,45.00000,48.00000,2.81077,0.50000,5.62154,colic. 13 | 3.20000,138.00000,99.00000,29.80000,80.77000,12.40000,224.11301,2.30000,3.90000,61.00000,24.00000,37.00000,40.00000,3.32568,0.50000,6.65136,colic. -------------------------------------------------------------------------------- /examples/data/continuous-training.txt: -------------------------------------------------------------------------------- 1 | K,Na,CL,HCO,Endotoxin,Aniongap,PLA2,SDH,GLDH,TPP,Breath rate,PCV,Pulse rate,Fibrinogen,Dimer,FibPerDim 2 | 4.60000,138.00000,102.00000,27.50000,3.45000,13.10000,420.62299,4.00000,1.00000,56.00000,10.00000,38.00000,48.00000,3.78216,0.20000,18.91080,healthy. 3 | 4.50000,141.00000,103.00000,26.50000,7.64000,16.00000,695.82800,0.70000,1.00000,72.00000,16.00000,37.00000,36.00000,4.86282,0.20000,24.31410,healthy. 4 | 4.60000,143.00000,104.00000,25.30000,3.04000,18.30000,243.71800,3.10000,0.40000,68.00000,20.00000,46.00000,52.00000,4.14486,0.20000,20.72430,healthy. 5 | 4.70000,140.00000,102.00000,27.60000,3.75000,15.10000,243.71800,3.10000,1.50000,66.00000,20.00000,32.00000,40.00000,4.11386,0.20000,20.56930,healthy. 6 | 4.50000,140.00000,101.00000,23.90000,4.12000,19.60000,233.71001,3.60000,6.90000,60.00000,12.00000,52.00000,48.00000,3.47588,0.20000,17.37940,healthy. 7 | 4.00000,139.00000,101.00000,29.30000,4.05000,12.70000,153.64301,1.60000,0.01000,55.00000,16.00000,41.00000,44.00000,3.63289,0.20000,18.16445,healthy. 8 | 3.20000,139.00000,98.00000,30.70000,101.18000,13.50000,564.12097,6.80000,16.40000,66.00000,56.00000,53.00000,80.00000,5.83544,1.00000,5.83544,colic. 9 | 3.20000,144.00000,105.00000,24.40000,51.15000,17.80000,386.71301,43.60000,471.60001,58.00000,20.00000,35.00000,48.00000,2.65903,0.50000,5.31806,colic. 10 | 3.90000,144.00000,99.00000,20.30000,94.45000,28.60000,1305.69495,16.60000,58.60000,64.00000,48.00000,75.00000,88.00000,1.86868,0.20000,9.34340,colic. 11 | 3.60000,134.00000,96.00000,26.30000,79.33000,15.30000,386.71301,4.50000,2.80000,48.00000,28.00000,35.00000,100.00000,3.86725,0.50000,7.73450,colic. 12 | 3.80000,148.00000,111.00000,23.90000,45.27000,16.90000,895.03497,1.60000,10.10000,84.00000,16.00000,55.00000,60.00000,4.58211,0.20000,22.91055,colic. 13 | 3.30000,140.00000,102.00000,20.90000,68.33000,20.40000,326.93799,2.00000,1.70000,84.00000,20.00000,46.00000,56.00000,3.57136,0.50000,7.14272,colic. 14 | 3.50000,140.00000,99.00000,25.10000,97.40000,19.40000,420.53101,5.40000,8.80000,94.00000,16.00000,53.00000,80.00000,4.02566,0.70000,5.75094,colic. 15 | 3.30000,137.00000,98.00000,30.80000,74.87000,11.50000,789.14801,168.60001,465.10001,60.00000,36.00000,40.00000,48.00000,5.79638,0.70000,8.28054,colic. 16 | 3.10000,126.00000,88.00000,27.90000,9.31000,13.20000,206.06100,2.10000,0.01000,70.00000,36.00000,37.00000,52.00000,5.55303,0.50000,11.10606,colic. 17 | 3.10000,138.00000,94.00000,39.80000,57.39000,7.30000,420.53101,3.80000,10.50000,68.00000,20.00000,46.00000,68.00000,2.45303,0.20000,12.26515,colic. 18 | 5.00000,136.00000,100.00000,31.40000,12.28000,9.60000,276.43900,4.90000,0.01000,58.00000,16.00000,40.00000,48.00000,4.00226,0.20000,20.01130,healthy. 19 | 3.60000,139.00000,100.00000,29.20000,7.25000,13.40000,288.27600,1.10000,1.10000,65.00000,12.00000,38.00000,48.00000,2.85107,0.20000,14.25535,healthy. 20 | 4.30000,142.00000,102.00000,29.90000,3.80000,14.40000,243.71800,3.00000,0.30000,67.00000,12.00000,44.00000,44.00000,3.87469,0.20000,19.37345,healthy. 21 | 4.60000,139.00000,100.00000,29.40000,2.40000,14.20000,288.27600,2.40000,2.10000,65.00000,16.00000,43.00000,52.00000,4.84979,0.20000,24.24895,healthy. 22 | 4.10000,136.00000,98.00000,28.40000,2.97000,13.70000,300.61499,2.00000,1.10000,62.00000,12.00000,43.00000,48.00000,5.19111,0.50000,10.38222,healthy. 23 | 4.20000,136.00000,98.00000,25.30000,2.93000,16.90000,224.11301,9.90000,0.70000,64.00000,16.00000,36.00000,52.00000,3.91034,0.20000,19.55170,healthy. 24 | 3.00000,132.00000,89.00000,29.40000,88.25000,16.60000,162.05200,3.40000,0.01000,52.00000,28.00000,45.00000,76.00000,1.64083,0.50000,3.28166,colic. 25 | 3.30000,139.00000,99.00000,25.70000,49.80000,17.60000,174.25400,0.90000,0.30000,62.00000,16.00000,38.00000,60.00000,3.20091,1.50000,2.13394,colic. 26 | 2.90000,138.00000,92.00000,24.80000,94.45000,24.10000,355.59201,9.20000,4.00000,51.00000,45.00000,44.00000,42.00000,2.42420,1.50000,1.61613,colic. 27 | 2.60000,131.00000,89.00000,26.50000,6.54000,18.10000,725.62500,4.70000,11.00000,80.00000,48.00000,43.00000,52.00000,4.10642,0.50000,8.21284,colic. 28 | 3.60000,135.00000,95.00000,26.70000,65.86000,16.90000,243.71800,4.80000,1.60000,58.00000,38.00000,50.00000,88.00000,2.92609,0.20000,14.63045,colic. 29 | 3.30000,147.00000,105.00000,28.00000,61.56000,17.30000,313.50201,3.70000,2.60000,75.00000,40.00000,48.00000,88.00000,3.60096,1.50000,2.40064,colic. 30 | 3.20000,142.00000,100.00000,26.70000,78.69000,18.50000,370.81000,42.90000,333.79999,80.00000,24.00000,55.00000,100.00000,4.53422,2.00000,2.26711,colic. 31 | 3.70000,136.00000,86.00000,25.30000,65.54000,28.40000,1103.97498,6.40000,4.80000,100.00000,20.00000,55.00000,132.00000,7.76240,1.00000,7.76240,colic. 32 | 3.30000,142.00000,99.00000,29.50000,82.42000,16.80000,420.53101,6.80000,40.70000,71.00000,28.00000,48.00000,72.00000,3.29344,0.50000,6.58688,colic. 33 | 3.30000,141.00000,99.00000,32.40000,87.43000,12.90000,326.93799,3.00000,1.50000,47.00000,36.00000,48.00000,48.00000,3.24353,0.20000,16.21765,colic. 34 | 3.10000,146.00000,103.00000,26.10000,79.08000,20.00000,476.97101,3.50000,1.20000,78.00000,24.00000,54.00000,80.00000,3.76666,0.50000,7.53332,colic. 35 | 4.10000,138.00000,101.00000,27.30000,8.01000,13.80000,147.29100,6.30000,5.20000,67.00000,10.00000,43.00000,40.00000,3.68016,0.20000,18.40080,healthy. 36 | 4.10000,136.00000,98.00000,28.50000,6.15000,13.60000,174.25400,2.10000,1.30000,60.00000,8.00000,35.00000,40.00000,1.94448,0.20000,9.72240,healthy. 37 | 4.50000,136.00000,99.00000,26.80000,5.08000,14.70000,189.47200,2.00000,0.60000,55.00000,12.00000,35.00000,44.00000,3.67257,0.20000,18.36285,healthy. 38 | 3.50000,142.00000,105.00000,22.20000,6.77000,18.30000,276.43900,3.40000,1.20000,64.00000,10.00000,39.00000,48.00000,3.45945,0.20000,17.29725,healthy. 39 | 3.90000,140.00000,101.00000,28.50000,3.61000,14.40000,340.96799,0.20000,0.01000,61.00000,12.00000,37.00000,48.00000,2.51116,0.20000,12.55580,healthy. 40 | 3.60000,145.00000,106.00000,27.50000,89.65000,15.10000,224.11301,2.80000,1.20000,78.00000,60.00000,48.00000,80.00000,2.42001,0.20000,12.10005,colic. 41 | 3.50000,136.00000,98.00000,25.40000,22.39000,16.10000,1420.03601,3.60000,0.80000,60.00000,20.00000,21.00000,56.00000,9.81956,4.00000,2.45489,colic. 42 | 3.60000,140.00000,98.00000,19.50000,99.57000,26.10000,789.14801,36.10000,293.20001,73.00000,48.00000,64.00000,100.00000,2.24781,2.00000,1.12390,colic. 43 | 3.60000,131.00000,92.00000,22.60000,76.04000,20.00000,564.12097,3.70000,4.70000,48.00000,56.00000,38.00000,120.00000,3.33932,0.50000,6.67864,colic. 44 | 3.50000,144.00000,104.00000,18.90000,64.19000,24.60000,1149.99500,4.80000,3.10000,60.00000,28.00000,40.00000,80.00000,4.12378,0.70000,5.89111,colic. 45 | 2.90000,142.00000,100.00000,30.00000,49.20000,14.90000,497.39899,2.50000,0.01000,74.00000,40.00000,52.00000,64.00000,3.21284,0.50000,6.42568,colic. 46 | 3.60000,138.00000,99.00000,24.40000,50.32000,18.20000,1610.51404,14.20000,1.30000,66.00000,20.00000,37.00000,60.00000,6.60548,2.00000,3.30274,colic. 47 | 3.40000,137.00000,93.00000,24.40000,6.29000,23.00000,4227.66113,43.60000,3.00000,71.00000,36.00000,60.00000,72.00000,5.17514,6.00000,0.86252,colic. 48 | 3.50000,144.00000,100.00000,32.50000,51.49000,15.00000,129.87900,7.90000,83.00000,61.00000,36.00000,44.00000,84.00000,3.42922,0.20000,17.14610,colic. 49 | 3.10000,136.00000,98.00000,23.40000,5.97000,17.70000,243.71800,2.10000,2.70000,66.00000,28.00000,45.00000,52.00000,2.84968,0.20000,14.24840,colic. 50 | 4.50000,137.00000,100.00000,27.20000,11.48000,14.30000,181.70300,2.00000,3.60000,62.00000,8.00000,38.00000,52.00000,4.01342,0.20000,20.06710,healthy. 51 | 4.20000,141.00000,103.00000,29.10000,3.77000,13.10000,288.27600,6.70000,5.60000,64.00000,8.00000,42.00000,40.00000,4.20329,0.20000,21.01645,healthy. 52 | 4.20000,138.00000,101.00000,28.30000,6.22000,12.90000,288.27600,5.40000,2.10000,65.00000,12.00000,43.00000,44.00000,5.08152,0.20000,25.40760,healthy. 53 | 4.50000,137.00000,101.00000,27.40000,6.68000,13.10000,167.07899,2.10000,1.10000,60.00000,16.00000,38.00000,48.00000,3.25795,0.20000,16.28975,healthy. 54 | 4.00000,141.00000,102.00000,27.20000,12.44000,15.80000,338.17999,3.40000,3.10000,72.00000,12.00000,33.00000,48.00000,4.98961,0.20000,24.94805,healthy. 55 | 4.20000,138.00000,96.00000,23.70000,51.83000,22.50000,355.59201,2.70000,4.20000,60.00000,20.00000,39.00000,100.00000,3.61817,0.50000,7.23634,colic. 56 | 3.60000,141.00000,101.00000,28.60000,97.70000,15.00000,667.21997,5.00000,3.70000,70.00000,12.00000,48.00000,60.00000,3.13410,1.00000,3.13410,colic. 57 | 3.20000,137.00000,100.00000,24.40000,71.53000,15.80000,224.11301,2.40000,2.20000,79.00000,28.00000,42.00000,60.00000,3.92367,1.00000,3.92367,colic. 58 | 3.50000,141.00000,102.00000,27.40000,51.93000,15.10000,1015.08801,3.10000,0.80000,62.00000,72.00000,54.00000,88.00000,2.50883,0.20000,12.54415,colic. 59 | 4.20000,143.00000,106.00000,24.00000,5.31000,17.20000,265.06000,8.00000,32.90000,77.00000,16.00000,38.00000,40.00000,3.98583,1.00000,3.98583,colic. 60 | 3.20000,138.00000,97.00000,25.00000,8.76000,19.20000,288.27600,5.40000,3.10000,70.00000,12.00000,47.00000,88.00000,5.01596,1.00000,5.01596,colic. 61 | 4.10000,132.00000,91.00000,28.60000,19.74000,16.50000,639.79999,6.70000,0.01000,78.00000,24.00000,38.00000,112.00000,8.94970,6.00000,1.49162,colic. 62 | 6.00000,140.00000,97.00000,32.20000,48.15000,16.80000,153.64301,17.00000,52.60000,48.00000,40.00000,67.00000,80.00000,2.18364,1.50000,1.45576,colic. 63 | 3.10000,138.00000,95.00000,29.30000,10.98000,16.80000,822.96600,3.90000,0.60000,58.00000,36.00000,36.00000,48.00000,2.52015,0.50000,5.04030,colic. 64 | 3.70000,144.00000,107.00000,25.40000,85.30000,15.30000,457.36600,3.10000,1.10000,66.00000,24.00000,48.00000,60.00000,2.81775,0.50000,5.63550,colic. 65 | 4.20000,139.00000,100.00000,29.40000,2.33000,13.80000,233.71001,3.40000,0.90000,64.00000,12.00000,40.00000,44.00000,3.78293,0.20000,18.91465,healthy. 66 | 4.20000,144.00000,107.00000,23.90000,7.87000,17.30000,300.61499,5.90000,16.40000,68.00000,20.00000,48.00000,48.00000,4.42355,0.20000,22.11775,healthy. 67 | 4.10000,139.00000,100.00000,28.60000,4.12000,14.50000,170.78101,0.70000,0.01000,60.00000,10.00000,43.00000,32.00000,3.22927,0.20000,16.14635,healthy. 68 | 4.70000,136.00000,99.00000,28.60000,10.43000,13.10000,288.27600,1.70000,0.20000,62.00000,8.00000,35.00000,40.00000,4.18454,0.20000,20.92270,healthy. 69 | 3.70000,140.00000,102.00000,28.20000,6.57000,13.50000,174.25400,3.20000,2.10000,60.00000,10.00000,39.00000,44.00000,3.40799,0.20000,17.03995,healthy. 70 | 3.70000,142.00000,101.00000,30.60000,94.68000,14.10000,300.61499,1.90000,0.10000,58.00000,32.00000,40.00000,80.00000,2.66538,0.20000,13.32690,colic. 71 | 3.00000,135.00000,95.00000,27.30000,8.19000,15.70000,265.06000,2.30000,0.01000,60.00000,40.00000,37.00000,48.00000,2.96841,0.20000,14.84205,colic. 72 | 2.70000,143.00000,96.00000,24.60000,83.61000,25.10000,386.71301,6.50000,3.80000,62.00000,28.00000,33.00000,52.00000,3.44921,0.50000,6.89842,colic. 73 | 4.00000,140.00000,103.00000,20.30000,99.16000,20.70000,300.61499,3.50000,1.70000,64.00000,24.00000,44.00000,64.00000,3.75317,0.20000,18.76585,colic. 74 | 3.50000,130.00000,93.00000,29.90000,4.35000,10.60000,265.06000,1.90000,0.70000,70.00000,20.00000,42.00000,52.00000,5.66107,0.50000,11.32214,colic. 75 | 3.10000,139.00000,96.00000,30.80000,20.02000,15.30000,167.07899,3.30000,1.80000,58.00000,20.00000,44.00000,72.00000,3.30615,0.20000,16.53075,colic. 76 | 3.00000,137.00000,91.00000,14.80000,7.32000,34.20000,181.70300,20.10000,1.70000,61.00000,16.00000,59.00000,72.00000,4.94729,0.50000,9.89458,colic. 77 | 3.70000,138.00000,99.00000,29.10000,97.72000,13.60000,214.92700,1.50000,0.01000,58.00000,20.00000,35.00000,56.00000,2.61113,0.20000,13.05565,colic. 78 | 4.00000,137.00000,98.00000,27.50000,56.43000,15.50000,243.71800,3.70000,0.90000,62.00000,16.00000,38.00000,60.00000,4.75695,0.50000,9.51390,colic. 79 | 3.20000,139.00000,98.00000,30.00000,76.75000,14.20000,276.43900,2.40000,0.01000,61.00000,60.00000,47.00000,72.00000,2.74397,0.20000,13.71985,colic. 80 | 4.50000,141.00000,103.00000,27.40000,9.08000,15.10000,457.36600,4.60000,5.50000,70.00000,8.00000,39.00000,32.00000,3.92956,0.20000,19.64780,healthy. 81 | 3.90000,134.00000,98.00000,25.10000,5.35000,14.80000,695.82800,1.90000,0.01000,72.00000,16.00000,33.00000,48.00000,8.01149,0.70000,11.44499,healthy. 82 | 3.90000,138.00000,102.00000,25.90000,4.05000,14.00000,564.12097,5.70000,5.50000,70.00000,10.00000,41.00000,40.00000,5.33758,0.20000,26.68790,healthy. 83 | 3.90000,141.00000,103.00000,25.20000,7.55000,16.70000,153.64301,2.90000,7.90000,70.00000,16.00000,34.00000,48.00000,3.46906,0.50000,6.93812,healthy. 84 | 4.60000,137.00000,101.00000,24.70000,3.18000,15.90000,206.06100,1.40000,1.10000,70.00000,10.00000,38.00000,40.00000,5.13267,0.20000,25.66335,healthy. 85 | 3.50000,131.00000,92.00000,30.70000,14.41000,11.80000,420.53101,3.30000,1.10000,64.00000,16.00000,41.00000,48.00000,2.23278,0.20000,11.16390,colic. 86 | 3.80000,141.00000,100.00000,29.20000,82.01000,15.60000,233.71001,2.20000,0.70000,62.00000,14.00000,33.00000,52.00000,4.07480,0.50000,8.14960,colic. 87 | 4.40000,140.00000,98.00000,24.10000,82.76000,22.30000,403.25699,2.80000,2.00000,60.00000,32.00000,62.00000,112.00000,2.15636,0.50000,4.31272,colic. 88 | 3.60000,144.00000,97.00000,19.90000,38.61000,30.70000,822.96600,10.60000,6.20000,80.00000,24.00000,62.00000,64.00000,3.64002,1.00000,3.64002,colic. 89 | 3.30000,144.00000,101.00000,28.90000,61.44000,17.40000,476.97101,28.90000,138.60001,89.00000,16.00000,54.00000,80.00000,5.20165,1.00000,5.20165,colic. 90 | 3.80000,136.00000,98.00000,23.90000,87.61000,17.90000,318.07199,6.10000,7.70000,100.00000,28.00000,54.00000,92.00000,3.27562,1.00000,3.27562,colic. 91 | 4.00000,139.00000,99.00000,26.00000,46.76000,18.00000,476.97101,5.30000,6.50000,73.00000,36.00000,37.00000,82.00000,3.37621,0.50000,6.75242,colic. 92 | 3.00000,141.00000,99.00000,32.10000,97.13000,12.90000,420.53101,2.90000,1.80000,73.00000,12.00000,28.00000,80.00000,3.37575,0.70000,4.82250,colic. 93 | 3.50000,145.00000,93.00000,20.00000,86.12000,35.50000,895.03497,5.70000,5.60000,80.00000,34.00000,65.00000,88.00000,2.57734,0.50000,5.15468,colic. 94 | 4.00000,137.00000,99.00000,29.70000,4.71000,12.30000,403.25699,2.40000,1.20000,56.00000,12.00000,37.00000,44.00000,3.37110,0.20000,16.85550,healthy. 95 | 4.20000,140.00000,103.00000,25.60000,4.80000,15.60000,386.71301,2.60000,3.50000,54.00000,12.00000,33.00000,40.00000,2.99693,0.20000,14.98465,healthy. 96 | 4.70000,139.00000,101.00000,27.40000,6.95000,15.30000,197.60699,1.30000,0.30000,58.00000,12.00000,37.00000,44.00000,2.50155,0.50000,5.00310,healthy. 97 | 5.20000,138.00000,99.00000,28.00000,4.46000,16.20000,340.96799,3.10000,2.70000,55.00000,12.00000,35.00000,56.00000,4.22825,0.20000,21.14125,healthy. 98 | 4.50000,137.00000,98.00000,26.40000,2.49000,17.10000,197.60699,14.10000,9.00000,54.00000,12.00000,42.00000,56.00000,3.47526,0.20000,17.37630,healthy. 99 | 4.40000,138.00000,101.00000,20.10000,65.74000,21.30000,476.97101,14.00000,88.60000,72.00000,14.00000,43.00000,82.00000,2.78303,0.50000,5.56606,colic. 100 | 3.80000,143.00000,101.00000,29.20000,100.22000,16.60000,313.50201,4.30000,26.50000,67.00000,20.00000,63.00000,80.00000,3.35963,1.00000,3.35963,colic. 101 | 3.50000,142.00000,101.00000,29.10000,73.95000,15.40000,386.71301,5.10000,4.30000,65.00000,28.00000,41.00000,56.00000,4.12300,0.20000,20.61500,colic. 102 | 4.30000,141.00000,104.00000,23.10000,82.72000,18.20000,386.71301,4.90000,1.60000,72.00000,36.00000,45.00000,92.00000,3.47479,0.50000,6.94958,colic. 103 | 3.60000,135.00000,98.00000,30.10000,83.79000,10.50000,254.18300,1.50000,0.01000,58.00000,20.00000,41.00000,48.00000,2.64120,0.50000,5.28240,colic. 104 | 2.80000,140.00000,101.00000,26.90000,31.25000,14.90000,463.62701,4.30000,3.80000,46.00000,28.00000,48.00000,64.00000,4.19771,0.50000,8.39542,colic. 105 | 3.30000,140.00000,99.00000,32.70000,97.22000,11.60000,300.61499,3.70000,3.40000,58.00000,24.00000,34.00000,44.00000,2.04600,0.70000,2.92286,colic. 106 | 3.10000,146.00000,103.00000,21.60000,83.65000,24.50000,288.27600,4.30000,3.50000,82.00000,32.00000,46.00000,64.00000,3.65040,0.50000,7.30080,colic. 107 | 4.10000,139.00000,102.00000,24.20000,88.23000,16.90000,214.92700,1.80000,0.01000,63.00000,12.00000,40.00000,42.00000,2.97430,0.20000,14.87150,colic. 108 | 4.50000,139.00000,100.00000,29.20000,6.04000,14.30000,210.72301,2.00000,0.20000,68.00000,10.00000,40.00000,40.00000,3.52393,0.50000,7.04786,healthy. 109 | 4.20000,130.00000,102.00000,27.90000,6.68000,4.30000,386.71301,1.90000,1.10000,56.00000,14.00000,37.00000,48.00000,4.05697,0.20000,20.28485,healthy. 110 | 5.30000,137.00000,99.00000,25.80000,4.35000,17.50000,276.43900,1.90000,0.60000,62.00000,16.00000,40.00000,52.00000,5.01906,0.50000,10.03812,healthy. 111 | 4.40000,135.00000,100.00000,25.10000,2.77000,14.30000,197.60699,0.60000,1.60000,60.00000,16.00000,36.00000,36.00000,3.56702,0.20000,17.83510,healthy. 112 | 2.90000,129.00000,86.00000,27.30000,82.85000,18.60000,756.74597,5.40000,29.50000,79.00000,16.00000,43.00000,84.00000,2.38374,1.00000,2.38374,colic. 113 | 3.40000,139.00000,98.00000,29.80000,54.42000,14.60000,695.82800,5.50000,30.00000,52.00000,24.00000,35.00000,52.00000,1.95393,0.50000,3.90786,colic. 114 | 3.30000,137.00000,96.00000,30.50000,53.76000,13.80000,233.71001,7.20000,28.90000,55.00000,24.00000,30.00000,100.00000,2.11327,0.20000,10.56635,colic. 115 | 2.50000,127.00000,88.00000,17.80000,88.37000,23.70000,588.29602,3.90000,3.20000,70.00000,24.00000,54.00000,88.00000,3.32398,2.00000,1.66199,colic. 116 | 3.30000,146.00000,97.00000,23.10000,70.02000,29.20000,1420.03601,42.70000,327.50000,70.00000,28.00000,68.00000,68.00000,2.19294,3.00000,0.73098,colic. 117 | 3.80000,140.00000,100.00000,26.70000,92.83000,17.10000,457.36600,4.60000,2.10000,61.00000,32.00000,38.00000,76.00000,2.07359,1.50000,1.38239,colic. 118 | 3.30000,134.00000,95.00000,31.60000,73.63000,10.70000,224.11301,3.30000,1.70000,62.00000,20.00000,37.00000,56.00000,3.68947,0.50000,7.37894,colic. 119 | 3.30000,140.00000,99.00000,29.60000,88.66000,14.70000,233.71001,1.60000,2.40000,74.00000,40.00000,38.00000,52.00000,2.76427,1.00000,2.76427,colic. 120 | 2.80000,145.00000,101.00000,35.40000,31.96000,11.40000,243.71800,0.40000,0.70000,70.00000,20.00000,47.00000,84.00000,3.82587,0.20000,19.12935,colic. 121 | 4.40000,136.00000,98.00000,28.50000,8.69000,13.90000,725.62500,1.90000,1.50000,60.00000,16.00000,40.00000,52.00000,3.41419,0.20000,17.07095,healthy. 122 | 3.70000,140.00000,100.00000,29.80000,5.15000,13.90000,189.47200,2.30000,0.70000,78.00000,12.00000,42.00000,48.00000,3.33607,0.20000,16.68035,healthy. 123 | 4.60000,138.00000,100.00000,28.60000,9.79000,14.00000,224.11301,1.60000,2.00000,61.00000,16.00000,35.00000,40.00000,3.58624,0.20000,17.93120,healthy. 124 | 4.00000,138.00000,102.00000,25.90000,90.54000,14.10000,326.93799,0.40000,1.70000,70.00000,20.00000,48.00000,79.00000,3.34645,0.20000,16.73225,colic. 125 | 2.70000,132.00000,93.00000,29.30000,52.57000,12.40000,1058.59497,5.00000,8.00000,78.00000,28.00000,48.00000,76.00000,4.77013,0.50000,9.54026,colic. 126 | 3.40000,133.00000,95.00000,28.50000,64.71000,12.90000,276.43900,8.70000,43.70000,76.00000,16.00000,47.00000,76.00000,4.15168,0.20000,20.75840,colic. 127 | 3.00000,139.00000,93.00000,33.30000,96.88000,15.70000,224.11301,6.90000,3.30000,48.00000,80.00000,43.00000,56.00000,2.32748,0.20000,11.63740,colic. 128 | 2.80000,139.00000,101.00000,25.90000,71.32000,14.90000,676.35999,2.30000,0.30000,71.00000,16.00000,46.00000,52.00000,2.50558,0.20000,12.52790,colic. 129 | 2.80000,142.00000,97.00000,29.80000,53.21000,18.00000,160.22400,4.70000,5.10000,50.00000,60.00000,44.00000,88.00000,2.31710,0.70000,3.31014,colic. 130 | 3.50000,140.00000,102.00000,23.00000,87.86000,18.50000,189.47200,2.20000,0.90000,73.00000,24.00000,47.00000,96.00000,3.73721,0.50000,7.47442,colic. 131 | 3.00000,142.00000,100.00000,22.60000,93.17000,22.40000,355.59201,16.30000,124.10000,80.00000,24.00000,45.00000,68.00000,2.75668,0.70000,3.93811,colic. 132 | 3.30000,149.00000,110.00000,19.20000,96.46000,23.10000,667.21997,5.70000,0.20000,59.00000,16.00000,41.00000,54.00000,3.18324,0.20000,15.91620,colic. 133 | 3.50000,141.00000,96.00000,31.20000,11.00000,17.30000,214.92700,3.80000,1.70000,53.00000,48.00000,39.00000,64.00000,2.89664,0.70000,4.13806,colic. -------------------------------------------------------------------------------- /examples/data/discrete-test.txt: -------------------------------------------------------------------------------- 1 | 36 - 55,masters,high,single,will buy 2 | 18 - 35,high school,low,single,won't buy 3 | 18 - 35,masters,high,single,won't buy 4 | 36 - 55,high school,low,single,will buy -------------------------------------------------------------------------------- /examples/data/discrete-training.txt: -------------------------------------------------------------------------------- 1 | Age,Education,Income,Marital Status 2 | 36 - 55,masters,high,single,will buy 3 | 18 - 35,high school,low,single,won't buy 4 | 36 - 55,masters,low,single,will buy 5 | 18 - 35,bachelors,high,single,won't buy 6 | < 18,high school,low,single,will buy 7 | 18 - 35,bachelors,high,married,won't buy 8 | 36 - 55,bachelors,low,married,won't buy 9 | > 55,bachelors,high,single,will buy 10 | 36 - 55,masters,low,married,won't buy 11 | > 55,masters,low,married,will buy 12 | 36 - 55,masters,high,single,will buy 13 | > 55,masters,high,single,will buy 14 | < 18,high school,high,single,won't buy 15 | 36 - 55,masters,low,single,will buy 16 | 36 - 55,high school,low,single,will buy 17 | < 18,high school,low,married,will buy 18 | 18 - 35,bachelors,high,married,won't buy 19 | > 55,high school,high,married,will buy 20 | > 55,bachelors,low,single,will buy 21 | 36 - 55,high school,high,married,won't buy -------------------------------------------------------------------------------- /examples/discrete-id3.rb: -------------------------------------------------------------------------------- 1 | require 'rubygems' 2 | require 'decisiontree' 3 | 4 | # ---Discrete--- 5 | 6 | # Read in the training data 7 | training = [] 8 | attributes = nil 9 | 10 | File.open('data/discrete-training.txt', 'r').each_line do |line| 11 | data = line.strip.split(',') 12 | attributes ||= data 13 | training_data = data.collect do |v| 14 | case v 15 | when 'will buy' 16 | 1 17 | when "won't buy" 18 | 0 19 | else 20 | v 21 | end 22 | end 23 | training.push(training_data) 24 | end 25 | 26 | # Remove the attribute row from the training data 27 | training.shift 28 | 29 | # Instantiate the tree, and train it based on the data (set default to '1') 30 | dec_tree = DecisionTree::ID3Tree.new(attributes, training, 1, :discrete) 31 | dec_tree.train 32 | 33 | # ---Test the tree--- 34 | 35 | # Read in the test cases 36 | # Note: omit the attribute line (first line), we know the labels from the training data 37 | test = [] 38 | File.open('data/discrete-test.txt', 'r').each_line do |line| 39 | data = line.strip.split(',') 40 | test_data = data.collect do |v| 41 | case v 42 | when 'will buy' 43 | 1 44 | when "won't buy" 45 | 0 46 | else 47 | v 48 | end 49 | end 50 | test.push(test_data) 51 | end 52 | 53 | # Let the tree predict the output and compare it to the true specified value 54 | test.each do |t| 55 | predict = dec_tree.predict(t) 56 | puts "Predict: #{predict} ... True: #{t.last}" 57 | end 58 | 59 | # Graph the tree, save to 'discrete.png' 60 | dec_tree.graph('discrete') 61 | -------------------------------------------------------------------------------- /examples/simple.rb: -------------------------------------------------------------------------------- 1 | #!/usr/bin/ruby 2 | 3 | require 'rubygems' 4 | require 'decisiontree' 5 | 6 | attributes = ['Temperature'] 7 | training = [ 8 | [36.6, 'healthy'], 9 | [37, 'sick'], 10 | [38, 'sick'], 11 | [36.7, 'healthy'], 12 | [40, 'sick'], 13 | [50, 'really sick'] 14 | ] 15 | 16 | # Instantiate the tree, and train it based on the data (set default to '1') 17 | dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous) 18 | dec_tree.train 19 | 20 | test = [37, 'sick'] 21 | 22 | decision = dec_tree.predict(test) 23 | puts "Predicted: #{decision} ... True decision: #{test.last}" 24 | 25 | # Graph the tree, save to 'tree.png' 26 | dec_tree.graph('tree') 27 | -------------------------------------------------------------------------------- /lib/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/igrigorik/decisiontree/e30c18853ad654cc1ba4b239e46bf858ad67363b/lib/.DS_Store -------------------------------------------------------------------------------- /lib/core_extensions/array.rb: -------------------------------------------------------------------------------- 1 | class Array 2 | def entropy 3 | each_with_object(Hash.new(0)) do |i, result| 4 | result[i] += 1 5 | end.values.inject(0, :+) do |count| 6 | percentage = count.to_f / length 7 | 8 | -percentage * Math.log2(percentage) 9 | end 10 | end 11 | end 12 | 13 | module ArrayClassification 14 | refine Array do 15 | def classification 16 | collect(&:last) 17 | end 18 | end 19 | end 20 | 21 | -------------------------------------------------------------------------------- /lib/core_extensions/object.rb: -------------------------------------------------------------------------------- 1 | class Object 2 | def save_to_file(filename) 3 | File.open(filename, 'w+') { |f| f << Marshal.dump(self) } 4 | end 5 | 6 | def self.load_from_file(filename) 7 | Marshal.load(File.read(filename)) 8 | end 9 | end 10 | -------------------------------------------------------------------------------- /lib/decisiontree.rb: -------------------------------------------------------------------------------- 1 | require 'core_extensions/object' 2 | require 'core_extensions/array' 3 | require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb' 4 | -------------------------------------------------------------------------------- /lib/decisiontree/id3_tree.rb: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # 3 | ### Copyright (c) 2007 Ilya Grigorik 4 | ### Modifed at 2007 by José Ignacio Fernández 5 | 6 | module DecisionTree 7 | Node = Struct.new(:attribute, :threshold, :gain) 8 | 9 | using ArrayClassification 10 | 11 | class ID3Tree 12 | def initialize(attributes, data, default, type) 13 | @used = {} 14 | @tree = {} 15 | @type = type 16 | @data = data 17 | @attributes = attributes 18 | @default = default 19 | end 20 | 21 | def train(data = @data, attributes = @attributes, default = @default) 22 | attributes = attributes.map(&:to_s) 23 | initialize(attributes, data, default, @type) 24 | 25 | # Remove samples with same attributes leaving most common classification 26 | data2 = data.inject({}) do |hash, d| 27 | hash[d.slice(0..-2)] ||= Hash.new(0) 28 | hash[d.slice(0..-2)][d.last] += 1 29 | hash 30 | end 31 | 32 | data2 = data2.map do |key, val| 33 | key + [val.sort_by { |_, v| v }.last.first] 34 | end 35 | 36 | @tree = id3_train(data2, attributes, default) 37 | end 38 | 39 | def type(attribute) 40 | @type.is_a?(Hash) ? @type[attribute.to_sym] : @type 41 | end 42 | 43 | def fitness_for(attribute) 44 | case type(attribute) 45 | when :discrete 46 | proc { |*args| id3_discrete(*args) } 47 | when :continuous 48 | proc { |*args| id3_continuous(*args) } 49 | end 50 | end 51 | 52 | def id3_train(data, attributes, default, _used={}) 53 | return default if data.empty? 54 | 55 | # return classification if all examples have the same classification 56 | return data.first.last if data.classification.uniq.size == 1 57 | 58 | # Choose best attribute: 59 | # 1. enumerate all attributes 60 | # 2. Pick best attribute 61 | # 3. If attributes all score the same, then pick a random one to avoid infinite recursion. 62 | performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) } 63 | max = performance.max { |a,b| a[0] <=> b[0] } 64 | min = performance.min { |a,b| a[0] <=> b[0] } 65 | max = performance.sample if max[0] == min[0] 66 | best = Node.new(attributes[performance.index(max)], max[1], max[0]) 67 | best.threshold = nil if @type == :discrete 68 | @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] 69 | tree, l = {best => {}}, ['>=', '<'] 70 | 71 | case type(best.attribute) 72 | when :continuous 73 | partitioned_data = data.partition do |d| 74 | d[attributes.index(best.attribute)] >= best.threshold 75 | end 76 | partitioned_data.each_with_index do |examples, i| 77 | tree[best][String.new(l[i])] = id3_train(examples, attributes, (data.classification.mode rescue 0)) 78 | end 79 | when :discrete 80 | values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort 81 | partitions = values.collect do |val| 82 | data.select do |d| 83 | d[attributes.index(best.attribute)] == val 84 | end 85 | end 86 | partitions.each_with_index do |examples, i| 87 | tree[best][values[i]] = id3_train(examples, attributes - [values[i]], (data.classification.mode rescue 0)) 88 | end 89 | end 90 | 91 | tree 92 | end 93 | 94 | # ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds) 95 | def id3_continuous(data, attributes, attribute) 96 | values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort 97 | thresholds = [] 98 | return [-1, -1] if values.size == 1 99 | values.each_index do |i| 100 | thresholds.push((values[i] + (values[i + 1].nil? ? values[i] : values[i + 1])).to_f / 2) 101 | end 102 | thresholds.pop 103 | #thresholds -= used[attribute] if used.has_key? attribute 104 | 105 | gain = thresholds.collect do |threshold| 106 | sp = data.partition { |d| d[attributes.index(attribute)] >= threshold } 107 | pos = (sp[0].size).to_f / data.size 108 | neg = (sp[1].size).to_f / data.size 109 | 110 | [data.classification.entropy - pos * sp[0].classification.entropy - neg * sp[1].classification.entropy, threshold] 111 | end 112 | gain = gain.max { |a, b| a[0] <=> b[0] } 113 | 114 | return [-1, -1] if gain.size == 0 115 | gain 116 | end 117 | 118 | # ID3 for discrete label cases 119 | def id3_discrete(data, attributes, attribute) 120 | index = attributes.index(attribute) 121 | 122 | values = data.map { |row| row[index] }.uniq 123 | remainder = values.sort.inject(0, :+) do |val| 124 | classification = data.each_with_object([]) do |row, result| 125 | result << row.last if row[index] == val 126 | end 127 | 128 | ((classification.size.to_f / data.size) * classification.entropy) 129 | end 130 | 131 | [data.classification.entropy - remainder, index] 132 | end 133 | 134 | def predict(test) 135 | descend(@tree, test) 136 | end 137 | 138 | def graph(filename, file_type = 'png') 139 | require 'graphr' 140 | dgp = DotGraphPrinter.new(build_tree) 141 | dgp.size = '' 142 | dgp.node_labeler = proc { |n| n.split("\n").first } 143 | dgp.write_to_file("#{filename}.#{file_type}", file_type) 144 | rescue LoadError 145 | STDERR.puts "Error: Cannot generate graph." 146 | STDERR.puts " The 'graphr' gem doesn't seem to be installed." 147 | STDERR.puts " Run 'gem install graphr' or add it to your Gemfile." 148 | end 149 | 150 | def ruleset 151 | rs = Ruleset.new(@attributes, @data, @default, @type) 152 | rs.rules = build_rules 153 | rs 154 | end 155 | 156 | def build_rules(tree = @tree) 157 | attr = tree.to_a.first 158 | cases = attr[1].to_a 159 | rules = [] 160 | cases.each do |c, child| 161 | if child.is_a?(Hash) 162 | build_rules(child).each do |r| 163 | r2 = r.clone 164 | r2.premises.unshift([attr.first, c]) 165 | rules << r2 166 | end 167 | else 168 | rules << Rule.new(@attributes, [[attr.first, c]], child) 169 | end 170 | end 171 | rules 172 | end 173 | 174 | private 175 | 176 | def descend(tree, test) 177 | attr = tree.to_a.first 178 | return @default unless attr 179 | if type(attr.first.attribute) == :continuous 180 | return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] >= attr.first.threshold 181 | return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) && test[@attributes.index(attr.first.attribute)] < attr.first.threshold 182 | return descend(attr[1]['>='], test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold 183 | return descend(attr[1]['<'], test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold 184 | else 185 | return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash) 186 | return descend(attr[1][test[@attributes.index(attr[0].attribute)]], test) 187 | end 188 | end 189 | 190 | def build_tree(tree = @tree) 191 | return [] unless tree.is_a?(Hash) 192 | return [['Always', @default]] if tree.empty? 193 | 194 | attr = tree.to_a.first 195 | 196 | links = attr[1].keys.collect do |key| 197 | parent_text = "#{attr[0].attribute}\n(#{attr[0].object_id})" 198 | if attr[1][key].is_a?(Hash) 199 | child = attr[1][key].to_a.first[0] 200 | child_text = "#{child.attribute}\n(#{child.object_id})" 201 | else 202 | child = attr[1][key] 203 | child_text = "#{child}\n(#{child.to_s.clone.object_id})" 204 | end 205 | 206 | if type(attr[0].attribute) == :continuous 207 | label_text = "#{key} #{attr[0].threshold}" 208 | else 209 | label_text = key 210 | end 211 | 212 | [parent_text, child_text, label_text] 213 | end 214 | attr[1].keys.each { |key| links += build_tree(attr[1][key]) } 215 | 216 | links 217 | end 218 | end 219 | 220 | class Rule 221 | attr_accessor :premises 222 | attr_accessor :conclusion 223 | attr_accessor :attributes 224 | 225 | def initialize(attributes, premises = [], conclusion = nil) 226 | @attributes = attributes 227 | @premises = premises 228 | @conclusion = conclusion 229 | end 230 | 231 | def to_s 232 | str = '' 233 | @premises.each do |p| 234 | if p.first.threshold 235 | str += "#{p.first.attribute} #{p.last} #{p.first.threshold}" 236 | else 237 | str += "#{p.first.attribute} = #{p.last}" 238 | end 239 | str += "\n" 240 | end 241 | str += "=> #{@conclusion} (#{accuracy})" 242 | end 243 | 244 | def predict(test) 245 | verifies = true 246 | @premises.each do |p| 247 | if p.first.threshold # Continuous 248 | if !(p.last == '>=' && test[@attributes.index(p.first.attribute)] >= p.first.threshold) && !(p.last == '<' && test[@attributes.index(p.first.attribute)] < p.first.threshold) 249 | verifies = false 250 | break 251 | end 252 | else # Discrete 253 | if test[@attributes.index(p.first.attribute)] != p.last 254 | verifies = false 255 | break 256 | end 257 | end 258 | end 259 | return @conclusion if verifies 260 | nil 261 | end 262 | 263 | def get_accuracy(data) 264 | correct = 0 265 | total = 0 266 | data.each do |d| 267 | prediction = predict(d) 268 | correct += 1 if d.last == prediction 269 | total += 1 unless prediction.nil? 270 | end 271 | (correct.to_f + 1) / (total.to_f + 2) 272 | end 273 | 274 | def accuracy(data = nil) 275 | data.nil? ? @accuracy : @accuracy = get_accuracy(data) 276 | end 277 | end 278 | 279 | class Ruleset 280 | attr_accessor :rules 281 | 282 | def initialize(attributes, data, default, type) 283 | @attributes = attributes 284 | @default = default 285 | @type = type 286 | mixed_data = data.sort_by { rand } 287 | cut = (mixed_data.size.to_f * 0.67).to_i 288 | @train_data = mixed_data.slice(0..cut - 1) 289 | @prune_data = mixed_data.slice(cut..-1) 290 | end 291 | 292 | def train(train_data = @train_data, attributes = @attributes, default = @default) 293 | dec_tree = DecisionTree::ID3Tree.new(attributes, train_data, default, @type) 294 | dec_tree.train 295 | @rules = dec_tree.build_rules 296 | @rules.each { |r| r.accuracy(train_data) } # Calculate accuracy 297 | prune 298 | end 299 | 300 | def prune(data = @prune_data) 301 | @rules.each do |r| 302 | (1..r.premises.size).each do 303 | acc1 = r.accuracy(data) 304 | p = r.premises.pop 305 | if acc1 > r.get_accuracy(data) 306 | r.premises.push(p) 307 | break 308 | end 309 | end 310 | end 311 | @rules = @rules.sort_by { |r| -r.accuracy(data) } 312 | end 313 | 314 | def to_s 315 | str = '' 316 | @rules.each { |rule| str += "#{rule}\n\n" } 317 | str 318 | end 319 | 320 | def predict(test) 321 | @rules.each do |r| 322 | prediction = r.predict(test) 323 | return prediction, r.accuracy unless prediction.nil? 324 | end 325 | [@default, 0.0] 326 | end 327 | end 328 | 329 | class Bagging 330 | attr_accessor :classifiers 331 | 332 | def initialize(attributes, data, default, type) 333 | @classifiers = [] 334 | @type = type 335 | @data = data 336 | @attributes = attributes 337 | @default = default 338 | end 339 | 340 | def train(data = @data, attributes = @attributes, default = @default) 341 | @classifiers = 10.times.map do |i| 342 | Ruleset.new(attributes, data, default, @type) 343 | end 344 | 345 | @classifiers.each_with_index do |classifier, index| 346 | puts "Processing classifier ##{index + 1}" 347 | classifier.train(data, attributes, default) 348 | end 349 | end 350 | 351 | def predict(test) 352 | predictions = Hash.new(0) 353 | @classifiers.each do |c| 354 | p, accuracy = c.predict(test) 355 | predictions[p] += accuracy unless p.nil? 356 | end 357 | return @default, 0.0 if predictions.empty? 358 | winner = predictions.sort_by { |_k, v| -v }.first 359 | [winner[0], winner[1].to_f / @classifiers.size.to_f] 360 | end 361 | end 362 | end 363 | 364 | -------------------------------------------------------------------------------- /spec/id3_spec.rb: -------------------------------------------------------------------------------- 1 | require 'spec_helper' 2 | 3 | describe describe DecisionTree::ID3Tree do 4 | 5 | describe "simple discrete case" do 6 | Given(:labels) { ["sun", "rain"]} 7 | Given(:data) do 8 | [ 9 | [1,0,1], 10 | [0,1,0] 11 | ] 12 | end 13 | Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) } 14 | When { tree.train } 15 | Then { expect(tree.predict([1,0])).to eq 1 } 16 | Then { expect(tree.predict([0,1])).to eq 0 } 17 | end 18 | 19 | describe "discrete attributes" do 20 | Given(:labels) { ["hungry", "color"] } 21 | Given(:data) do 22 | [ 23 | ["yes", "red", "angry"], 24 | ["no", "blue", "not angry"], 25 | ["yes", "blue", "not angry"], 26 | ["no", "red", "not angry"] 27 | ] 28 | end 29 | Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :discrete) } 30 | When { tree.train } 31 | Then { expect(tree.predict(["yes", "red"])).to eq "angry" } 32 | Then { expect(tree.predict(["no", "red"])).to eq "not angry" } 33 | end 34 | 35 | describe "discrete attributes" do 36 | Given(:labels) { ["hunger", "happiness"] } 37 | Given(:data) do 38 | [ 39 | [8, 7, "angry"], 40 | [6, 7, "angry"], 41 | [7, 9, "angry"], 42 | [7, 1, "not angry"], 43 | [2, 9, "not angry"], 44 | [3, 2, "not angry"], 45 | [2, 3, "not angry"], 46 | [1, 4, "not angry"] 47 | ] 48 | end 49 | Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", :continuous) } 50 | When { tree.train } 51 | Then { expect(tree.predict([7, 7])).to eq "angry" } 52 | Then { expect(tree.predict([2, 3])).to eq "not angry" } 53 | end 54 | 55 | describe "a mixture" do 56 | Given(:labels) { ["hunger", "color"] } 57 | Given(:data) do 58 | [ 59 | [8, "red", "angry"], 60 | [6, "red", "angry"], 61 | [7, "red", "angry"], 62 | [7, "blue", "not angry"], 63 | [2, "red", "not angry"], 64 | [3, "blue", "not angry"], 65 | [2, "blue", "not angry"], 66 | [1, "red", "not angry"] 67 | ] 68 | end 69 | Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "not angry", color: :discrete, hunger: :continuous) } 70 | When { tree.train } 71 | Then { expect(tree.predict([7, "red"])).to eq "angry" } 72 | Then { expect(tree.predict([2, "blue"])).to eq "not angry" } 73 | end 74 | 75 | describe "infinite recursion case" do 76 | Given(:labels) { [:a, :b, :c] } 77 | Given(:data) do 78 | [ 79 | ["a1", "b0", "c0", "RED"], 80 | ["a1", "b1", "c1", "RED"], 81 | ["a1", "b1", "c0", "BLUE"], 82 | ["a1", "b0", "c1", "BLUE"] 83 | ] 84 | end 85 | Given(:tree) { DecisionTree::ID3Tree.new(labels, data, "RED", :discrete) } 86 | When { tree.train } 87 | Then { expect(tree.predict(["a1","b0","c0"])).to eq "RED" } 88 | end 89 | 90 | describe "numerical labels case" do 91 | Given(:labels) { [1, 2] } 92 | Given(:data) do 93 | [ 94 | [1, 1, true], 95 | [1, 2, false], 96 | [2, 1, false], 97 | [2, 2, true] 98 | ] 99 | end 100 | Given(:tree) { DecisionTree::ID3Tree.new labels, data, nil, :discrete } 101 | When { tree.train } 102 | Then { 103 | expect { tree.predict([1, 1]) }.to_not raise_error 104 | } 105 | end 106 | 107 | describe "create a figure" do 108 | after(:all) do 109 | File.delete("#{FIGURE_FILENAME}.png") if File.file?("#{FIGURE_FILENAME}.png") 110 | end 111 | 112 | Given(:labels) { ["sun", "rain"]} 113 | Given(:data) do 114 | [ 115 | [1,0,1], 116 | [0,1,0] 117 | ] 118 | end 119 | Given(:tree) { DecisionTree::ID3Tree.new(labels, data, 1, :discrete) } 120 | When { tree.train } 121 | When(:result) { tree.graph(FIGURE_FILENAME) } 122 | Then { expect(result).to_not have_failed } 123 | And { File.file?("#{FIGURE_FILENAME}.png") } 124 | end 125 | end 126 | -------------------------------------------------------------------------------- /spec/spec_helper.rb: -------------------------------------------------------------------------------- 1 | require 'rspec/given' 2 | require 'decisiontree' 3 | require 'pry' 4 | 5 | FIGURE_FILENAME = "just_a_spec" 6 | --------------------------------------------------------------------------------