├── .gitignore ├── .travis.yml ├── LICENSE ├── NPtarget.go ├── README.md ├── adaboosttarget.go ├── adacosttarget.go ├── applyforest └── applyforest.go ├── arff.go ├── benchmark.png ├── benchmarks ├── README.md ├── benchmark.sh └── sklrf.py ├── benchmarks_test.go ├── catballotbox.go ├── catmap.go ├── data ├── README.md ├── forestfires.fm ├── forestfires.trans.fm ├── iris.data.fm └── iris.data.trans.fm ├── data_test.go ├── densecatfeature.go ├── densecatfeature_test.go ├── densenumfeature.go ├── densenumfeature_test.go ├── densitytarget.go ├── dentropytarget.go ├── doc.go ├── entropytarget.go ├── error.png ├── evaluator.go ├── evaluator_test.go ├── featureinterfaces.go ├── featurematrix.go ├── featurematrix_test.go ├── forest.go ├── forest_test.go ├── forestreader.go ├── forestwriter.go ├── forestwriterreader_test.go ├── gradboostclasstarget.go ├── gradboosttarget.go ├── growforest └── growforest.go ├── hdistancetarget.go ├── importance_test.go ├── install.sh ├── l1target.go ├── leafcount └── leafcount.go ├── libsvm.go ├── n.csv ├── node.go ├── numadaboostingtarget.go ├── numballotbox.go ├── ordinaltarget.go ├── preds.csv ├── regrettarget.go ├── sampling.go ├── sampling_test.go ├── sklearn_tree.go ├── sortablefeature.go ├── sortby ├── sortby.go └── sortby_test.go ├── splitallocations.go ├── splitter.go ├── stats ├── stats.go └── welchst_test.go ├── sumballotbox.go ├── transduction.go ├── tree.go ├── utils.go ├── utils ├── nfold │ └── main.go └── toafm │ └── main.go ├── utils_test.go ├── voter.go ├── wrappers └── python │ ├── CFClassifier.py │ └── test_CFClassifier.py └── wrftarget.go /.gitignore: -------------------------------------------------------------------------------- 1 | #ignore binaries 2 | growforest 3 | applyforest 4 | leafcount 5 | toafm 6 | nfold 7 | #ignore common files used/generated in local testing 8 | *.libsvm 9 | *.arff 10 | *.sf 11 | *.fm 12 | *.dot 13 | *.prof 14 | *.out 15 | *.imp 16 | *.tsv 17 | #ignore os x .DS_Store 18 | .DS_Store 19 | .httr-oauth 20 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | env: 4 | global: 5 | - TRAVIS_GO_VERSION=1.14.x 6 | 7 | go: 8 | - 1.13.x 9 | - 1.14.x 10 | - release 11 | - tip 12 | 13 | before_install: 14 | - go get -t -v ./... 15 | 16 | script: 17 | - go test -race -cpu 1,2 -v -timeout 5m 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, the CloudForest Contributors listed at 2 | http://github.com/ryanbressler/CloudForest/graphs/contributors 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of the The Institute for Systems Biology, CloudForest 12 | nor the names of its contributors may be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY 19 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /NPtarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "math" 5 | ) 6 | 7 | /* 8 | NPTarget wraps a categorical feature for use in experimental approximate Neyman-Pearson (NP) 9 | classification...constraints and optimization are done on percision false 10 | positive/negative rate. 11 | 12 | It uses an impurity measure with a soft constraint from the seccond family presented in 13 | 14 | "Comparison and Design of Neyman-Pearson Classifiers" 15 | Clayton Scott, October 2005 16 | 17 | http://www.stat.rice.edu/~cscott/pubs/npdesign.pdf 18 | 19 | 20 | N(f) = κ max((R0(f) − α), 0) + R1(f) 21 | 22 | Where f is the classifer, R0 is the flase positive rate R1 is the false negative rate, 23 | α is the false positive constraint and k controls the cost of violating 24 | this constraint and β is a constant we can ignore as it subtracts out in diffrences. 25 | 26 | The vote assigned to each leaf node is a corrected mode where the count of the 27 | positive/constrained label is corrected by 1/α. Without this modification constraints 28 | > .5 won't work since nodes with that many negatives false positives won't vote positive. 29 | */ 30 | type NPTarget struct { 31 | CatFeature 32 | Posi int 33 | Alpha float64 34 | Kappa float64 35 | } 36 | 37 | //NewNPTarget wraps a Categorical Feature for NP Classification. It accepts 38 | //a string representing the contstrained label and floats Alpha and Kappa 39 | //representing the constraint and constraint weight. 40 | func NewNPTarget(f CatFeature, Pos string, Alpha, Kappa float64) *NPTarget { 41 | return &NPTarget{f, f.CatToNum(Pos), Alpha, Kappa} 42 | } 43 | 44 | /* 45 | SplitImpurity is a version of Split Impurity that calls NPTarget.Impurity 46 | */ 47 | func (target *NPTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 48 | nl := float64(len(*l)) 49 | nr := float64(len(*r)) 50 | nm := 0.0 51 | 52 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 53 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 54 | if m != nil && len(*m) > 0 { 55 | nm = float64(len(*m)) 56 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 57 | } 58 | 59 | impurityDecrease /= nl + nr + nm 60 | return 61 | } 62 | 63 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l 64 | //to avoid recalulatign the entire split impurity. 65 | func (target *NPTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 66 | var cat, i int 67 | lcounter := *allocs.LCounter 68 | rcounter := *allocs.RCounter 69 | for _, i = range *movedRtoL { 70 | 71 | //most expensive statement: 72 | cat = target.Geti(i) 73 | lcounter[cat]++ 74 | rcounter[cat]-- 75 | //counter[target.Geti(i)]++ 76 | 77 | } 78 | nl := float64(len(*l)) 79 | nr := float64(len(*r)) 80 | nm := 0.0 81 | 82 | impurityDecrease = nl * target.ImpFromCounts(len(*l), allocs.LCounter) 83 | impurityDecrease += nr * target.ImpFromCounts(len(*r), allocs.RCounter) 84 | if m != nil && len(*m) > 0 { 85 | nm = float64(len(*m)) 86 | impurityDecrease += nm * target.ImpFromCounts(len(*m), allocs.Counter) 87 | } 88 | 89 | impurityDecrease /= nl + nr + nm 90 | return 91 | } 92 | 93 | //FindPredicted does a mode calulation with the count of the positive/constrained 94 | //class corrected. 95 | func (target *NPTarget) FindPredicted(cases []int) (pred string) { 96 | 97 | mi := 0 98 | mc := 0.0 99 | counts := make([]int, target.NCats()) 100 | 101 | target.CountPerCat(&cases, &counts) 102 | 103 | for cat, count := range counts { 104 | cc := float64(count) 105 | if cat == target.Posi { 106 | cc /= target.Alpha 107 | } 108 | if cc > mc { 109 | mi = cat 110 | mc = cc 111 | } 112 | } 113 | 114 | return target.NumToCat(mi) 115 | 116 | } 117 | 118 | //ImpFromCounts recalculates gini impurity from class counts for us in intertive updates. 119 | func (target *NPTarget) ImpFromCounts(t int, counter *[]int) (e float64) { 120 | 121 | var totalpos, totalneg, mi int 122 | 123 | mc := 0.0 124 | 125 | for cat, count := range *counter { 126 | cc := float64(count) 127 | if cat == target.Posi { 128 | totalpos += count 129 | cc /= target.Alpha 130 | } else { 131 | totalneg += count 132 | } 133 | 134 | if cc > mc { 135 | mi = cat 136 | mc = cc 137 | } 138 | 139 | } 140 | 141 | if target.Posi == mi { 142 | //False positive constraint 143 | e = target.Kappa * math.Max(float64(totalneg)/float64(t)-target.Alpha, 0) 144 | } else { 145 | //False negative rate 146 | e = float64(totalpos) / float64(t) 147 | } 148 | 149 | return 150 | 151 | } 152 | 153 | //NPTarget.Impurity implements an impurity that minimizes false negatives subject 154 | //to a soft constrain on fale positives. 155 | func (target *NPTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 156 | 157 | target.CountPerCat(cases, counter) 158 | t := len(*cases) 159 | e = target.ImpFromCounts(t, counter) 160 | 161 | return 162 | 163 | } 164 | -------------------------------------------------------------------------------- /adaboosttarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "math" 5 | ) 6 | 7 | /* 8 | AdaBoostTarget wraps a numerical feature as a target for us in Adaptive Boosting (AdaBoost) 9 | 10 | */ 11 | type AdaBoostTarget struct { 12 | CatFeature 13 | Weights []float64 14 | } 15 | 16 | /* 17 | NewAdaBoostTarget creates a categorical adaptive boosting target and initializes its weights. 18 | */ 19 | func NewAdaBoostTarget(f CatFeature) (abt *AdaBoostTarget) { 20 | nCases := f.Length() 21 | abt = &AdaBoostTarget{f, make([]float64, nCases)} 22 | for i := range abt.Weights { 23 | abt.Weights[i] = 1 / float64(nCases) 24 | } 25 | return 26 | } 27 | 28 | /* 29 | SplitImpurity is an AdaCosting version of SplitImpurity. 30 | */ 31 | func (target *AdaBoostTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 32 | nl := float64(len(*l)) 33 | nr := float64(len(*r)) 34 | nm := 0.0 35 | 36 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 37 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 38 | if m != nil && len(*m) > 0 { 39 | nm = float64(len(*m)) 40 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 41 | } 42 | 43 | impurityDecrease /= nl + nr + nm 44 | return 45 | } 46 | 47 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 48 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 49 | func (target *AdaBoostTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 50 | var cat, i int 51 | lcounter := *allocs.LCounter 52 | rcounter := *allocs.RCounter 53 | for _, i = range *movedRtoL { 54 | 55 | //most expensive statement: 56 | cat = target.Geti(i) 57 | lcounter[cat]++ 58 | rcounter[cat]-- 59 | //counter[target.Geti(i)]++ 60 | 61 | } 62 | nl := float64(len(*l)) 63 | nr := float64(len(*r)) 64 | nm := 0.0 65 | 66 | impurityDecrease = nl * target.ImpFromCounts(l, allocs.LCounter) 67 | impurityDecrease += nr * target.ImpFromCounts(r, allocs.RCounter) 68 | if m != nil && len(*m) > 0 { 69 | nm = float64(len(*m)) 70 | impurityDecrease += nm * target.ImpFromCounts(m, allocs.Counter) 71 | } 72 | 73 | impurityDecrease /= nl + nr + nm 74 | return 75 | } 76 | 77 | //Impurity is an AdaCosting that uses the weights specified in weights. 78 | func (target *AdaBoostTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 79 | e = 0.0 80 | //m := target.Modei(cases) 81 | 82 | target.CountPerCat(cases, counter) 83 | e = target.ImpFromCounts(cases, counter) 84 | 85 | return 86 | } 87 | 88 | //ImpFromCounts recalculates gini impurity from class counts for us in intertive updates. 89 | func (target *AdaBoostTarget) ImpFromCounts(cases *[]int, counter *[]int) (e float64) { 90 | 91 | var m, mc int 92 | 93 | for i, c := range *counter { 94 | if c > mc { 95 | m = i 96 | mc = c 97 | } 98 | } 99 | 100 | for _, c := range *cases { 101 | 102 | cat := target.Geti(c) 103 | if cat != m { 104 | e += target.Weights[c] 105 | } 106 | 107 | } 108 | 109 | return 110 | 111 | } 112 | 113 | //Boost performs categorical adaptive boosting using the specified partition and 114 | //returns the weight that tree that generated the partition should be given. 115 | func (t *AdaBoostTarget) Boost(leaves *[][]int) (weight float64) { 116 | weight = 0.0 117 | counter := make([]int, t.NCats()) 118 | for _, cases := range *leaves { 119 | weight += t.Impurity(&cases, &counter) 120 | } 121 | if weight >= .5 { 122 | return 0.0 123 | } 124 | weight = .5 * math.Log((1-weight)/weight) 125 | 126 | for _, cases := range *leaves { 127 | 128 | t.CountPerCat(&cases, &counter) 129 | 130 | var m, mc int 131 | for i, c := range counter { 132 | if c > mc { 133 | m = i 134 | mc = c 135 | } 136 | } 137 | 138 | for _, c := range cases { 139 | if t.IsMissing(c) == false { 140 | cat := t.Geti(c) 141 | //CHANGE from adaboost: 142 | if cat != m { 143 | t.Weights[c] = t.Weights[c] * math.Exp(weight) 144 | } else { 145 | t.Weights[c] = t.Weights[c] * math.Exp(-weight) 146 | } 147 | } 148 | 149 | } 150 | } 151 | normfactor := 0.0 152 | for _, v := range t.Weights { 153 | normfactor += v 154 | } 155 | for i, v := range t.Weights { 156 | t.Weights[i] = v / normfactor 157 | } 158 | return 159 | } 160 | -------------------------------------------------------------------------------- /adacosttarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "math" 5 | ) 6 | 7 | /* 8 | AdaCostTarget wraps a numerical feature as a target for us in Cost Sensitive Adaptive Boosting (AdaC2.M1) 9 | 10 | "Boosting for Learning Multiple Classes with Imbalanced Class Distribution" 11 | Yanmin Sun, Mohamed S. Kamel and Yang Wang 12 | 13 | See equations in slides here: 14 | http://people.ee.duke.edu/~lcarin/Minhua4.18.08.pdf 15 | 16 | */ 17 | type AdaCostTarget struct { 18 | CatFeature 19 | Weights []float64 20 | Costs []float64 21 | } 22 | 23 | /* 24 | NewAdaCostTarget creates a categorical adaptive boosting target and initializes its weights. 25 | */ 26 | func NewAdaCostTarget(f CatFeature) (abt *AdaCostTarget) { 27 | nCases := f.Length() 28 | abt = &AdaCostTarget{f, make([]float64, nCases), make([]float64, f.NCats())} 29 | for i := range abt.Weights { 30 | abt.Weights[i] = 1 / float64(nCases) 31 | } 32 | return 33 | } 34 | 35 | /*RegretTarget.SetCosts puts costs in a map[string]float64 by feature name into the proper 36 | entries in RegretTarget.Costs.*/ 37 | func (target *AdaCostTarget) SetCosts(costmap map[string]float64) { 38 | for i := 0; i < target.NCats(); i++ { 39 | c := target.NumToCat(i) 40 | target.Costs[i] = costmap[c] 41 | } 42 | } 43 | 44 | /* 45 | SplitImpurity is an AdaCosting version of SplitImpurity. 46 | */ 47 | func (target *AdaCostTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 48 | nl := float64(len(*l)) 49 | nr := float64(len(*r)) 50 | nm := 0.0 51 | 52 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 53 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 54 | if m != nil && len(*m) > 0 { 55 | nm = float64(len(*m)) 56 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 57 | } 58 | 59 | impurityDecrease /= nl + nr + nm 60 | return 61 | } 62 | 63 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 64 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 65 | func (target *AdaCostTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 66 | var cat, i int 67 | lcounter := *allocs.LCounter 68 | rcounter := *allocs.RCounter 69 | for _, i = range *movedRtoL { 70 | 71 | //most expensive statement: 72 | cat = target.Geti(i) 73 | lcounter[cat]++ 74 | rcounter[cat]-- 75 | //counter[target.Geti(i)]++ 76 | 77 | } 78 | nl := float64(len(*l)) 79 | nr := float64(len(*r)) 80 | nm := 0.0 81 | 82 | impurityDecrease = nl * target.ImpFromCounts(l, allocs.LCounter) 83 | impurityDecrease += nr * target.ImpFromCounts(r, allocs.RCounter) 84 | if m != nil && len(*m) > 0 { 85 | nm = float64(len(*m)) 86 | impurityDecrease += nm * target.ImpFromCounts(m, allocs.Counter) 87 | } 88 | 89 | impurityDecrease /= nl + nr + nm 90 | return 91 | } 92 | 93 | //Impurity is an AdaCosting that uses the weights specified in weights. 94 | func (target *AdaCostTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 95 | e = 0.0 96 | //m := target.Modei(cases) 97 | 98 | target.CountPerCat(cases, counter) 99 | e = target.ImpFromCounts(cases, counter) 100 | 101 | return 102 | } 103 | 104 | //ImpFromCounts recalculates gini impurity from class counts for us in intertive updates. 105 | func (target *AdaCostTarget) ImpFromCounts(cases *[]int, counter *[]int) (e float64) { 106 | 107 | var m, mc int 108 | 109 | for i, c := range *counter { 110 | if c > mc { 111 | m = i 112 | mc = c 113 | } 114 | } 115 | 116 | for _, c := range *cases { 117 | 118 | cat := target.Geti(c) 119 | if cat != m { 120 | e += target.Weights[c] * target.Costs[cat] 121 | } 122 | 123 | } 124 | 125 | return 126 | 127 | } 128 | 129 | //Boost performs categorical adaptive boosting using the specified partition and 130 | //returns the weight that tree that generated the partition should be given. 131 | func (t *AdaCostTarget) Boost(leaves *[][]int) (weight float64) { 132 | weight = 0.0 133 | counter := make([]int, t.NCats()) 134 | for _, cases := range *leaves { 135 | weight += t.Impurity(&cases, &counter) 136 | } 137 | if weight >= .5 { 138 | return 0.0 139 | } 140 | weight = .5 * math.Log((1-weight)/weight) 141 | 142 | for _, cases := range *leaves { 143 | t.CountPerCat(&cases, &counter) 144 | 145 | var m, mc int 146 | for i, c := range counter { 147 | if c > mc { 148 | m = i 149 | mc = c 150 | } 151 | } 152 | 153 | for _, c := range cases { 154 | if t.IsMissing(c) == false { 155 | cat := t.Geti(c) 156 | //CHANGE from adaboost: 157 | if cat != m { 158 | t.Weights[c] = t.Weights[c] * math.Exp(weight) * t.Costs[cat] 159 | } else { 160 | t.Weights[c] = t.Weights[c] * math.Exp(-weight) * t.Costs[cat] 161 | } 162 | } 163 | 164 | } 165 | } 166 | normfactor := 0.0 167 | for _, v := range t.Weights { 168 | normfactor += v 169 | } 170 | for i, v := range t.Weights { 171 | t.Weights[i] = v / normfactor 172 | } 173 | return 174 | } 175 | -------------------------------------------------------------------------------- /applyforest/applyforest.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | "strings" 9 | 10 | "github.com/lytics/CloudForest" 11 | ) 12 | 13 | func main() { 14 | fm := flag.String("fm", 15 | "featurematrix.afm", "AFM formated feature matrix containing data.") 16 | rf := flag.String("rfpred", 17 | "rface.sf", "A predictor forest.") 18 | predfn := flag.String("preds", 19 | "", "The name of a file to write the predictions into.") 20 | votefn := flag.String("votes", 21 | "", "The name of a file to write categorical vote totals to.") 22 | var num bool 23 | flag.BoolVar(&num, "mean", false, "Force numeric (mean) voting.") 24 | var sum bool 25 | flag.BoolVar(&sum, "sum", false, "Force numeric sum voting (for gradient boosting etc).") 26 | var expit bool 27 | flag.BoolVar(&expit, "expit", false, "Expit (inverst logit) transform data (for gradient boosting classification).") 28 | var cat bool 29 | flag.BoolVar(&cat, "mode", false, "Force categorical (mode) voting.") 30 | 31 | flag.Parse() 32 | 33 | //Parse Data 34 | data, err := CloudForest.LoadAFM(*fm) 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | 39 | forestfile, err := os.Open(*rf) // For read access. 40 | if err != nil { 41 | log.Fatal(err) 42 | } 43 | defer forestfile.Close() 44 | forestreader := CloudForest.NewForestReader(forestfile) 45 | forest, err := forestreader.ReadForest() 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | 50 | var predfile *os.File 51 | if *predfn != "" { 52 | predfile, err = os.Create(*predfn) 53 | if err != nil { 54 | log.Fatal(err) 55 | } 56 | defer predfile.Close() 57 | } 58 | 59 | var bb CloudForest.VoteTallyer 60 | switch { 61 | case sum: 62 | bb = CloudForest.NewSumBallotBox(data.Data[0].Length()) 63 | 64 | case !cat && (num || strings.HasPrefix(forest.Target, "N")): 65 | bb = CloudForest.NewNumBallotBox(data.Data[0].Length()) 66 | 67 | default: 68 | bb = CloudForest.NewCatBallotBox(data.Data[0].Length()) 69 | } 70 | 71 | for _, tree := range forest.Trees { 72 | tree.Vote(data, bb) 73 | } 74 | 75 | targeti, hasTarget := data.Map[forest.Target] 76 | if hasTarget { 77 | fmt.Printf("Target is %v in feature %v\n", forest.Target, targeti) 78 | er := bb.TallyError(data.Data[targeti]) 79 | fmt.Printf("Error: %v\n", er) 80 | } 81 | if *predfn != "" { 82 | fmt.Printf("Outputting label predicted actual tsv to %v\n", *predfn) 83 | for i, l := range data.CaseLabels { 84 | actual := "NA" 85 | if hasTarget { 86 | actual = data.Data[targeti].GetStr(i) 87 | } 88 | 89 | result := "" 90 | 91 | if sum || forest.Intercept != 0.0 { 92 | numresult := 0.0 93 | if sum { 94 | numresult = bb.(*CloudForest.SumBallotBox).TallyNum(i) + forest.Intercept 95 | } else { 96 | numresult = bb.(*CloudForest.NumBallotBox).TallyNum(i) + forest.Intercept 97 | } 98 | if expit { 99 | numresult = CloudForest.Expit(numresult) 100 | } 101 | result = fmt.Sprintf("%v", numresult) 102 | 103 | } else { 104 | result = bb.Tally(i) 105 | } 106 | fmt.Fprintf(predfile, "%v\t%v\t%v\n", l, result, actual) 107 | } 108 | } 109 | 110 | //Not thread safe code! 111 | if *votefn != "" { 112 | fmt.Printf("Outputting vote totals to %v\n", *votefn) 113 | cbb := bb.(*CloudForest.CatBallotBox) 114 | votefile, err := os.Create(*votefn) 115 | if err != nil { 116 | log.Fatal(err) 117 | } 118 | defer votefile.Close() 119 | fmt.Fprintf(votefile, ".") 120 | 121 | for _, lable := range cbb.CatMap.Back { 122 | fmt.Fprintf(votefile, "\t%v", lable) 123 | } 124 | fmt.Fprintf(votefile, "\n") 125 | 126 | for i, box := range cbb.Box { 127 | fmt.Fprintf(votefile, "%v", data.CaseLabels[i]) 128 | 129 | for j := range cbb.CatMap.Back { 130 | total := 0.0 131 | total = box.Map[j] 132 | 133 | fmt.Fprintf(votefile, "\t%v", total) 134 | 135 | } 136 | fmt.Fprintf(votefile, "\n") 137 | 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /arff.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "bufio" 5 | "encoding/csv" 6 | "fmt" 7 | "io" 8 | "log" 9 | "strings" 10 | ) 11 | 12 | //ParseARFF reads a file in weka'sarff format: 13 | //http://www.cs.waikato.ac.nz/ml/weka/arff.html 14 | //The relation is ignored and only catagorical and numerical variables are supported 15 | func ParseARFF(input io.Reader) *FeatureMatrix { 16 | 17 | reader := bufio.NewReader(input) 18 | 19 | data := make([]Feature, 0, 100) 20 | lookup := make(map[string]int, 0) 21 | //labels := make([]string, 0, 0) 22 | 23 | i := 0 24 | for { 25 | 26 | line, err := reader.ReadString('\n') 27 | if err != nil { 28 | log.Print("Error:", err) 29 | return nil 30 | } 31 | norm := strings.ToLower(line) 32 | 33 | if strings.HasPrefix(norm, "@data") { 34 | break 35 | } 36 | 37 | if strings.HasPrefix(norm, "@attribute") { 38 | vals := strings.Fields(line) 39 | 40 | if strings.ToLower(vals[2]) == "numeric" || strings.ToLower(vals[2]) == "real" { 41 | data = append(data, &DenseNumFeature{ 42 | make([]float64, 0, 0), 43 | make([]bool, 0, 0), 44 | vals[1], 45 | false}) 46 | } else { 47 | data = append(data, &DenseCatFeature{ 48 | NewCatMap(), 49 | make([]int, 0, 0), 50 | make([]bool, 0, 0), 51 | vals[1], 52 | false, 53 | false}) 54 | } 55 | 56 | lookup[vals[1]] = i 57 | //labels = append(labels, vals[1]) 58 | i++ 59 | } 60 | 61 | } 62 | 63 | fm := &FeatureMatrix{data, lookup, make([]string, 0, 0)} 64 | 65 | csvdata := csv.NewReader(reader) 66 | csvdata.Comment = '%' 67 | //csvdata.Comma = ',' 68 | 69 | fm.LoadCases(csvdata, false) 70 | return fm 71 | 72 | } 73 | 74 | //WriteArffCases writes the specified cases from the provied feature matrix into an arff file with the given relation string. 75 | func WriteArffCases(data *FeatureMatrix, cases []int, relation string, outfile io.Writer) error { 76 | /*@RELATION iris 77 | 78 | @ATTRIBUTE sepallength NUMERIC 79 | @ATTRIBUTE class {Iris-setosa,Iris-versicolor,Iris-virginica}*/ 80 | 81 | fmt.Fprintf(outfile, "@RELATION %v\n\n", relation) 82 | 83 | for _, f := range data.Data { 84 | ftype := "NUMERIC" 85 | switch f.(type) { 86 | case (*DenseCatFeature): 87 | ftype = fmt.Sprintf("{%v}", strings.Join(f.(*DenseCatFeature).Back, ",")) 88 | } 89 | 90 | fmt.Fprintf(outfile, "@ATTRIBUTE %v %v\n", f.GetName(), ftype) 91 | } 92 | 93 | fmt.Fprint(outfile, "\n@DATA\n") 94 | 95 | oucsv := csv.NewWriter(outfile) 96 | oucsv.Comma = ',' 97 | 98 | for _, i := range cases { 99 | entries := make([]string, 0, 10) 100 | 101 | for _, f := range data.Data { 102 | v := "?" 103 | if !f.IsMissing(i) { 104 | v = f.GetStr(i) 105 | } 106 | entries = append(entries, v) 107 | 108 | } 109 | //fmt.Println(entries) 110 | err := oucsv.Write(entries) 111 | if err != nil { 112 | return err 113 | } 114 | 115 | } 116 | oucsv.Flush() 117 | return nil 118 | } 119 | -------------------------------------------------------------------------------- /benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lytics/CloudForest/381792ef996b4d29adf78fb053e14df7390dd508/benchmark.png -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | Benchmarks on Forest Coverage Data 2 | ================================================ 3 | 4 | Learner | CloudForest | scikit.learn 0.14.1 | scikit.learn 0.15 | CloudForest 5 | --------|-------------|---------------------|-------------------|------------ 6 | Format | libsvm | libsvm | libsvm | arff 7 | Time | 38 seconds | ??? seconds | 30 seconds | 29 seconds 8 | 9 | The arff format records which variables are binary or catagorical allowing cloudforest to use appropriate splitters for greater speed. Scikit.learn treats all data as numerical. This data set was chosen to allow comparison with benchmarks by [wise.io](http://about.wise.io/blog/2013/07/15/benchmarking-random-forest-part-1/) and [Alex Rubinsteyn](http://blog.explainmydata.com/2014/03/big-speedup-for-random-forest-learning.html). 10 | 11 | Cloudforest and scikit.learn 0.15 were checked out on 3/10/2014. 12 | 13 | Hardware 14 | --------- 15 | Benchmarks were performed using 8 hyperthreads on a 15-inc MacBook Pro 10,1 with a 2.4 Ghz Intel Core i7 (I7-3635QM) with per Core L3 cahe of 256 KB, 6 MB of L3 cache and 8Gb of 1600 MHz ram. 16 | 17 | Data Sources 18 | ------------ 19 | 20 | The forest coverage data set in libsvm format was aquired [from here](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#covtype). 21 | 22 | Scaled forest coverage in arff format with catagorical variables not converted to numerical was aquired [from the MOA project](http://sourceforge.net/projects/moa-datastream/files/Datasets/Classification/). -------------------------------------------------------------------------------- /benchmarks/benchmark.sh: -------------------------------------------------------------------------------- 1 | python sklrf.py covtype.libsvm 2 | growforest -train covtype.libsvm -target "0" -nTrees 50 -nCores 8 -mTry 7 3 | growforest -train covtypeNorm.arff -target "class" -nTrees 50 -nCores 8 -mTry 7 4 | #time rf-ace --trainData covtype.fm --target C:class --nodeSize 1 -n 50 -e 8 -m 7 -------------------------------------------------------------------------------- /benchmarks/sklrf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from sklearn.datasets import load_svmlight_file 3 | 4 | from sklearn.ensemble import RandomForestClassifier 5 | 6 | from time import time 7 | 8 | import numpy as np 9 | 10 | 11 | def dumptree(atree, fn): 12 | from sklearn import tree 13 | f = open(fn,"w") 14 | tree.export_graphviz(atree,out_file=f) 15 | f.close() 16 | 17 | # def main(): 18 | fn = sys.argv[1] 19 | X,Y = load_svmlight_file(fn) 20 | 21 | rf_parameters = { 22 | "n_estimators": 2000, 23 | "n_jobs": 8 24 | } 25 | clf = RandomForestClassifier(**rf_parameters) 26 | X = X.toarray() 27 | 28 | print clf 29 | 30 | print "Starting Training" 31 | t0 = time() 32 | clf.fit(X, Y) 33 | train_time = time() - t0 34 | print "Training on %s took %s"%(fn, train_time) 35 | 36 | if len(sys.argv) == 2: 37 | score = clf.score(X, Y) 38 | count = np.sum(clf.predict(X)==Y) 39 | print "Score: %s, %s / %s "%(score, count, len(Y)) 40 | else: 41 | fn = sys.argv[2] 42 | X,Y = load_svmlight_file(fn) 43 | X = X.toarray() 44 | score = clf.score(X, Y) 45 | count = np.sum(clf.predict(X)==Y) 46 | c1 = np.sum(clf.predict(X[Y==1])==Y[Y==1] ) 47 | c0 = np.sum(clf.predict(X[Y==0])==Y[Y==0] ) 48 | l = len(Y) 49 | print "Testing Score: %s, %s / %s, %s, %s, %s "%(score, count, l, c1, c0, (float(c1)/float(sum(Y==1))+float(c0)/float(sum(Y==0)))/2.0) 50 | 51 | 52 | # if __name__ == '__main__': 53 | # main() 54 | -------------------------------------------------------------------------------- /benchmarks_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func BenchmarkIris(b *testing.B) { 9 | 10 | candidates := []int{1, 2, 3, 4} 11 | // irisreader := strings.NewReader(irisarff) 12 | // fm := ParseARFF(irisreader) 13 | // targeti := 4 14 | 15 | irisreader := strings.NewReader(irislibsvm) 16 | fm := ParseLibSVM(irisreader) 17 | targeti := 0 18 | 19 | target := fm.Data[targeti] 20 | 21 | cases := make([]int, 0, 150) 22 | for i := 0; i < fm.Data[0].Length(); i++ { 23 | cases = append(cases, i) 24 | } 25 | allocs := NewBestSplitAllocs(len(cases), target) 26 | 27 | b.ResetTimer() 28 | for i := 0; i < b.N; i++ { 29 | tree := NewTree() 30 | tree.Grow(fm, target, cases, candidates, nil, 2, 1, 0, false, false, false, false, false, nil, nil, allocs) 31 | 32 | } 33 | } 34 | 35 | func BenchmarkBoston(b *testing.B) { 36 | 37 | boston := strings.NewReader(boston_housing) 38 | 39 | fm := ParseARFF(boston) 40 | 41 | candidates := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} 42 | 43 | target := fm.Data[fm.Map["class"]] 44 | 45 | cases := make([]int, 0, fm.Data[0].Length()) 46 | for i := 0; i < fm.Data[0].Length(); i++ { 47 | cases = append(cases, i) 48 | } 49 | allocs := NewBestSplitAllocs(len(cases), target) 50 | 51 | b.ResetTimer() 52 | for i := 0; i < b.N; i++ { 53 | tree := NewTree() 54 | tree.Grow(fm, target, cases, candidates, nil, 2, 1, 0, false, false, false, false, false, nil, nil, allocs) 55 | 56 | } 57 | } 58 | 59 | func BenchmarkBestNumSplit(b *testing.B) { 60 | 61 | // irisreader := strings.NewReader(irisarff) 62 | // fm := ParseARFF(irisreader) 63 | // targeti := 4 64 | 65 | irisreader := strings.NewReader(irislibsvm) 66 | fm := ParseLibSVM(irisreader) 67 | targeti := 0 68 | 69 | targetf := fm.Data[targeti] 70 | 71 | cases := make([]int, 0, 150) 72 | for i := 0; i < fm.Data[0].Length(); i++ { 73 | cases = append(cases, i) 74 | } 75 | allocs := NewBestSplitAllocs(len(cases), targetf) 76 | 77 | parentImp := targetf.Impurity(&cases, allocs.Counter) 78 | 79 | b.ResetTimer() 80 | for i := 0; i < b.N; i++ { 81 | _, _, _ = fm.Data[1].BestSplit(targetf, &cases, parentImp, 1, false, allocs) 82 | 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /catballotbox.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "sort" 5 | "sync" 6 | ) 7 | 8 | //CatBallot is used insideof CatBallotBox to record catagorical votes in a thread safe 9 | //manner. 10 | type CatBallot struct { 11 | Mutex sync.RWMutex 12 | Map map[int]float64 13 | } 14 | 15 | //NewCatBallot returns a pointer to an initalized CatBallot with a 0 size Map. 16 | func NewCatBallot() (cb *CatBallot) { 17 | cb = new(CatBallot) 18 | cb.Map = make(map[int]float64, 0) 19 | return 20 | } 21 | 22 | //CatBallotBox keeps track of votes by trees in a thread safe manner. 23 | type CatBallotBox struct { 24 | *CatMap 25 | Box []*CatBallot 26 | } 27 | 28 | //NewCatBallotBox builds a new ballot box for the number of cases specified by "size". 29 | func NewCatBallotBox(size int) *CatBallotBox { 30 | bb := CatBallotBox{ 31 | CatMap: NewCatMap(), 32 | Box: make([]*CatBallot, 0, size), 33 | } 34 | for i := 0; i < size; i++ { 35 | bb.Box = append(bb.Box, NewCatBallot()) 36 | } 37 | return &bb 38 | } 39 | 40 | //Vote registers a vote that case "casei" should be predicted to be the 41 | //category "pred". 42 | func (bb *CatBallotBox) Vote(casei int, pred string, weight float64) { 43 | predn := bb.CatToNum(pred) 44 | bb.Box[casei].Mutex.Lock() 45 | if _, ok := bb.Box[casei].Map[predn]; !ok { 46 | bb.Box[casei].Map[predn] = 0 47 | } 48 | bb.Box[casei].Map[predn] = bb.Box[casei].Map[predn] + weight 49 | bb.Box[casei].Mutex.Unlock() 50 | } 51 | 52 | //Tally tallies the votes for the case specified by i as 53 | //if it is a Categorical or boolean feature. Ie it returns the mode 54 | //(the most frequent value) of all votes. 55 | func (bb *CatBallotBox) Tally(i int) (predicted string) { 56 | var predictedn int 57 | var maxVote float64 58 | var ties []int 59 | bb.Box[i].Mutex.RLock() 60 | for k, v := range bb.Box[i].Map { 61 | if v > maxVote { 62 | predictedn = k 63 | maxVote = v 64 | ties = nil 65 | } 66 | 67 | // keep track of the ties so that our predictions 68 | // are deterministic 69 | if v == maxVote { 70 | ties = append(ties, k) 71 | } 72 | } 73 | bb.Box[i].Mutex.RUnlock() 74 | 75 | // if there is a tie in the predictions, 76 | // then pick the smaller key 77 | if len(ties) > 1 { 78 | sort.Ints(ties) 79 | predictedn = ties[0] 80 | } 81 | 82 | if maxVote > 0 { 83 | predicted = bb.Back[predictedn] 84 | } else { 85 | predicted = "NA" 86 | } 87 | return 88 | } 89 | 90 | /* 91 | TallyError returns the balanced classification error for categorical features. 92 | 93 | 1 - sum((sum(Y(xi)=Y'(xi))/|xi|)) 94 | 95 | where 96 | Y are the labels 97 | Y' are the estimated labels 98 | xi is the set of samples with the ith actual label 99 | 100 | Case for which the true category is not known are ignored. 101 | 102 | */ 103 | func (bb *CatBallotBox) TallyError(feature Feature) float64 { 104 | catfeature := feature.(CatFeature) 105 | ncats := catfeature.NCats() 106 | correct := make([]int, ncats) 107 | total := make([]int, ncats) 108 | for i := 0; i < feature.Length(); i++ { 109 | value := catfeature.Geti(i) 110 | predicted := bb.Tally(i) 111 | if feature.IsMissing(i) { 112 | continue 113 | } 114 | total[value]++ 115 | if catfeature.NumToCat(value) == predicted { 116 | correct[value]++ 117 | } 118 | } 119 | 120 | var e float64 121 | for i, ncorrect := range correct { 122 | e += float64(ncorrect) / float64(total[i]) 123 | } 124 | 125 | e /= float64(ncats) 126 | e = 1.0 - e 127 | return e 128 | } 129 | -------------------------------------------------------------------------------- /catmap.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | /*CatMap is for mapping categorical values to integers. 8 | It contains: 9 | 10 | Map : a map of ints by the string used for the category 11 | Back : a slice of strings by the int that represents them 12 | 13 | And is embedded by Feature and CatBallotBox. 14 | */ 15 | type CatMap struct { 16 | privateMap map[string]int //map categories from string to Num 17 | Back []string // map categories from Num to string 18 | 19 | CatMapMut sync.Mutex 20 | } 21 | 22 | func NewCatMap() *CatMap { 23 | return &CatMap{ 24 | privateMap: make(map[string]int, 0), 25 | } 26 | } 27 | 28 | func (cm *CatMap) CopyCatMap() *CatMap { 29 | cm.CatMapMut.Lock() 30 | defer cm.CatMapMut.Unlock() 31 | 32 | cp := NewCatMap() 33 | for k, v := range cm.privateMap { 34 | cp.privateMap[k] = v 35 | } 36 | cp.Back = make([]string, len(cm.Back)) 37 | copy(cp.Back, cm.Back) 38 | return cp 39 | } 40 | 41 | //CatToNum provides the int equivalent of the provided categorical value 42 | //if it already exists or adds it to the map and returns the new value if 43 | //it doesn't. 44 | func (cm *CatMap) CatToNum(value string) (numericv int) { 45 | cm.CatMapMut.Lock() 46 | numericv, ok := cm.privateMap[value] 47 | if !ok { 48 | numericv = len(cm.Back) 49 | cm.privateMap[value] = numericv 50 | cm.Back = append(cm.Back, value) 51 | } 52 | cm.CatMapMut.Unlock() 53 | return 54 | } 55 | 56 | //NumToCat returns the catagory label that has been assigned i 57 | func (cm *CatMap) NumToCat(i int) (value string) { 58 | return cm.Back[i] 59 | } 60 | 61 | //NCats returns the number of distinct catagories. 62 | func (cm *CatMap) NCats() (n int) { 63 | if cm.Back == nil { 64 | n = 0 65 | } else { 66 | n = len(cm.Back) 67 | } 68 | return 69 | } 70 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | CloudForest - Sample Data Sets 2 | =============================== 3 | 4 | A few small sample data sets are included in this directory in ready to use formats for use in testing and benchmarking. 5 | 6 | Unless otherwise noted they are taken from the UCI Machine Learning Repository and licensed accordingly. 7 | 8 | http://archive.ics.uci.edu/ml 9 | 10 | Bache, K. & Lichman, M. (2013). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. 11 | 12 | 13 | Iris Data Set 14 | -------------- 15 | iris.data.fm, iris.data.trans.fm 16 | 17 | Classify the C:Class feature. 18 | 19 | From http://archive.ics.uci.edu/ml/datasets/Iris 20 | 21 | Fisher,R.A. "The use of multiple measurements in taxonomic problems" Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to Mathematical Statistics" (John Wiley, NY, 1950). 22 | 23 | 24 | 25 | Fores Fires Data Set 26 | --------------------- 27 | forestfires.fm, forestfires.trans.fm 28 | 29 | Regress on the N:area feature. Also interesting for feature imporatnce. 30 | 31 | From http://archive.ics.uci.edu/ml/datasets/Forest+Fires 32 | 33 | P. Cortez and A. Morais. A Data Mining Approach to Predict Forest Fires using Meteorological Data. In J. Neves, M. F. Santos and J. Machado Eds., New Trends in Artificial Intelligence, Proceedings of the 13th EPIA 2007 - Portuguese Conference on Artificial Intelligence, December, Guimarães, Portugal, pp. 512-523, 2007. APPIA, ISBN-13 978-989-95618-0-9. -------------------------------------------------------------------------------- /data/iris.data.fm: -------------------------------------------------------------------------------- 1 | . 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 2 | N:SepalLength 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 5.4 4.8 4.8 4.3 5.8 5.7 5.4 5.1 5.7 5.1 5.4 5.1 4.6 5.1 4.8 5 5 5.2 5.2 4.7 4.8 5.4 5.2 5.5 4.9 5 5.5 4.9 4.4 5.1 5 4.5 4.4 5 5.1 4.8 5.1 4.6 5.3 5 7 6.4 6.9 5.5 6.5 5.7 6.3 4.9 6.6 5.2 5 5.9 6 6.1 5.6 6.7 5.6 5.8 6.2 5.6 5.9 6.1 6.3 6.1 6.4 6.6 6.8 6.7 6 5.7 5.5 5.5 5.8 6 5.4 6 6.7 6.3 5.6 5.5 5.5 6.1 5.8 5 5.6 5.7 5.7 6.2 5.1 5.7 6.3 5.8 7.1 6.3 6.5 7.6 4.9 7.3 6.7 7.2 6.5 6.4 6.8 5.7 5.8 6.4 6.5 7.7 7.7 6 6.9 5.6 7.7 6.3 6.7 7.2 6.2 6.1 6.4 7.2 7.4 7.9 6.4 6.3 6.1 7.7 6.3 6.4 6 6.9 6.7 6.9 5.8 6.8 6.7 6.7 6.3 6.5 6.2 5.9 3 | N:SepalWidth 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 3.7 3.4 3 3 4 4.4 3.9 3.5 3.8 3.8 3.4 3.7 3.6 3.3 3.4 3 3.4 3.5 3.4 3.2 3.1 3.4 4.1 4.2 3.1 3.2 3.5 3.1 3 3.4 3.5 2.3 3.2 3.5 3.8 3 3.8 3.2 3.7 3.3 3.2 3.2 3.1 2.3 2.8 2.8 3.3 2.4 2.9 2.7 2 3 2.2 2.9 2.9 3.1 3 2.7 2.2 2.5 3.2 2.8 2.5 2.8 2.9 3 2.8 3 2.9 2.6 2.4 2.4 2.7 2.7 3 3.4 3.1 2.3 3 2.5 2.6 3 2.6 2.3 2.7 3 2.9 2.9 2.5 2.8 3.3 2.7 3 2.9 3 3 2.5 2.9 2.5 3.6 3.2 2.7 3 2.5 2.8 3.2 3 3.8 2.6 2.2 3.2 2.8 2.8 2.7 3.3 3.2 2.8 3 2.8 3 2.8 3.8 2.8 2.8 2.6 3 3.4 3.1 3 3.1 3.1 3.1 2.7 3.2 3.3 3 2.5 3 3.4 3 4 | N:PetalLength 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 1.5 1.6 1.4 1.1 1.2 1.5 1.3 1.4 1.7 1.5 1.7 1.5 1 1.7 1.9 1.6 1.6 1.5 1.4 1.6 1.6 1.5 1.5 1.4 1.5 1.2 1.3 1.5 1.3 1.5 1.3 1.3 1.3 1.6 1.9 1.4 1.6 1.4 1.5 1.4 4.7 4.5 4.9 4 4.6 4.5 4.7 3.3 4.6 3.9 3.5 4.2 4 4.7 3.6 4.4 4.5 4.1 4.5 3.9 4.8 4 4.9 4.7 4.3 4.4 4.8 5 4.5 3.5 3.8 3.7 3.9 5.1 4.5 4.5 4.7 4.4 4.1 4 4.4 4.6 4 3.3 4.2 4.2 4.2 4.3 3 4.1 6 5.1 5.9 5.6 5.8 6.6 4.5 6.3 5.8 6.1 5.1 5.3 5.5 5 5.1 5.3 5.5 6.7 6.9 5 5.7 4.9 6.7 4.9 5.7 6 4.8 4.9 5.6 5.8 6.1 6.4 5.6 5.1 5.6 6.1 5.6 5.5 4.8 5.4 5.6 5.1 5.1 5.9 5.7 5.2 5 5.2 5.4 5.1 5 | N:PetalWidth 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 0.2 0.2 0.1 0.1 0.2 0.4 0.4 0.3 0.3 0.3 0.2 0.4 0.2 0.5 0.2 0.2 0.4 0.2 0.2 0.2 0.2 0.4 0.1 0.2 0.1 0.2 0.2 0.1 0.2 0.2 0.3 0.3 0.2 0.6 0.4 0.3 0.2 0.2 0.2 0.2 1.4 1.5 1.5 1.3 1.5 1.3 1.6 1 1.3 1.4 1 1.5 1 1.4 1.3 1.4 1.5 1 1.5 1.1 1.8 1.3 1.5 1.2 1.3 1.4 1.4 1.7 1.5 1 1.1 1 1.2 1.6 1.5 1.6 1.5 1.3 1.3 1.3 1.2 1.4 1.2 1 1.3 1.2 1.3 1.3 1.1 1.3 2.5 1.9 2.1 1.8 2.2 2.1 1.7 1.8 1.8 2.5 2 1.9 2.1 2 2.4 2.3 1.8 2.2 2.3 1.5 2.3 2 2 1.8 2.1 1.8 1.8 1.8 2.1 1.6 1.9 2 2.2 1.5 1.4 2.3 2.4 1.8 1.8 2.1 2.4 2.3 1.9 2.3 2.5 2.3 1.9 2 2.3 1.8 6 | C:Class Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-setosa Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-versicolor Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica Iris-virginica -------------------------------------------------------------------------------- /data/iris.data.trans.fm: -------------------------------------------------------------------------------- 1 | . N:SepalLength N:SepalWidth N:PetalLength N:PetalWidth C:Class 2 | 1 5.1 3.5 1.4 0.2 Iris-setosa 3 | 2 4.9 3 1.4 0.2 Iris-setosa 4 | 3 4.7 3.2 1.3 0.2 Iris-setosa 5 | 4 4.6 3.1 1.5 0.2 Iris-setosa 6 | 5 5 3.6 1.4 0.2 Iris-setosa 7 | 6 5.4 3.9 1.7 0.4 Iris-setosa 8 | 7 4.6 3.4 1.4 0.3 Iris-setosa 9 | 8 5 3.4 1.5 0.2 Iris-setosa 10 | 9 4.4 2.9 1.4 0.2 Iris-setosa 11 | 10 4.9 3.1 1.5 0.1 Iris-setosa 12 | 11 5.4 3.7 1.5 0.2 Iris-setosa 13 | 12 4.8 3.4 1.6 0.2 Iris-setosa 14 | 13 4.8 3 1.4 0.1 Iris-setosa 15 | 14 4.3 3 1.1 0.1 Iris-setosa 16 | 15 5.8 4 1.2 0.2 Iris-setosa 17 | 16 5.7 4.4 1.5 0.4 Iris-setosa 18 | 17 5.4 3.9 1.3 0.4 Iris-setosa 19 | 18 5.1 3.5 1.4 0.3 Iris-setosa 20 | 19 5.7 3.8 1.7 0.3 Iris-setosa 21 | 20 5.1 3.8 1.5 0.3 Iris-setosa 22 | 21 5.4 3.4 1.7 0.2 Iris-setosa 23 | 22 5.1 3.7 1.5 0.4 Iris-setosa 24 | 23 4.6 3.6 1 0.2 Iris-setosa 25 | 24 5.1 3.3 1.7 0.5 Iris-setosa 26 | 25 4.8 3.4 1.9 0.2 Iris-setosa 27 | 26 5 3 1.6 0.2 Iris-setosa 28 | 27 5 3.4 1.6 0.4 Iris-setosa 29 | 28 5.2 3.5 1.5 0.2 Iris-setosa 30 | 29 5.2 3.4 1.4 0.2 Iris-setosa 31 | 30 4.7 3.2 1.6 0.2 Iris-setosa 32 | 31 4.8 3.1 1.6 0.2 Iris-setosa 33 | 32 5.4 3.4 1.5 0.4 Iris-setosa 34 | 33 5.2 4.1 1.5 0.1 Iris-setosa 35 | 34 5.5 4.2 1.4 0.2 Iris-setosa 36 | 35 4.9 3.1 1.5 0.1 Iris-setosa 37 | 36 5 3.2 1.2 0.2 Iris-setosa 38 | 37 5.5 3.5 1.3 0.2 Iris-setosa 39 | 38 4.9 3.1 1.5 0.1 Iris-setosa 40 | 39 4.4 3 1.3 0.2 Iris-setosa 41 | 40 5.1 3.4 1.5 0.2 Iris-setosa 42 | 41 5 3.5 1.3 0.3 Iris-setosa 43 | 42 4.5 2.3 1.3 0.3 Iris-setosa 44 | 43 4.4 3.2 1.3 0.2 Iris-setosa 45 | 44 5 3.5 1.6 0.6 Iris-setosa 46 | 45 5.1 3.8 1.9 0.4 Iris-setosa 47 | 46 4.8 3 1.4 0.3 Iris-setosa 48 | 47 5.1 3.8 1.6 0.2 Iris-setosa 49 | 48 4.6 3.2 1.4 0.2 Iris-setosa 50 | 49 5.3 3.7 1.5 0.2 Iris-setosa 51 | 50 5 3.3 1.4 0.2 Iris-setosa 52 | 51 7 3.2 4.7 1.4 Iris-versicolor 53 | 52 6.4 3.2 4.5 1.5 Iris-versicolor 54 | 53 6.9 3.1 4.9 1.5 Iris-versicolor 55 | 54 5.5 2.3 4 1.3 Iris-versicolor 56 | 55 6.5 2.8 4.6 1.5 Iris-versicolor 57 | 56 5.7 2.8 4.5 1.3 Iris-versicolor 58 | 57 6.3 3.3 4.7 1.6 Iris-versicolor 59 | 58 4.9 2.4 3.3 1 Iris-versicolor 60 | 59 6.6 2.9 4.6 1.3 Iris-versicolor 61 | 60 5.2 2.7 3.9 1.4 Iris-versicolor 62 | 61 5 2 3.5 1 Iris-versicolor 63 | 62 5.9 3 4.2 1.5 Iris-versicolor 64 | 63 6 2.2 4 1 Iris-versicolor 65 | 64 6.1 2.9 4.7 1.4 Iris-versicolor 66 | 65 5.6 2.9 3.6 1.3 Iris-versicolor 67 | 66 6.7 3.1 4.4 1.4 Iris-versicolor 68 | 67 5.6 3 4.5 1.5 Iris-versicolor 69 | 68 5.8 2.7 4.1 1 Iris-versicolor 70 | 69 6.2 2.2 4.5 1.5 Iris-versicolor 71 | 70 5.6 2.5 3.9 1.1 Iris-versicolor 72 | 71 5.9 3.2 4.8 1.8 Iris-versicolor 73 | 72 6.1 2.8 4 1.3 Iris-versicolor 74 | 73 6.3 2.5 4.9 1.5 Iris-versicolor 75 | 74 6.1 2.8 4.7 1.2 Iris-versicolor 76 | 75 6.4 2.9 4.3 1.3 Iris-versicolor 77 | 76 6.6 3 4.4 1.4 Iris-versicolor 78 | 77 6.8 2.8 4.8 1.4 Iris-versicolor 79 | 78 6.7 3 5 1.7 Iris-versicolor 80 | 79 6 2.9 4.5 1.5 Iris-versicolor 81 | 80 5.7 2.6 3.5 1 Iris-versicolor 82 | 81 5.5 2.4 3.8 1.1 Iris-versicolor 83 | 82 5.5 2.4 3.7 1 Iris-versicolor 84 | 83 5.8 2.7 3.9 1.2 Iris-versicolor 85 | 84 6 2.7 5.1 1.6 Iris-versicolor 86 | 85 5.4 3 4.5 1.5 Iris-versicolor 87 | 86 6 3.4 4.5 1.6 Iris-versicolor 88 | 87 6.7 3.1 4.7 1.5 Iris-versicolor 89 | 88 6.3 2.3 4.4 1.3 Iris-versicolor 90 | 89 5.6 3 4.1 1.3 Iris-versicolor 91 | 90 5.5 2.5 4 1.3 Iris-versicolor 92 | 91 5.5 2.6 4.4 1.2 Iris-versicolor 93 | 92 6.1 3 4.6 1.4 Iris-versicolor 94 | 93 5.8 2.6 4 1.2 Iris-versicolor 95 | 94 5 2.3 3.3 1 Iris-versicolor 96 | 95 5.6 2.7 4.2 1.3 Iris-versicolor 97 | 96 5.7 3 4.2 1.2 Iris-versicolor 98 | 97 5.7 2.9 4.2 1.3 Iris-versicolor 99 | 98 6.2 2.9 4.3 1.3 Iris-versicolor 100 | 99 5.1 2.5 3 1.1 Iris-versicolor 101 | 100 5.7 2.8 4.1 1.3 Iris-versicolor 102 | 101 6.3 3.3 6 2.5 Iris-virginica 103 | 102 5.8 2.7 5.1 1.9 Iris-virginica 104 | 103 7.1 3 5.9 2.1 Iris-virginica 105 | 104 6.3 2.9 5.6 1.8 Iris-virginica 106 | 105 6.5 3 5.8 2.2 Iris-virginica 107 | 106 7.6 3 6.6 2.1 Iris-virginica 108 | 107 4.9 2.5 4.5 1.7 Iris-virginica 109 | 108 7.3 2.9 6.3 1.8 Iris-virginica 110 | 109 6.7 2.5 5.8 1.8 Iris-virginica 111 | 110 7.2 3.6 6.1 2.5 Iris-virginica 112 | 111 6.5 3.2 5.1 2 Iris-virginica 113 | 112 6.4 2.7 5.3 1.9 Iris-virginica 114 | 113 6.8 3 5.5 2.1 Iris-virginica 115 | 114 5.7 2.5 5 2 Iris-virginica 116 | 115 5.8 2.8 5.1 2.4 Iris-virginica 117 | 116 6.4 3.2 5.3 2.3 Iris-virginica 118 | 117 6.5 3 5.5 1.8 Iris-virginica 119 | 118 7.7 3.8 6.7 2.2 Iris-virginica 120 | 119 7.7 2.6 6.9 2.3 Iris-virginica 121 | 120 6 2.2 5 1.5 Iris-virginica 122 | 121 6.9 3.2 5.7 2.3 Iris-virginica 123 | 122 5.6 2.8 4.9 2 Iris-virginica 124 | 123 7.7 2.8 6.7 2 Iris-virginica 125 | 124 6.3 2.7 4.9 1.8 Iris-virginica 126 | 125 6.7 3.3 5.7 2.1 Iris-virginica 127 | 126 7.2 3.2 6 1.8 Iris-virginica 128 | 127 6.2 2.8 4.8 1.8 Iris-virginica 129 | 128 6.1 3 4.9 1.8 Iris-virginica 130 | 129 6.4 2.8 5.6 2.1 Iris-virginica 131 | 130 7.2 3 5.8 1.6 Iris-virginica 132 | 131 7.4 2.8 6.1 1.9 Iris-virginica 133 | 132 7.9 3.8 6.4 2 Iris-virginica 134 | 133 6.4 2.8 5.6 2.2 Iris-virginica 135 | 134 6.3 2.8 5.1 1.5 Iris-virginica 136 | 135 6.1 2.6 5.6 1.4 Iris-virginica 137 | 136 7.7 3 6.1 2.3 Iris-virginica 138 | 137 6.3 3.4 5.6 2.4 Iris-virginica 139 | 138 6.4 3.1 5.5 1.8 Iris-virginica 140 | 139 6 3 4.8 1.8 Iris-virginica 141 | 140 6.9 3.1 5.4 2.1 Iris-virginica 142 | 141 6.7 3.1 5.6 2.4 Iris-virginica 143 | 142 6.9 3.1 5.1 2.3 Iris-virginica 144 | 143 5.8 2.7 5.1 1.9 Iris-virginica 145 | 144 6.8 3.2 5.9 2.3 Iris-virginica 146 | 145 6.7 3.3 5.7 2.5 Iris-virginica 147 | 146 6.7 3 5.2 2.3 Iris-virginica 148 | 147 6.3 2.5 5 1.9 Iris-virginica 149 | 148 6.5 3 5.2 2 Iris-virginica 150 | 149 6.2 3.4 5.4 2.3 Iris-virginica 151 | 150 5.9 3 5.1 1.8 Iris-virginica -------------------------------------------------------------------------------- /densecatfeature_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestCatFeature(t *testing.T) { 9 | 10 | //Start with a small cat feature and do some simple spliting tests 11 | //then build it up and do some best split finding tests 12 | 13 | name := "catfeature" 14 | 15 | f := &DenseCatFeature{ 16 | NewCatMap(), 17 | make([]int, 0, 0), 18 | make([]bool, 0, 0), 19 | name, 20 | false, 21 | false} 22 | 23 | fm := FeatureMatrix{[]Feature{f}, 24 | map[string]int{name: 0}, 25 | []string{name}} 26 | 27 | f.Append("0") 28 | f.Append("1") 29 | f.Append("1") 30 | 31 | //f has 0 1 1 32 | 33 | if x := f.NCats(); x != 2 { 34 | t.Errorf("Boolean NCats = %v != 2", x) 35 | } 36 | 37 | fns := f.EncodeToNum() 38 | fn := fns[0].(NumFeature) 39 | 40 | if len(fns) != 1 || fn.Get(0) != 0.0 || fn.Get(1) != 1.0 || fn.Get(2) != 1.0 { 41 | t.Errorf("Error: cat feature %v encoded to %v", f.CatData, fn.(*DenseNumFeature).NumData) 42 | } 43 | 44 | codedSplit := 1 45 | cases := []int{0, 1, 2} 46 | 47 | l, r, m := f.Split(0, cases) 48 | if len(l) != 0 || len(r) != 3 || len(m) != 0 { 49 | t.Errorf("After Coded Boolean Split 0 Left, Right, Missing Lengths = %v %v %v not 0 3 0", len(l), len(r), len(m)) 50 | } 51 | 52 | decodedsplit := f.DecodeSplit(0) 53 | 54 | l, r, m = decodedsplit.Split(&fm, cases) 55 | 56 | if len(l) != 0 || len(r) != 3 || len(m) != 0 { 57 | t.Errorf("After Decoded Boolean Split 0 Left, Right, Missing Lengths = %v %v %v not 0 3 0", len(l), len(r), len(m)) 58 | } 59 | 60 | l, r, m = f.Split(1, cases) 61 | if len(l) != 1 || len(r) != 2 || len(m) != 0 { 62 | t.Errorf("After Coded Boolean Split 1 Left, Right, Missing Lengths = %v %v %v not 1 2 0", len(l), len(r), len(m)) 63 | } 64 | 65 | l, r, m = f.Split(2, cases) 66 | if len(l) != 2 || len(r) != 1 || len(m) != 0 { 67 | t.Errorf("After Coded Boolean Split 2 Left, Right, Missing Lengths = %v %v %v not 2 1 0", len(l), len(r), len(m)) 68 | } 69 | 70 | decodedsplit = f.DecodeSplit(codedSplit) 71 | 72 | l, r, m = decodedsplit.Split(&fm, cases) 73 | 74 | if len(l) != 1 || len(r) != 2 || len(m) != 0 { 75 | t.Errorf("After Decoded Boolean Split Left, Right, Missing Lengths = %v %v %v not 1 2 0", len(l), len(r), len(m)) 76 | } 77 | 78 | f.Append("0") 79 | cases = append(cases, 3) 80 | // f has 0 1 1 0 81 | 82 | l, r, m = decodedsplit.Split(&fm, cases) 83 | 84 | if len(l) != 2 || len(r) != 2 || len(m) != 0 { 85 | t.Errorf("After Decoded Boolean Split Left, Right, Missing Lengths = %v %v %v not 2 2 0", len(l), len(r), len(m)) 86 | } 87 | 88 | l, r, m = f.Split(codedSplit, cases) 89 | if len(l) != 2 || len(r) != 2 || len(m) != 0 { 90 | t.Errorf("After Coded Boolean Split Left, Right, Missing Lengths = %v %v %v not 2 2 0", len(l), len(r), len(m)) 91 | } 92 | 93 | f.Append("0") 94 | cases = append(cases, 4) 95 | 96 | allocs := NewBestSplitAllocs(5, f) 97 | 98 | _, split, _, _ := fm.BestSplitter(f, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 99 | 100 | if split.(int) != 1 { 101 | t.Errorf("Boolean feature didn't self split. Returned %v", split) 102 | } 103 | 104 | //f has 0 1 1 0 0 105 | 106 | target := f.Copy() 107 | target.Append("1") 108 | f.Append("NA") 109 | 110 | //f has 0 1 1 0 0 NA 111 | //target has 0 1 1 0 0 1 112 | 113 | if f.IsMissing(5) != true || f.MissingVals() != true || f.HasMissing != true { 114 | t.Error("Feature with missing values claims not") 115 | } 116 | 117 | if target.IsMissing(5) == true || target.MissingVals() == true || target.(*DenseCatFeature).HasMissing == true { 118 | t.Error("Target with missing values claims not") 119 | } 120 | 121 | cases = append(cases, 5) 122 | 123 | allocs = NewBestSplitAllocs(6, target) 124 | 125 | _, split, _, _ = fm.BestSplitter(target, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 126 | 127 | if split.(int) != 1 { 128 | t.Errorf("Boolean with missing val feature didn't split non missing copy. Returned %v", split) 129 | } 130 | 131 | target.PutStr(5, "2") 132 | 133 | //f has 0 1 1 0 0 NA 134 | //target has 0 1 1 0 0 2 135 | 136 | allocs = NewBestSplitAllocs(6, target) 137 | 138 | _, split, _, _ = fm.BestSplitter(target, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 139 | 140 | if split.(int) != 1 { 141 | t.Errorf("Trinary target with bool missing val feature didn't split. Returned %v", split) 142 | } 143 | 144 | f.PutStr(5, "0") 145 | 146 | //f has 0 1 1 0 0 0 147 | //target has 0 1 1 0 0 2 148 | // zero should go left by itself, code = 1 149 | 150 | allocs = NewBestSplitAllocs(6, target) 151 | 152 | _, split, _, _ = fm.BestSplitter(target, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 153 | 154 | if split.(int) != 1 { 155 | t.Errorf("Trinary target with bool missing val feature didn't split. Returned %v", split) 156 | } 157 | 158 | mediumf := &DenseCatFeature{ 159 | NewCatMap(), 160 | make([]int, 0, 0), 161 | make([]bool, 0, 0), 162 | "mediumf", 163 | false, 164 | false} 165 | 166 | for i := 0; i < 6; i++ { 167 | mediumf.Append(fmt.Sprintf("%v", i)) 168 | } 169 | 170 | mediumfm := FeatureMatrix{[]Feature{mediumf}, 171 | map[string]int{mediumf.Name: 0}, 172 | []string{mediumf.Name}} 173 | 174 | //f has 0 1 1 0 0 0 175 | //target has 0 1 1 0 0 2 176 | //medieumf has 0 1 2 3 4 5 177 | 178 | allocs = NewBestSplitAllocs(6, target) 179 | 180 | //split f by medium f should send 1 and 2 to one side, coded 6 181 | _, split, _, _ = mediumfm.BestSplitter(f, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 182 | 183 | if split.(int) != 6 { 184 | t.Errorf("Binary target with 6 valued feature didn't split. Returned %v", split) 185 | } 186 | 187 | l, r, m = mediumf.Split(split, cases) 188 | if len(l) != 2 || len(r) != 4 || len(m) != 0 { 189 | t.Errorf("After Coded Boolean vs Multivalued Split Left, Right, Missing Lengths = %v %v %v not 2 4 0", len(l), len(r), len(m)) 190 | } 191 | 192 | decodedsplit = mediumf.DecodeSplit(split) 193 | 194 | l, r, m = decodedsplit.Split(&mediumfm, cases) 195 | //fmt.Println(decodedsplit.Left) 196 | 197 | if len(l) != 2 || len(r) != 4 || len(m) != 0 { 198 | t.Errorf("After Decoded Boolean Split Left, Right, Missing Lengths = %v %v %v not 2 4 0", len(l), len(r), len(m)) 199 | } 200 | 201 | //target.Append(v) 202 | 203 | f = &DenseCatFeature{ 204 | NewCatMap(), 205 | make([]int, 0, 0), 206 | make([]bool, 0, 0), 207 | "tmp", 208 | false, 209 | false} 210 | 211 | f.Append("") 212 | 213 | if !f.IsMissing(0) || !f.IsZero(0) { 214 | t.Errorf("feature should be missing (%v), and be of zero value (%v)", f.IsMissing(0), f.IsZero(0)) 215 | } 216 | } 217 | 218 | func TestBigCatFeature(t *testing.T) { 219 | 220 | bigf := &DenseCatFeature{ 221 | NewCatMap(), 222 | make([]int, 0, 0), 223 | make([]bool, 0, 0), 224 | "big", 225 | false, 226 | false} 227 | 228 | boolf := &DenseCatFeature{ 229 | NewCatMap(), 230 | make([]int, 0, 0), 231 | make([]bool, 0, 0), 232 | "bool", 233 | false, 234 | false} 235 | 236 | cases := make([]int, 40, 40) 237 | for i := 0; i < 40; i++ { 238 | bigf.Append(fmt.Sprintf("%v", i)) 239 | boolf.Append(fmt.Sprintf("%v", i < 20)) 240 | cases[i] = i 241 | } 242 | 243 | bigfm := FeatureMatrix{[]Feature{bigf}, 244 | map[string]int{bigf.Name: 0}, 245 | []string{bigf.Name}} 246 | 247 | allocs := NewBestSplitAllocs(40, boolf) 248 | 249 | //split f by medium f should send 1 and 2 to one side, coded 6 250 | _, split, _, _ := bigfm.BestSplitter(boolf, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 251 | 252 | l, r, m := bigf.Split(split, cases) 253 | if len(l) != 20 || len(r) != 20 || len(m) != 0 { 254 | t.Errorf("After Coded big split Left, Right, Missing Lengths = %v %v %v not 20 20 0", len(l), len(r), len(m)) 255 | } 256 | 257 | decodedsplit := bigf.DecodeSplit(split) 258 | 259 | l, r, m = decodedsplit.Split(&bigfm, cases) 260 | //fmt.Println(decodedsplit.Left) 261 | 262 | if len(l) != 20 || len(r) != 20 || len(m) != 0 { 263 | t.Errorf("After Decoded big split Left, Right, Missing Lengths = %v %v %v not 20 20 0", len(l), len(r), len(m)) 264 | } 265 | 266 | bigf.RandomSearch = true 267 | _, split, _, _ = bigfm.BestSplitter(boolf, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 268 | 269 | l, r, m = bigf.Split(split, cases) 270 | //won't perfectelly split but should do okay 271 | if len(l) < 18 || len(r) < 18 || len(m) != 0 { 272 | t.Errorf("After Coded big random split Left, Right, Missing Lengths = %v %v %v not >=18 >=18 0", len(l), len(r), len(m)) 273 | } 274 | 275 | bigf.PutMissing(23) 276 | bigf.RandomSearch = false 277 | 278 | //split f by medium f should send 1 and 2 to one side, coded 6 279 | _, split, _, _ = bigfm.BestSplitter(boolf, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 280 | 281 | l, r, m = bigf.Split(split, cases) 282 | if len(l) < 19 || len(r) < 19 || len(m) != 1 { 283 | t.Errorf("After Coded big missing split Left, Right, Missing Lengths = %v %v %v not 19 20 1", len(l), len(r), len(m)) 284 | } 285 | 286 | decodedsplit = bigf.DecodeSplit(split) 287 | 288 | l, r, m = decodedsplit.Split(&bigfm, cases) 289 | //fmt.Println(decodedsplit.Left) 290 | 291 | if len(l) < 19 || len(r) < 19 || len(m) != 1 { 292 | t.Errorf("After Decoded big split Left, Right, Missing Lengths = %v %v %v not 19 20 1", len(l), len(r), len(m)) 293 | } 294 | 295 | bigf.RandomSearch = true 296 | _, split, _, _ = bigfm.BestSplitter(boolf, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 297 | 298 | l, r, m = bigf.Split(split, cases) 299 | //won't perfectelly split but should do okay 300 | if len(l) < 18 || len(r) < 18 || len(m) != 1 { 301 | t.Errorf("After Coded big random split Left, Right, Missing Lengths = %v %v %v not >=18 >=18 1", len(l), len(r), len(m)) 302 | } 303 | 304 | } 305 | -------------------------------------------------------------------------------- /densenumfeature_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import "testing" 4 | 5 | func TestNumFeature(t *testing.T) { 6 | 7 | name := "numfeature" 8 | 9 | f := &DenseNumFeature{ 10 | make([]float64, 0, 0), 11 | make([]bool, 0, 0), 12 | name, 13 | false} 14 | 15 | f.Append("0.1") 16 | f.Append("10.1") 17 | f.Append("10.2") 18 | 19 | if x := f.NCats(); x != 0 { 20 | t.Errorf("Numerical NCats = %v != 0", x) 21 | } 22 | 23 | codedSplit := 0.5 24 | cases := []int{0, 1, 2} 25 | 26 | l, r, m := f.Split(codedSplit, cases) 27 | if len(l) != 1 || len(r) != 2 || len(m) != 0 { 28 | t.Errorf("After Coded Numerical Split Left, Right, Missing Lengths = %v %v %v not 1 2 0", len(l), len(r), len(m)) 29 | } 30 | 31 | decodedsplit := f.DecodeSplit(codedSplit) 32 | 33 | fm := FeatureMatrix{[]Feature{f}, 34 | map[string]int{name: 0}, 35 | []string{name}} 36 | 37 | if !f.GoesLeft(0, decodedsplit) { 38 | t.Errorf("Value %v sent right by spliter decoded from %v", f.NumData[0], codedSplit) 39 | } 40 | if f.GoesLeft(1, decodedsplit) { 41 | t.Errorf("Value %v sent left by spliter decoded from %v", f.NumData[1], codedSplit) 42 | } 43 | 44 | l, r, m = decodedsplit.Split(&fm, cases) 45 | 46 | if len(l) != 1 || len(r) != 2 || len(m) != 0 { 47 | t.Errorf("After Decoded Numerical Split Left, Right, Missing Lengths = %v %v %v not 1 2 0", len(l), len(r), len(m)) 48 | } 49 | 50 | f.Append("0.0") 51 | cases = append(cases, 3) 52 | 53 | f.Append("0.0") 54 | cases = append(cases, 4) 55 | 56 | l, r, m = decodedsplit.Split(&fm, cases) 57 | 58 | if len(l) != 3 || len(r) != 2 || len(m) != 0 { 59 | t.Errorf("After Decoded Numerical Split Left, Right, Missing Lengths = %v %v %v not 3 2 0", len(l), len(r), len(m)) 60 | } 61 | 62 | l, r, m = f.Split(codedSplit, cases) 63 | if len(l) != 3 || len(r) != 2 || len(m) != 0 { 64 | t.Errorf("After Coded Numerical Split Left, Right, Missing Lengths = %v %v %v not 3 2 0", len(l), len(r), len(m)) 65 | } 66 | 67 | //check self slitting 68 | 69 | allocs := NewBestSplitAllocs(5, f) 70 | 71 | _, split, _, _ := fm.BestSplitter(f, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 72 | //fm.BestSplitter(target, cases, candidates, oob, leafSize, vet, evaloob, allocs) 73 | 74 | if split.(float64) != 5.1 { 75 | t.Errorf("Numerical feature didn't self split correctelly. Returned %v not 5.1", split) 76 | } 77 | 78 | l, r, m = f.Split(split, cases) 79 | if len(l) != 3 || len(r) != 2 || len(m) != 0 { 80 | t.Errorf("After Coded Numerical Split Left, Right, Missing Lengths = %v %v %v not 3 2 0", len(l), len(r), len(m)) 81 | } 82 | 83 | //and with a run of equals 84 | f.Append(".1") 85 | cases = append(cases, 5) 86 | f.Append(".1") 87 | cases = append(cases, 6) 88 | 89 | allocs = NewBestSplitAllocs(7, f) 90 | 91 | _, split, _, constants := fm.BestSplitter(f, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 92 | //fm.BestSplitter(target, cases, candidates, oob, leafSize, vet, evaloob, allocs) 93 | 94 | if split.(float64) != 5.1 || constants != 0 { 95 | t.Errorf("Numerical feature didn't self split correctelly with equal run. Returned %v not 5.1", split) 96 | } 97 | 98 | l, r, m = f.Split(split, cases) 99 | if len(l) != 5 || len(r) != 2 || len(m) != 0 { 100 | t.Errorf("After Coded Numerical Split with equal run Left, Right, Missing Lengths = %v %v %v not 5 2 0", len(l), len(r), len(m)) 101 | } 102 | 103 | //spliting between two runs of equals 104 | f.Append("10.1") 105 | cases = append(cases, 7) 106 | 107 | allocs = NewBestSplitAllocs(8, f) 108 | 109 | _, split, _, _ = fm.BestSplitter(f, &cases, &[]int{0}, 1, nil, 1, false, false, false, false, allocs, 0) 110 | //fm.BestSplitter(target, cases, candidates, oob, leafSize, vet, evaloob, allocs) 111 | 112 | sorted := true 113 | for i := 1; i < len(cases); i++ { 114 | if f.NumData[cases[i]] < f.NumData[cases[i-1]] { 115 | sorted = false 116 | } 117 | 118 | } 119 | if !sorted { 120 | t.Error("Numerical feature didn't sort cases.") 121 | } 122 | 123 | if split.(float64) != 5.1 { 124 | t.Errorf("Numerical feature didn't self split correctelly between equal runs. Returned %v not 5.1", split) 125 | } 126 | 127 | l, r, m = f.Split(split, cases) 128 | if len(l) != 5 || len(r) != 3 || len(m) != 0 { 129 | t.Errorf("After Coded Numerical Split between equal runs Left, Right, Missing Lengths = %v %v %v not 5 3 0", len(l), len(r), len(m)) 130 | } 131 | 132 | f = &DenseNumFeature{ 133 | make([]float64, 0, 0), 134 | make([]bool, 0, 0), 135 | name, 136 | false} 137 | f.Append("0") 138 | if f.IsMissing(0) || !f.IsZero(0) { 139 | t.Errorf("feature should not be missing (%v), and be of zero value (%v)", !f.IsMissing(0), f.IsZero(0)) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /densitytarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | /* 8 | DensityTarget is used for density estimating trees. It contains a set of features and the 9 | count of cases. 10 | */ 11 | type DensityTarget struct { 12 | Features *[]Feature 13 | N int 14 | } 15 | 16 | func (target *DensityTarget) GetName() string { 17 | return "DensityTarget" 18 | } 19 | 20 | /* 21 | DensityTarget.SplitImpurity is a density estimating version of SplitImpurity. 22 | */ 23 | func (target *DensityTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 24 | nl := float64(len(*l)) 25 | nr := float64(len(*r)) 26 | nm := 0.0 27 | 28 | impurityDecrease = nl * target.Impurity(l, nil) 29 | impurityDecrease += nr * target.Impurity(r, nil) 30 | if m != nil && len(*m) > 0 { 31 | nm = float64(len(*m)) 32 | impurityDecrease += nm * target.Impurity(m, nil) 33 | } 34 | 35 | impurityDecrease /= nl + nr + nm 36 | return 37 | } 38 | 39 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 40 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 41 | func (target *DensityTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 42 | return target.SplitImpurity(l, r, m, allocs) 43 | } 44 | 45 | //DensityTarget.Impurity uses the impurity measure defined in "Density Estimating Trees" 46 | //by Parikshit Ram and Alexander G. Gray 47 | func (target *DensityTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 48 | t := len(*cases) 49 | e = float64(t*t) / float64(target.N*target.N) 50 | for _, f := range *target.Features { 51 | switch f.(type) { 52 | case CatFeature: 53 | bigenoughcounter := make([]int, f.NCats()) 54 | e /= f.Span(cases, &bigenoughcounter) 55 | case NumFeature: 56 | e /= f.Span(cases, nil) 57 | } 58 | } 59 | 60 | return 61 | } 62 | 63 | //DensityTarget.FindPredicted returns the string representation of the density in the region 64 | //spaned by the specified cases. 65 | func (target *DensityTarget) FindPredicted(cases []int) string { 66 | t := len(cases) 67 | e := float64(t) / float64(target.N) 68 | 69 | for _, f := range *target.Features { 70 | switch f.(type) { 71 | case CatFeature: 72 | bigenoughcounter := make([]int, f.NCats()) 73 | e /= f.Span(&cases, &bigenoughcounter) 74 | case NumFeature: 75 | e /= f.Span(&cases, nil) 76 | } 77 | } 78 | 79 | return fmt.Sprintf("%v", e) 80 | } 81 | 82 | func (target *DensityTarget) NCats() int { 83 | return 0 84 | } 85 | -------------------------------------------------------------------------------- /dentropytarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | ) 7 | 8 | /* 9 | DEntropyTarget wraps a categorical feature for use in entropy driven classification 10 | as in Ross Quinlan's ID3 (Iterative Dichotomizer 3) with a the entropy modified to use 11 | "disutility entropy" 12 | 13 | I = - k Sum ri * pi * log(pi) 14 | 15 | */ 16 | type DEntropyTarget struct { 17 | CatFeature 18 | Costs []float64 19 | } 20 | 21 | //NewDEntropyTarget creates a RefretTarget and initializes DEntropyTarget.Costs to the proper length. 22 | func NewDEntropyTarget(f CatFeature) *DEntropyTarget { 23 | return &DEntropyTarget{f, make([]float64, f.NCats())} 24 | } 25 | 26 | /*NewDEntropyTarget.SetCosts puts costs in a map[string]float64 by feature name into the proper 27 | entries in NewDEntropyTarget.Costs.*/ 28 | func (target *DEntropyTarget) SetCosts(costmap map[string]float64) { 29 | for i := 0; i < target.NCats(); i++ { 30 | c := target.NumToCat(i) 31 | target.Costs[i] = costmap[c] 32 | } 33 | } 34 | 35 | /* 36 | DEntropyTarget.SplitImpurity is a version of Split Impurity that calls DEntropyTarget.Impurity 37 | */ 38 | func (target *DEntropyTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 39 | nl := float64(len(*l)) 40 | nr := float64(len(*r)) 41 | nm := 0.0 42 | 43 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 44 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 45 | if m != nil && len(*m) > 0 { 46 | nm = float64(len(*m)) 47 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 48 | } 49 | 50 | impurityDecrease /= nl + nr + nm 51 | return 52 | } 53 | 54 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 55 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 56 | func (target *DEntropyTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 57 | target.MoveCountsRtoL(allocs, movedRtoL) 58 | nl := float64(len(*l)) 59 | nr := float64(len(*r)) 60 | nm := 0.0 61 | 62 | impurityDecrease = nl * target.ImpFromCounts(len(*l), allocs.LCounter) 63 | impurityDecrease += nr * target.ImpFromCounts(len(*r), allocs.RCounter) 64 | if m != nil && len(*m) > 0 { 65 | nm = float64(len(*m)) 66 | impurityDecrease += nm * target.ImpFromCounts(len(*m), allocs.Counter) 67 | } 68 | 69 | impurityDecrease /= nl + nr + nm 70 | return 71 | } 72 | 73 | func (target *DEntropyTarget) ImpFromCounts(total int, counts *[]int) (e float64) { 74 | p := 0.0 75 | for c, i := range *counts { 76 | if i > 0 { 77 | p = float64(i) / float64(total) 78 | e -= target.Costs[c] * p * math.Log(p) 79 | } 80 | } 81 | return 82 | 83 | } 84 | 85 | func (target *DEntropyTarget) FindPredicted(cases []int) (pred string) { 86 | prob_true := 0.0 87 | t := target.CatToNum("True") 88 | weightedvoted := true 89 | if weightedvoted { 90 | count := 0.0 91 | total := 0.0 92 | for _, i := range cases { 93 | ti := target.Geti(i) 94 | cost := target.Costs[ti] 95 | if ti == t { 96 | count += cost 97 | } 98 | total += cost 99 | 100 | } 101 | prob_true = count / total 102 | 103 | } else { 104 | count := 0 105 | for _, i := range cases { 106 | if target.Geti(i) == t { 107 | count++ 108 | } 109 | 110 | } 111 | prob_true = float64(count) / float64(len(cases)) 112 | } 113 | return fmt.Sprintf("%v", prob_true) 114 | } 115 | 116 | //DEntropyTarget.Impurity implements categorical entropy as sum(pj*log2(pj)) where pj 117 | //is the number of cases with the j'th category over the total number of cases. 118 | func (target *DEntropyTarget) Impurity(cases *[]int, counts *[]int) (e float64) { 119 | 120 | total := len(*cases) 121 | target.CountPerCat(cases, counts) 122 | 123 | p := 0.0 124 | for _, i := range *counts { 125 | if i > 0 { 126 | p = float64(i) / float64(total) 127 | e -= p * math.Log(p) 128 | } 129 | 130 | } 131 | 132 | return 133 | 134 | } 135 | -------------------------------------------------------------------------------- /entropytarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "math" 5 | ) 6 | 7 | /* 8 | EntropyTarget wraps a categorical feature for use in entropy driven classification 9 | as in Ross Quinlan's ID3 (Iterative Dichotomizer 3). 10 | */ 11 | type EntropyTarget struct { 12 | CatFeature 13 | } 14 | 15 | //NewEntropyTarget creates a RefretTarget and initializes EntropyTarget.Costs to the proper length. 16 | func NewEntropyTarget(f CatFeature) *EntropyTarget { 17 | return &EntropyTarget{f} 18 | } 19 | 20 | /* 21 | EntropyTarget.SplitImpurity is a version of Split Impurity that calls EntropyTarget.Impurity 22 | */ 23 | func (target *EntropyTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 24 | nl := float64(len(*l)) 25 | nr := float64(len(*r)) 26 | nm := 0.0 27 | 28 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 29 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 30 | if m != nil && len(*m) > 0 { 31 | nm = float64(len(*m)) 32 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 33 | } 34 | 35 | impurityDecrease /= nl + nr + nm 36 | return 37 | } 38 | 39 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 40 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 41 | func (target *EntropyTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 42 | target.MoveCountsRtoL(allocs, movedRtoL) 43 | nl := float64(len(*l)) 44 | nr := float64(len(*r)) 45 | nm := 0.0 46 | 47 | impurityDecrease = nl * target.ImpFromCounts(len(*l), allocs.LCounter) 48 | impurityDecrease += nr * target.ImpFromCounts(len(*r), allocs.RCounter) 49 | if m != nil && len(*m) > 0 { 50 | nm = float64(len(*m)) 51 | impurityDecrease += nm * target.ImpFromCounts(len(*m), allocs.Counter) 52 | } 53 | 54 | impurityDecrease /= nl + nr + nm 55 | return 56 | } 57 | 58 | func (target *EntropyTarget) ImpFromCounts(total int, counts *[]int) (e float64) { 59 | p := 0.0 60 | for _, i := range *counts { 61 | if i > 0 { 62 | p = float64(i) / float64(total) 63 | e -= p * math.Log(p) 64 | } 65 | } 66 | return 67 | 68 | } 69 | 70 | //EntropyTarget.Impurity implements categorical entropy as sum(pj*log2(pj)) where pj 71 | //is the number of cases with the j'th category over the total number of cases. 72 | func (target *EntropyTarget) Impurity(cases *[]int, counts *[]int) (e float64) { 73 | 74 | total := len(*cases) 75 | target.CountPerCat(cases, counts) 76 | 77 | p := 0.0 78 | for _, i := range *counts { 79 | if i > 0 { 80 | p = float64(i) / float64(total) 81 | e -= p * math.Log(p) 82 | } 83 | 84 | } 85 | 86 | return 87 | 88 | } 89 | -------------------------------------------------------------------------------- /error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lytics/CloudForest/381792ef996b4d29adf78fb053e14df7390dd508/error.png -------------------------------------------------------------------------------- /evaluator.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "strconv" 5 | ) 6 | 7 | const leafFeature = -1 8 | 9 | // the evaluator interface implements high performance 10 | // decision tree evaluation strategies. 11 | // The PiecewiseFlatForest and the ContiguousFlatForest 12 | // provide faster analogs of the Predict function 13 | type NumEvaluator interface { 14 | EvaluateNum(fm *FeatureMatrix) []float64 15 | } 16 | 17 | type CatEvaluator interface { 18 | EvaluateCat(fm *FeatureMatrix) []string 19 | } 20 | 21 | type FlatNode struct { 22 | Feature int `json:"feature"` 23 | Float float64 `json:"float"` 24 | Value string `json:"value"` 25 | LeftChild uint32 `json:"leftchild"` 26 | } 27 | 28 | type FlatTree struct { 29 | Nodes []*FlatNode `json:"nodes"` 30 | Weight float64 31 | } 32 | 33 | func NewFlatTree(t *Tree) *FlatTree { 34 | f := &FlatTree{ 35 | Nodes: make([]*FlatNode, 1), 36 | Weight: adjustWeight(t.Weight), 37 | } 38 | f.recurse(t.Root, 0) 39 | return f 40 | } 41 | 42 | func (f *FlatTree) recurse(n *Node, idx uint32) { 43 | if n.Left == nil && n.Right == nil { 44 | fl, _ := strconv.ParseFloat(n.Pred, 64) 45 | f.Nodes[idx] = &FlatNode{ 46 | Feature: leafFeature, 47 | Float: fl, 48 | Value: n.Pred, 49 | } 50 | return 51 | } 52 | leftChild := uint32(len(f.Nodes)) 53 | f.Nodes = append(f.Nodes, make([]*FlatNode, 2)...) 54 | var value string 55 | var fl float64 56 | switch x := n.CodedSplit.(type) { 57 | case float64: 58 | fl = x 59 | case int: 60 | fl = float64(x) 61 | case string: 62 | value = x 63 | } 64 | f.Nodes[idx] = &FlatNode{ 65 | Feature: n.Featurei, 66 | Value: value, 67 | Float: fl, 68 | LeftChild: leftChild, 69 | } 70 | f.recurse(n.Left, leftChild) 71 | f.recurse(n.Right, leftChild+1) 72 | } 73 | 74 | func (f *FlatTree) EvaluateNum(fm *FeatureMatrix) []float64 { 75 | sz := fm.Data[0].Length() 76 | preds := make([]float64, sz) 77 | for i := 0; i < sz; i++ { 78 | current := uint32(0) 79 | for { 80 | n := f.Nodes[current] 81 | // leaf node 82 | if n.Feature == leafFeature { 83 | preds[i] = n.Float * f.Weight 84 | break 85 | } 86 | switch f := fm.Data[n.Feature].(type) { 87 | case *DenseNumFeature: 88 | val := f.NumData[i] 89 | splitValue := n.Float 90 | if val < splitValue { 91 | current = n.LeftChild 92 | } else { 93 | current = n.LeftChild + 1 94 | } 95 | case *DenseCatFeature: 96 | val := f.GetStr(i) 97 | splitValue := n.Value 98 | if val == splitValue { 99 | current = n.LeftChild 100 | } else { 101 | current = n.LeftChild + 1 102 | } 103 | } 104 | } 105 | } 106 | return preds 107 | } 108 | 109 | func (f *FlatTree) EvaluateCat(fm *FeatureMatrix) []string { 110 | sz := fm.Data[0].Length() 111 | preds := make([]string, sz) 112 | for i := 0; i < sz; i++ { 113 | current := uint32(0) 114 | for { 115 | n := f.Nodes[current] 116 | // leaf node 117 | if n.Feature == leafFeature { 118 | preds[i] = n.Value 119 | break 120 | } 121 | switch f := fm.Data[n.Feature].(type) { 122 | case *DenseNumFeature: 123 | val := f.NumData[i] 124 | splitValue := n.Float 125 | if val < splitValue { 126 | current = n.LeftChild 127 | } else { 128 | current = n.LeftChild + 1 129 | } 130 | case *DenseCatFeature: 131 | val := f.GetStr(i) 132 | splitValue := n.Value 133 | if val == splitValue { 134 | current = n.LeftChild 135 | } else { 136 | current = n.LeftChild + 1 137 | } 138 | } 139 | } 140 | } 141 | return preds 142 | } 143 | 144 | type PiecewiseFlatForest struct { 145 | Trees []*FlatTree `json:"trees"` 146 | Intercept float64 147 | } 148 | 149 | func NewPiecewiseFlatForest(forest *Forest) *PiecewiseFlatForest { 150 | p := &PiecewiseFlatForest{ 151 | Trees: make([]*FlatTree, len(forest.Trees)), 152 | Intercept: forest.Intercept, 153 | } 154 | for i, n := range forest.Trees { 155 | p.Trees[i] = NewFlatTree(n) 156 | } 157 | return p 158 | } 159 | 160 | func (p *PiecewiseFlatForest) EvaluateNum(fm *FeatureMatrix) []float64 { 161 | sz := fm.Data[0].Length() 162 | n := float64(len(p.Trees)) 163 | preds := make([]float64, sz) 164 | for i := range preds { 165 | preds[i] = p.Intercept 166 | } 167 | 168 | for _, tree := range p.Trees { 169 | for i, pred := range tree.EvaluateNum(fm) { 170 | preds[i] += pred / n 171 | } 172 | } 173 | return preds 174 | } 175 | 176 | func (p *PiecewiseFlatForest) EvaluateCat(fm *FeatureMatrix) []string { 177 | sz := fm.Data[0].Length() 178 | bb := NewCatBallotBox(sz) 179 | for _, tree := range p.Trees { 180 | for i, pred := range tree.EvaluateCat(fm) { 181 | bb.Vote(i, pred, 1.0) 182 | } 183 | } 184 | 185 | preds := make([]string, sz) 186 | for i := 0; i < sz; i++ { 187 | preds[i] = bb.Tally(i) 188 | } 189 | return preds 190 | } 191 | 192 | type ContiguousFlatForest struct { 193 | Roots []uint32 194 | Nodes []*FlatNode 195 | Weights []float64 196 | Intercept float64 197 | } 198 | 199 | func NewContiguousFlatForest(forest *Forest) *ContiguousFlatForest { 200 | var roots []uint32 201 | var nodes []*FlatNode 202 | var weights []float64 203 | for _, tree := range forest.Trees { 204 | idx := uint32(len(nodes)) 205 | roots = append(roots, idx) 206 | weights = append(weights, adjustWeight(tree.Weight)) 207 | for _, node := range NewFlatTree(tree).Nodes { 208 | node.LeftChild += idx 209 | nodes = append(nodes, node) 210 | } 211 | } 212 | return &ContiguousFlatForest{ 213 | Roots: roots, 214 | Nodes: nodes, 215 | Weights: weights, 216 | Intercept: forest.Intercept, 217 | } 218 | } 219 | 220 | func (c *ContiguousFlatForest) EvaluateNum(fm *FeatureMatrix) []float64 { 221 | sz := fm.Data[0].Length() 222 | preds := make([]float64, sz) 223 | for i := 0; i < sz; i++ { 224 | result := 0.0 225 | for _, root := range c.Roots { 226 | current := root 227 | for { 228 | n := c.Nodes[current] 229 | if n.Feature == leafFeature { 230 | // im a leaf 231 | result += n.Float 232 | break 233 | } 234 | switch f := fm.Data[n.Feature].(type) { 235 | case *DenseNumFeature: 236 | val := f.NumData[i] 237 | splitValue := n.Float 238 | if val < splitValue { 239 | current = n.LeftChild 240 | } else { 241 | current = n.LeftChild + 1 242 | } 243 | case *DenseCatFeature: 244 | val := f.GetStr(i) 245 | splitValue := n.Value 246 | if val == splitValue { 247 | current = n.LeftChild 248 | } else { 249 | current = n.LeftChild + 1 250 | } 251 | } 252 | } 253 | } 254 | preds[i] = (result / float64(len(c.Roots))) * c.Weights[i] 255 | preds[i] += c.Intercept 256 | } 257 | return preds 258 | } 259 | 260 | func (c *ContiguousFlatForest) EvaluateCat(fm *FeatureMatrix) []string { 261 | sz := fm.Data[0].Length() 262 | bb := NewCatBallotBox(sz) 263 | 264 | for i := 0; i < sz; i++ { 265 | for _, root := range c.Roots { 266 | current := root 267 | for { 268 | n := c.Nodes[current] 269 | if n.Feature == leafFeature { 270 | // im a leaf 271 | bb.Vote(i, n.Value, 1.0) 272 | break 273 | } 274 | switch f := fm.Data[n.Feature].(type) { 275 | case *DenseNumFeature: 276 | val := f.NumData[0] 277 | splitValue := n.Float 278 | if val < splitValue { 279 | current = n.LeftChild 280 | } else { 281 | current = n.LeftChild + 1 282 | } 283 | case *DenseCatFeature: 284 | val := f.GetStr(0) 285 | splitValue := n.Value 286 | if val == splitValue { 287 | current = n.LeftChild 288 | } else { 289 | current = n.LeftChild + 1 290 | } 291 | } 292 | } 293 | } 294 | } 295 | 296 | preds := make([]string, sz) 297 | for i := 0; i < sz; i++ { 298 | preds[i] = bb.Tally(i) 299 | } 300 | return preds 301 | } 302 | 303 | func adjustWeight(x float64) float64 { 304 | if x <= 0.0 { 305 | return 1.0 306 | } 307 | return x 308 | } 309 | -------------------------------------------------------------------------------- /evaluator_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/bmizerany/assert" 9 | ) 10 | 11 | func setupCategorical() (*Forest, *FeatureMatrix) { 12 | irisreader := strings.NewReader(irislibsvm) 13 | fm := ParseLibSVM(irisreader) 14 | targeti := 0 15 | cattarget := fm.Data[targeti] 16 | config := &ForestConfig{ 17 | NSamples: fm.Data[0].Length(), 18 | MTry: 3, 19 | NTrees: 10, 20 | LeafSize: 1, 21 | } 22 | 23 | sample := &FeatureMatrix{ 24 | Data: make([]Feature, len(fm.Map)), 25 | Map: make(map[string]int), 26 | } 27 | for k, v := range fm.Map { 28 | var feature Feature 29 | if v == 0 { 30 | feature = NewDenseCatFeature(k) 31 | } else { 32 | feature = NewDenseNumFeature(k) 33 | } 34 | sample.Map[k] = v 35 | sample.Data[v] = feature 36 | sample.Data[v].Append(fm.Data[v].GetStr(0)) 37 | } 38 | 39 | model := GrowRandomForest(fm, cattarget, config) 40 | return model.Forest, sample 41 | } 42 | 43 | func setupNumeric() (*Forest, *FeatureMatrix) { 44 | boston := strings.NewReader(boston_housing) 45 | fm := ParseARFF(boston) 46 | target := fm.Data[fm.Map["class"]] 47 | sample := &FeatureMatrix{ 48 | Data: make([]Feature, len(fm.Map)), 49 | Map: make(map[string]int), 50 | } 51 | for k, v := range fm.Map { 52 | sample.Map[k] = v 53 | sample.Data[v] = NewDenseNumFeature(k) 54 | sample.Data[v].Append(fm.Data[v].GetStr(0)) 55 | } 56 | config := &ForestConfig{ 57 | NSamples: target.Length(), 58 | MTry: 4, 59 | NTrees: 20, 60 | LeafSize: 1, 61 | MaxDepth: 4, 62 | InBag: true, 63 | } 64 | model := GrowRandomForest(fm, target, config) 65 | return model.Forest, sample 66 | } 67 | 68 | func TestEvaluator(t *testing.T) { 69 | forest, sample := setupNumeric() 70 | predVal := forest.Predict(sample)[0] 71 | 72 | evalPW := NewPiecewiseFlatForest(forest) 73 | evalVal := evalPW.EvaluateNum(sample)[0] 74 | assert.Equal(t, fmt.Sprintf("%.4f", predVal), fmt.Sprintf("%.4f", evalVal)) 75 | 76 | evalCT := NewContiguousFlatForest(forest) 77 | evalVal = evalCT.EvaluateNum(sample)[0] 78 | assert.Equal(t, fmt.Sprintf("%.4f", predVal), fmt.Sprintf("%.4f", evalVal)) 79 | } 80 | 81 | func TestCatEvaluator(t *testing.T) { 82 | forest, sample := setupCategorical() 83 | pred := forest.PredictCat(sample)[0] 84 | 85 | pw := NewPiecewiseFlatForest(forest) 86 | predPW := pw.EvaluateCat(sample)[0] 87 | assert.Equal(t, pred, predPW) 88 | 89 | ct := NewContiguousFlatForest(forest) 90 | predCT := ct.EvaluateCat(sample)[0] 91 | assert.Equal(t, predPW, predCT) 92 | } 93 | 94 | // BenchmarkPredict-8 100000 12542 ns/op 95 | func BenchmarkPredict(b *testing.B) { 96 | forest, sample := setupNumeric() 97 | 98 | b.StartTimer() 99 | for i := 0; i < b.N; i++ { 100 | forest.Predict(sample) 101 | } 102 | b.StopTimer() 103 | } 104 | 105 | // BenchmarkFlatForest-8 2000000 821 ns/op 106 | func BenchmarkFlatForest(b *testing.B) { 107 | forest, sample := setupNumeric() 108 | pw := NewPiecewiseFlatForest(forest) 109 | 110 | b.StartTimer() 111 | for i := 0; i < b.N; i++ { 112 | pw.EvaluateNum(sample) 113 | } 114 | b.StopTimer() 115 | } 116 | 117 | // BenchmarkContiguousForest-8 5000000 339 ns/op 118 | func BenchmarkContiguousForest(b *testing.B) { 119 | forest, sample := setupNumeric() 120 | ct := NewContiguousFlatForest(forest) 121 | 122 | b.StartTimer() 123 | for i := 0; i < b.N; i++ { 124 | ct.EvaluateNum(sample) 125 | } 126 | b.StopTimer() 127 | } 128 | -------------------------------------------------------------------------------- /featureinterfaces.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | const maxExhaustiveCats = 5 4 | const maxNonRandomExahustive = 10 5 | const maxNonBigCats = 30 6 | const minImp = 1e-7 7 | const constant_cutoff = 1e-7 8 | 9 | //Feature contains all methods needed for a predictor feature. 10 | type Feature interface { 11 | Span(cases *[]int, counter *[]int) float64 12 | NCats() (n int) 13 | Length() (l int) 14 | GetStr(i int) (value string) 15 | IsMissing(i int) bool 16 | IsZero(i int) bool 17 | MissingVals() bool 18 | GoesLeft(i int, splitter *Splitter) bool 19 | PutMissing(i int) 20 | PutStr(i int, v string) 21 | SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) 22 | UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) 23 | Impurity(cases *[]int, counter *[]int) (impurity float64) 24 | FindPredicted(cases []int) (pred string) 25 | BestSplit(target Target, 26 | cases *[]int, 27 | parentImp float64, 28 | leafSize int, 29 | randomSplit bool, 30 | allocs *BestSplitAllocs) (codedSplit interface{}, impurityDecrease float64, constant bool) 31 | DecodeSplit(codedSplit interface{}) (s *Splitter) 32 | ShuffledCopy() (fake Feature) 33 | Copy() (copy Feature) 34 | CopyInTo(copy Feature) 35 | Shuffle() 36 | ShuffleCases(cases *[]int) 37 | ImputeMissing() 38 | GetName() string 39 | Append(v string) 40 | Split(codedSplit interface{}, cases []int) (l []int, r []int, m []int) 41 | SplitPoints(codedSplit interface{}, cases *[]int) (lastl int, firstr int) 42 | } 43 | 44 | //NumFeature contains the methods of Feature plus methods needed to implement 45 | //diffrent types of regression. It is usually embeded by regression targets to 46 | //provide access to the underlying data. 47 | type NumFeature interface { 48 | Feature 49 | Get(i int) float64 50 | Put(i int, v float64) 51 | Predicted(cases *[]int) float64 52 | Mean(cases *[]int) float64 53 | Norm(i int, v float64) float64 54 | Error(cases *[]int, predicted float64) (e float64) 55 | Less(i int, j int) bool 56 | SumAndSumSquares(cases *[]int) (float64, float64) 57 | } 58 | 59 | //CatFeature contains the methods of Feature plus methods needed to implement 60 | //diffrent types of classification. It is usually embeded by classification targets to 61 | //provide access to the underlying data. 62 | type CatFeature interface { 63 | Feature 64 | CountPerCat(cases *[]int, counter *[]int) 65 | MoveCountsRtoL(allocs *BestSplitAllocs, movedRtoL *[]int) 66 | CatToNum(value string) (numericv int) 67 | NumToCat(i int) (value string) 68 | Geti(i int) int 69 | Puti(i int, v int) 70 | Modei(cases *[]int) int 71 | Mode(cases *[]int) string 72 | Gini(cases *[]int) float64 73 | GiniWithoutAlocate(cases *[]int, counts *[]int) (e float64) 74 | EncodeToNum() (fs []Feature) 75 | OneHot() (fs []Feature) 76 | } 77 | 78 | //Target abstracts the methods needed for a feature to be predictable 79 | //as either a catagroical or numerical feature in a random forest. 80 | type Target interface { 81 | GetName() string 82 | NCats() (n int) 83 | SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) 84 | UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) 85 | Impurity(cases *[]int, counter *[]int) (impurity float64) 86 | FindPredicted(cases []int) (pred string) 87 | } 88 | 89 | //BoostingTarget augments Target with a "Boost" method that will be called after each 90 | //tree is grown with the partion generated by that tree. It will return the weigh the 91 | //tree should be given and boost the target for the next tree. 92 | type BoostingTarget interface { 93 | Target 94 | Boost(partition *[][]int, preds *[]string) (weight float64) 95 | } 96 | 97 | type TargetWithIntercept interface { 98 | Target 99 | Intercept() float64 100 | } 101 | -------------------------------------------------------------------------------- /featurematrix_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "bytes" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/bmizerany/assert" 9 | "gonum.org/v1/gonum/mat" 10 | ) 11 | 12 | //A toy feature matrix where either of the first 13 | //two variables should be easilly predictible 14 | //by the other by a single greedy tree. 15 | var constantsfm = `. 0 1 2 3 4 5 6 7 16 | C:CatTarget 0 0 0 0 0 1 1 1 17 | N:GoodVals 0 0 0 0 0 1 1 1 18 | C:Const1 0 0 0 0 0 0 0 1 19 | C:Const2 0 0 0 0 0 0 0 1 20 | C:Const3 0 0 0 0 0 0 0 1 21 | N:Const4 0 0 0 0 0 0 0 1 22 | N:Const5 0 0 0 0 0 0 0 1 23 | N:Const6 0 0 0 0 0 0 0 1` 24 | 25 | func TestBestSplitter(t *testing.T) { 26 | fm := readFm() 27 | 28 | target := fm.Data[0] 29 | cases := &[]int{0, 1, 2, 3, 4, 5, 6} 30 | candidates := []int{1, 2, 3, 4, 5, 6, 7} 31 | allocs := NewBestSplitAllocs(len(*cases), target) 32 | 33 | _, imp, constant := fm.Data[1].BestSplit(target, cases, 1, 1, false, allocs) 34 | if imp <= minImp || constant == true { 35 | t.Errorf("Good feature had imp %v and constant: %v", imp, constant) 36 | } 37 | 38 | _, imp, constant = fm.Data[2].BestSplit(target, cases, 1, 1, false, allocs) 39 | if imp > minImp || constant == false { 40 | t.Errorf("Constant cat feature had imp %v and constant: %v %v", imp, constant, fm.Data[2].(*DenseCatFeature).CatData) 41 | } 42 | 43 | _, imp, constant = fm.Data[7].BestSplit(target, cases, 1, 1, false, allocs) 44 | if imp > minImp || constant == false { 45 | t.Errorf("Constant num feature had imp %v and constant: %v", imp, constant) 46 | } 47 | 48 | fi, split, impDec, nconstants := fm.BestSplitter(target, cases, &candidates, len(candidates), nil, 1, true, false, false, false, allocs, 0) 49 | if fi != 1 || split == nil || impDec == minImp || nconstants != 6 { 50 | t.Errorf("BestSplitter couldn't find non constant feature and six constants fi: %v split: %v impDex: %v nconstants: %v ", fi, split, impDec, nconstants) 51 | } 52 | 53 | for i := 0; i < 7; i++ { 54 | 55 | candidates = []int{1, 2, 3, 4, 5, 6, 7} 56 | 57 | fi, split, impDec, nconstants = fm.BestSplitter(target, cases, &candidates, 1, nil, 1, true, false, false, false, allocs, i) 58 | if fi != 1 || split == nil || impDec == minImp { 59 | t.Errorf("BestSplitter couldn't find non constant feature with mTry=1 and %v known constants fi: %v split: %v impDex: %v nconstants: %v ", i, fi, split, impDec, nconstants) 60 | } 61 | 62 | candidates = []int{1, 2, 3, 4, 5, 6, 7} 63 | fi, split, impDec, nconstants = fm.BestSplitter(target, cases, &candidates, len(candidates), nil, 1, true, false, false, false, allocs, i) 64 | if fi != 1 || split == nil || impDec == minImp || nconstants != 6 { 65 | t.Errorf("BestSplitter couldn't find non constant feature and six constants with %v known constants fi: %v split: %v impDex: %v nconstants: %v ", i, fi, split, impDec, nconstants) 66 | } 67 | } 68 | } 69 | 70 | func TestFmWrite(t *testing.T) { 71 | fm := readFm() 72 | header := true 73 | 74 | writer := &bytes.Buffer{} 75 | if err := fm.WriteFM(writer, "\t", header, true); err != nil { 76 | t.Fatalf("could not write feature matrix: %v", err) 77 | } 78 | 79 | if writer.String() == "" { 80 | t.Fatalf("could not write FM - buffer is empty") 81 | } 82 | firstLen := writer.Len() 83 | 84 | writer = &bytes.Buffer{} 85 | if err := fm.WriteFM(writer, "\t", header, false); err != nil { 86 | t.Fatalf("could not write feature matrix: %v", err) 87 | } 88 | 89 | if writer.String() == "" { 90 | t.Fatalf("could not write FM - buffer is empty") 91 | } 92 | secondLen := writer.Len() 93 | 94 | if firstLen != secondLen { 95 | t.Fatalf("expected buffers to have the same length: %v != %v", firstLen, secondLen) 96 | } 97 | } 98 | 99 | func TestMat64(t *testing.T) { 100 | fm := readFm() 101 | dense := fm.Matrix(false, false) 102 | 103 | compareCol := func(i int, exp []float64) { 104 | col := mat.Col(nil, i, dense) 105 | assert.Equal(t, len(col), len(exp)) 106 | for i := range exp { 107 | assert.Equal(t, col[i], exp[i]) 108 | } 109 | } 110 | 111 | compareCol(1, []float64{0, 0, 0, 0, 0, 1, 1, 1}) 112 | compareCol(2, []float64{0, 0, 0, 0, 0, 0, 0, 1}) 113 | } 114 | 115 | func readFm() *FeatureMatrix { 116 | fmReader := strings.NewReader(constantsfm) 117 | return ParseAFM(fmReader) 118 | } 119 | 120 | func TestFeatureMatrixCopy(t *testing.T) { 121 | fm := readFm() 122 | fmCopy := fm.Copy() 123 | 124 | // check feature matrix copy 125 | for k, v := range fm.Map { 126 | assert.Equal(t, v, fmCopy.Map[k]) 127 | } 128 | for i, v := range fm.CaseLabels { 129 | assert.Equal(t, v, fmCopy.CaseLabels[i]) 130 | } 131 | assert.Equal(t, len(fm.Data), len(fmCopy.Data)) 132 | for i, feature := range fm.Data { 133 | rows := fm.Data[1].Length() 134 | featureCopy := fmCopy.Data[i] 135 | assert.Equal(t, rows, featureCopy.Length()) 136 | for r := 0; r < rows; r++ { 137 | assert.Equal(t, feature.GetStr(r), featureCopy.GetStr(r)) 138 | } 139 | } 140 | 141 | // alter original feature matrix 142 | fm.CaseLabels = fm.CaseLabels[0 : len(fm.CaseLabels)-1] 143 | delete(fm.Map, "C:Const1") 144 | fm.Data = fm.Data[0 : len(fm.CaseLabels)-1] 145 | assert.NotEqual(t, len(fm.Data), len(fmCopy.Data)) 146 | assert.NotEqual(t, len(fm.Map), len(fmCopy.Map)) 147 | assert.NotEqual(t, len(fm.CaseLabels), len(fmCopy.CaseLabels)) 148 | } 149 | -------------------------------------------------------------------------------- /forest_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "encoding/csv" 5 | "fmt" 6 | "os" 7 | "strconv" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/bmizerany/assert" 12 | ) 13 | 14 | var ( 15 | predFilePath = "preds.csv" 16 | inBagFilePath = "n.csv" 17 | ) 18 | 19 | func TestPartialDependencyCategorical(t *testing.T) { 20 | irisreader := strings.NewReader(irislibsvm) 21 | fm := ParseLibSVM(irisreader) 22 | 23 | tgt := fm.Data[1] 24 | model := GrowRandomForest(fm, tgt, &ForestConfig{ 25 | NSamples: fm.Data[1].Length(), 26 | MTry: 3, 27 | NTrees: 500, 28 | LeafSize: 1, 29 | }) 30 | forest := model.Forest 31 | 32 | // Partial Dependency Plot with 1 variable 33 | pdp, err := PDP(forest.Predict, fm, "0") 34 | assert.Equal(t, nil, err) 35 | assert.Equal(t, 3, len(pdp)) 36 | 37 | // ensure all the probabilities are unique 38 | uniq := make(map[float64]struct{}) 39 | for _, x := range pdp { 40 | assert.Equal(t, 2, len(x)) 41 | uniq[x[1]] = struct{}{} 42 | } 43 | assert.Equal(t, 3, len(uniq)) 44 | } 45 | 46 | func TestPartialDependencyNumeric(t *testing.T) { 47 | irisreader := strings.NewReader(irislibsvm) 48 | fm := ParseLibSVM(irisreader) 49 | 50 | // write dataset to CSV for R comparison/validation 51 | if os.Getenv("WRITEDATA") != "" { 52 | iris, err := os.Create("iris.csv") 53 | assert.Equal(t, nil, err) 54 | 55 | for _, feature := range fm.Data { 56 | str := make([]string, feature.Length()) 57 | for i := 0; i < feature.Length(); i++ { 58 | str[i] = feature.GetStr(i) 59 | } 60 | iris.WriteString(strings.Join(str, ",")) 61 | iris.Write([]byte("\n")) 62 | } 63 | 64 | err = iris.Close() 65 | assert.Equal(t, nil, err) 66 | } 67 | 68 | tgt := fm.Data[0] 69 | model := GrowRandomForest(fm, tgt, &ForestConfig{ 70 | NSamples: fm.Data[0].Length(), 71 | MTry: 3, 72 | NTrees: 500, 73 | LeafSize: 1, 74 | }) 75 | forest := model.Forest 76 | 77 | // Partial Dependency Plot with 1 variable 78 | single, err := PDP(forest.Predict, fm, "3") 79 | assert.Equal(t, nil, err) 80 | assert.NotEqual(t, nil, single) 81 | 82 | // Partial Dependency Plot with 2 variables 83 | double, err := PDP(forest.Predict, fm, "3", "2") 84 | assert.Equal(t, nil, err) 85 | assert.NotEqual(t, nil, double) 86 | 87 | if os.Getenv("WRITEDATA") != "" { 88 | writeDeps("singleDep.csv", single) 89 | writeDeps("doubleDep.csv", double) 90 | } 91 | } 92 | 93 | func writeDeps(name string, vals [][]float64) { 94 | file, _ := os.Create(name) 95 | for _, val := range vals { 96 | writeSlice(file, val) 97 | } 98 | } 99 | 100 | func writeSlice(f *os.File, vals []float64) { 101 | str := make([]string, len(vals)) 102 | for i, v := range vals { 103 | str[i] = strconv.FormatFloat(v, 'f', -1, 64) 104 | } 105 | 106 | f.WriteString(strings.Join(str, ",")) 107 | f.Write([]byte("\n")) 108 | } 109 | 110 | func TestJackKnife(t *testing.T) { 111 | // read data 112 | preds := readCsv(t, predFilePath) 113 | inbag := readCsv(t, inBagFilePath) 114 | 115 | // run jackknife 116 | predictions, err := JackKnife(preds, inbag) 117 | if err != nil { 118 | t.Fatalf("error jack-knifing: %v", err) 119 | } 120 | 121 | if os.Getenv("EXPORT_JACKKNIFE") != "" { 122 | file, err := os.Create("validation.csv") 123 | if err != nil { 124 | t.Fatalf("error creating file: %v", err) 125 | } 126 | defer file.Close() 127 | 128 | fmt.Fprintln(file, "prediction, variance") 129 | for _, pred := range predictions { 130 | fmt.Fprintf(file, "%v, %v\n", pred.Value, pred.Variance) 131 | } 132 | } 133 | } 134 | 135 | func readCsv(t *testing.T, file string) [][]float64 { 136 | predFile, err := os.Open(file) 137 | if err != nil { 138 | t.Fatalf("could not open file %v: %v", predFile, err) 139 | } 140 | 141 | reader := csv.NewReader(predFile) 142 | all, err := reader.ReadAll() 143 | if err != nil { 144 | t.Fatalf("could not read file %s: %v", file, err) 145 | } 146 | 147 | values := make([][]float64, len(all)) 148 | for i, v := range all { 149 | values[i] = strToFloat(t, v) 150 | } 151 | return values 152 | } 153 | 154 | func strToFloat(t *testing.T, values []string) []float64 { 155 | f := make([]float64, len(values)) 156 | var err error 157 | for i := range f { 158 | f[i], err = strconv.ParseFloat(values[i], 64) 159 | if err != nil { 160 | t.Fatalf("could not convert %s, %v", values[i], err) 161 | } 162 | } 163 | return f 164 | } 165 | -------------------------------------------------------------------------------- /forestreader.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | /* 14 | ForestReader wraps an io.Reader to reads a forest. It includes ReadForest for reading an 15 | entire forest or ReadTree for reading a forest tree by tree. 16 | The forest should be in .sf format see the package doc's in doc.go for full format details. 17 | It ignores fields that are not use by CloudForest. 18 | */ 19 | type ForestReader struct { 20 | br *bufio.Reader 21 | } 22 | 23 | //NewForestReader wraps the supplied io.Reader as a ForestReader. 24 | func NewForestReader(r io.Reader) *ForestReader { 25 | return &ForestReader{bufio.NewReader(r)} 26 | } 27 | 28 | /* 29 | ForestReader.ReadForest reads the next forest from the underlying reader. 30 | If io.EOF or another error is encountered it returns that. 31 | */ 32 | func (fr *ForestReader) ReadForest() (forest *Forest, err error) { 33 | peek := []byte(" ") 34 | peek, err = fr.br.Peek(1) 35 | if err != nil { 36 | return 37 | } 38 | if peek[0] != 'F' && peek[0] != 'T' { 39 | err = errors.New("Forest Header Not Found.") 40 | return 41 | } 42 | for { 43 | peek, err = fr.br.Peek(1) 44 | if peek[0] == 'F' && forest != nil { 45 | return 46 | } 47 | t, f, e := fr.ReadTree() 48 | if forest != nil && f != nil { 49 | return 50 | } 51 | if forest == nil && f != nil { 52 | forest = f 53 | } 54 | if t != nil { 55 | if forest == nil { 56 | forest = new(Forest) 57 | forest.Target = t.Target 58 | } 59 | forest.Trees = append(forest.Trees, t) 60 | } 61 | if e == io.EOF { 62 | return forest, nil 63 | } 64 | if e != nil { 65 | return 66 | } 67 | 68 | } 69 | } 70 | 71 | /*ForestReader.ReadTree reads the next tree from the underlying reader. If the next tree 72 | is in a new forest it returns a forest object as well. If an io.EOF or other error is 73 | encountered it returns that as well as any partially parsed structs.*/ 74 | func (fr *ForestReader) ReadTree() (tree *Tree, forest *Forest, err error) { 75 | intree := false 76 | line := "" 77 | peek := []byte(" ") 78 | for { 79 | peek, err = fr.br.Peek(1) 80 | //If their is no next line or it starts a new Tree or Forest return 81 | if err != nil || (intree && (peek[0] == 'T' || peek[0] == 'F')) { 82 | return 83 | } 84 | 85 | line, err = fr.br.ReadString('\n') 86 | if err != nil { 87 | return 88 | } 89 | parsed := fr.ParseRfAcePredictorLine(line) 90 | switch { 91 | case strings.HasPrefix(line, "FOREST"): 92 | forest = new(Forest) 93 | forest.Target = parsed["TARGET"] 94 | i, ok := parsed["INTERCEPT"] 95 | if ok { 96 | intercept, err := strconv.ParseFloat(i, 64) 97 | if err != nil { 98 | log.Print("Error parsing forest intercept value ", err) 99 | } 100 | forest.Intercept = intercept 101 | } 102 | 103 | case strings.HasPrefix(line, "TREE"): 104 | intree = true 105 | tree = new(Tree) 106 | tree.Target = parsed["TARGET"] 107 | weights, ok := parsed["WEIGHT"] 108 | if ok { 109 | weight, err := strconv.ParseFloat(weights, 64) 110 | if err != nil { 111 | log.Print("Error parsing weight value ", err) 112 | } 113 | tree.Weight = weight 114 | } else { 115 | tree.Weight = -1.0 116 | } 117 | 118 | case strings.HasPrefix(line, "NODE"): 119 | if intree == false { 120 | err = errors.New("Poorly formed .sf file. Node found outside of tree.") 121 | return 122 | } 123 | var splitter *Splitter 124 | 125 | pred := "" 126 | if filepred, ok := parsed["PRED"]; ok { 127 | pred = filepred 128 | } 129 | 130 | if stype, ok := parsed["SPLITTERTYPE"]; ok { 131 | splitter = new(Splitter) 132 | splitter.Feature = parsed["SPLITTER"] 133 | switch stype { 134 | case "CATEGORICAL": 135 | splitter.Numerical = false 136 | 137 | splitter.Left = make(map[string]bool) 138 | for _, f := range strings.Split(parsed["LVALUES"], ":") { 139 | splitter.Left[f] = true 140 | } 141 | 142 | case "NUMERICAL": 143 | splitter.Numerical = true 144 | lvalue, err := strconv.ParseFloat(parsed["LVALUES"], 64) 145 | if err != nil { 146 | log.Print("Error parsing lvalues value ", err) 147 | } 148 | splitter.Value = float64(lvalue) 149 | } 150 | } 151 | 152 | tree.AddNode(parsed["NODE"], pred, splitter) 153 | 154 | } 155 | } 156 | 157 | } 158 | 159 | /* 160 | ParseRfAcePredictorLine parses a single line of an rf-ace sf "stochastic forest" 161 | and returns a map[string]string of the key value pairs. 162 | */ 163 | func (fr *ForestReader) ParseRfAcePredictorLine(line string) map[string]string { 164 | clauses := make([]string, 0) 165 | insidequotes := make([]string, 0) 166 | terms := strings.Split(strings.TrimSpace(line), ",") 167 | for _, term := range terms { 168 | term = strings.TrimSpace(term) 169 | quotes := strings.Count(term, "\"") 170 | //if quotes have been opend join terms 171 | if quotes == 1 || len(insidequotes) > 0 { 172 | insidequotes = append(insidequotes, term) 173 | } else { 174 | //If the term doesn't have an = in it join it to the last term 175 | if strings.Count(term, "=") == 0 { 176 | clauses[len(clauses)-1] += "," + term 177 | } else { 178 | clauses = append(clauses, term) 179 | } 180 | } 181 | //quotes were closed 182 | if quotes == 1 && len(insidequotes) > 1 { 183 | clauses = append(clauses, strings.Join(insidequotes, ",")) 184 | insidequotes = make([]string, 0) 185 | } 186 | 187 | } 188 | parsed := make(map[string]string, 0) 189 | for _, clause := range clauses { 190 | vs := strings.Split(clause, "=") 191 | for i, v := range vs { 192 | vs[i] = strings.Trim(strings.TrimSpace(v), "\"") 193 | } 194 | if len(vs) != 2 { 195 | fmt.Println("Parser Choked on : \"", line, "\"") 196 | } 197 | parsed[vs[0]] = vs[1] 198 | } 199 | 200 | return parsed 201 | } 202 | -------------------------------------------------------------------------------- /forestwriter.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "strings" 7 | ) 8 | 9 | /* 10 | ForestWriter wraps an io writer with functionality to write forests either with one 11 | call to WriteForest or incrementally using WriteForestHeader and WriteTree. 12 | ForestWriter saves a forest in .sf format; see the package doc's in doc.go for 13 | full format details. 14 | It won't include fields that are not use by CloudForest. 15 | */ 16 | type ForestWriter struct { 17 | w io.Writer 18 | } 19 | 20 | /*NewForestWriter returns a pointer to a new ForestWriter. */ 21 | func NewForestWriter(w io.Writer) *ForestWriter { 22 | return &ForestWriter{w} 23 | } 24 | 25 | //WriteForest writes an entire forest including all headers. 26 | func (fw *ForestWriter) WriteForest(forest *Forest) { 27 | if forest.Intercept != 0.0 { 28 | fw.WriteForestHeader(0, forest.Target, forest.Intercept) 29 | } 30 | for i, tree := range forest.Trees { 31 | fw.WriteTree(tree, i) 32 | } 33 | } 34 | 35 | //WriteTree writes an entire Tree including the header. 36 | func (fw *ForestWriter) WriteTree(tree *Tree, ntree int) { 37 | fw.WriteTreeHeader(ntree, tree.Target, tree.Weight) 38 | fw.WriteNodeAndChildren(tree.Root, "*") 39 | } 40 | 41 | //WrieTreeHeader writes only the header line for a tree. 42 | func (fw *ForestWriter) WriteTreeHeader(ntree int, target string, weight float64) { 43 | weightterm := "" 44 | if weight >= 0.0 { 45 | weightterm = fmt.Sprintf(",WEIGHT=%v", weight) 46 | } 47 | fmt.Fprintf(fw.w, "TREE=%v,TARGET=\"%v\"%v\n", ntree, target, weightterm) 48 | } 49 | 50 | //WrieTreeHeader writes only the header line for a tree. 51 | func (fw *ForestWriter) WriteForestHeader(nforest int, target string, intercept float64) { 52 | interceptterm := "" 53 | if intercept != 0.0 { 54 | interceptterm = fmt.Sprintf(",INTERCEPT=%v", intercept) 55 | } 56 | fmt.Fprintf(fw.w, "FOREST=%v,TARGET=\"%v\"%v\n", nforest, target, interceptterm) 57 | } 58 | 59 | //WriteNodeAndChildren recursively writes out the target node and all of its children. 60 | //WriteTree is preferred for most use cases. 61 | func (fw *ForestWriter) WriteNodeAndChildren(n *Node, path string) { 62 | 63 | fw.WriteNode(n, path) 64 | if n.Splitter != nil && n.Left != nil { 65 | fw.WriteNodeAndChildren(n.Left, path+"L") 66 | } 67 | if n.Splitter != nil && n.Right != nil { 68 | fw.WriteNodeAndChildren(n.Right, path+"R") 69 | } 70 | if n.Splitter != nil && n.Missing != nil { 71 | fw.WriteNodeAndChildren(n.Right, path+"M") 72 | } 73 | 74 | } 75 | 76 | //WriteNode writes a single node but not it's children. WriteTree will be used more 77 | //often but WriteNode can be used to grow a large tree directly to disk without 78 | //storing it in memory. 79 | func (fw *ForestWriter) WriteNode(n *Node, path string) { 80 | node := fmt.Sprintf("NODE=%v", path) 81 | if n.Pred != "" { 82 | node += fmt.Sprintf(",PRED=%v", n.Pred) 83 | } 84 | 85 | if n.Splitter != nil { 86 | node += fmt.Sprintf(",SPLITTER=%v", n.Splitter.Feature) 87 | switch n.Splitter.Numerical { 88 | case true: 89 | node += fmt.Sprintf(",SPLITTERTYPE=NUMERICAL,LVALUES=%v,RVALUES=%v", n.Splitter.Value, n.Splitter.Value) 90 | case false: 91 | left := fw.DescribeMap(n.Splitter.Left) 92 | node += fmt.Sprintf(",SPLITTERTYPE=CATEGORICAL,LVALUES=%v", left) 93 | } 94 | } 95 | fmt.Fprintln(fw.w, node) 96 | } 97 | 98 | //DescribeMap serializes the "left" map of a categorical splitter. 99 | func (fw *ForestWriter) DescribeMap(input map[string]bool) string { 100 | keys := make([]string, 0) 101 | for k := range input { 102 | keys = append(keys, k) 103 | } 104 | return "\"" + strings.Join(keys, ":") + "\"" 105 | } 106 | -------------------------------------------------------------------------------- /forestwriterreader_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "io" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestFileFormats(t *testing.T) { 10 | 11 | //Write out a fm and read it back in 12 | pipereader, pipewriter := io.Pipe() 13 | cases := []int{0, 1, 2, 3, 4, 5, 6, 7} 14 | 15 | fm1 := ParseAFM(strings.NewReader(fm)) 16 | 17 | go func() { 18 | fm1.WriteCases(pipewriter, cases) 19 | pipewriter.Close() 20 | }() 21 | 22 | fm := ParseAFM(pipereader) 23 | 24 | if len(fm.Data) != 5 || fm.Data[0].Length() != 8 { 25 | t.Errorf("Iris feature matrix has %v features and %v cases not 5 and 8", len(fm.Data), fm.Data[0].Length()) 26 | } 27 | 28 | cattarget := fm.Data[1] 29 | config := &ForestConfig{ 30 | NSamples: fm.Data[0].Length(), 31 | MTry: 3, 32 | NTrees: 10, 33 | LeafSize: 1, 34 | } 35 | ff := GrowRandomForest(fm, cattarget.(Feature), config) 36 | 37 | count := 0 38 | for _, tree := range ff.Forest.Trees { 39 | tree.Root.Recurse(func(*Node, []int, int) { count++ }, fm, cases, 0) 40 | } 41 | 42 | if count < 30 { 43 | t.Errorf("Trees before send to file has only %v nodes.", count) 44 | } 45 | 46 | pipereader, pipewriter = io.Pipe() 47 | 48 | go func() { 49 | fw := NewForestWriter(pipewriter) 50 | fw.WriteForest(ff.Forest) 51 | pipewriter.Close() 52 | }() 53 | 54 | fr := NewForestReader(pipereader) 55 | 56 | forest, err := fr.ReadForest() 57 | if err != nil { 58 | t.Errorf("Error parseing forest from pipe: %v", err) 59 | } 60 | if len(forest.Trees) != 10 { 61 | t.Errorf("Parsed forrest has only %v trees.", len(forest.Trees)) 62 | } 63 | 64 | catvotes := NewCatBallotBox(cattarget.Length()) 65 | count2 := 0 66 | for _, tree := range forest.Trees { 67 | tree.Vote(fm, catvotes) 68 | tree.Root.Recurse(func(*Node, []int, int) { count2++ }, fm, cases, 0) 69 | 70 | } 71 | if count != count2 { 72 | t.Errorf("Forest before file has %v nodes differs form %v nodes after.", count, count2) 73 | } 74 | 75 | //TODO(ryan): figure out what is going on with go 1.3 and use more stringent threshold here 76 | score := catvotes.TallyError(cattarget) 77 | if score > 0.4 { 78 | t.Errorf("Error: Classification of simpledataset from sf file had score: %v", score) 79 | } 80 | 81 | } 82 | -------------------------------------------------------------------------------- /gradboostclasstarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | ) 7 | 8 | func Logit(x float64) float64 { 9 | return math.Log(x / (1.0 - x)) 10 | } 11 | 12 | func Expit(x float64) (out float64) { 13 | //return 1.0 / (1.0 + math.Exp(-1.0*x)) 14 | out = 0.5 * x 15 | out = math.Tanh(out) 16 | out += 1.0 17 | out *= 0.5 18 | return out 19 | } 20 | 21 | /* 22 | GradBoostClassTarget wraps a numerical feature as a target for us in Two Class Gradiant Boosting Trees. 23 | 24 | It should be used with SumBallotBox and expit transformed to get class probabilities. 25 | */ 26 | type GradBoostClassTarget struct { 27 | *GradBoostTarget 28 | Actual NumFeature 29 | Pred NumFeature 30 | LearnRate float64 31 | Prior float64 32 | Pos_class string 33 | } 34 | 35 | func NewGradBoostClassTarget(f CatFeature, learnrate float64, pos_class string) (gbc *GradBoostClassTarget) { 36 | 37 | //fmt.Println("Back: ", f.CatToNum(pos_class), f.(*DenseCatFeature).Back) 38 | 39 | actual := f.EncodeToNum()[0].(*DenseNumFeature) 40 | pred := actual.Copy().(*DenseNumFeature) 41 | // Make sure the encoding has the positive class as 1 42 | for i := 0; i < f.Length(); i++ { 43 | if f.GetStr(i) == pos_class { 44 | actual.Put(i, 1.0) 45 | } else { 46 | actual.Put(i, 0.0) 47 | } 48 | 49 | } 50 | 51 | res := &GradBoostTarget{actual.Copy().(*DenseNumFeature), learnrate, 0.0} 52 | 53 | pos := 0.0 54 | for i := 0; i < actual.Length(); i++ { 55 | pos += actual.Get(i) 56 | } 57 | 58 | // Set intial residual to 59 | prior := math.Log(pos / (float64(res.Length()) - pos)) 60 | 61 | for i := 0; i < res.Length(); i++ { 62 | pred.Put(i, prior) 63 | v := actual.Get(i) - Expit(prior) 64 | res.Put(i, v) 65 | } 66 | 67 | //fmt.Println(res.Copy().(*DenseNumFeature).NumData) 68 | 69 | gbc = &GradBoostClassTarget{res, actual, pred, learnrate, prior, pos_class} 70 | return 71 | 72 | } 73 | 74 | func (f *GradBoostClassTarget) Intercept() float64 { 75 | return f.Prior 76 | } 77 | 78 | //BUG(ryan) does GradBoostingTarget need seperate residuals and values? 79 | func (f *GradBoostClassTarget) Boost(leaves *[][]int, preds *[]string) (weight float64) { 80 | for i, cases := range *leaves { 81 | f.Update(&cases, ParseFloat((*preds)[i])) 82 | } 83 | return f.LearnRate 84 | 85 | } 86 | 87 | func (f *GradBoostClassTarget) Predicted(cases *[]int) float64 { 88 | //TODO(ryan): update predicted on whole data not just in bag 89 | num := 0.0 90 | denom := 0.0 91 | 92 | for _, c := range *cases { 93 | r := f.Get(c) 94 | num += r 95 | y := f.Actual.Get(c) 96 | denom += (y - r) * (1.0 - y + r) 97 | 98 | } 99 | 100 | return num / denom // 1.0 / (1.0 + math.Exp(-1*meanlogodds)) 101 | } 102 | 103 | func (f *GradBoostClassTarget) FindPredicted(cases []int) (pred string) { 104 | pred = fmt.Sprintf("%v", f.Predicted(&cases)) 105 | return 106 | 107 | } 108 | 109 | //Update updates the underlying numeric data by subtracting the mean*weight of the 110 | //specified cases from the value for those cases. 111 | func (f *GradBoostClassTarget) Update(cases *[]int, predicted float64) { 112 | for _, i := range *cases { 113 | pred := f.Pred.Get(i) + f.LearnRate*predicted 114 | f.Pred.Put(i, pred) 115 | 116 | g := f.Actual.Get(i) - Expit(pred) 117 | f.Put(i, g) 118 | 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /gradboosttarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | /* 4 | GradBoostTarget wraps a numerical feature as a target for us in Gradiant Boosting Trees. 5 | 6 | It should be used with the SumBallotBox. 7 | */ 8 | type GradBoostTarget struct { 9 | NumFeature 10 | LearnRate float64 11 | Mean float64 12 | } 13 | 14 | func NewGradBoostTarget(f NumFeature, learnrate float64) (gbc *GradBoostTarget) { 15 | 16 | //res := NumFeature.(*DenseNumFeature).Copy().(*DenseNumFeature) 17 | sum := 0.0 18 | for i := 0; i < f.Length(); i++ { 19 | sum += f.Get(i) 20 | } 21 | 22 | // Set intial residual to 23 | prior := sum / float64(f.Length()) 24 | 25 | for i := 0; i < f.Length(); i++ { 26 | v := f.Get(i) - prior 27 | f.Put(i, v) 28 | } 29 | 30 | //fmt.Println(res.Copy().(*DenseNumFeature).NumData) 31 | 32 | gbc = &GradBoostTarget{f, learnrate, prior} 33 | return 34 | 35 | } 36 | 37 | func (f *GradBoostTarget) Intercept() float64 { 38 | return f.Mean 39 | } 40 | 41 | //BUG(ryan) does GradBoostingTarget need seperate residuals and values? 42 | func (f *GradBoostTarget) Boost(leaves *[][]int, preds *[]string) (weight float64) { 43 | for i, cases := range *leaves { 44 | f.Update(&cases, ParseFloat((*preds)[i])) 45 | } 46 | return f.LearnRate 47 | 48 | } 49 | 50 | //Update updates the underlying numeric data by subtracting the mean*weight of the 51 | //specified cases from the value for those cases. 52 | func (f *GradBoostTarget) Update(cases *[]int, predicted float64) { 53 | for _, i := range *cases { 54 | if !f.IsMissing(i) { 55 | f.Put(i, f.Get(i)-f.LearnRate*predicted) 56 | } 57 | } 58 | } 59 | 60 | //Impurity returns Gini impurity or mean squared error vs the mean for a set of cases 61 | //depending on weather the feature is categorical or numerical 62 | func (target *GradBoostTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 63 | e = target.NumFeature.Impurity(cases, counter) 64 | if e <= minImp { 65 | return e 66 | } 67 | e = -1.0 68 | return e 69 | 70 | } 71 | 72 | func (target *GradBoostTarget) Sum(cases *[]int) (sum float64) { 73 | for _, i := range *cases { 74 | x := target.Get(i) 75 | sum += x 76 | } 77 | return 78 | } 79 | 80 | func FriedmanScore(allocs *BestSplitAllocs, l, r *[]int) (impurityDecrease float64) { 81 | nl := float64(len(*l)) 82 | nr := float64(len(*r)) 83 | diff := (allocs.Lsum / nl) - (allocs.Rsum / nr) 84 | impurityDecrease = (diff * diff * nl * nr) / (nl + nr) 85 | 86 | // if impurityDecrease <= 10e-6 { 87 | // impurityDecrease = 0.0 88 | // } 89 | return 90 | 91 | } 92 | 93 | // Friedman MSE slit improvment score from from equation 35 in "Greedy Function Approximation: A Gradiet Boosting Machine" 94 | // Todo...what should the parent impurity be 95 | func (target *GradBoostTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 96 | 97 | allocs.Lsum = target.Sum(l) 98 | allocs.Rsum = target.Sum(r) 99 | 100 | impurityDecrease = FriedmanScore(allocs, l, r) 101 | return 102 | } 103 | 104 | func (target *GradBoostTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 105 | 106 | MVsum := target.Sum(movedRtoL) 107 | 108 | allocs.Lsum += MVsum 109 | allocs.Rsum -= MVsum 110 | 111 | impurityDecrease = FriedmanScore(allocs, l, r) 112 | return 113 | } 114 | -------------------------------------------------------------------------------- /hdistancetarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | /* 8 | HDistanceTarget wraps a categorical feature for use in Hellinger Distance tree 9 | growth. 10 | */ 11 | type HDistanceTarget struct { 12 | CatFeature 13 | Pos_class string 14 | } 15 | 16 | //NewHDistanceTarget creates a RefretTarget and initializes HDistanceTarget.Costs to the proper length. 17 | func NewHDistanceTarget(f CatFeature, pos_class string) *HDistanceTarget { 18 | return &HDistanceTarget{f, pos_class} 19 | } 20 | 21 | /* 22 | HDistanceTarget.SplitImpurity is a version of Split Impurity that calls HDistanceTarget.Impurity 23 | */ 24 | func (target *HDistanceTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) float64 { 25 | target.CountPerCat(l, allocs.LCounter) 26 | target.CountPerCat(r, allocs.RCounter) 27 | 28 | return target.HDist(allocs.LCounter, allocs.RCounter) 29 | } 30 | 31 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 32 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 33 | func (target *HDistanceTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) float64 { 34 | target.MoveCountsRtoL(allocs, movedRtoL) 35 | return target.HDist(allocs.LCounter, allocs.RCounter) 36 | } 37 | 38 | func (target *HDistanceTarget) HDist(lcounts *[]int, rcounts *[]int) (d float64) { 39 | l := *lcounts 40 | r := *rcounts 41 | 42 | // Hellinger Distance = sqrt 43 | // (count(1, left)/count(1) - count(0, left)/count0)^2 44 | // (count(1, right)/count(1) - count(0, right)/count0)^2 45 | 46 | total_0 := float64(l[0] + r[0]) 47 | total_1 := float64(l[1] + r[1]) 48 | 49 | inner := float64(l[0]) 50 | inner /= total_0 51 | inner -= float64(l[1]) / total_1 52 | d = inner * inner 53 | 54 | inner = float64(r[0]) 55 | inner /= total_0 56 | inner -= float64(r[1]) / total_1 57 | d += inner * inner 58 | 59 | // not needed because monotonic 60 | // d = math.Sqrt(d) 61 | return 62 | 63 | } 64 | 65 | func (target *HDistanceTarget) FindPredicted(cases []int) (pred string) { 66 | // TODO(ryan): lapalcian smoothing? 67 | prob_true := 0.0 68 | t := target.CatToNum(target.Pos_class) 69 | 70 | count := 0 71 | for _, i := range cases { 72 | if target.Geti(i) == t { 73 | count++ 74 | } 75 | 76 | } 77 | prob_true = float64(count) / float64(len(cases)) 78 | 79 | return fmt.Sprintf("%v", prob_true) 80 | } 81 | 82 | //HDistanceTarget.Impurity 83 | func (target *HDistanceTarget) Impurity(cases *[]int, counts *[]int) (e float64) { 84 | 85 | return -1.0 86 | 87 | } 88 | -------------------------------------------------------------------------------- /importance_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestImportance(t *testing.T) { 9 | if testing.Short() { 10 | t.Skip("Skipping importance test on boston data set.") 11 | } 12 | boston := strings.NewReader(boston_housing) 13 | 14 | fm := ParseARFF(boston) 15 | 16 | if len(fm.Data) != 14 { 17 | t.Errorf("Boston feature matrix has %v features not 14", len(fm.Data)) 18 | } 19 | 20 | // add artifical contrasts 21 | fm.ContrastAll() 22 | 23 | targeti := fm.Map["class"] 24 | 25 | candidates := make([]int, 0, 0) 26 | 27 | for i := 0; i < len(fm.Data); i++ { 28 | if i != targeti { 29 | candidates = append(candidates, i) 30 | } 31 | } 32 | 33 | numtarget := fm.Data[targeti] 34 | 35 | nTrees := 20 36 | //Brieman's importance definition 37 | imp := func(mean float64, count float64) float64 { 38 | return mean * float64(count) / float64(nTrees) 39 | } 40 | 41 | //standard 42 | config := &ForestConfig{ 43 | NSamples: fm.Data[0].Length(), 44 | MTry: 6, 45 | NTrees: nTrees, 46 | LeafSize: 1, 47 | } 48 | 49 | ff := GrowRandomForest(fm, numtarget.(Feature), config) 50 | 51 | imppnt := ff.Importance 52 | //TODO read importance scores and verify RM and LSTAT come out on top 53 | 54 | roomimp := imp((*imppnt)[fm.Map["RM"]].Read()) 55 | lstatimp := imp((*imppnt)[fm.Map["LSTAT"]].Read()) 56 | beatlstat := 0 57 | beatroom := 0 58 | 59 | for _, rm := range *imppnt { 60 | fimp := imp(rm.Read()) 61 | if fimp > roomimp { 62 | beatroom++ 63 | } 64 | if fimp > lstatimp { 65 | beatlstat++ 66 | } 67 | } 68 | if beatroom > 1 || beatlstat > 1 { 69 | t.Error("RM and LSTAT features not most important in boston data set regression.") 70 | } 71 | 72 | //vetting 73 | config.Vet = true 74 | ff = GrowRandomForest(fm, numtarget.(Feature), config) 75 | imppnt = ff.Importance 76 | //TODO read importance scores and verify RM and LSTAT come out on top 77 | 78 | roomimp = imp((*imppnt)[fm.Map["RM"]].Read()) 79 | lstatimp = imp((*imppnt)[fm.Map["LSTAT"]].Read()) 80 | beatlstat = 0 81 | beatroom = 0 82 | 83 | for _, rm := range *imppnt { 84 | fimp := imp(rm.Read()) 85 | if fimp > roomimp { 86 | beatroom++ 87 | } 88 | if fimp > lstatimp { 89 | beatlstat++ 90 | } 91 | } 92 | if beatroom > 1 || beatlstat > 1 { 93 | t.Error("RM and LSTAT features not most important in vetted boston data set regression.") 94 | } 95 | 96 | //evaloob 97 | //vetting 98 | ff = GrowRandomForest(fm, numtarget.(Feature), config) 99 | imppnt = ff.Importance 100 | //TODO read importance scores and verify RM and LSTAT come out on top 101 | 102 | roomimp = imp((*imppnt)[fm.Map["RM"]].Read()) 103 | lstatimp = imp((*imppnt)[fm.Map["LSTAT"]].Read()) 104 | beatlstat = 0 105 | beatroom = 0 106 | 107 | for _, rm := range *imppnt { 108 | fimp := imp(rm.Read()) 109 | if fimp > roomimp { 110 | beatroom++ 111 | } 112 | if fimp > lstatimp { 113 | beatlstat++ 114 | } 115 | } 116 | if beatroom > 1 || beatlstat > 1 { 117 | t.Error("RM and LSTAT features not most important in boston data set regression with eval oob.") 118 | } 119 | 120 | } 121 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | go install github.com/ryanbressler/CloudForest/growforest 2 | go install github.com/ryanbressler/CloudForest/applyforest 3 | go install github.com/ryanbressler/CloudForest/leafcount -------------------------------------------------------------------------------- /l1target.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "math" 5 | ) 6 | 7 | /* 8 | L1Target wraps a numerical feature as a target for us in l1 norm regression. 9 | */ 10 | type L1Target struct { 11 | NumFeature 12 | } 13 | 14 | /* 15 | L1Target.SplitImpurity is an L1 version of SplitImpurity. 16 | */ 17 | func (target *L1Target) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 18 | nl := float64(len(*l)) 19 | nr := float64(len(*r)) 20 | nm := 0.0 21 | 22 | impurityDecrease = nl * target.Impurity(l, nil) 23 | impurityDecrease += nr * target.Impurity(r, nil) 24 | if m != nil && len(*m) > 0 { 25 | nm = float64(len(*m)) 26 | impurityDecrease += nm * target.Impurity(m, nil) 27 | } 28 | 29 | impurityDecrease /= nl + nr + nm 30 | return 31 | } 32 | 33 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 34 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 35 | func (target *L1Target) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 36 | return target.SplitImpurity(l, r, m, allocs) 37 | } 38 | 39 | //L1Target.Impurity is an L1 version of impurity returning L1 instead of squared error. 40 | func (target *L1Target) Impurity(cases *[]int, counter *[]int) (e float64) { 41 | m := target.Mean(cases) 42 | e = target.Error(cases, m) 43 | return 44 | 45 | } 46 | 47 | //L1Target.MeanL1Error returns the Mean L1 norm error of the cases specified vs the predicted 48 | //value. Only non missing cases are considered. 49 | func (target *L1Target) Error(cases *[]int, predicted float64) (e float64) { 50 | e = 0.0 51 | n := 0 52 | for _, i := range *cases { 53 | if !target.IsMissing(i) { 54 | e += math.Abs(predicted - target.Get(i)) 55 | n += 1 56 | } 57 | 58 | } 59 | e = e / float64(n) 60 | return 61 | 62 | } 63 | -------------------------------------------------------------------------------- /leafcount/leafcount.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | "runtime" 10 | "strings" 11 | 12 | "github.com/lytics/CloudForest" 13 | ) 14 | 15 | func main() { 16 | fm := flag.String("fm", "featurematrix.afm", "AFM formated feature matrix to use.") 17 | rf := flag.String("rfpred", "rface.sf", "A predictor forest.") 18 | outf := flag.String("leaves", "leaves.tsv", "a case by case sparse matrix of leaf co-occurrence in tsv format") 19 | boutf := flag.String("branches", "", "a case by feature sparse matrix of leaf co-occurrence in tsv format") 20 | soutf := flag.String("splits", "", "a file to write a json record of splite per feature") 21 | var threads int 22 | flag.IntVar(&threads, "threads", 1, "Parse seperate forests in n seperate threads.") 23 | 24 | flag.Parse() 25 | 26 | splits := make(map[string][]string) 27 | 28 | //Parse Data 29 | data, err := CloudForest.LoadAFM(*fm) 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | 34 | log.Print("Data file ", len(data.Data), " by ", data.Data[0].Length()) 35 | 36 | counts := new(CloudForest.SparseCounter) 37 | var caseFeatureCounts *CloudForest.SparseCounter 38 | if *boutf != "" { 39 | caseFeatureCounts = new(CloudForest.SparseCounter) 40 | } 41 | 42 | files := strings.Split(*rf, ",") 43 | 44 | runtime.GOMAXPROCS(threads) 45 | 46 | fileChan := make(chan string, 0) 47 | doneChan := make(chan int, 0) 48 | 49 | go func() { 50 | for _, fn := range files { 51 | fileChan <- fn 52 | } 53 | }() 54 | 55 | for i := 0; i < threads; i++ { 56 | 57 | go func() { 58 | for { 59 | fn := <-fileChan 60 | 61 | forestfile, err := os.Open(fn) // For read access. 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | defer forestfile.Close() 66 | forestreader := CloudForest.NewForestReader(forestfile) 67 | forest, err := forestreader.ReadForest() 68 | if err != nil { 69 | log.Fatal(err) 70 | } 71 | log.Print("Forest has ", len(forest.Trees), " trees ") 72 | 73 | for i := 0; i < len(forest.Trees); i++ { 74 | fmt.Print(".") 75 | leaves := forest.Trees[i].GetLeaves(data, caseFeatureCounts) 76 | for _, leaf := range leaves { 77 | for j := 0; j < len(leaf.Cases); j++ { 78 | for k := 0; k < len(leaf.Cases); k++ { 79 | 80 | counts.Add(leaf.Cases[j], leaf.Cases[k], 1) 81 | 82 | } 83 | } 84 | } 85 | 86 | if *soutf != "" { 87 | forest.Trees[i].Root.Climb(func(n *CloudForest.Node) { 88 | if n.Splitter != nil { 89 | name := n.Splitter.Feature 90 | _, ok := splits[name] 91 | if !ok { 92 | splits[name] = make([]string, 0, 10) 93 | } 94 | split := "" 95 | switch n.Splitter.Numerical { 96 | case true: 97 | split = fmt.Sprintf("%v", n.Splitter.Value) 98 | case false: 99 | keys := make([]string, 0, len(n.Splitter.Left)) 100 | for k := range n.Splitter.Left { 101 | keys = append(keys, k) 102 | } 103 | split = strings.Join(keys, ",") 104 | } 105 | splits[name] = append(splits[name], split) 106 | } 107 | }) 108 | } 109 | 110 | } 111 | doneChan <- 1 112 | } 113 | }() 114 | 115 | } 116 | 117 | for i := 0; i < len(files); i++ { 118 | <-doneChan 119 | } 120 | 121 | log.Print("Outputting Case Case Co-Occurrence Counts") 122 | outfile, err := os.Create(*outf) 123 | if err != nil { 124 | log.Fatal(err) 125 | } 126 | defer outfile.Close() 127 | counts.WriteTsv(outfile) 128 | 129 | if *boutf != "" { 130 | log.Print("Outputting Case Feature Co-Occurrence Counts") 131 | boutfile, err := os.Create(*boutf) 132 | if err != nil { 133 | log.Fatal(err) 134 | } 135 | defer boutfile.Close() 136 | caseFeatureCounts.WriteTsv(boutfile) 137 | } 138 | 139 | if *soutf != "" { 140 | soutfile, err := os.Create(*soutf) 141 | if err != nil { 142 | log.Fatal(err) 143 | } 144 | defer soutfile.Close() 145 | encoder := json.NewEncoder(soutfile) 146 | encoder.Encode(splits) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /libsvm.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "bufio" 5 | "encoding/csv" 6 | "fmt" 7 | "io" 8 | "log" 9 | "strconv" 10 | "strings" 11 | ) 12 | 13 | func ParseLibSVM(input io.Reader) *FeatureMatrix { 14 | reader := bufio.NewReader(input) 15 | 16 | data := make([]Feature, 0, 100) 17 | lookup := make(map[string]int, 0) 18 | labels := make([]string, 0, 0) 19 | 20 | i := 0 21 | ncases := 0 22 | for { 23 | ncases++ 24 | 25 | line, err := reader.ReadString('\n') 26 | if err == io.EOF { 27 | break 28 | } else if err != nil { 29 | log.Print("Error:", err) 30 | return nil 31 | } 32 | 33 | vals := strings.Fields(line) 34 | 35 | if i == 0 { 36 | name := "0" 37 | lookup[name] = 0 38 | if strings.Contains(vals[0], ".") { 39 | //looks like a float...add dense float64 feature regression 40 | data = append(data, &DenseNumFeature{ 41 | make([]float64, 0, 0), 42 | make([]bool, 0, 0), 43 | name, 44 | false}) 45 | 46 | } else { 47 | //doesn't look like a float...add dense catagorical 48 | data = append(data, &DenseCatFeature{ 49 | NewCatMap(), 50 | make([]int, 0, 0), 51 | make([]bool, 0, 0), 52 | name, 53 | false, 54 | false}) 55 | } 56 | 57 | } 58 | data[0].Append(vals[0]) 59 | 60 | //pad existing features 61 | for _, f := range data[1:] { 62 | f.Append("0") 63 | } 64 | 65 | for _, v := range vals[1:] { 66 | parts := strings.Split(v, ":") 67 | xi, err := strconv.Atoi(parts[0]) 68 | if err != nil { 69 | log.Print("Atoi error: ", err, " Line ", i, " Parsing: ", v) 70 | } 71 | //pad out the data to include this feature 72 | for xi >= len(data) { 73 | name := fmt.Sprintf("%v", len(data)) 74 | lookup[name] = len(data) 75 | data = append(data, &DenseNumFeature{ 76 | make([]float64, ncases, ncases), 77 | make([]bool, ncases, ncases), 78 | name, 79 | false}) 80 | 81 | } 82 | data[xi].PutStr(i, parts[1]) 83 | 84 | } 85 | 86 | label := fmt.Sprintf("%v", i) 87 | labels = append(labels, label) 88 | i++ 89 | 90 | } 91 | 92 | fm := &FeatureMatrix{data, lookup, labels} 93 | 94 | return fm 95 | 96 | } 97 | 98 | func WriteLibSvm(data *FeatureMatrix, targetn string, outfile io.Writer) error { 99 | targeti, ok := data.Map[targetn] 100 | if !ok { 101 | return fmt.Errorf("Target '%v' not found in data.", targetn) 102 | } 103 | target := data.Data[targeti] 104 | 105 | //data.Data = append(data.Data[:targeti], data.Data[targeti+1:]...) 106 | 107 | noTargetFm := &FeatureMatrix{make([]Feature, 0, len(data.Data)), make(map[string]int), data.CaseLabels} 108 | 109 | for i, f := range data.Data { 110 | if i != targeti { 111 | noTargetFm.Map[f.GetName()] = len(noTargetFm.Data) 112 | noTargetFm.Data = append(noTargetFm.Data, f.Copy()) 113 | 114 | } 115 | } 116 | 117 | noTargetFm.ImputeMissing() 118 | encodedfm := noTargetFm.EncodeToNum() 119 | 120 | oucsv := csv.NewWriter(outfile) 121 | oucsv.Comma = ' ' 122 | 123 | for i := 0; i < target.Length(); i++ { 124 | entries := make([]string, 0, 10) 125 | switch target.(type) { 126 | case NumFeature: 127 | entries = append(entries, target.GetStr(i)) 128 | case CatFeature: 129 | entries = append(entries, fmt.Sprintf("%v", target.(CatFeature).Geti(i))) 130 | } 131 | 132 | for j, f := range encodedfm.Data { 133 | v := f.(NumFeature).Get(i) 134 | if v != 0.0 { 135 | entries = append(entries, fmt.Sprintf("%v:%v", j+1, v)) 136 | } 137 | } 138 | //fmt.Println(entries) 139 | err := oucsv.Write(entries) 140 | if err != nil { 141 | return err 142 | } 143 | 144 | } 145 | oucsv.Flush() 146 | return nil 147 | } 148 | 149 | func WriteLibSvmCases(data *FeatureMatrix, cases []int, targetn string, outfile io.Writer) error { 150 | targeti, ok := data.Map[targetn] 151 | if !ok { 152 | return fmt.Errorf("Target '%v' not found in data.", targetn) 153 | } 154 | target := data.Data[targeti] 155 | 156 | noTargetFm := &FeatureMatrix{make([]Feature, 0, len(data.Data)), make(map[string]int), data.CaseLabels} 157 | 158 | encode := false 159 | for i, f := range data.Data { 160 | if i != targeti { 161 | if data.Data[i].NCats() > 0 { 162 | encode = true 163 | } 164 | noTargetFm.Map[f.GetName()] = len(noTargetFm.Data) 165 | noTargetFm.Data = append(noTargetFm.Data, f) 166 | 167 | } 168 | } 169 | 170 | noTargetFm.ImputeMissing() 171 | 172 | encodedfm := noTargetFm 173 | if encode { 174 | encodedfm = noTargetFm.EncodeToNum() 175 | } 176 | 177 | oucsv := csv.NewWriter(outfile) 178 | oucsv.Comma = ' ' 179 | 180 | for _, i := range cases { 181 | entries := make([]string, 0, 10) 182 | switch target.(type) { 183 | case NumFeature: 184 | entries = append(entries, fmt.Sprintf("%g", target.(NumFeature).Get(i))) 185 | case CatFeature: 186 | entries = append(entries, fmt.Sprintf("%v", target.(CatFeature).Geti(i))) 187 | } 188 | 189 | for j, f := range encodedfm.Data { 190 | v := f.(NumFeature).Get(i) 191 | if v != 0.0 { 192 | entries = append(entries, fmt.Sprintf("%v:%v", j+1, v)) 193 | } 194 | } 195 | //fmt.Println(entries) 196 | err := oucsv.Write(entries) 197 | if err != nil { 198 | return err 199 | } 200 | 201 | } 202 | oucsv.Flush() 203 | return nil 204 | } 205 | -------------------------------------------------------------------------------- /node.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | //Recursable defines a function signature for functions that can be called at every 4 | //down stream node of a tree as Node.Recurse recurses up the tree. The function should 5 | //have two parameters, the current node and an array of ints specifying the cases that 6 | //have not been split away. 7 | type Recursable func(*Node, []int, int) 8 | 9 | type CodedRecursable func(*Node, *[]int, int, int) (int, interface{}, int) 10 | 11 | //A node of a decision tree. 12 | //Pred is a string containing either the category or a representation of a float 13 | //(less then ideal) 14 | type Node struct { 15 | CodedSplit interface{} 16 | Featurei int 17 | Left *Node 18 | Right *Node 19 | Missing *Node 20 | Pred string 21 | Splitter *Splitter 22 | } 23 | 24 | func (n *Node) Copy() *Node { 25 | if n == nil { 26 | return nil 27 | } 28 | 29 | cp := &Node{ 30 | CodedSplit: n.CodedSplit, 31 | Featurei: n.Featurei, 32 | Pred: n.Pred, 33 | Splitter: n.Splitter.Copy(), 34 | } 35 | 36 | cp.Left = n.Left.Copy() 37 | cp.Right = n.Right.Copy() 38 | cp.Missing = n.Missing.Copy() 39 | return cp 40 | } 41 | 42 | //vist each child node with the supplied function 43 | func (n *Node) Climb(c func(*Node)) { 44 | c(n) 45 | if n.Left != nil { 46 | n.Left.Climb(c) 47 | } 48 | if n.Right != nil { 49 | n.Right.Climb(c) 50 | } 51 | if n.Missing != nil { 52 | n.Missing.Climb(c) 53 | } 54 | } 55 | 56 | //Recurse is used to apply a Recursable function at every downstream node as the cases 57 | //specified by case []int are split using the data in fm *Featurematrix. Recursion 58 | //down a branch stops when a a node with n.Splitter == nil is reached. Recursion down 59 | //the Missing branch is only used if n.Missing!=nil. 60 | //For example votes can be tabulated using code like: 61 | // t.Root.Recurse(func(n *Node, cases []int) { 62 | // if n.Left == nil && n.Right == nil { 63 | // // I'm in a leaf node 64 | // for i := 0; i < len(cases); i++ { 65 | // bb.Vote(cases[i], n.Pred) 66 | // } 67 | // } 68 | // }, fm, cases) 69 | func (n *Node) Recurse(r Recursable, fm *FeatureMatrix, cases []int, depth int) { 70 | r(n, cases, depth) 71 | depth++ 72 | var ls, rs, ms []int 73 | switch { 74 | case n.CodedSplit != nil: 75 | ls, rs, ms = fm.Data[n.Featurei].Split(n.CodedSplit, cases) 76 | case n.Splitter != nil: 77 | ls, rs, ms = n.Splitter.Split(fm, cases) 78 | default: 79 | return 80 | } 81 | 82 | if n.Left != nil { 83 | n.Left.Recurse(r, fm, ls, depth) 84 | } 85 | if n.Right != nil { 86 | n.Right.Recurse(r, fm, rs, depth) 87 | } 88 | if len(ms) > 0 && n.Missing != nil { 89 | n.Missing.Recurse(r, fm, ms, depth) 90 | } 91 | } 92 | 93 | func (n *Node) CodedRecurse(r CodedRecursable, fm *FeatureMatrix, cases *[]int, depth int, nconstantsbefore int) { 94 | fi, codedSplit, nconstants := r(n, cases, depth, nconstantsbefore) 95 | depth++ 96 | if codedSplit != nil { 97 | li, ri := fm.Data[fi].SplitPoints(codedSplit, cases) 98 | cs := (*cases)[:li] 99 | n.Left.CodedRecurse(r, fm, &cs, depth, nconstants) 100 | cs = (*cases)[ri:] 101 | n.Right.CodedRecurse(r, fm, &cs, depth, nconstants) 102 | if li != ri && n.Missing != nil { 103 | cs = (*cases)[li:ri] 104 | n.Missing.CodedRecurse(r, fm, &cs, depth, nconstants) 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /numadaboostingtarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "math" 5 | ) 6 | 7 | /* 8 | NumNumAdaBoostTarget wraps a numerical feature as a target for us in (Experimental) Adaptive Boosting 9 | Regression. 10 | */ 11 | type NumAdaBoostTarget struct { 12 | NumFeature 13 | Weights []float64 14 | NormFactor float64 15 | } 16 | 17 | func NewNumAdaBoostTarget(f NumFeature) (abt *NumAdaBoostTarget) { 18 | nCases := f.Length() 19 | abt = &NumAdaBoostTarget{f, make([]float64, nCases), 0.0} 20 | cases := make([]int, nCases) 21 | for i := range abt.Weights { 22 | abt.Weights[i] = 1 / float64(nCases) 23 | cases[i] = i 24 | } 25 | abt.NormFactor = abt.Impurity(&cases, nil) * float64(nCases) 26 | return 27 | } 28 | 29 | /* 30 | NumAdaBoostTarget.SplitImpurity is an AdaBoosting version of SplitImpurity. 31 | */ 32 | func (target *NumAdaBoostTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 33 | nl := float64(len(*l)) 34 | nr := float64(len(*r)) 35 | nm := 0.0 36 | 37 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 38 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 39 | if m != nil && len(*m) > 0 { 40 | nm = float64(len(*m)) 41 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 42 | } 43 | 44 | impurityDecrease /= nl + nr + nm 45 | return 46 | } 47 | 48 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 49 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 50 | func (target *NumAdaBoostTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 51 | return target.SplitImpurity(l, r, m, allocs) 52 | } 53 | 54 | //NumAdaBoostTarget.Impurity is an AdaBoosting that uses the weights specified in NumAdaBoostTarget.weights. 55 | func (target *NumAdaBoostTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 56 | e = 0.0 57 | m := target.Predicted(cases) 58 | for _, c := range *cases { 59 | if target.IsMissing(c) == false { 60 | e += target.Weights[c] * target.Norm(c, m) 61 | } 62 | 63 | } 64 | return 65 | } 66 | 67 | //AdaBoostTarget.Boost performs numerical adaptive boosting using the specified partition and 68 | //returns the weight that tree that generated the partition should be given. 69 | //Trees with error greater then the impurity of the total feature (NormFactor) times the number 70 | //of partions are given zero weight. Other trees have tree weight set to: 71 | // 72 | // weight = math.Log(1 / norm) 73 | // 74 | //and weights updated to: 75 | // 76 | // t.Weights[c] = t.Weights[c] * math.Exp(t.Error(&[]int{c}, m)*weight) 77 | // 78 | //These functions are chosen to provide a rough analog to catagorical adaptive boosting for 79 | //numerical data with unbounded error. 80 | func (t *NumAdaBoostTarget) Boost(leaves *[][]int) (weight float64) { 81 | if len(*leaves) == 0 { 82 | return 0.0 83 | } 84 | imp := 0.0 85 | //nCases := 0 86 | for _, cases := range *leaves { 87 | imp += t.Impurity(&cases, nil) 88 | //nCases += len(cases) 89 | } 90 | norm := t.NormFactor 91 | if imp > norm { 92 | return 0.0 93 | } 94 | 95 | weight = math.Log(norm / imp) 96 | 97 | for _, cases := range *leaves { 98 | m := t.Predicted(&cases) 99 | for _, c := range cases { 100 | if t.IsMissing(c) == false { 101 | t.Weights[c] = t.Weights[c] * math.Exp(weight*(t.Norm(c, m)-imp)) 102 | } 103 | 104 | } 105 | } 106 | 107 | normfactor := 0.0 108 | for _, v := range t.Weights { 109 | normfactor += v 110 | } 111 | for i, v := range t.Weights { 112 | t.Weights[i] = v / normfactor 113 | } 114 | return 115 | } 116 | -------------------------------------------------------------------------------- /numballotbox.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "strconv" 7 | ) 8 | 9 | //Keeps track of votes by trees. 10 | //Voteing is thread safe. 11 | type NumBallotBox struct { 12 | box []*RunningMean 13 | } 14 | 15 | //Build a new ballot box for the number of cases specified by "size". 16 | func NewNumBallotBox(size int) *NumBallotBox { 17 | bb := NumBallotBox{ 18 | make([]*RunningMean, 0, size)} 19 | for i := 0; i < size; i++ { 20 | bb.box = append(bb.box, new(RunningMean)) 21 | } 22 | return &bb 23 | } 24 | 25 | //Vote parses the float in the string and votes for it 26 | func (bb *NumBallotBox) Vote(casei int, pred string, weight float64) { 27 | v, err := strconv.ParseFloat(pred, 64) 28 | if err == nil { 29 | bb.box[casei].WeightedAdd(v, weight) 30 | } 31 | 32 | } 33 | 34 | //TallyNumerical tallies the votes for the case specified by i as 35 | //if it is a Numerical feature. Ie it returns the mean of all votes. 36 | func (bb *NumBallotBox) TallyNum(i int) (predicted float64) { 37 | predicted, _ = bb.box[i].Read() 38 | return 39 | } 40 | 41 | func (bb *NumBallotBox) Tally(i int) (predicted string) { 42 | mean, count := bb.box[i].Read() 43 | if count > 0 { 44 | predicted = fmt.Sprintf("%v", mean) 45 | } else { 46 | predicted = "NA" 47 | } 48 | return 49 | } 50 | 51 | //TallySquareError returns the error of the votes vs the provided feature. 52 | //For categorical features it returns the error rate 53 | //For numerical features it returns mean squared error. 54 | //The provided feature must use the same index as the feature matrix 55 | //the ballot box was constructed with. 56 | //Missing values are ignored. 57 | //Gini impurity is not used so this is not for use in rf implementations. 58 | func (bb *NumBallotBox) TallySquaredError(feature Feature) (e float64) { 59 | e = 0.0 60 | 61 | // Numerical feature. Calculate mean squared 62 | d := 0.0 63 | c := 0 64 | for i := 0; i < feature.Length(); i++ { 65 | predicted := bb.TallyNum(i) 66 | if !feature.IsMissing(i) && !math.IsNaN(predicted) { 67 | value := feature.(NumFeature).Get(i) 68 | 69 | d = float64(value) - predicted 70 | e += d * d 71 | c += 1 72 | } 73 | } 74 | if c == 0.0 { 75 | e = math.NaN() 76 | } else { 77 | e = e / float64(c) 78 | } 79 | 80 | return 81 | 82 | } 83 | 84 | //TallyScore returns the squared error (unexplained variance) divided by the data variance. 85 | func (bb *NumBallotBox) TallyError(feature Feature) (e float64) { 86 | mean := 0.0 87 | r2 := 0.0 88 | total := 0 89 | for i := 0; i < feature.Length(); i++ { 90 | if !feature.IsMissing(i) { 91 | mean += feature.(*DenseNumFeature).Get(i) 92 | total++ 93 | } 94 | } 95 | mean /= float64(total) 96 | 97 | for i := 0; i < feature.Length(); i++ { 98 | if !feature.IsMissing(i) { 99 | value := feature.(NumFeature).Get(i) 100 | 101 | d := float64(value) - mean 102 | r2 += d * d 103 | } 104 | } 105 | r2 /= float64(total) 106 | 107 | e = bb.TallySquaredError(feature) / r2 108 | 109 | return 110 | 111 | } 112 | 113 | //Tally score returns the R2 score or coefichent of determination. 114 | func (bb *NumBallotBox) TallyR2Score(feature Feature) (e float64) { 115 | 116 | e = 1 - bb.TallyError(feature) 117 | 118 | return 119 | 120 | } 121 | -------------------------------------------------------------------------------- /ordinaltarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | /* 8 | OrdinalTarget wraps a numerical feature as a target for us in ordinal regression. 9 | Data should be represented as positive integers and the Error is embeded from the 10 | embeded NumFeature. 11 | */ 12 | type OrdinalTarget struct { 13 | NumFeature 14 | nClass int 15 | max float64 16 | } 17 | 18 | /* 19 | NewOrdinalTarget creates a categorical adaptive boosting target and initializes its weights. 20 | */ 21 | func NewOrdinalTarget(f NumFeature) (abt *OrdinalTarget) { 22 | nCases := f.Length() 23 | abt = &OrdinalTarget{f, 0, 0.0} 24 | for i := 0; i < nCases; i++ { 25 | v := f.Get(i) 26 | if v > abt.max { 27 | abt.max = v 28 | } 29 | } 30 | 31 | abt.nClass = int(abt.max) + 1 32 | return 33 | } 34 | 35 | /* 36 | OrdinalTarget.SplitImpurity is an ordinal version of SplitImpurity. 37 | */ 38 | func (target *OrdinalTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 39 | nl := float64(len(*l)) 40 | nr := float64(len(*r)) 41 | nm := 0.0 42 | 43 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 44 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 45 | if m != nil && len(*m) > 0 { 46 | nm = float64(len(*m)) 47 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 48 | } 49 | 50 | impurityDecrease /= nl + nr + nm 51 | return 52 | } 53 | 54 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 55 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 56 | func (target *OrdinalTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 57 | return target.SplitImpurity(l, r, m, allocs) 58 | } 59 | 60 | func (f *OrdinalTarget) Predicted(cases *[]int) float64 { 61 | return f.Mode(cases) 62 | } 63 | 64 | func (f *OrdinalTarget) Mode(cases *[]int) (m float64) { 65 | counts := make([]int, f.nClass) 66 | for _, i := range *cases { 67 | if !f.IsMissing(i) { 68 | counts[int(f.Get(i))] += 1 69 | } 70 | 71 | } 72 | max := 0 73 | for k, v := range counts { 74 | if v > max { 75 | m = float64(k) 76 | max = v 77 | } 78 | } 79 | return 80 | 81 | } 82 | 83 | //OrdinalTarget.Impurity is an ordinal version of impurity using Mode instead of Mean for prediction. 84 | func (target *OrdinalTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 85 | m := target.Predicted(cases) 86 | e = target.Error(cases, m) 87 | return 88 | 89 | } 90 | 91 | func (target *OrdinalTarget) FindPredicted(cases []int) (pred string) { 92 | return fmt.Sprintf("%v", target.Predicted(&cases)) 93 | } 94 | -------------------------------------------------------------------------------- /regrettarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | /* 4 | RegretTarget wraps a categorical feature for use in regret driven classification. 5 | The ith entry in costs should contain the cost of misclassifying a case that actually 6 | has the ith category. 7 | 8 | It is roughly equivelent to the ideas presented in: 9 | 10 | http://machinelearning.wustl.edu/mlpapers/paper_files/icml2004_LingYWZ04.pdf 11 | 12 | "Decision Trees with Minimal Costs" 13 | Charles X. Ling,Qiang Yang,Jianning Wang,Shichao Zhang 14 | */ 15 | type RegretTarget struct { 16 | CatFeature 17 | Costs []float64 18 | } 19 | 20 | //NewRegretTarget creates a RefretTarget and initializes RegretTarget.Costs to the proper length. 21 | func NewRegretTarget(f CatFeature) *RegretTarget { 22 | return &RegretTarget{f, make([]float64, f.NCats())} 23 | } 24 | 25 | /*RegretTarget.SetCosts puts costs in a map[string]float64 by feature name into the proper 26 | entries in RegretTarget.Costs.*/ 27 | func (target *RegretTarget) SetCosts(costmap map[string]float64) { 28 | for i := 0; i < target.NCats(); i++ { 29 | c := target.NumToCat(i) 30 | target.Costs[i] = costmap[c] 31 | } 32 | } 33 | 34 | /* 35 | RegretTarget.SplitImpurity is a version of Split Impurity that calls RegretTarget.Impurity 36 | */ 37 | func (target *RegretTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 38 | nl := float64(len(*l)) 39 | nr := float64(len(*r)) 40 | nm := 0.0 41 | 42 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 43 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 44 | if m != nil && len(*m) > 0 { 45 | nm = float64(len(*m)) 46 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 47 | } 48 | 49 | impurityDecrease /= nl + nr + nm 50 | return 51 | } 52 | 53 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l 54 | //to avoid recalulatign the entire split impurity. 55 | func (target *RegretTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 56 | var cat, i int 57 | lcounter := *allocs.LCounter 58 | rcounter := *allocs.RCounter 59 | for _, i = range *movedRtoL { 60 | 61 | //most expensive statement: 62 | cat = target.Geti(i) 63 | lcounter[cat]++ 64 | rcounter[cat]-- 65 | //counter[target.Geti(i)]++ 66 | 67 | } 68 | nl := float64(len(*l)) 69 | nr := float64(len(*r)) 70 | nm := 0.0 71 | 72 | impurityDecrease = nl * target.ImpFromCounts(len(*l), allocs.LCounter) 73 | impurityDecrease += nr * target.ImpFromCounts(len(*r), allocs.RCounter) 74 | if m != nil && len(*m) > 0 { 75 | nm = float64(len(*m)) 76 | impurityDecrease += nm * target.ImpFromCounts(len(*m), allocs.Counter) 77 | } 78 | 79 | impurityDecrease /= nl + nr + nm 80 | return 81 | } 82 | 83 | //FindPredicted does a mode calulation with the count of the positive/constrained 84 | //class corrected. 85 | func (target *RegretTarget) FindPredicted(cases []int) (pred string) { 86 | 87 | mi := 0 88 | mc := 0.0 89 | counts := make([]int, target.NCats()) 90 | 91 | target.CountPerCat(&cases, &counts) 92 | 93 | for cat, count := range counts { 94 | cc := float64(count) * target.Costs[cat] 95 | if cc > mc { 96 | mi = cat 97 | mc = cc 98 | } 99 | } 100 | 101 | return target.NumToCat(mi) 102 | 103 | } 104 | 105 | //ImpFromCounts recalculates gini impurity from class counts for us in intertive updates. 106 | func (target *RegretTarget) ImpFromCounts(t int, counter *[]int) (e float64) { 107 | 108 | mi := 0 109 | 110 | mc := 0.0 111 | 112 | for cat, count := range *counter { 113 | cc := float64(count) * target.Costs[cat] 114 | 115 | if cc > mc { 116 | mi = cat 117 | mc = cc 118 | } 119 | 120 | } 121 | 122 | for cat, count := range *counter { 123 | 124 | t += count 125 | if cat != mi { 126 | e += target.Costs[cat] * float64(count) 127 | } 128 | 129 | } 130 | e /= float64(t) 131 | 132 | return 133 | 134 | } 135 | 136 | //Impurity implements an impurity based on misslassification costs. 137 | func (target *RegretTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 138 | 139 | target.CountPerCat(cases, counter) 140 | t := len(*cases) 141 | e = target.ImpFromCounts(t, counter) 142 | 143 | return 144 | 145 | } 146 | 147 | //RegretTarget.Impurity implements a simple regret function that finds the average cost of 148 | //a set using the misclassification costs in RegretTarget.Costs. 149 | // func (target *RegretTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 150 | // m := target.Modei(cases) 151 | // t := 0 152 | // for _, c := range *cases { 153 | // if target.IsMissing(c) == false { 154 | // t += 1 155 | // cat := target.Geti(c) 156 | // if cat != m { 157 | // e += target.Costs[cat] 158 | // } 159 | // } 160 | 161 | // } 162 | // e /= float64(t) 163 | 164 | // return 165 | // } 166 | -------------------------------------------------------------------------------- /sampling.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import "math/rand" 4 | 5 | type Bagger interface { 6 | Sample(samples *[]int, n int) 7 | } 8 | 9 | //BalancedSampler provides for random sampelign of integers (usually case indexes) 10 | //in a way that ensures a balanced presence of classes. 11 | type BalancedSampler struct { 12 | Cases [][]int 13 | } 14 | 15 | //NeaBalancedSampler initalizes a balanced sampler that will evenly balance cases 16 | //between the classes present in the provided DesnseeCatFeature. 17 | func NewBalancedSampler(catf *DenseCatFeature) (s *BalancedSampler) { 18 | s = &BalancedSampler{make([][]int, 0, catf.NCats())} 19 | 20 | for i := 0; i < catf.NCats(); i++ { 21 | s.Cases = append(s.Cases, make([]int, 0, catf.Length())) 22 | } 23 | 24 | for i, v := range catf.CatData { 25 | if !catf.IsMissing(i) { 26 | s.Cases[v] = append(s.Cases[v], i) 27 | } 28 | } 29 | return 30 | } 31 | 32 | //Sample samples n integers in a balnced-with-replacment fashion into the provided array. 33 | func (s *BalancedSampler) Sample(samples *[]int, n int) { 34 | (*samples) = (*samples)[0:0] 35 | nCases := len(s.Cases) 36 | c := 0 37 | for i := 0; i < n; i++ { 38 | c = rand.Intn(nCases) 39 | (*samples) = append((*samples), s.Cases[c][rand.Intn(len(s.Cases[c]))]) 40 | } 41 | 42 | } 43 | 44 | //SecondaryBalancedSampler roughly balances the target feature within the classes of another catagorical 45 | //feature while roughly preserving the origional rate of the secondary feature. 46 | type SecondaryBalancedSampler struct { 47 | Total int 48 | Counts []int 49 | Samplers [][][]int 50 | } 51 | 52 | //NewSecondaryBalancedSampler returns an initalized balanced sampler. 53 | func NewSecondaryBalancedSampler(target *DenseCatFeature, balanceby *DenseCatFeature) (s *SecondaryBalancedSampler) { 54 | nSecondaryCats := balanceby.NCats() 55 | s = &SecondaryBalancedSampler{0, make([]int, nSecondaryCats, nSecondaryCats), make([][][]int, 0, nSecondaryCats)} 56 | 57 | for i := 0; i < nSecondaryCats; i++ { 58 | s.Samplers = append(s.Samplers, make([][]int, 0, target.NCats())) 59 | for j := 0; j < target.NCats(); j++ { 60 | s.Samplers[i] = append(s.Samplers[i], make([]int, 0, target.Length())) 61 | } 62 | 63 | } 64 | 65 | for i := 0; i < target.Length(); i++ { 66 | if !target.IsMissing(i) && !balanceby.IsMissing(i) { 67 | s.Total += 1 68 | balanceCat := balanceby.Geti(i) 69 | targetCat := target.Geti(i) 70 | s.Counts[balanceCat] += 1 71 | s.Samplers[balanceCat][targetCat] = append(s.Samplers[balanceCat][targetCat], i) 72 | } 73 | } 74 | return 75 | 76 | } 77 | 78 | func (s *SecondaryBalancedSampler) Sample(samples *[]int, n int) { 79 | (*samples) = (*samples)[0:0] 80 | 81 | b := 0 82 | c := 0 83 | for i := 0; i < n; i++ { 84 | b = rand.Intn(s.Total) 85 | for j, v := range s.Counts { 86 | b = b - v 87 | if b < 0 || j == (len(s.Counts)-1) { 88 | b = j 89 | break 90 | } 91 | } 92 | nCases := len(s.Samplers[b]) 93 | c = rand.Intn(nCases) 94 | (*samples) = append((*samples), s.Samplers[b][c][rand.Intn(len(s.Samplers[b][c]))]) 95 | } 96 | 97 | } 98 | 99 | /* 100 | SampleFirstN ensures that the first n entries in the supplied 101 | deck are randomly drawn from all entries without replacement for use in selecting candidate 102 | features to split on. It accepts a pointer to the deck so that it can be used repeatedly on 103 | the same deck avoiding reallocations. 104 | */ 105 | func SampleFirstN(deck *[]int, samples *[]int, n int, nconstants int) { 106 | cards := *deck 107 | length := len(cards) 108 | old := 0 109 | randi := 0 110 | lastSample := 0 111 | nDrawnConstants := 0 112 | nnonconstant := length - nconstants 113 | for i := 0; i < n && i < nnonconstant; i++ { 114 | 115 | randi = lastSample + rand.Intn(length-nDrawnConstants-lastSample) 116 | //randi = lastSample + rand.Intn(nnonconstant-lastSample) 117 | if randi >= nnonconstant { 118 | nDrawnConstants++ 119 | continue 120 | } 121 | 122 | old = cards[lastSample] 123 | cards[lastSample] = cards[randi] 124 | cards[randi] = old 125 | lastSample++ 126 | } 127 | if samples != nil { 128 | (*samples) = cards[:lastSample] 129 | } 130 | 131 | } 132 | 133 | /* 134 | SampleWithReplacment samples nSamples random draws from [0,totalCases) with replacement 135 | for use in selecting cases to grow a tree from. 136 | */ 137 | func SampleWithReplacment(nSamples, totalCases int) []int { 138 | cases := make([]int, nSamples) 139 | for i := 0; i < nSamples; i++ { 140 | cases[i] = rand.Intn(totalCases) 141 | } 142 | return cases 143 | } 144 | 145 | /* 146 | SampleWithoutReplacement samples nSamples random draws from [0, totalCases] w/o replacement 147 | for use in selecting cases to grow a tree from. 148 | */ 149 | func SampleWithoutReplacement(nSamples, totalCases int) []int { 150 | return rand.Perm(totalCases)[:nSamples] 151 | } 152 | -------------------------------------------------------------------------------- /sampling_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestSampleFirstN(t *testing.T) { 9 | deck := []int{0, 1, 2, 3} 10 | var samples []int 11 | SampleFirstN(&deck, &samples, 2, 0) 12 | if len(samples) != 2 { 13 | t.Errorf("Error: sampeling 2 items returned %v samples", len(samples)) 14 | } 15 | deck = []int{0, 1, 2, 3} 16 | SampleFirstN(&deck, &samples, 2, 2) 17 | 18 | if deck[2] != 2 || deck[3] != 3 { 19 | t.Errorf("Sampeling 2 items with 2 constant resulted in %v %v", deck, samples) 20 | } 21 | 22 | deck = []int{0, 1, 2, 3} 23 | SampleFirstN(&deck, &samples, 2, 3) 24 | 25 | if deck[1] != 1 || deck[2] != 2 || deck[3] != 3 { 26 | t.Errorf("Sampeling 2 items with 3 constant resulted in %v %v", deck, samples) 27 | } 28 | 29 | } 30 | 31 | var bfm = `. 0 1 2 3 4 5 6 7 32 | C:1 0 0 1 1 1 1 1 1 33 | C:2 0 1 0 1 0 1 0 1` 34 | 35 | func TestSampeling(t *testing.T) { 36 | fmReader := strings.NewReader(bfm) 37 | 38 | fm := ParseAFM(fmReader) 39 | cases := make([]int, 0, 1000) 40 | 41 | samplers := []Bagger{NewBalancedSampler(fm.Data[0].(*DenseCatFeature)), 42 | NewSecondaryBalancedSampler(fm.Data[0].(*DenseCatFeature), fm.Data[1].(*DenseCatFeature)), 43 | } 44 | 45 | for _, bs := range samplers { 46 | bs.Sample(&cases, 1000) 47 | case0 := 0 48 | case1 := 0 49 | 50 | for _, c := range cases { 51 | if c == 0 { 52 | case0++ 53 | } 54 | if c == 1 { 55 | case1++ 56 | } 57 | } 58 | switch bs.(type) { 59 | case *BalancedSampler: 60 | s := bs.(*BalancedSampler) 61 | if l := len(s.Cases); l != 2 { 62 | t.Errorf("Balanced sampler found %v cases not 2: %v", l, fm.Data[0].(*DenseCatFeature).Back) 63 | } 64 | 65 | case *SecondaryBalancedSampler: 66 | s := bs.(*SecondaryBalancedSampler) 67 | if s.Total != 8 { 68 | t.Errorf("SecondaryBalanced sampler found %v total cases not 8", s.Total) 69 | } 70 | if l := len(s.Samplers); l != 2 { 71 | t.Errorf("SecondaryBalanced sampler found %v cases not 2", l) 72 | } 73 | if l := len(s.Counts); l != 2 { 74 | t.Errorf("SecondaryBalanced sampler found %v cases not 2", l) 75 | } 76 | } 77 | if case0 < 200 || case1 < 200 { 78 | t.Errorf("Cases 0 and 1 underprepresented after balanced sampeling from %T.", bs) 79 | } 80 | 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /sklearn_tree.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | // ScikitNode 4 | // cdef struct Node: 5 | // # Base storage structure for the nodes in a Tree object 6 | 7 | // SIZE_t left_child # id of the left child of the node 8 | // SIZE_t right_child # id of the right child of the node 9 | // SIZE_t feature # Feature used for splitting the node 10 | // DOUBLE_t threshold # Threshold value at the node 11 | // DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) 12 | // SIZE_t n_node_samples # Number of samples at the node 13 | // DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node 14 | 15 | type ScikitNode struct { 16 | LeftChild int `json:"left_child"` 17 | RightChild int `json:"right_child"` 18 | Feature int `json:"feature"` 19 | Threshold float64 `json:"threshold"` 20 | Impurity float64 `json:"impurity"` //TODO(ryan): support this? 21 | NNodeSamples int `json:"n_node_samples"` //TODO(ryan): support this? 22 | WeightedNNodeSamples float64 `json:"weighted_n_node_samples"` //TODO(ryan): support this? 23 | } 24 | 25 | // AnnotatedTree represents a decision tree in the memory format used by scikit learn. 26 | // cdef class Tree: 27 | // # The Tree object is a binary tree structure constructed by the 28 | // # TreeBuilder. The tree structure is used for predictions and 29 | // # feature importances. 30 | 31 | // # Input/Output layout 32 | // cdef public SIZE_t n_features # Number of features in X 33 | // cdef SIZE_t* n_classes # Number of classes in y[:, k] 34 | // cdef public SIZE_t n_outputs # Number of outputs in y 35 | // cdef public SIZE_t max_n_classes # max(n_classes) 36 | 37 | // # Inner structures: values are stored separately from node structure, 38 | // # since size is determined at runtime. 39 | // cdef public SIZE_t max_depth # Max depth of the tree 40 | // cdef public SIZE_t node_count # Counter for node IDs 41 | // cdef public SIZE_t capacity # Capacity of tree, in terms of nodes 42 | // cdef Node* nodes # Array of nodes 43 | // cdef double* value # (capacity, n_outputs, max_n_classes) array of values 44 | // cdef SIZE_t value_stride # = n_outputs * max_n_classes 45 | type ScikitTree struct { 46 | NFeatures int `json:"n_features"` 47 | NClasses []int `json:"n_classes"` 48 | NOutputs int `json:"n_outputs"` //TODO(ryan): support other values 49 | MaxNClasses int `json:"max_n_classes"` //TODO(ryan): support other values 50 | MaxDepth int `json:"max_depth"` 51 | NodeCount int `json:"node_count"` 52 | Capacity int `json:"capacity"` 53 | Nodes []ScikitNode `json:"nodes"` 54 | Value [][][]float64 `json:"value"` //TODO(ryan): support actual values 55 | ValueStride int `json:"value_stride"` 56 | } 57 | 58 | func NewScikitTree(nFeatures int) *ScikitTree { 59 | tree := &ScikitTree{ 60 | NFeatures: nFeatures, 61 | NClasses: []int{2}, 62 | NOutputs: 1, 63 | MaxNClasses: 2, 64 | MaxDepth: 0, 65 | NodeCount: 0, 66 | Capacity: 0, 67 | Nodes: make([]ScikitNode, 0), 68 | Value: make([][][]float64, 0), 69 | ValueStride: 0} 70 | 71 | return tree 72 | } 73 | 74 | // BuildScikkitTree currentelly only builds the split threshold and node structure of a sickit tree from a 75 | // Cloudforest tree specified by root node 76 | func BuildScikitTree(depth int, n *Node, sktree *ScikitTree) { 77 | if depth > sktree.MaxDepth { 78 | sktree.MaxDepth = depth 79 | } 80 | depth++ 81 | sktree.NodeCount++ 82 | sktree.Capacity++ 83 | skn := ScikitNode{} 84 | pos := len(sktree.Nodes) 85 | // We can't use a pointer here because the array will move and we're building this as an array 86 | // of structs for sklearn memory compatibility later so we use a pos. 87 | sktree.Nodes = append(sktree.Nodes, skn) 88 | if n.Splitter != nil { 89 | sktree.Nodes[pos].Feature = n.Featurei 90 | sktree.Nodes[pos].Threshold = n.Splitter.Value 91 | sktree.Nodes[pos].LeftChild = sktree.NodeCount 92 | BuildScikitTree(depth, n.Left, sktree) 93 | sktree.Nodes[pos].RightChild = sktree.NodeCount 94 | BuildScikitTree(depth, n.Right, sktree) 95 | 96 | } else { 97 | // Leaf node 98 | sktree.Nodes[pos].LeftChild = -1 99 | sktree.Nodes[pos].RightChild = -1 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /sortablefeature.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "github.com/lytics/CloudForest/sortby" 5 | ) 6 | 7 | /*SortableFeature is a wrapper for a feature and set of cases that satisfies the 8 | sort.Interface interface so that the case indexes in Cases can be sorted using 9 | sort.Sort 10 | */ 11 | type SortableFeature struct { 12 | //Feature NumFeature 13 | Vals []float64 14 | Cases []int 15 | } 16 | 17 | //Sort performs introsort + heapsort using the sortby sub package. 18 | func (sf *SortableFeature) Sort() { 19 | //n := len(sf.Cases) 20 | // maxd := 2 * int(math.Log(float64(n))) 21 | // sf.introsort(0, n, maxd) 22 | sortby.SortBy(&sf.Cases, &sf.Vals) 23 | //sf.heapsort(0, n) 24 | //sort.Sort(sf) 25 | } 26 | 27 | //Len returns the number of cases. 28 | func (sf *SortableFeature) Len() int { 29 | return len(sf.Cases) 30 | } 31 | 32 | //Less determines if the ith case is less then the jth case. 33 | func (sf *SortableFeature) Less(i int, j int) bool { 34 | v := sf.Vals 35 | return v[i] < v[j] 36 | 37 | } 38 | 39 | //Swap exchanges the ith and jth cases. 40 | func (sf *SortableFeature) Swap(i int, j int) { 41 | c := sf.Cases 42 | v := c[i] 43 | c[i] = c[j] 44 | c[j] = v 45 | vs := sf.Vals 46 | w := vs[i] 47 | vs[i] = vs[j] 48 | vs[j] = w 49 | 50 | } 51 | 52 | //Load loads the values of the cases into a cache friendly array. 53 | func (sf *SortableFeature) Load(vals *[]float64, cases *[]int) { 54 | sf.Cases = *cases 55 | sfvals := sf.Vals 56 | vs := *vals 57 | for i, p := range *cases { 58 | sfvals[i] = vs[p] 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /sortby/sortby.go: -------------------------------------------------------------------------------- 1 | /*Package sortby is a hybrid, non stable sort based on go's standard sort but with 2 | all less function and many swaps inlined to sort a list of ints by an acompanying list 3 | of floats as needed in random forest training. It is about 30-40% faster then the 4 | standard sort.*/ 5 | package sortby 6 | 7 | //SortBy will sort the values in cases and vals by the values in vals in increasing order. 8 | //If vals is longer then cases only the coresponding section will be sorted. 9 | func SortBy(cases *[]int, vals *[]float64) { 10 | n := len(*cases) 11 | // Switch to heapsort if depth of 2*ceil(lg(n+1)) is reached. 12 | maxDepth := 0 13 | for i := n; i > 0; i >>= 1 { 14 | maxDepth++ 15 | } 16 | maxDepth *= 2 17 | quickSort(cases, vals, 0, n, maxDepth) 18 | //introsort(cases, vals, 0, n, maxd) 19 | //heapsort(cases, vals, 0, n) 20 | } 21 | 22 | //Swap exchanges the ith and jth cases. 23 | func swap(cases *[]int, vals *[]float64, i int, j int) { 24 | //swap(cases, vals, 25 | c := *cases 26 | v := c[i] 27 | c[i] = c[j] 28 | c[j] = v 29 | vs := *vals 30 | w := vs[i] 31 | vs[i] = vs[j] 32 | vs[j] = w 33 | 34 | } 35 | 36 | func quickSort(cases *[]int, vals *[]float64, a, b, maxDepth int) { 37 | for b-a > 7 { 38 | if maxDepth == 0 { 39 | heapSort(cases, vals, a, b) 40 | return 41 | } 42 | maxDepth-- 43 | mlo, mhi := doPivot(cases, vals, a, b) 44 | // Avoiding recursion on the larger subproblem guarantees 45 | // a stack depth of at most lg(b-a). 46 | if mlo-a < b-mhi { 47 | quickSort(cases, vals, a, mlo, maxDepth) 48 | a = mhi // i.e., quickSort(cases, vals, mhi, b) 49 | } else { 50 | quickSort(cases, vals, mhi, b, maxDepth) 51 | b = mlo // i.e., quickSort(cases, vals, a, mlo) 52 | } 53 | } 54 | if b-a > 1 { 55 | insertionSort(cases, vals, a, b) 56 | } 57 | } 58 | 59 | func doPivot(cases *[]int, vals *[]float64, lo, hi int) (midlo, midhi int) { 60 | cs := *cases 61 | vs := *vals 62 | var swapi int 63 | var swapf float64 64 | m := lo + (hi-lo)/2 // Written like this to avoid integer overflow. 65 | if hi-lo > 40 { 66 | // Tukey's ``Ninther,'' median of three medians of three. 67 | s := (hi - lo) / 8 68 | medianOfThree(cases, vals, lo, lo+s, lo+2*s) 69 | medianOfThree(cases, vals, m, m-s, m+s) 70 | medianOfThree(cases, vals, hi-1, hi-1-s, hi-1-2*s) 71 | } 72 | medianOfThree(cases, vals, lo, m, hi-1) 73 | 74 | // Invariants are: 75 | // data[lo] = pivot (set up by ChoosePivot) 76 | // data[lo <= i < a] = pivot 77 | // data[a <= i < b] < pivot 78 | // data[b <= i < c] is unexamined 79 | // data[c <= i < d] > pivot 80 | // data[d <= i < hi] = pivot 81 | // 82 | // Once b meets c, can swap the "= pivot" sections 83 | // into the middle of the slice. 84 | pivotv := vs[lo] 85 | a, b, c, d := lo+1, lo+1, hi, hi 86 | for { 87 | for b < c { 88 | swapf = vs[b] 89 | if swapf < pivotv { 90 | b++ 91 | } else if pivotv == swapf { 92 | 93 | vs[b] = vs[a] 94 | vs[a] = swapf 95 | 96 | swapi = cs[a] 97 | cs[a] = cs[b] 98 | cs[b] = swapi 99 | 100 | a++ 101 | b++ 102 | } else { 103 | break 104 | } 105 | } 106 | for b < c { 107 | c-- 108 | swapf = vs[c] 109 | if pivotv < swapf { 110 | 111 | } else if swapf == pivotv { 112 | d-- 113 | vs[c] = vs[d] 114 | vs[d] = swapf 115 | 116 | swapi = cs[c] 117 | cs[c] = cs[d] 118 | cs[d] = swapi 119 | 120 | } else { 121 | c++ 122 | break 123 | } 124 | // if pivotv > swapf { 125 | // c++ 126 | // break 127 | // } else if pivotv >= swapf { 128 | // d-- 129 | // vs[c] = vs[d] 130 | // vs[d] = swapf 131 | 132 | // swapi = cs[c] 133 | // cs[c] = cs[d] 134 | // cs[d] = swapi 135 | 136 | // } 137 | } 138 | if b >= c { 139 | break 140 | } 141 | 142 | c-- 143 | 144 | swapf = vs[c] 145 | vs[c] = vs[b] 146 | vs[b] = swapf 147 | 148 | swapi = cs[c] 149 | cs[c] = cs[b] 150 | cs[b] = swapi 151 | b++ 152 | 153 | } 154 | 155 | n := min(b-a, a-lo) 156 | 157 | //swapRange(cases, vals, lo, b-n, n) 158 | a2 := lo 159 | b2 := b - n 160 | for i := 0; i < n; i++ { 161 | 162 | swapf = vs[a2] 163 | vs[a2] = vs[b2] 164 | vs[b2] = swapf 165 | 166 | swapi = cs[a2] 167 | cs[a2] = cs[b2] 168 | cs[b2] = swapi 169 | 170 | a2++ 171 | b2++ 172 | } 173 | 174 | n = min(hi-d, d-c) 175 | //swapRange(cases, vals, c, hi-n, n) 176 | a2 = c 177 | b2 = hi - n 178 | for i := 0; i < n; i++ { 179 | 180 | swapf = vs[a2] 181 | vs[a2] = vs[b2] 182 | vs[b2] = swapf 183 | 184 | swapi = cs[a2] 185 | cs[a2] = cs[b2] 186 | cs[b2] = swapi 187 | 188 | a2++ 189 | b2++ 190 | } 191 | 192 | return lo + b - a, hi - (d - c) 193 | } 194 | 195 | // medianOfThree moves the median of the three values data[a], data[b], data[c] into data[a]. 196 | func medianOfThree(cases *[]int, vals *[]float64, a, b, c int) { 197 | vs := *vals 198 | //cs := *cases 199 | m0 := b 200 | m1 := a 201 | m2 := c 202 | // bubble sort on 3 elements 203 | if vs[m1] < vs[m0] { 204 | swap(cases, vals, m1, m0) 205 | } 206 | if vs[m2] < vs[m1] { 207 | swap(cases, vals, m2, m1) 208 | } 209 | if vs[m1] < vs[m0] { 210 | swap(cases, vals, m1, m0) 211 | } 212 | // now data[m0] <= data[m1] <= data[m2] 213 | } 214 | 215 | func swapRange(cases *[]int, vals *[]float64, a, b, n int) { 216 | vs := *vals 217 | cs := *cases 218 | //var api, bpi = a, b 219 | var swapi int 220 | var swapf float64 221 | for i := 0; i < n; i++ { 222 | //swap(cases, vals, a, b+i) 223 | 224 | // vs[a+i], vs[b+i] = vs[b+i], vs[a+i] 225 | // cs[a+i], cs[b+i] = cs[b+i], cs[a+i] 226 | 227 | swapf = vs[a] 228 | vs[a] = vs[b] 229 | vs[b] = swapf 230 | 231 | swapi = cs[a] 232 | cs[a] = cs[b] 233 | cs[b] = swapi 234 | 235 | // vs[a], vs[b] = vs[b], vs[a] 236 | // cs[a], cs[b] = cs[b], cs[a] 237 | a++ 238 | b++ 239 | } 240 | } 241 | 242 | // Insertion sort 243 | func insertionSort(cases *[]int, vals *[]float64, a, b int) { 244 | vs := *vals 245 | //cs := *cases 246 | for i := a + 1; i < b; i++ { 247 | for j := i; j > a && vs[j] < vs[j-1]; j-- { 248 | swap(cases, vals, j, j-1) 249 | } 250 | } 251 | } 252 | 253 | // siftDown implements the heap property on data[lo, hi). 254 | // first is an offset into the array where the root of the heap lies. 255 | func siftDown(cases *[]int, vals *[]float64, lo, hi, first int) { 256 | vs := *vals 257 | root := lo 258 | for { 259 | child := 2*root + 1 260 | if child >= hi { 261 | break 262 | } 263 | if child+1 < hi && vs[first+child] < vs[first+child+1] { 264 | child++ 265 | } 266 | if vs[first+root] >= vs[first+child] { 267 | return 268 | } 269 | swap(cases, vals, first+root, first+child) 270 | root = child 271 | } 272 | } 273 | 274 | func heapSort(cases *[]int, vals *[]float64, a, b int) { 275 | first := a 276 | lo := 0 277 | hi := b - a 278 | 279 | // Build heap with greatest element at top. 280 | for i := (hi - 1) / 2; i >= 0; i-- { 281 | siftDown(cases, vals, i, hi, first) 282 | } 283 | 284 | // Pop elements, largest first, into end of data. 285 | for i := hi - 1; i >= 0; i-- { 286 | swap(cases, vals, first, first+i) 287 | siftDown(cases, vals, lo, i, first) 288 | } 289 | } 290 | 291 | func min(a, b int) int { 292 | if a < b { 293 | return a 294 | } 295 | return b 296 | } 297 | -------------------------------------------------------------------------------- /sortby/sortby_test.go: -------------------------------------------------------------------------------- 1 | package sortby 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | var cases = []int{10, 1, 6, 3, 4, 9, 8, 7, 2, 0, 5} 8 | var vals = []float64{1.0, 0.1, 0.6, 0.3, 0.4, 0.9, 0.8, 0.7, 0.2, 0.0, 0.5, -1.0} 9 | var binvals = []int{1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0} 10 | 11 | func TestShortSort(t *testing.T) { 12 | 13 | cases := []int{0, 4, 3, 1, 2} 14 | vals := []float64{0.1, 10.1, 10.2, 0, 0} 15 | 16 | SortBy(&cases, &vals) 17 | 18 | for i := 1; i < len(cases); i++ { 19 | 20 | if vals[i] < vals[i-1] { 21 | t.Errorf("Vals weren't sorted \n%v\n%v ", cases, vals) 22 | } 23 | } 24 | } 25 | 26 | func TestSortBy(t *testing.T) { 27 | 28 | if len(cases) != 11 || len(vals) != 12 { 29 | t.Errorf("Cases and vals had wrong length before sort %v and %v not 11 and 12", len(cases), len(vals)) 30 | } 31 | 32 | SortBy(&cases, &vals) 33 | 34 | if len(cases) != 11 || len(vals) != 12 { 35 | t.Errorf("Cases and vals had wrong length after sort %v and %v not 11 and 12", len(cases), len(vals)) 36 | } 37 | 38 | for i := 1; i < len(cases); i++ { 39 | 40 | if cases[i] < cases[i-1] { 41 | t.Errorf("Cases weren't sorted: \n%v\n%v ", cases, vals) 42 | } 43 | if vals[i] < vals[i-1] { 44 | t.Errorf("Vals weren't sorted \n%v\n%v ", cases, vals) 45 | } 46 | if int(10.0*vals[i]) != cases[i] { 47 | t.Errorf("Cases and val's don't match at pos %v : %v and %v", i, vals[i], cases[i]) 48 | } 49 | } 50 | 51 | if vals[11] != -1.0 { 52 | t.Error("Value in vals beyond the range of cases was not left untouched.") 53 | } 54 | 55 | SortBy(&cases, &vals) 56 | 57 | for i := 1; i < len(cases); i++ { 58 | 59 | if vals[i] < vals[i-1] { 60 | t.Errorf("Vals weren't sorted \n%v\n%v ", cases, vals) 61 | } 62 | } 63 | 64 | } 65 | 66 | // func TestHeapSort(t *testing.T) { 67 | // heapsort(&cases, &vals, 0, 11) 68 | 69 | // if len(cases) != 11 || len(vals) != 12 { 70 | // t.Errorf("Cases and vals had wrong length after sort %v and %v not 11 and 12", len(cases), len(vals)) 71 | // } 72 | 73 | // for i := 1; i < len(cases); i++ { 74 | 75 | // if cases[i] < cases[i-1] { 76 | // t.Errorf("Cases weren't sorted: \n%v\n%v ", cases, vals) 77 | // } 78 | // if vals[i] < vals[i-1] { 79 | // t.Errorf("Vals weren't sorted \n%v\n%v ", cases, vals) 80 | // } 81 | // if int(10.0*vals[i]) != cases[i] { 82 | // t.Errorf("Cases and val's don't match at pos %v : %v and %v", i, vals[i], cases[i]) 83 | // } 84 | // } 85 | 86 | // if vals[11] != -1.0 { 87 | // t.Error("Value in vals beyond the range of cases was not left untouched.") 88 | // } 89 | // } 90 | 91 | // func TestQuickSort(t *testing.T) { 92 | // introsort(&cases, &vals, 0, 11, 11) 93 | 94 | // if len(cases) != 11 || len(vals) != 12 { 95 | // t.Errorf("Cases and vals had wrong length after sort %v and %v not 11 and 12", len(cases), len(vals)) 96 | // } 97 | 98 | // for i := 1; i < len(cases); i++ { 99 | 100 | // if cases[i] < cases[i-1] { 101 | // t.Errorf("Cases weren't sorted: \n%v\n%v ", cases, vals) 102 | // } 103 | // if vals[i] < vals[i-1] { 104 | // t.Errorf("Vals weren't sorted \n%v\n%v ", cases, vals) 105 | // } 106 | // if int(10.0*vals[i]) != cases[i] { 107 | // t.Errorf("Cases and val's don't match at pos %v : %v and %v", i, vals[i], cases[i]) 108 | // } 109 | // } 110 | 111 | // if vals[11] != -1.0 { 112 | // t.Error("Value in vals beyond the range of cases was not left untouched.") 113 | // } 114 | // } 115 | 116 | // func TestIntroSort(t *testing.T) { 117 | 118 | // introsort(&cases, &vals, 0, 11, 2) 119 | 120 | // if len(cases) != 11 || len(vals) != 12 { 121 | // t.Errorf("Cases and vals had wrong length after sort %v and %v not 11 and 12", len(cases), len(vals)) 122 | // } 123 | 124 | // for i := 1; i < len(cases); i++ { 125 | 126 | // if cases[i] < cases[i-1] { 127 | // t.Errorf("Cases weren't sorted: \n%v\n%v ", cases, vals) 128 | // } 129 | // if vals[i] < vals[i-1] { 130 | // t.Errorf("Vals weren't sorted \n%v\n%v ", cases, vals) 131 | // } 132 | // if int(10.0*vals[i]) != cases[i] { 133 | // t.Errorf("Cases and val's don't match at pos %v : %v and %v", i, vals[i], cases[i]) 134 | // } 135 | // } 136 | 137 | // if vals[11] != -1.0 { 138 | // t.Error("Value in vals beyond the range of cases was not left untouched.") 139 | // } 140 | // } 141 | -------------------------------------------------------------------------------- /splitallocations.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | //BestSplitAllocs contains reusable allocations for split searching and evaluation. 4 | //Seprate instances should be used in each go routing doing learning. 5 | type BestSplitAllocs struct { 6 | L []int //Allocated to size 7 | R []int 8 | M []int 9 | LM []int //Used to point at other array 10 | RM []int 11 | MM []int 12 | Left *[]int //left cases for potential splits 13 | Right *[]int //right cases for potential splits 14 | NonMissing *[]int //non missing cases for potential splits 15 | Counter *[]int //class counter for counting classes in splits used alone of for missing 16 | LCounter *[]int //left class counter sumarizing (mean) splits 17 | RCounter *[]int //right class counter sumarizing (mean) splits 18 | Lsum float64 //left value for sumarizing splits 19 | Rsum float64 //right value for sumarizing splits 20 | Msum float64 //missing value for sumarizing splits 21 | Lsum_sqr float64 //left value for sumarizing splits 22 | Rsum_sqr float64 //right value for sumarizing splits 23 | Msum_sqr float64 //missing value for sumarizing splits 24 | CatVals []int 25 | SortVals []float64 26 | Sorter *SortableFeature //for learning from numerical features 27 | ContrastTarget Target 28 | } 29 | 30 | //NewBestSplitAllocs initializes all of the reusable allocations for split 31 | //searching to the appropriate size. nTotalCases should be number of total 32 | //cases in the feature matrix being analyzed. 33 | func NewBestSplitAllocs(nTotalCases int, target Target) (bsa *BestSplitAllocs) { 34 | left := make([]int, 0, nTotalCases) 35 | right := make([]int, 0, nTotalCases) 36 | nonmissing := make([]int, 0, nTotalCases) 37 | counter := make([]int, target.NCats()) 38 | lcounter := make([]int, target.NCats()) 39 | rcounter := make([]int, target.NCats()) 40 | bsa = &BestSplitAllocs{make([]int, 0, nTotalCases), 41 | make([]int, 0, nTotalCases), 42 | make([]int, 0, nTotalCases), 43 | nil, 44 | nil, 45 | nil, 46 | // make([]int, 0, nTotalCases), 47 | // make([]int, nTotalCases, nTotalCases), 48 | &left, 49 | &right, 50 | &nonmissing, 51 | &counter, 52 | &lcounter, 53 | &rcounter, 54 | 0.0, 55 | 0.0, 56 | 0.0, 57 | 0.0, 58 | 0.0, 59 | 0.0, 60 | make([]int, nTotalCases, nTotalCases), 61 | make([]float64, nTotalCases, nTotalCases), 62 | &SortableFeature{make([]float64, nTotalCases, nTotalCases), 63 | nil}, 64 | target.(Feature).Copy().(Target)} 65 | return 66 | } 67 | -------------------------------------------------------------------------------- /splitter.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | //"fmt" 4 | 5 | //Splitter contains fields that can be used to cases by a single feature. The split 6 | //can be either numerical in which case it is defined by the Value field or 7 | //categorical in which case it is defined by the Left and Right fields. 8 | type Splitter struct { 9 | Feature string 10 | Numerical bool 11 | Value float64 12 | Left map[string]bool 13 | } 14 | 15 | func (c *Splitter) Copy() *Splitter { 16 | if c == nil { 17 | return nil 18 | } 19 | 20 | lft := make(map[string]bool) 21 | if c.Left != nil { 22 | for s, b := range c.Left { 23 | lft[s] = b 24 | } 25 | } 26 | 27 | return &Splitter{ 28 | Feature: c.Feature, 29 | Numerical: c.Numerical, 30 | Value: c.Value, 31 | Left: lft, 32 | } 33 | } 34 | 35 | //func 36 | 37 | /* 38 | Split splits a slice of cases into left, right and missing slices without allocating 39 | a new underlying array by sorting cases into left, missing, right order and returning 40 | slices that point to the left and right cases. 41 | */ 42 | func (s *Splitter) Split(fm *FeatureMatrix, cases []int) (l []int, r []int, m []int) { 43 | length := len(cases) 44 | 45 | lastleft := -1 46 | lastright := length 47 | swaper := 0 48 | 49 | f := fm.Data[fm.Map[s.Feature]] 50 | 51 | //Move left cases to the start and right cases to the end so that missing cases end up 52 | //in between. 53 | hasmissing := f.MissingVals() 54 | for i := 0; i < lastright; i++ { 55 | if hasmissing && f.IsMissing(cases[i]) { 56 | continue 57 | } 58 | if f.GoesLeft(cases[i], s) { 59 | lastleft++ 60 | if i != lastleft { 61 | 62 | swaper = cases[i] 63 | cases[i] = cases[lastleft] 64 | cases[lastleft] = swaper 65 | i-- 66 | 67 | } 68 | 69 | } else { 70 | //Right 71 | lastright-- 72 | swaper = cases[i] 73 | cases[i] = cases[lastright] 74 | cases[lastright] = swaper 75 | i-- 76 | 77 | } 78 | 79 | } 80 | //fmt.Println(cases, lastleft, lastright) 81 | l = cases[:lastleft+1] 82 | r = cases[lastright:] 83 | m = cases[lastleft+1 : lastright] 84 | 85 | return 86 | } 87 | -------------------------------------------------------------------------------- /stats/stats.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package stats currentelly only implements a welch's t-test for importance score analysis 3 | in CloudForest. 4 | */ 5 | package stats 6 | 7 | import ( 8 | "math" 9 | ) 10 | 11 | //MeanAndVar returns the sample mean, variance and count as float64's. 12 | func MeanAndVar(X *[]float64) (m, v, n float64) { 13 | for _, x := range *X { 14 | m += x 15 | v += x * x 16 | } 17 | n = float64(len(*X)) 18 | m /= n 19 | v -= n * m * m 20 | v /= (n - 1.0) 21 | return 22 | } 23 | 24 | //Ttest performs a Welch's t test and returns p-value, t-value and the degrees of freedom. 25 | //The p value based on the hypothesis that mean(B)>mean(A). 26 | //Based on similar code in rf-ace (Apache 2.0, Timo Erkkilä) 27 | func Ttest(A, B *[]float64) (p, t, v, am float64) { 28 | 29 | // Calculate means and variances for each of two samples. 30 | Am, Av, An := MeanAndVar(A) 31 | Bm, Bv, Bn := MeanAndVar(B) 32 | am = Am 33 | 34 | //Welch's t test 35 | As := Av / An 36 | Bs := Bv / Bn 37 | s := As + Bs 38 | t = (Am - Bm) / math.Sqrt(s) 39 | 40 | // Degree's Freedom for Welch's 41 | v = s * s / (As*As/(An-1) + Bs*Bs/(Bn-1)) 42 | 43 | // Find the tail probability of t 44 | 45 | // Transformed t-test statistic 46 | ttrans := v / (t*t + v) 47 | 48 | // This variable will store the integral of the tail of the t-distribution 49 | integral := 0.0 50 | 51 | // Comment from rf-ace: 52 | // When ttrans > 0.9, we need to recast the integration in order to retain 53 | // accuracy. In other words we make use of the following identity: 54 | // 55 | // I(x,a,b) = 1 - I(1-x,b,a) 56 | if ttrans > 0.9 { 57 | // Calculate I(x,a,b) as 1 - I(1-x,b,a) 58 | integral = 1 - regularizedIncompleteBeta(1-ttrans, 0.5, v/2) 59 | 60 | } else { 61 | // Calculate I(x,a,b) directly 62 | integral = regularizedIncompleteBeta(ttrans, v/2, 0.5) 63 | } 64 | 65 | // Comment from rf-ace: 66 | // We need to be careful about which way to calculate the integral so that it represents 67 | // the tail of the t-distribution. The sign of the tvalue hints which way to integrate 68 | if t > 0.0 { 69 | p = (integral / 2) 70 | } else { 71 | p = (1 - integral/2) 72 | } 73 | return 74 | } 75 | 76 | //Based on similar code in rf-ace (Apache 2.0, Timo Erkkilä) 77 | func regularizedIncompleteBeta(x, a, b float64) float64 { 78 | i := 50 79 | continuedFraction := 1.0 80 | m := 0.0 81 | 82 | for i > 0 { 83 | m = float64(i) 84 | continuedFraction = 1.0 + dE(m, x, a, b)/(1+dO(m, x, a, b)/continuedFraction) 85 | i-- 86 | } 87 | return (math.Pow(x, a) * math.Pow(1-x, b) / (a * beta(a, b) * (1 + dO(0, x, a, b)/continuedFraction))) 88 | } 89 | 90 | /* 91 | Even and odd factors for the infinite continued fraction representation of the 92 | regularized incomplete beta function. 93 | Based on similar code in rf-ace (Apache 2.0, Timo Erkkilä) 94 | */ 95 | func dO(m, x, a, b float64) float64 { 96 | return (-1.0 * (a + m) * (a + b + m) * x / ((a + 2*m) * (a + 2*m + 1))) 97 | } 98 | 99 | func dE(m, x, a, b float64) float64 { 100 | return (m * (b - m) * x / ((a + 2*m - 1) * (a + 2*m))) 101 | } 102 | 103 | //Based on similar code in rf-ace (Apache 2.0, Timo Erkkilä) 104 | func lgamma(x float64) float64 { 105 | v, _ := math.Lgamma(x) 106 | //v := math.Log(math.Abs(math.Gamma(x))) 107 | return v 108 | } 109 | 110 | //Based on similar code in rf-ace (Apache 2.0, Timo Erkkilä) 111 | func beta(a, b float64) float64 { 112 | return (math.Exp(lgamma(a) + lgamma(b) - lgamma(a+b))) 113 | } 114 | -------------------------------------------------------------------------------- /stats/welchst_test.go: -------------------------------------------------------------------------------- 1 | package stats 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | ) 7 | 8 | func notE(a, b float64) bool { 9 | return math.Abs(a-b) > 0.001 10 | } 11 | 12 | func TestTTest(t *testing.T) { 13 | /* Simpel Test Case generated with R: 14 | 15 | > x = rnorm(10) 16 | > y = rnorm(10) 17 | > x 18 | [1] -1.96987304 0.51258439 -0.98814832 -1.04462895 0.04199386 -0.74186695 19 | [7] -1.76605177 -1.08967410 0.90011966 -0.49636826 20 | > y 21 | [1] -0.09087432 0.35026448 0.89435080 -1.40248504 -1.14944188 0.23536083 22 | [7] -0.45775375 0.24868155 -1.18380814 1.70410704 23 | > y 24 | [1] -0.09087432 0.35026448 0.89435080 -1.40248504 -1.14944188 0.23536083 25 | [7] -0.45775375 0.24868155 -1.18380814 1.70410704 26 | > t.test(x,y,alternative="greater") 27 | 28 | Welch Two Sample t-test 29 | 30 | data: x and y 31 | t = -1.3526, df = 17.925, p-value = 0.9035 32 | alternative hypothesis: true difference in means is greater than 0 33 | 95 percent confidence interval: 34 | -1.321523 Inf 35 | sample estimates: 36 | mean of x mean of y 37 | -0.66419135 -0.08515984 38 | 39 | > mean(x) 40 | [1] -0.6641913 41 | > var(x) 42 | [1] 0.8571537 43 | 44 | > mean(y) 45 | [1] -0.08515984 46 | > var(y) 47 | [1] 0.9754027 48 | > 49 | 50 | > */ 51 | 52 | x := []float64{-1.96987304, 0.51258439, -0.98814832, -1.04462895, 0.04199386, -0.74186695, -1.76605177, -1.08967410, 0.90011966, -0.49636826} 53 | y := []float64{-0.09087432, 0.35026448, 0.89435080, -1.40248504, -1.14944188, 0.23536083, -0.45775375, 0.24868155, -1.18380814, 1.70410704} 54 | mean, v, n := MeanAndVar(&x) 55 | if notE(mean, -0.6641913) || notE(v, 0.8571537) || n != 10 { 56 | t.Errorf("Bad MeanAndVarResults %v, %v, %v. not close to --0.6641913, 0.8571537, 10", mean, v, n) 57 | } 58 | 59 | p, tv, df, _ := Ttest(&x, &y) 60 | if notE(p, 0.9035) { 61 | t.Errorf("Bad p value from TTest. %v not close to 0.9035", p) 62 | } 63 | 64 | if notE(tv, -1.3526) { 65 | t.Errorf("Bad t value TTest. %v not close to -1.3526", tv) 66 | } 67 | 68 | if notE(df, 17.925) { 69 | t.Errorf("Bad degrees freedom from TTest. %v not close to 17.925", df) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /sumballotbox.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "sync" 7 | ) 8 | 9 | //SumBallot is used insideof SumBallotBox to record sum votes in a thread safe 10 | //manner. 11 | type SumBallot struct { 12 | Mutex sync.Mutex 13 | Sum float64 14 | } 15 | 16 | //NewSumBallot returns a pointer to an initalized SumBallot with a 0 size Map. 17 | func NewSumBallot() (cb *SumBallot) { 18 | cb = new(SumBallot) 19 | cb.Sum = 0.0 20 | return 21 | } 22 | 23 | //SumBallotBox keeps track of votes by trees in a thread safe manner. 24 | //It should be used with gradient boosting when a sum instead of an average 25 | //or mode is desired. 26 | type SumBallotBox struct { 27 | Box []*SumBallot 28 | } 29 | 30 | //NewSumBallotBox builds a new ballot box for the number of cases specified by "size". 31 | func NewSumBallotBox(size int) *SumBallotBox { 32 | bb := SumBallotBox{ 33 | make([]*SumBallot, 0, size)} 34 | for i := 0; i < size; i++ { 35 | bb.Box = append(bb.Box, NewSumBallot()) 36 | } 37 | return &bb 38 | } 39 | 40 | //Vote registers a vote that case "casei" should have pred added to its 41 | //sum. 42 | func (bb *SumBallotBox) Vote(casei int, pred string, weight float64) { 43 | v, err := strconv.ParseFloat(pred, 64) 44 | if err == nil { 45 | 46 | bb.Box[casei].Mutex.Lock() 47 | bb.Box[casei].Sum += v * weight 48 | bb.Box[casei].Mutex.Unlock() 49 | } 50 | } 51 | 52 | //Tally tallies the votes for the case specified by i as 53 | //if it is a Categorical or boolean feature. Ie it returns the sum 54 | //of all votes. 55 | func (bb *SumBallotBox) Tally(i int) (predicted string) { 56 | predicted = "NA" 57 | predicted = fmt.Sprintf("%v", bb.TallyNum(i)) 58 | 59 | return 60 | 61 | } 62 | 63 | func (bb *SumBallotBox) TallyNum(i int) (predicted float64) { 64 | bb.Box[i].Mutex.Lock() 65 | predicted = bb.Box[i].Sum 66 | bb.Box[i].Mutex.Unlock() 67 | 68 | return 69 | 70 | } 71 | 72 | /* 73 | TallyError is non functional here. 74 | */ 75 | func (bb *SumBallotBox) TallyError(feature Feature) (e float64) { 76 | 77 | return 1.0 78 | 79 | } 80 | -------------------------------------------------------------------------------- /transduction.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | /* 4 | TransTarget is used for semi supervised transduction trees [1] that balance compine supervised impurity with 5 | a purelly density based term. 6 | 7 | I = I_supervised + alpha * I_unsupervised 8 | 9 | I_supervised is called from the embeded CatFeature so that it can be Gini, Entropy, Weighted or any other 10 | of the existing non-boosting impurities. Boosting impurities could be implemented with minimal work. 11 | 12 | I_unsupervised uses a density estimating term that differs from the one described in [1] and is instead 13 | based on the technique described in [2] which avoids some assumptions and allows a simple implementation. 14 | 15 | [1] A. Criminisi, J. Shotton, and E. Konukoglu, "Decision Forests for Classification, Regression, 16 | Density Estimation, Manifold Learning and Semi-Supervised Learning" 17 | Microsoft Research technical report TR-2011-114 18 | 19 | [2] Parikshit Ram, Alexander G. Gray, Density Estimation Trees 20 | http://research.microsoft.com/pubs/155552/decisionForests_MSR_TR_2011_114.pdf 21 | 22 | One diffrence from [1] is that the unlabelled class is considered a standard class for I_supervised 23 | to allow once class problems. 24 | */ 25 | type TransTarget struct { 26 | CatFeature 27 | Features *[]Feature 28 | Unlabeled int 29 | Alpha float64 30 | Beta float64 31 | N int 32 | MaxCats int 33 | } 34 | 35 | /*NewTransTarget returns a TransTarget using the supervised Impurity from the provided CatFeature t, 36 | Density in the specified Features fm (excluding any with the same name as t), considering the class label 37 | provided in "unlabeled" as unlabeled for transduction. Alpha is the weight of the unspervised term relative to 38 | the supervised and ncases is the number of cases that will be called at the root of the tree (may be depreciated as not needed). 39 | */ 40 | func NewTransTarget(t CatFeature, fm *[]Feature, unlabeled string, alpha, beta float64, ncases int) *TransTarget { 41 | maxcats := 0 42 | for _, f := range *fm { 43 | if f.NCats() > maxcats { 44 | maxcats = f.NCats() 45 | } 46 | } 47 | 48 | return &TransTarget{t, fm, t.CatToNum(unlabeled), alpha, beta, ncases, maxcats} 49 | 50 | } 51 | 52 | /* 53 | TransTarget.SplitImpurity is a density estimating version of SplitImpurity. 54 | */ 55 | func (target *TransTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 56 | if target.Alpha == 0.0 { 57 | impurityDecrease = target.CatFeature.SplitImpurity(l, r, m, allocs) 58 | } else { 59 | nl := float64(len(*l)) 60 | nr := float64(len(*r)) 61 | nm := 0.0 62 | 63 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 64 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 65 | if m != nil && len(*m) > 0 { 66 | nm = float64(len(*m)) 67 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 68 | } 69 | 70 | impurityDecrease /= nl + nr + nm 71 | } 72 | return 73 | } 74 | 75 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables. 76 | //Here it just wraps SplitImpurity but it can be implemented to provide further optimization. 77 | func (target *TransTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 78 | return target.SplitImpurity(l, r, m, allocs) 79 | } 80 | 81 | func (target *TransTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 82 | //TODO: filter out unlabeled cases from the call to target.CatFeature.Impurity at least for 83 | //multiclass problems 84 | if target.Alpha == 0.0 { 85 | e = target.CatFeature.Impurity(cases, counter) 86 | } else { 87 | e = target.CatFeature.Impurity(cases, counter) + target.Alpha*target.Density(cases, counter) 88 | } 89 | return 90 | } 91 | 92 | /*TransTarget.Density uses an impurity designed to maximize the density within each side of the split 93 | based on the method in "Density Estimating Trees" by Parikshit Ram and Alexander G. Gray. 94 | It loops over all of the non target features and for the ones with non zero span calculates product(span_i)/(t*t) 95 | where t is the number of cases. 96 | 97 | Refinements to this method might include t*t->t^n where n is the number of features with 98 | non zero span or other changes to how zero span features are handeled. I also suspect that this method 99 | handles numerical features for which diffrent splits will have diffrent total spans based on the 100 | distance between the points on either side of the split point better then categorical 101 | features for which the total span of a split will allways be the number of categories. 102 | 103 | The origional paper also included N which is not used here.*/ 104 | func (target *TransTarget) Density(cases *[]int, counter *[]int) (e float64) { 105 | t := len(*cases) 106 | //e = float64(t*t) / float64(target.N*target.N) 107 | e = 1 / float64(t*t) // float64(target.N*target.N) 108 | span := 0.0 109 | bigenoughcounter := make([]int, target.MaxCats, target.MaxCats) 110 | targetname := target.GetName() 111 | 112 | for _, f := range *target.Features { 113 | if f.GetName() != targetname { 114 | 115 | span = f.Span(cases, &bigenoughcounter) 116 | 117 | if span > 0.0 { 118 | e *= span 119 | } 120 | 121 | ncats := f.NCats() 122 | for i := 0; i < ncats; i++ { 123 | bigenoughcounter[i] = 0 124 | } 125 | 126 | } 127 | } 128 | 129 | return 130 | } 131 | 132 | //TransTarget.FindPredicted returns the prediction of the specified cases which is the majority 133 | //class that is not the unlabeled class. A set of cases will only be predicted to be the ulabeled 134 | //class if has no labeled points. 135 | func (target *TransTarget) FindPredicted(cases []int) string { 136 | counts := make([]int, target.NCats()) 137 | for _, i := range cases { 138 | 139 | counts[target.Geti(i)] += 1 140 | 141 | } 142 | max := 0.0 143 | vf := 0.0 144 | m := target.Unlabeled 145 | for k, v := range counts { 146 | if k == target.Unlabeled { 147 | vf = target.Beta * float64(v) 148 | } else { 149 | vf = float64(v) 150 | } 151 | if vf > max { 152 | m = k 153 | max = vf 154 | } 155 | } 156 | 157 | // if counts[target.Unlabeled] > 10*max { 158 | // m = target.Unlabeled 159 | // } 160 | 161 | return target.NumToCat(m) 162 | } 163 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "math" 8 | "strconv" 9 | "strings" 10 | "sync" 11 | ) 12 | 13 | func ParseFloat(s string) float64 { 14 | frac, _ := strconv.ParseFloat(s, 64) 15 | return frac 16 | 17 | } 18 | 19 | //RunningMean is a thread safe strut for keeping track of running means as used in 20 | //importance calculations. (TODO: could this be made lock free?) 21 | type RunningMean struct { 22 | mutex sync.Mutex 23 | Mean float64 24 | Count float64 25 | } 26 | 27 | //Add add's 1.0 to the running mean in a thread safe way. 28 | func (rm *RunningMean) Add(val float64) { 29 | rm.WeightedAdd(val, 1.0) 30 | } 31 | 32 | //WeightedAdd add's the specified value to the running mean in a thread safe way. 33 | func (rm *RunningMean) WeightedAdd(val float64, weight float64) { 34 | if !math.IsNaN(val) && !math.IsNaN(weight) { 35 | rm.mutex.Lock() 36 | rm.Mean = (rm.Mean*rm.Count + weight*val) / (rm.Count + weight) 37 | rm.Count += weight 38 | if rm.Count == 0 { 39 | log.Print("WeightedAdd reached 0 count!.") 40 | } 41 | if math.IsNaN(rm.Mean) || math.IsNaN(rm.Count) { 42 | log.Print("Weighted add reached nan after adding ", val, weight) 43 | } 44 | 45 | rm.mutex.Unlock() 46 | } 47 | 48 | } 49 | 50 | //Read reads the mean and count 51 | func (rm *RunningMean) Read() (mean float64, count float64) { 52 | rm.mutex.Lock() 53 | mean = rm.Mean 54 | count = rm.Count 55 | rm.mutex.Unlock() 56 | return 57 | } 58 | 59 | //NewRunningMeans returns an initalized *[]*RunningMean. 60 | func NewRunningMeans(size int) *[]*RunningMean { 61 | importance := make([]*RunningMean, 0, size) 62 | for i := 0; i < size; i++ { 63 | rm := new(RunningMean) 64 | importance = append(importance, rm) 65 | } 66 | return &importance 67 | 68 | } 69 | 70 | //SparseCounter uses maps to track sparse integer counts in large matrix. 71 | //The matrix is assumed to contain zero values where nothing has been added. 72 | type SparseCounter struct { 73 | Map map[int]map[int]int 74 | mutex sync.Mutex 75 | } 76 | 77 | //Add increases the count in i,j by val. 78 | func (sc *SparseCounter) Add(i int, j int, val int) { 79 | sc.mutex.Lock() 80 | defer sc.mutex.Unlock() 81 | if sc.Map == nil { 82 | sc.Map = make(map[int]map[int]int, 0) 83 | } 84 | 85 | if v, ok := sc.Map[i]; !ok || v == nil { 86 | sc.Map[i] = make(map[int]int, 0) 87 | } 88 | if _, ok := sc.Map[i][j]; !ok { 89 | sc.Map[i][j] = 0 90 | } 91 | sc.Map[i][j] = sc.Map[i][j] + val 92 | 93 | } 94 | 95 | //WriteTsv writes the non zero counts out into a three column tsv containing i, j, and 96 | //count in the columns. 97 | func (sc *SparseCounter) WriteTsv(writer io.Writer) { 98 | sc.mutex.Lock() 99 | defer sc.mutex.Unlock() 100 | for i := range sc.Map { 101 | for j, val := range sc.Map[i] { 102 | if _, err := fmt.Fprintf(writer, "%v\t%v\t%v\n", i, j, val); err != nil { 103 | log.Println(err) 104 | return 105 | } 106 | } 107 | } 108 | 109 | } 110 | 111 | /* 112 | ParseAsIntOrFractionOfTotal parses strings that may specify an count or a percent of 113 | the total for use in specifying paramaters. 114 | It parses term as a float if it contains a "." and as an int otherwise. If term is parsed 115 | as a float frac it returns int(math.Ceil(frac * float64(total))). 116 | It returns zero if term == "" or if a parsing error occures. 117 | */ 118 | func ParseAsIntOrFractionOfTotal(term string, total int) (parsed int) { 119 | if term == "" { 120 | return 0 121 | } 122 | 123 | if strings.Contains(term, ".") { 124 | frac, err := strconv.ParseFloat(term, 64) 125 | if err == nil { 126 | parsed = int(math.Ceil(frac * float64(total))) 127 | } else { 128 | parsed = 0 129 | } 130 | } else { 131 | count, err := strconv.ParseInt(term, 0, 0) 132 | if err != nil { 133 | parsed = 0 134 | } else { 135 | parsed = int(count) 136 | } 137 | 138 | } 139 | return 140 | } 141 | -------------------------------------------------------------------------------- /utils/nfold/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/csv" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "log" 9 | "os" 10 | 11 | "github.com/lytics/CloudForest" 12 | ) 13 | 14 | func openfiles(trainfn string, testfn string) (trainW io.WriteCloser, testW io.WriteCloser) { 15 | 16 | trainfo, err := os.Create(trainfn) 17 | if err != nil { 18 | log.Fatal(err) 19 | } 20 | testfo, err := os.Create(testfn) 21 | if err != nil { 22 | log.Fatal(err) 23 | } 24 | trainW = trainfo 25 | testW = testfo 26 | // if zipoutput { 27 | // trainz := zip.NewWriter(trainfo) 28 | // trainW, err = trainz.Create(trainfn) 29 | // if err != nil { 30 | // log.Fatal(err) 31 | // } 32 | // //defer trainz.Close() 33 | // testz := zip.NewWriter(testfo) 34 | // testW, err = testz.Create(testfn) 35 | // if err != nil { 36 | // log.Fatal(err) 37 | // } 38 | // //defer testz.Close() 39 | // } 40 | 41 | return 42 | 43 | } 44 | 45 | func main() { 46 | fm := flag.String("fm", 47 | "featurematrix.afm", "AFM formated feature matrix containing data.") 48 | 49 | blacklist := flag.String("blacklist", 50 | "", "A list of feature id's to exclude from the set of predictors.") 51 | 52 | targetname := flag.String("target", 53 | "", "The row header of the target in the feature matrix.") 54 | train := flag.String("train", 55 | "train_%v.fm", "Format string for training fms.") 56 | test := flag.String("test", 57 | "test_%v.fm", "Format string for testing fms.") 58 | 59 | // var zipoutput bool 60 | // flag.BoolVar(&zipoutput, "zip", false, "Output ziped files.") 61 | var unstratified bool 62 | flag.BoolVar(&unstratified, "unstratified", false, "Force unstratified sampeling of categorical target.") 63 | 64 | var writelibsvm bool 65 | flag.BoolVar(&writelibsvm, "writelibsvm", false, "Output libsvm.") 66 | 67 | var writearff bool 68 | flag.BoolVar(&writearff, "writearff", false, "Output arff.") 69 | 70 | var writeall bool 71 | flag.BoolVar(&writeall, "writeall", false, "Output all three formats.") 72 | 73 | var folds int 74 | flag.IntVar(&folds, "folds", 5, "Number of folds to generate.") 75 | 76 | var maxcats int 77 | flag.IntVar(&maxcats, "maxcats", -1, "Maximum number of categories to allow in a feature.") 78 | 79 | var impute bool 80 | flag.BoolVar(&impute, "impute", false, "Impute missing values to feature mean/mode.") 81 | 82 | var onehot bool 83 | flag.BoolVar(&onehot, "onehot", false, "Do one hot encoding of categorical features to boolean true false.") 84 | 85 | var num bool 86 | flag.BoolVar(&num, "num", false, "Do one hot encoding of categorical features to numerical features.") 87 | 88 | flag.Parse() 89 | 90 | //Parse Data 91 | data, err := CloudForest.LoadAFM(*fm) 92 | if err != nil { 93 | log.Fatal(err) 94 | } 95 | 96 | blacklisted := 0 97 | blacklistis := make([]bool, len(data.Data)) 98 | if *blacklist != "" { 99 | fmt.Printf("Loading blacklist from: %v\n", *blacklist) 100 | blackfile, err := os.Open(*blacklist) 101 | if err != nil { 102 | log.Fatal(err) 103 | } 104 | tsv := csv.NewReader(blackfile) 105 | tsv.Comma = '\t' 106 | for { 107 | id, err := tsv.Read() 108 | if err == io.EOF { 109 | break 110 | } else if err != nil { 111 | log.Fatal(err) 112 | } 113 | if id[0] == *targetname { 114 | continue 115 | } 116 | i, ok := data.Map[id[0]] 117 | if !ok { 118 | fmt.Printf("Ignoring blacklist feature not found in data: %v\n", id[0]) 119 | continue 120 | } 121 | if !blacklistis[i] { 122 | blacklisted += 1 123 | blacklistis[i] = true 124 | } 125 | 126 | } 127 | blackfile.Close() 128 | 129 | } 130 | 131 | newdata := make([]CloudForest.Feature, 0, len(data.Data)-blacklisted) 132 | newmap := make(map[string]int, len(data.Data)-blacklisted) 133 | 134 | for i, f := range data.Data { 135 | if !blacklistis[i] && (maxcats == -1 || f.NCats() <= maxcats) { 136 | newmap[f.GetName()] = len(newdata) 137 | newdata = append(newdata, f) 138 | } 139 | } 140 | 141 | data.Data = newdata 142 | data.Map = newmap 143 | 144 | if impute { 145 | fmt.Println("Imputing missing values to feature mean/mode.") 146 | data.ImputeMissing() 147 | } 148 | 149 | if onehot { 150 | fmt.Println("OneHot encoding.") 151 | data.OneHot() 152 | } 153 | 154 | if num { 155 | fmt.Println("Numerical OneHot encoding.") 156 | data = data.EncodeToNum() 157 | } 158 | 159 | foldis := make([][]int, 0, folds) 160 | 161 | foldsize := len(data.CaseLabels) / folds 162 | fmt.Printf("%v cases, foldsize %v\n", len(data.CaseLabels), foldsize) 163 | for i := 0; i < folds; i++ { 164 | foldis = append(foldis, make([]int, 0, foldsize)) 165 | } 166 | 167 | var targetf CloudForest.Feature 168 | 169 | //find the target feature 170 | fmt.Printf("Target : %v\n", *targetname) 171 | targeti, ok := data.Map[*targetname] 172 | if !ok { 173 | fmt.Println("Target not found in data, doing unstratified sampeling.") 174 | unstratified = true 175 | } 176 | 177 | if ok { 178 | targetf = data.Data[targeti] 179 | 180 | switch targetf.(type) { 181 | case *CloudForest.DenseNumFeature: 182 | unstratified = true 183 | } 184 | } 185 | if unstratified { 186 | ncases := len(data.CaseLabels) 187 | cases := make([]int, ncases, ncases) 188 | for i := 0; i < ncases; i++ { 189 | cases[i] = i 190 | } 191 | CloudForest.SampleFirstN(&cases, nil, len(cases), 0) 192 | for j := 0; j < folds; j++ { 193 | for k := j * foldsize; k < (j+1)*foldsize; k++ { 194 | foldis[j] = append(foldis[j], cases[k]) 195 | } 196 | } 197 | 198 | } else { 199 | //sample folds stratified by case 200 | fmt.Printf("Stratifying by %v classes.\n", targetf.(*CloudForest.DenseCatFeature).NCats()) 201 | bSampler := CloudForest.NewBalancedSampler(targetf.(*CloudForest.DenseCatFeature)) 202 | 203 | fmt.Printf("Stratifying by %v classes.\n", len(bSampler.Cases)) 204 | var samples []int 205 | for i := 0; i < len(bSampler.Cases); i++ { 206 | fmt.Printf("%v cases in class %v.\n", len(bSampler.Cases[i]), i) 207 | //shuffle in place 208 | CloudForest.SampleFirstN(&bSampler.Cases[i], &samples, len(bSampler.Cases[i]), 0) 209 | stratFoldSize := len(bSampler.Cases[i]) / folds 210 | for j := 0; j < folds; j++ { 211 | for k := j * stratFoldSize; k < (j+1)*stratFoldSize; k++ { 212 | foldis[j] = append(foldis[j], bSampler.Cases[i][k]) 213 | 214 | } 215 | } 216 | 217 | } 218 | } 219 | encode := false 220 | 221 | for _, f := range data.Data { 222 | if f.NCats() > 0 { 223 | encode = true 224 | } 225 | } 226 | 227 | encoded := data 228 | if encode && (writelibsvm || writeall) { 229 | encoded = data.EncodeToNum() 230 | } 231 | 232 | trainis := make([]int, 0, foldsize*(folds-1)) 233 | //Write training and testing matrixes 234 | for i := 0; i < folds; i++ { 235 | 236 | trainfn := fmt.Sprintf(*train, i) 237 | testfn := fmt.Sprintf(*test, i) 238 | 239 | trainis = trainis[0:0] 240 | for j := 0; j < folds; j++ { 241 | if i != j { 242 | trainis = append(trainis, foldis[j]...) 243 | } 244 | } 245 | 246 | if writearff || writeall { 247 | trainW, testW := openfiles(trainfn+".arff", testfn+".arff") 248 | CloudForest.WriteArffCases(data, foldis[i], *targetname, testW) 249 | CloudForest.WriteArffCases(data, trainis, *targetname, trainW) 250 | } 251 | 252 | if ((!writelibsvm) && (!writearff)) || writeall { 253 | trainW, testW := openfiles(trainfn, testfn) 254 | data.WriteCases(testW, foldis[i]) 255 | data.WriteCases(trainW, trainis) 256 | } 257 | 258 | if writelibsvm || writeall { 259 | trainW, testW := openfiles(trainfn+".libsvm", testfn+".libsvm") 260 | CloudForest.WriteLibSvmCases(encoded, foldis[i], *targetname, testW) 261 | CloudForest.WriteLibSvmCases(encoded, trainis, *targetname, trainW) 262 | } 263 | 264 | fmt.Printf("Wrote fold %v. %v testing cases and %v training cases.\n", i, len(foldis[i]), len(trainis)) 265 | } 266 | 267 | } 268 | -------------------------------------------------------------------------------- /utils/toafm/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/csv" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "log" 9 | "os" 10 | "strings" 11 | 12 | "github.com/lytics/CloudForest" 13 | ) 14 | 15 | func main() { 16 | fm := flag.String("data", 17 | "", "Data file to read.") 18 | 19 | outfn := flag.String("out", 20 | "", "The name of a file to write feature matrix too.") 21 | 22 | libsvmtarget := flag.String("libsvmtarget", 23 | "", "Output lib svm with the named feature in the first position.") 24 | 25 | anontarget := flag.String("anontarget", 26 | "", "Strip strings with named feature in the first position.") 27 | 28 | blacklist := flag.String("blacklist", 29 | "", "A list of feature id's to exclude from the set of predictors.") 30 | 31 | flag.Parse() 32 | 33 | //Parse Data 34 | data, err := CloudForest.LoadAFM(*fm) 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | 39 | blacklisted := 0 40 | blacklistis := make([]bool, len(data.Data)) 41 | if *blacklist != "" { 42 | fmt.Printf("Loading blacklist from: %v\n", *blacklist) 43 | blackfile, err := os.Open(*blacklist) 44 | if err != nil { 45 | log.Fatal(err) 46 | } 47 | tsv := csv.NewReader(blackfile) 48 | tsv.Comma = '\t' 49 | for { 50 | id, err := tsv.Read() 51 | if err == io.EOF { 52 | break 53 | } else if err != nil { 54 | log.Fatal(err) 55 | } 56 | if id[0] == *anontarget || id[0] == *libsvmtarget { 57 | continue 58 | } 59 | i, ok := data.Map[id[0]] 60 | if !ok { 61 | fmt.Printf("Ignoring blacklist feature not found in data: %v\n", id[0]) 62 | continue 63 | } 64 | if !blacklistis[i] { 65 | blacklisted += 1 66 | blacklistis[i] = true 67 | } 68 | 69 | } 70 | blackfile.Close() 71 | 72 | } 73 | 74 | newdata := make([]CloudForest.Feature, 0, len(data.Data)-blacklisted) 75 | newmap := make(map[string]int, len(data.Data)-blacklisted) 76 | 77 | for i, f := range data.Data { 78 | if !blacklistis[i] { 79 | newmap[f.GetName()] = len(newdata) 80 | newdata = append(newdata, f) 81 | } 82 | } 83 | 84 | data.Data = newdata 85 | data.Map = newmap 86 | 87 | if *anontarget != "" { 88 | data.StripStrings(*anontarget) 89 | 90 | } 91 | 92 | //anotate with type information 93 | for _, f := range data.Data { 94 | switch f.(type) { 95 | case *CloudForest.DenseNumFeature: 96 | nf := f.(*CloudForest.DenseNumFeature) 97 | if !strings.HasPrefix(nf.Name, "N:") { 98 | nf.Name = "N:" + nf.Name 99 | } 100 | case *CloudForest.DenseCatFeature: 101 | nf := f.(*CloudForest.DenseCatFeature) 102 | if !(strings.HasPrefix(nf.Name, "C:") || strings.HasPrefix(nf.Name, "B:")) { 103 | nf.Name = "C:" + nf.Name 104 | } 105 | 106 | } 107 | } 108 | 109 | ncases := data.Data[0].Length() 110 | cases := make([]int, ncases, ncases) 111 | 112 | for i := 0; i < ncases; i++ { 113 | cases[i] = i 114 | } 115 | 116 | outfile, err := os.Create(*outfn) 117 | if err != nil { 118 | log.Fatal(err) 119 | } 120 | defer outfile.Close() 121 | 122 | if *libsvmtarget == "" { 123 | 124 | err = data.WriteCases(outfile, cases) 125 | if err != nil { 126 | log.Fatal(err) 127 | } 128 | } else { 129 | // targeti, ok := data.Map[*libsvmtarget] 130 | // if !ok { 131 | // log.Fatalf("Target '%v' not found in data.", *libsvmtarget) 132 | // } 133 | // target := data.Data[targeti] 134 | 135 | // data.Data = append(data.Data[:targeti], data.Data[targeti+1:]...) 136 | 137 | // encodedfm := data.EncodeToNum() 138 | 139 | // oucsv := csv.NewWriter(outfile) 140 | // oucsv.Comma = ' ' 141 | 142 | // for i := 0; i < target.Length(); i++ { 143 | // entries := make([]string, 0, 10) 144 | // switch target.(type) { 145 | // case CloudForest.NumFeature: 146 | // entries = append(entries, target.GetStr(i)) 147 | // case CloudForest.CatFeature: 148 | // entries = append(entries, fmt.Sprintf("%v", target.(CloudForest.CatFeature).Geti(i))) 149 | // } 150 | 151 | // for j, f := range encodedfm.Data { 152 | // v := f.(CloudForest.NumFeature).Get(i) 153 | // if v != 0.0 { 154 | // entries = append(entries, fmt.Sprintf("%v:%v", j+1, v)) 155 | // } 156 | // } 157 | // //fmt.Println(entries) 158 | // err := oucsv.Write(entries) 159 | // if err != nil { 160 | // log.Fatalf("Error writing libsvm:\n%v", err) 161 | // } 162 | 163 | // } 164 | // oucsv.Flush() 165 | err = CloudForest.WriteLibSvm(data, *libsvmtarget, outfile) 166 | if err != nil { 167 | log.Fatalf("Error writing libsvm:\n%v", err) 168 | } 169 | 170 | } 171 | 172 | } 173 | -------------------------------------------------------------------------------- /utils_test.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | import ( 4 | "encoding/csv" 5 | "io" 6 | "testing" 7 | ) 8 | 9 | func TestSpareseCounter(t *testing.T) { 10 | sc := new(SparseCounter) 11 | sc.Add(1, 1, 1) 12 | 13 | pipereader, pipewriter := io.Pipe() 14 | 15 | go func() { 16 | sc.WriteTsv(pipewriter) 17 | pipewriter.Close() 18 | }() 19 | 20 | tsv := csv.NewReader(pipereader) 21 | tsv.Comma = '\t' 22 | 23 | records, err := tsv.Read() 24 | if err != nil { 25 | t.Errorf("Error reading tsv output by SpareCOunter %v", err) 26 | } 27 | if l := len(records); l != 3 { 28 | t.Errorf("Sparse counter output tsv with %v records", l) 29 | } 30 | for i, r := range records { 31 | if r != "1" { 32 | t.Errorf("Spares counter out put wrong value %v or field %v", r, i) 33 | } 34 | } 35 | 36 | } 37 | 38 | func TestParseAsIntOrFractionOfTotal(t *testing.T) { 39 | 40 | if p := ParseAsIntOrFractionOfTotal("70", 100); p != 70 { 41 | t.Errorf("ParseAsIntOrFractionOfTotal parsed 70 as %v", p) 42 | } 43 | 44 | if p := ParseAsIntOrFractionOfTotal(".7", 100); p != 70 { 45 | t.Errorf("ParseAsIntOrFractionOfTotal parsed .7 as %v / 100", p) 46 | } 47 | 48 | if p := ParseAsIntOrFractionOfTotal("blah.7", 100); p != 0 { 49 | t.Errorf("ParseAsIntOrFractionOfTotal parsed blah.7 as %v / 100", p) 50 | } 51 | 52 | if p := ParseAsIntOrFractionOfTotal("blah", 100); p != 0 { 53 | t.Errorf("ParseAsIntOrFractionOfTotal parsed blah as %v / 100", p) 54 | } 55 | 56 | } 57 | -------------------------------------------------------------------------------- /voter.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | //VoteTallyer is used to tabulate votes by trees and is implemented by feature type specific 4 | //structs like NumBallotBox and CatBallotBox. 5 | //Vote should register a cote that casei should be predicted as pred. 6 | //TallyError returns the error vs the supplied feature. 7 | type VoteTallyer interface { 8 | Vote(casei int, pred string, weight float64) 9 | TallyError(feature Feature) float64 10 | Tally(casei int) string 11 | } 12 | -------------------------------------------------------------------------------- /wrappers/python/CFClassifier.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import pandas as pd 3 | import subprocess 4 | import numpy as np 5 | 6 | __doc__="""The CFClassifier module includes the CFClassifier class which will wrap 7 | calls to cloudforests growforest and applyforest utilities to be called as a scikit-learn 8 | classifier. 9 | 10 | It works via writting uuid identified temp files to disk in the current working directory so 11 | it has more overhead then a pure in memory implementation but handle problems where the 12 | model is too large to fit in system memory. 13 | """ 14 | 15 | def strtobool (val): 16 | """Convert a string representation of truth to true (1) or false (0). 17 | 18 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values 19 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 20 | 'val' is anything else. 21 | """ 22 | val = val.lower() 23 | if val in ('y', 'yes', 't', 'true', 'on', '1'): 24 | return 1 25 | elif val in ('n', 'no', 'f', 'false', 'off', '0'): 26 | return 0 27 | else: 28 | raise ValueError("invalid truth value %r" % (val,)) 29 | 30 | def writearff(fo, df, target="", unique=[]): 31 | """writearff writes a pandasdataframe, df, to a file like object fo""" 32 | 33 | fo.write("@RELATION %(target)s\n\n"%{"target":target}) 34 | 35 | #print df[target] 36 | 37 | for col in df.columns: 38 | #print df[col].dtype 39 | coltype="NUMERIC" 40 | 41 | if df.dtypes[col] == bool: 42 | coltype = "{True,False}" 43 | 44 | if target!="" and col == target: 45 | coltype = "{%(values)s}"%{"values":",".join([str(v) for v in unique])} 46 | 47 | 48 | 49 | 50 | fo.write("@ATTRIBUTE %(name)s %(type)s\n"%{"name":col,"type":coltype}) 51 | 52 | fo.write("\n@DATA\n") 53 | df.to_csv(fo, na_rep="NA", index=False, header=False) 54 | 55 | class CFClassifier: 56 | """CFClassifier wraps command line calls to cloudforest's growforest 57 | and applyforest for use as a scikit-learn Classifier. It will write 58 | temporary files to your workding directory.""" 59 | 60 | options = "" 61 | 62 | 63 | def __init__(self, optionstring): 64 | self.options = optionstring 65 | self.uuid = uuid.uuid1() 66 | 67 | def fit(self, X, y): 68 | df = pd.DataFrame(X).copy() 69 | target = "%(uuid)s.target"%{"uuid":self.uuid} 70 | fn = "%(uuid)s.train.cloudforest.arff"%{"uuid":self.uuid} 71 | self.forest = "%(uuid)s.forest.cloudforest.sf"%{"uuid":self.uuid} 72 | 73 | 74 | self.unique = np.unique(y) 75 | 76 | #print y 77 | df[target] = np.array(y,dtype=bool) 78 | #print df[target] 79 | 80 | 81 | fo = open(fn,"w") 82 | writearff(fo,df,target,self.unique) 83 | fo.close() 84 | 85 | invocation = "growforest -train %(data)s -target %(target)s -rfpred %(forest)s %(options)s"%{"data":fn, 86 | "target":target, 87 | "forest":self.forest, 88 | "options":self.options} 89 | 90 | #print invocation 91 | 92 | subprocess.call(invocation, shell=True) 93 | 94 | def predict(self, X): 95 | df = pd.DataFrame(X) 96 | fn = "%(uuid)s.test.cloudforest.arff"%{"uuid":self.uuid} 97 | preds = "%(uuid)s.preds.cloudforest.tsv"%{"uuid":self.uuid} 98 | 99 | fo = open(fn,"w") 100 | writearff(fo,df) 101 | fo.close() 102 | 103 | invocation = "applyforest -fm %(data)s -rfpred %(forest)s -preds %(preds)s"%{"data":fn, 104 | "forest":self.forest, 105 | "preds": preds} 106 | 107 | subprocess.call(invocation, shell=True) 108 | 109 | fo =open(preds) 110 | predictions = [] 111 | for line in fo: 112 | vs= line.rstrip().split() 113 | predictions.append(vs[1]) 114 | fo.close() 115 | 116 | return np.array(predictions,dtype=int) 117 | 118 | def predict_proba(self, X): 119 | df = pd.DataFrame(X) 120 | fn = "%(uuid)s.test.cloudforest.arff"%{"uuid":self.uuid} 121 | votes = "%(uuid)s.votes.cloudforest.tsv"%{"uuid":self.uuid} 122 | 123 | 124 | fo = open(fn,"w") 125 | writearff(fo,df) 126 | fo.close() 127 | 128 | invocation = "applyforest -fm %(data)s -rfpred %(forest)s -votes %(votes)s"%{"data":fn, 129 | "forest":self.forest, 130 | "votes":votes} 131 | 132 | 133 | 134 | subprocess.call(invocation, shell=True) 135 | 136 | fo =open(votes) 137 | 138 | header = 0 139 | votes = 0 140 | 141 | 142 | line = fo.next() 143 | vs = line.split()[1:] 144 | if vs[0]=="True" or vs[0]=="False": 145 | header = np.array([strtobool(v) for v in vs],dtype=bool) 146 | votes = np.loadtxt(fo, dtype="int") 147 | else: 148 | header = np.array([int(v) for v in vs],dtype=int) 149 | votes = np.loadtxt(fo, dtype="int") 150 | fo.close() 151 | 152 | vote_totals = np.sum(votes[:,1:],axis=1) 153 | 154 | #print vote_totals.shape, votes.shape, self.unique.shape, self.unique 155 | 156 | probs = [] 157 | for v in self.unique: 158 | if v in header: 159 | probs.append(np.array(votes[:,1:][:,header==v],dtype=float).T/np.array(vote_totals,dtype=float)[0]) 160 | else: 161 | probs.append(np.zeros_like(vote_totals)) 162 | 163 | 164 | 165 | return np.dstack(probs)[0] 166 | 167 | -------------------------------------------------------------------------------- /wrappers/python/test_CFClassifier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from sklearn import datasets 3 | from sklearn.utils.validation import check_random_state 4 | from CFClassifier import CFClassifier 5 | import os.path 6 | 7 | import numpy as np 8 | 9 | from sklearn.metrics import roc_auc_score 10 | 11 | class TestCFClassifier(unittest.TestCase): 12 | 13 | def test_iris(self): 14 | """Check consistency on dataset iris.""" 15 | 16 | # also load the iris dataset 17 | # and randomly permute it 18 | iris = datasets.load_iris() 19 | rng = check_random_state(0) 20 | perm = rng.permutation(iris.target.size) 21 | iris.data = iris.data[perm] 22 | iris.target = iris.target[perm] 23 | 24 | 25 | 26 | clf = CFClassifier("") 27 | clf.fit(iris.data, iris.target) 28 | 29 | self.assertTrue(os.path.isfile(clf.forest)) 30 | 31 | preds = clf.predict(iris.data) 32 | 33 | 34 | predicted_ratio = float(np.sum(preds==iris.target))/float(len(iris.target)) 35 | print predicted_ratio 36 | 37 | self.assertGreaterEqual(predicted_ratio, .97) 38 | 39 | probs = clf.predict_proba(iris.data) 40 | 41 | 42 | bin_idx=iris.target!=2 43 | 44 | roc_auc = roc_auc_score(iris.target[bin_idx], probs[bin_idx,1]) 45 | 46 | self.assertGreaterEqual(roc_auc, .97) 47 | 48 | 49 | 50 | 51 | #score = clf.score(iris.data, iris.target) 52 | 53 | #assert_greater(score, 0.9, "Failed with criterion %s and score = %f" 54 | # % (criterion, score) 55 | 56 | if __name__ == '__main__': 57 | unittest.main() -------------------------------------------------------------------------------- /wrftarget.go: -------------------------------------------------------------------------------- 1 | package CloudForest 2 | 3 | /* 4 | WRFTarget wraps a numerical feature as a target for us weigted random forest. 5 | */ 6 | type WRFTarget struct { 7 | CatFeature 8 | Weights []float64 9 | } 10 | 11 | /* 12 | NewWRFTarget creates a weighted random forest target and initializes its weights. 13 | */ 14 | func NewWRFTarget(f CatFeature, weights map[string]float64) (abt *WRFTarget) { 15 | abt = &WRFTarget{f, make([]float64, f.NCats())} 16 | 17 | for i := range abt.Weights { 18 | abt.Weights[i] = weights[f.NumToCat(i)] 19 | } 20 | 21 | return 22 | } 23 | 24 | /* 25 | SplitImpurity is an weigtedRF version of SplitImpurity. 26 | */ 27 | func (target *WRFTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs) (impurityDecrease float64) { 28 | nl := float64(len(*l)) 29 | nr := float64(len(*r)) 30 | nm := 0.0 31 | 32 | impurityDecrease = nl * target.Impurity(l, allocs.LCounter) 33 | impurityDecrease += nr * target.Impurity(r, allocs.RCounter) 34 | if m != nil && len(*m) > 0 { 35 | nm = float64(len(*m)) 36 | impurityDecrease += nm * target.Impurity(m, allocs.Counter) 37 | } 38 | 39 | impurityDecrease /= nl + nr + nm 40 | return 41 | } 42 | 43 | //UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l 44 | //to avoid recalulatign the entire split impurity. 45 | func (target *WRFTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) { 46 | var cat, i int 47 | lcounter := *allocs.LCounter 48 | rcounter := *allocs.RCounter 49 | for _, i = range *movedRtoL { 50 | 51 | //most expensive statement: 52 | cat = target.Geti(i) 53 | lcounter[cat]++ 54 | rcounter[cat]-- 55 | //counter[target.Geti(i)]++ 56 | 57 | } 58 | nl := float64(len(*l)) 59 | nr := float64(len(*r)) 60 | nm := 0.0 61 | 62 | impurityDecrease = nl * target.ImpFromCounts(allocs.LCounter) 63 | impurityDecrease += nr * target.ImpFromCounts(allocs.RCounter) 64 | if m != nil && len(*m) > 0 { 65 | nm = float64(len(*m)) 66 | impurityDecrease += nm * target.ImpFromCounts(allocs.Counter) 67 | } 68 | 69 | impurityDecrease /= nl + nr + nm 70 | return 71 | } 72 | 73 | //Impurity is Gini impurity that uses the weights specified in WRFTarget.weights. 74 | func (target *WRFTarget) Impurity(cases *[]int, counter *[]int) (e float64) { 75 | 76 | target.CountPerCat(cases, counter) 77 | 78 | return target.ImpFromCounts(counter) 79 | } 80 | 81 | //ImpFromCounts recalculates gini impurity from class counts for us in intertive updates. 82 | func (target *WRFTarget) ImpFromCounts(counter *[]int) (e float64) { 83 | 84 | total := 0.0 85 | for i, v := range *counter { 86 | w := target.Weights[i] 87 | total += float64(v) * w 88 | 89 | e -= float64(v*v) * w * w 90 | } 91 | 92 | e /= float64(total * total) 93 | e++ 94 | 95 | return 96 | 97 | } 98 | 99 | //FindPredicted finds the predicted target as the weighted catagorical Mode. 100 | func (target *WRFTarget) FindPredicted(cases []int) (pred string) { 101 | 102 | counts := make([]int, target.NCats()) 103 | for _, i := range cases { 104 | if !target.IsMissing(i) { 105 | counts[target.Geti(i)] += 1 106 | } 107 | 108 | } 109 | m := 0 110 | max := 0.0 111 | for k, v := range counts { 112 | val := float64(v) * target.Weights[k] 113 | if val > max { 114 | m = k 115 | max = val 116 | } 117 | } 118 | 119 | pred = target.NumToCat(m) 120 | 121 | return 122 | 123 | } 124 | --------------------------------------------------------------------------------