├── LICENSE ├── Makefile ├── README ├── boost.cc ├── boost.h ├── boost_test.cc ├── driver.cc ├── io.cc ├── io.h ├── io_test.cc ├── srm_test.h ├── testdata └── breast-cancer-wisconsin.data ├── tree.cc ├── tree.h ├── tree_test.cc └── types.h /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # SYNOPSIS: 2 | # 3 | # make - make everything 4 | # make test - make and run all tests 5 | # make clean - remove all files generated by make 6 | # make driver - make the main executable 7 | 8 | # LIB_DIR should satisfy the following: 9 | # LIB_DIR/include/gflags contains Google Commandline Flags include files 10 | # LIB_DIR/include/glog contains Google Logging include files 11 | # LIB_DIR/include/gtest contains Google Test include files 12 | # LIB_DIR/src contains Google Test source files 13 | # LIB_DIR/lib contains libgflags* and libglog* library files. 14 | LIB_DIR = /usr/local/google/home/usyed/googleopensource 15 | 16 | # Where to find user code. 17 | USER_DIR = . 18 | 19 | # Flags passed to the preprocessor. 20 | CPPFLAGS += -isystem $(LIB_DIR)/include 21 | 22 | # Flags passed to the C++ compiler. Add -O3 for the highest optimization level. 23 | # Add -ggdb for GDB debugging info. 24 | CXXFLAGS += -Wall -Wextra -pthread -std=c++11 25 | 26 | # All tests produced by this Makefile. Remember to add new tests you 27 | # created to the list. 28 | TESTS = tree_test boost_test io_test 29 | 30 | # All Google Test headers. Usually you shouldn't change this 31 | # definition. 32 | GTEST_HEADERS = $(LIB_DIR)/include/gtest/*.h \ 33 | $(LIB_DIR)/include/gtest/internal/*.h 34 | 35 | # House-keeping build targets. 36 | 37 | test: $(TESTS) 38 | ./tree_test 39 | ./io_test 40 | ./boost_test 41 | clean : 42 | rm -f $(TESTS) gtest_main.a driver *.o 43 | 44 | # Builds gtest_main.a. 45 | 46 | # Usually you shouldn't tweak such internal variables, indicated by a 47 | # trailing _. 48 | GTEST_SRCS_ = $(LIB_DIR)/src/*.cc $(LIB_DIR)/src/*.h $(GTEST_HEADERS) 49 | 50 | # For simplicity and to avoid depending on Google Test's 51 | # implementation details, the dependencies specified below are 52 | # conservative and not optimized. This is fine as Google Test 53 | # compiles fast and for ordinary users its source rarely changes. 54 | gtest-all.o : $(GTEST_SRCS_) 55 | $(CXX) $(CPPFLAGS) -I$(LIB_DIR) $(CXXFLAGS) -c \ 56 | $(LIB_DIR)/src/gtest-all.cc 57 | 58 | gtest_main.o : $(GTEST_SRCS_) 59 | $(CXX) $(CPPFLAGS) -I$(LIB_DIR) $(CXXFLAGS) -c \ 60 | $(LIB_DIR)/src/gtest_main.cc 61 | 62 | gtest_main.a : gtest-all.o gtest_main.o 63 | $(AR) $(ARFLAGS) $@ $^ 64 | 65 | # Builds tests. A test should link with gtest_main.a. 66 | 67 | tree.o : $(USER_DIR)/tree.cc $(USER_DIR)/tree.h $(GTEST_HEADERS) 68 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/tree.cc 69 | 70 | tree_test.o : $(USER_DIR)/tree_test.cc \ 71 | $(USER_DIR)/tree.h $(GTEST_HEADERS) 72 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/tree_test.cc 73 | 74 | tree_test : tree.o tree_test.o gtest_main.a 75 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -static -lpthread $^ -o $@ -L$(LIB_DIR)/lib -lgflags -lglog 76 | 77 | boost.o : $(USER_DIR)/boost.cc $(USER_DIR)/boost.h $(GTEST_HEADERS) 78 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/boost.cc 79 | 80 | boost_test.o : $(USER_DIR)/boost_test.cc \ 81 | $(USER_DIR)/boost.h $(GTEST_HEADERS) 82 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/boost_test.cc 83 | 84 | boost_test : tree.o boost.o boost_test.o gtest_main.a 85 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -static -lpthread $^ -o $@ -L$(LIB_DIR)/lib -lgflags -lglog 86 | 87 | io.o : $(USER_DIR)/io.cc $(USER_DIR)/io.h $(GTEST_HEADERS) 88 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/io.cc 89 | 90 | io_test.o : $(USER_DIR)/io_test.cc \ 91 | $(USER_DIR)/io.h $(GTEST_HEADERS) 92 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/io_test.cc 93 | 94 | io_test : tree.o io.o io_test.o gtest_main.a 95 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -static -lpthread $^ -o $@ -L$(LIB_DIR)/lib -lgflags -lglog 96 | 97 | # Build the main executable 98 | 99 | driver.o : $(USER_DIR)/driver.cc 100 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $(USER_DIR)/driver.cc 101 | 102 | driver : tree.o boost.o io.o driver.o 103 | $(CXX) $(CPPFLAGS) $(CXXFLAGS) -static -lpthread $^ -o $@ -L$(LIB_DIR)/lib -lgflags -lglog 104 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | Code for DeepBoost algorithm described in: 2 | 3 | Corinna Cortes, Mehryar Mohri, Umar Syed (2014) "Deep Boosting", ICML 2014 4 | -------------------------------------------------------------------------------- /boost.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include "boost.h" 18 | 19 | #include 20 | #include 21 | 22 | #include "gflags/gflags.h" 23 | #include "glog/logging.h" 24 | #include "tree.h" 25 | 26 | DEFINE_string(loss_type, "", 27 | "Loss type. Required: One of exponential, logistic."); 28 | 29 | float ComputeEta(float wgtd_error, float tree_size, float alpha) { 30 | wgtd_error = fmax(wgtd_error, kTolerance); // Helps with division by zero. 31 | const float error_term = 32 | (1 - wgtd_error) * exp(alpha) - wgtd_error * exp(-alpha); 33 | const float complexity_penalty = ComplexityPenalty(tree_size); 34 | const float ratio = complexity_penalty / wgtd_error; 35 | float eta; 36 | if (fabs(error_term) <= 2 * complexity_penalty) { 37 | eta = -alpha; 38 | } else if (error_term > 2 * complexity_penalty) { 39 | eta = log(-ratio + sqrt(ratio * ratio + (1 - wgtd_error)/wgtd_error)); 40 | } else { 41 | eta = log(ratio + sqrt(ratio * ratio + (1 - wgtd_error)/wgtd_error)); 42 | } 43 | return eta; 44 | } 45 | 46 | // TODO(usyed): examples is passed by non-const reference because the example 47 | // weights need to be changed. This is bad style. 48 | void AddTreeToModel(vector& examples, Model* model) { 49 | // Initialize normalizer 50 | static float normalizer; 51 | if (model->empty()) { 52 | if (FLAGS_loss_type == "exponential") { 53 | normalizer = exp(1) * static_cast(examples.size()); 54 | } else if (FLAGS_loss_type == "logistic") { 55 | normalizer = 56 | static_cast(examples.size()) / (log(2) * (1 + exp(-1))); 57 | } else { 58 | LOG(FATAL) << "Unexpected loss type: " << FLAGS_loss_type; 59 | } 60 | } 61 | InitializeTreeData(examples, normalizer); 62 | int best_old_tree_idx = -1; 63 | float best_wgtd_error, wgtd_error, gradient, best_gradient = 0; 64 | 65 | // Find best old tree 66 | bool old_tree_is_best = false; 67 | for (int i = 0; i < model->size(); ++i) { 68 | const float alpha = (*model)[i].first; 69 | if (fabs(alpha) < kTolerance) continue; // Skip zeroed-out weights. 70 | const Tree& old_tree = (*model)[i].second; 71 | wgtd_error = EvaluateTreeWgtd(examples, old_tree); 72 | int sign_edge = (wgtd_error >= 0.5) ? 1 : -1; 73 | gradient = Gradient(wgtd_error, old_tree.size(), alpha, sign_edge); 74 | if (fabs(gradient) >= fabs(best_gradient)) { 75 | best_gradient = gradient; 76 | best_wgtd_error = wgtd_error; 77 | best_old_tree_idx = i; 78 | old_tree_is_best = true; 79 | } 80 | } 81 | 82 | // Find best new tree 83 | Tree new_tree = TrainTree(examples); 84 | wgtd_error = EvaluateTreeWgtd(examples, new_tree); 85 | gradient = Gradient(wgtd_error, new_tree.size(), 0, -1); 86 | if (model->empty() || fabs(gradient) > fabs(best_gradient)) { 87 | best_gradient = gradient; 88 | best_wgtd_error = wgtd_error; 89 | old_tree_is_best = false; 90 | } 91 | 92 | // Update model weights 93 | float alpha; 94 | const Tree* tree; 95 | if (old_tree_is_best) { 96 | alpha = (*model)[best_old_tree_idx].first; 97 | tree = &((*model)[best_old_tree_idx].second); 98 | } else { 99 | alpha = 0; 100 | tree = &(new_tree); 101 | } 102 | const float eta = ComputeEta(best_wgtd_error, tree->size(), alpha); 103 | if (old_tree_is_best) { 104 | (*model)[best_old_tree_idx].first += eta; 105 | } else { 106 | model->push_back(make_pair(eta, new_tree)); 107 | } 108 | 109 | // Update examples weights and compute normalizer 110 | const float old_normalizer = normalizer; 111 | normalizer = 0; 112 | for (Example& example : examples) { 113 | const float u = eta * example.label * ClassifyExample(example, *tree); 114 | if (FLAGS_loss_type == "exponential") { 115 | example.weight *= exp(-u); 116 | } else if (FLAGS_loss_type == "logistic") { 117 | const float z = (1 - log(2) * example.weight * old_normalizer) / 118 | (log(2) * example.weight * old_normalizer); 119 | example.weight = 1 / (log(2) * (1 + z * exp(u))); 120 | } else { 121 | LOG(FATAL) << "Unexpected loss type: " << FLAGS_loss_type; 122 | } 123 | normalizer += example.weight; 124 | } 125 | 126 | // Renormalize example weights 127 | // TODO(usyed): Two loops is inefficient. 128 | for (Example& example : examples) { 129 | example.weight /= normalizer; 130 | } 131 | } 132 | 133 | Label ClassifyExample(const Example& example, const Model& model) { 134 | float score = 0; 135 | for (const pair& wgtd_tree : model) { 136 | score += wgtd_tree.first * ClassifyExample(example, wgtd_tree.second); 137 | } 138 | if (score < 0) { 139 | return -1; 140 | } else { 141 | return 1; 142 | } 143 | } 144 | 145 | void EvaluateModel(const vector& examples, const Model& model, 146 | float* error, float* avg_tree_size, int* num_trees) { 147 | float incorrect = 0; 148 | for (const Example& example : examples) { 149 | if (example.label != ClassifyExample(example, model)) { 150 | ++incorrect; 151 | } 152 | } 153 | *num_trees = 0; 154 | int sum_tree_size = 0; 155 | for (const pair& wgtd_tree : model) { 156 | if (fabs(wgtd_tree.first) >= kTolerance) { 157 | ++(*num_trees); 158 | sum_tree_size += wgtd_tree.second.size(); 159 | } 160 | } 161 | *error = (incorrect / examples.size()); 162 | *avg_tree_size = static_cast(sum_tree_size) / *num_trees; 163 | } 164 | -------------------------------------------------------------------------------- /boost.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #ifndef BOOST_H_ 18 | #define BOOST_H_ 19 | 20 | #include "types.h" 21 | 22 | // Either add a new tree to model or update the weight of an existing tree in 23 | // model. The tree and weight are selected via approximate coordinate descent on 24 | // the objective, where the "approximate" indicates that we do not search all 25 | // trees but instead grow trees greedily. 26 | void AddTreeToModel(vector& examples, Model* model); 27 | 28 | // Classify example with model. 29 | Label ClassifyExample(const Example& example, const Model& model); 30 | 31 | // Compute the error of model on examples. Also compute the number of trees in 32 | // model and their average size. 33 | void EvaluateModel(const vector& examples, const Model& model, 34 | float* error, float* avg_tree_size, int* num_trees); 35 | 36 | // Return the optimal weight to add to a tree that will maximally decrease the 37 | // objective. 38 | float ComputeEta(float wgtd_error, float tree_size, float alpha); 39 | 40 | #endif // BOOST_H_ 41 | -------------------------------------------------------------------------------- /boost_test.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include 18 | 19 | #include "boost.h" 20 | #include "tree.h" // TODO(usyed): Figure out how not to have to include this. 21 | #include "srm_test.h" 22 | 23 | #include "gflags/gflags.h" 24 | #include "gtest/gtest.h" 25 | 26 | DECLARE_int32(tree_depth); 27 | DECLARE_double(beta); 28 | DECLARE_double(lambda); 29 | DECLARE_string(loss_type); 30 | 31 | class BoostTest : public SrmTest { 32 | protected: 33 | virtual void SetUp() { 34 | SrmTest::SetUp(); 35 | InitializeTreeData(examples_, examples_.size()); 36 | } 37 | }; 38 | 39 | TEST_F(BoostTest, TestAddTreeToModel) { 40 | FLAGS_tree_depth = 1; 41 | FLAGS_beta = 0; 42 | FLAGS_lambda = 0; 43 | FLAGS_loss_type = "exponential"; 44 | Model model; 45 | // Train a model with a single tree. The tree's weighted error will be 0.2, 46 | // and it will only get example 3 wrong. 47 | AddTreeToModel(examples_, &model); 48 | // Every example is originally weighted equally. 49 | const float original_wgt = 0.2; 50 | // alpha = 0.5 * log((1 - error) / error), where error = 0.2. 51 | float alpha = 0.69314718056; 52 | // Normalizer is sum of all adjusted weights. 53 | float normalizer = 4 * original_wgt * exp(-alpha) + original_wgt * exp(alpha); 54 | // Adjust weights and normalize. 55 | float correct_wgt = original_wgt * exp(-alpha) / normalizer; 56 | float incorrect_wgt = original_wgt * exp(alpha) / normalizer; 57 | EXPECT_NEAR(correct_wgt, examples_[0].weight, kTolerance); 58 | EXPECT_NEAR(correct_wgt, examples_[1].weight, kTolerance); 59 | EXPECT_NEAR(correct_wgt, examples_[2].weight, kTolerance); 60 | EXPECT_NEAR(incorrect_wgt, examples_[3].weight, kTolerance); 61 | EXPECT_NEAR(correct_wgt, examples_[4].weight, kTolerance); 62 | 63 | // Add another tree to the model. The tree's weighted error will be 0.125, and 64 | // it will only get example 4 wrong. 65 | AddTreeToModel(examples_, &model); 66 | // alpha = 0.5 * log((1 - error) / error), where error = 0.125. 67 | alpha = 0.97295507452; 68 | // Normalizer is sum of all adjusted weights. 69 | normalizer = 3 * correct_wgt * exp(-alpha) + correct_wgt * exp(alpha) + 70 | incorrect_wgt * exp(-alpha); 71 | float both_correct_wgt = correct_wgt * exp(-alpha) / normalizer; 72 | float first_correct_wgt = correct_wgt * exp(alpha) / normalizer; 73 | float second_correct_wgt = incorrect_wgt * exp(-alpha) / normalizer; 74 | EXPECT_NEAR(both_correct_wgt, examples_[0].weight, kTolerance); 75 | EXPECT_NEAR(both_correct_wgt, examples_[1].weight, kTolerance); 76 | EXPECT_NEAR(both_correct_wgt, examples_[2].weight, kTolerance); 77 | EXPECT_NEAR(second_correct_wgt, examples_[3].weight, kTolerance); 78 | EXPECT_NEAR(first_correct_wgt, examples_[4].weight, kTolerance); 79 | } 80 | 81 | TEST_F(BoostTest, TestClassifyExampleDepthOne) { 82 | FLAGS_tree_depth = 1; 83 | FLAGS_beta = 0; 84 | FLAGS_lambda = 0; 85 | FLAGS_loss_type = "exponential"; 86 | Model model; 87 | AddTreeToModel(examples_, &model); 88 | AddTreeToModel(examples_, &model); 89 | // By the previous test, the first tree gets example 3 wrong and has weight 90 | // 0.69314718056, and the second tree has weight gets example 4 wrong and has 91 | // weight 0.97295507452. Since 0.97295507452 > 0.69314718056, the second tree 92 | // outvotes the first on all examples, so their combination gets example 4 93 | // wrong. 94 | EXPECT_EQ(examples_[0].label, ClassifyExample(examples_[0], model)); 95 | EXPECT_EQ(examples_[1].label, ClassifyExample(examples_[1], model)); 96 | EXPECT_EQ(examples_[2].label, ClassifyExample(examples_[2], model)); 97 | EXPECT_EQ(examples_[3].label, ClassifyExample(examples_[3], model)); 98 | EXPECT_EQ(-examples_[4].label, ClassifyExample(examples_[4], model)); 99 | } 100 | 101 | TEST_F(BoostTest, TestClassifyExampleDepthTwo) { 102 | FLAGS_tree_depth = 2; 103 | FLAGS_beta = 0; 104 | FLAGS_lambda = 0; 105 | FLAGS_loss_type = "exponential"; 106 | Model model; 107 | AddTreeToModel(examples_, &model); 108 | // Depth 2 trees can classify all examples perfectly. 109 | EXPECT_EQ(examples_[0].label, ClassifyExample(examples_[0], model)); 110 | EXPECT_EQ(examples_[1].label, ClassifyExample(examples_[1], model)); 111 | EXPECT_EQ(examples_[2].label, ClassifyExample(examples_[2], model)); 112 | EXPECT_EQ(examples_[3].label, ClassifyExample(examples_[3], model)); 113 | EXPECT_EQ(examples_[4].label, ClassifyExample(examples_[4], model)); 114 | const float alpha = model[0].first; 115 | // Won't actually add trees, will just increase weight on current tree. 116 | for (int i = 0; i < 99; ++i) { 117 | AddTreeToModel(examples_, &model); 118 | } 119 | EXPECT_EQ(1, model.size()); 120 | EXPECT_NEAR(alpha, model[0].first / 100, kTolerance * 100); 121 | EXPECT_EQ(examples_[0].label, ClassifyExample(examples_[0], model)); 122 | EXPECT_EQ(examples_[1].label, ClassifyExample(examples_[1], model)); 123 | EXPECT_EQ(examples_[2].label, ClassifyExample(examples_[2], model)); 124 | EXPECT_EQ(examples_[3].label, ClassifyExample(examples_[3], model)); 125 | EXPECT_EQ(examples_[4].label, ClassifyExample(examples_[4], model)); 126 | } 127 | 128 | TEST_F(BoostTest, TestClassifyExampleEmptyModel) { 129 | Model model; 130 | // Empty model classifies every example as positive 131 | EXPECT_EQ(1, ClassifyExample(examples_[0], model)); 132 | EXPECT_EQ(1, ClassifyExample(examples_[1], model)); 133 | EXPECT_EQ(1, ClassifyExample(examples_[2], model)); 134 | EXPECT_EQ(1, ClassifyExample(examples_[3], model)); 135 | EXPECT_EQ(1, ClassifyExample(examples_[4], model)); 136 | } 137 | 138 | TEST_F(BoostTest, TestEvaluateModelDepthOne) { 139 | FLAGS_tree_depth = 1; 140 | FLAGS_beta = 0; 141 | FLAGS_lambda = 0; 142 | FLAGS_loss_type = "exponential"; 143 | Model model; 144 | AddTreeToModel(examples_, &model); 145 | AddTreeToModel(examples_, &model); 146 | float error, avg_tree_size; 147 | int num_trees; 148 | EvaluateModel(examples_, model, &error, &avg_tree_size, &num_trees); 149 | EXPECT_NEAR(0.2, error, kTolerance); 150 | EXPECT_EQ(2, num_trees); 151 | EXPECT_NEAR(3, avg_tree_size, kTolerance); 152 | } 153 | 154 | TEST_F(BoostTest, TestEvaluateModelDepthTwo) { 155 | FLAGS_tree_depth = 2; 156 | FLAGS_beta = 0; 157 | FLAGS_lambda = 0; 158 | FLAGS_loss_type = "exponential"; 159 | Model model; 160 | AddTreeToModel(examples_, &model); 161 | float error, avg_tree_size; 162 | int num_trees; 163 | EvaluateModel(examples_, model, &error, &avg_tree_size, &num_trees); 164 | EXPECT_NEAR(0.0, error, kTolerance); 165 | EXPECT_EQ(1, num_trees); 166 | EXPECT_NEAR(5, avg_tree_size, kTolerance); 167 | } 168 | 169 | TEST_F(BoostTest, ComputeEtaTest) { 170 | FLAGS_beta = 1; 171 | FLAGS_lambda = 1; 172 | float eta = ComputeEta(1, 10, 1); 173 | EXPECT_NEAR(-1, eta, kTolerance); 174 | 175 | FLAGS_beta = 1; 176 | FLAGS_lambda = 0; 177 | eta = ComputeEta(0.1, 5, 2); 178 | float ratio = ComplexityPenalty(5) / 0.1; 179 | EXPECT_NEAR(log(-ratio + sqrt(ratio * ratio + (0.9 / 0.1))), eta, kTolerance); 180 | 181 | FLAGS_beta = 0; 182 | FLAGS_lambda = 1; 183 | eta = ComputeEta(0.75, 10, -10); 184 | ratio = ComplexityPenalty(10) / 0.75; 185 | EXPECT_NEAR(log(ratio + sqrt(ratio * ratio + (0.25 / 0.75))), 186 | eta, kTolerance); 187 | } 188 | 189 | TEST_F(BoostTest, TestAddTreeToModelLargeBeta) { 190 | // Large beta penalty means two trees with zero weight and depth 0. 191 | FLAGS_tree_depth = 1; 192 | FLAGS_beta = 100; 193 | FLAGS_lambda = 0; 194 | FLAGS_loss_type = "exponential"; 195 | Model model; 196 | AddTreeToModel(examples_, &model); 197 | AddTreeToModel(examples_, &model); 198 | EXPECT_EQ(2, model.size()); 199 | EXPECT_LT(model[0].first, kTolerance); 200 | EXPECT_LT(model[1].first, kTolerance); 201 | EXPECT_EQ(1, model[0].second.size()); 202 | EXPECT_EQ(1, model[1].second.size()); 203 | } 204 | 205 | TEST_F(BoostTest, TestAddTreeToModelLargeLambda) { 206 | // Large lambda penalty means two trees with zero weight and depth 0. 207 | FLAGS_tree_depth = 1; 208 | FLAGS_beta = 0; 209 | FLAGS_lambda = 100; 210 | FLAGS_loss_type = "exponential"; 211 | Model model; 212 | AddTreeToModel(examples_, &model); 213 | AddTreeToModel(examples_, &model); 214 | EXPECT_EQ(2, model.size()); 215 | EXPECT_LT(model[0].first, kTolerance); 216 | EXPECT_LT(model[1].first, kTolerance); 217 | EXPECT_EQ(1, model[0].second.size()); 218 | EXPECT_EQ(1, model[1].second.size()); 219 | } 220 | 221 | TEST_F(BoostTest, TestAddTreeToModelLogisticLoss) { 222 | FLAGS_tree_depth = 1; 223 | FLAGS_beta = 0; 224 | FLAGS_lambda = 0; 225 | FLAGS_loss_type = "logistic"; 226 | Model model; 227 | // Train a model with a single tree. The tree's weighted error will be 0.2, 228 | // and it will only get example 3 wrong. 229 | AddTreeToModel(examples_, &model); 230 | // alpha1 = 0.5 * log((1 - error) / error), where error = 0.2. 231 | float alpha1 = 0.69314718056; 232 | // Normalizer is sum of all adjusted weights. 233 | float normalizer = 234 | 4 * (1 / (1 + exp(alpha1 - 1))) + (1 / (1 + exp(-alpha1 - 1))); 235 | // Adjust weights and normalize. 236 | float correct_wgt = (1 / (1 + exp(alpha1 - 1))) / normalizer; 237 | float incorrect_wgt = (1 / (1 + exp(-alpha1 - 1))) / normalizer; 238 | EXPECT_NEAR(correct_wgt, examples_[0].weight, kTolerance); 239 | EXPECT_NEAR(correct_wgt, examples_[1].weight, kTolerance); 240 | EXPECT_NEAR(correct_wgt, examples_[2].weight, kTolerance); 241 | EXPECT_NEAR(incorrect_wgt, examples_[3].weight, kTolerance); 242 | EXPECT_NEAR(correct_wgt, examples_[4].weight, kTolerance); 243 | 244 | // Add another tree to the model. The tree's weighted error will be 245 | // 0.182946235, and it will only get example 4 wrong. 246 | AddTreeToModel(examples_, &model); 247 | // alpha2 = 0.5 * log((1 - error) / error), where error = 0.182946235. 248 | float alpha2 = 0.7482563445; 249 | // Normalizer is sum of all adjusted weights. 250 | normalizer = 3 * (1 / (1 + exp(alpha1 + alpha2 - 1))) + 251 | (1 / (1 + exp(alpha1 - alpha2 - 1))) + 252 | (1 / (1 + exp(-alpha1 + alpha2 - 1))); 253 | float both_correct_wgt = (1 / (1 + exp(alpha1 + alpha2 - 1))) / normalizer; 254 | float first_correct_wgt = (1 / (1 + exp(alpha1 - alpha2 - 1))) / normalizer; 255 | float second_correct_wgt = (1 / (1 + exp(-alpha1 + alpha2 - 1))) / normalizer; 256 | EXPECT_NEAR(both_correct_wgt, examples_[0].weight, kTolerance); 257 | EXPECT_NEAR(both_correct_wgt, examples_[1].weight, kTolerance); 258 | EXPECT_NEAR(both_correct_wgt, examples_[2].weight, kTolerance); 259 | EXPECT_NEAR(second_correct_wgt, examples_[3].weight, kTolerance); 260 | EXPECT_NEAR(first_correct_wgt, examples_[4].weight, kTolerance); 261 | } 262 | -------------------------------------------------------------------------------- /driver.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include "gflags/gflags.h" 18 | #include "glog/logging.h" 19 | #include "boost.h" 20 | #include "io.h" 21 | #include "types.h" 22 | 23 | DECLARE_int32(tree_depth); 24 | DECLARE_string(data_set); 25 | DECLARE_string(data_filename); 26 | DECLARE_int32(num_folds); 27 | DECLARE_int32(fold_to_cv); 28 | DECLARE_int32(fold_to_test); 29 | DECLARE_double(beta); 30 | DECLARE_double(lambda); 31 | DECLARE_string(loss_type); 32 | DEFINE_int32(num_iter, -1, 33 | "Number of boosting iterations. Required: num_iter >= 1."); 34 | DEFINE_int32(seed, -1, 35 | "Seed for random number generator. Required: seed >= 0."); 36 | 37 | void ValidateFlags() { 38 | CHECK_GE(FLAGS_tree_depth, 0); 39 | CHECK_GE(FLAGS_num_iter, 1); 40 | CHECK(!FLAGS_data_filename.empty()); 41 | CHECK(FLAGS_dataset == "breastcancer" || FLAGS_dataset == "ionosphere" || 42 | FLAGS_dataset == "ocr17-mnist" || FLAGS_dataset == "ocr49-mnist" || 43 | FLAGS_dataset == "splice" || FLAGS_dataset == "german" || 44 | FLAGS_dataset == "ocr17" || FLAGS_dataset == "ocr49" || 45 | FLAGS_dataset == "diabetes"); 46 | CHECK_GE(FLAGS_num_folds, 3); 47 | CHECK_GE(FLAGS_fold_to_cv, 0); 48 | CHECK_GE(FLAGS_fold_to_test, 0); 49 | CHECK_LE(FLAGS_fold_to_cv, FLAGS_num_folds - 1); 50 | CHECK_LE(FLAGS_fold_to_test, FLAGS_num_folds - 1); 51 | CHECK_GE(FLAGS_seed, 0); 52 | CHECK_GE(FLAGS_beta, 0.0); 53 | CHECK_GE(FLAGS_lambda, 0.0); 54 | CHECK(FLAGS_loss_type == "exponential" || FLAGS_loss_type == "logistic"); 55 | } 56 | 57 | int main(int argc, char** argv) { 58 | gflags::ParseCommandLineFlags(&argc, &argv, true); 59 | google::InitGoogleLogging(argv[0]); 60 | 61 | ValidateFlags(); 62 | 63 | SetSeed(FLAGS_seed); 64 | 65 | vector train_examples, cv_examples, test_examples; 66 | ReadData(&train_examples, &cv_examples, &test_examples); 67 | 68 | Model model; 69 | for (int iter = 1; iter <= FLAGS_num_iter; ++iter) { 70 | AddTreeToModel(train_examples, &model); 71 | // TODO(usyed): Evaluating every iteration might be very expensive. Add an 72 | // option to evaluate every K iterations, where K is a command-line 73 | // parameter. 74 | float cv_error, test_error, avg_tree_size; 75 | int num_trees; 76 | EvaluateModel(cv_examples, model, &cv_error, &avg_tree_size, 77 | &num_trees); 78 | EvaluateModel(test_examples, model, &test_error, &avg_tree_size, 79 | &num_trees); 80 | printf("Iteration: %d, test error: %g, cv error: %g, " 81 | "avg tree size: %g, num trees: %d\n", 82 | iter, test_error, cv_error, avg_tree_size, num_trees); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /io.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include "io.h" 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include "gflags/gflags.h" 24 | #include "glog/logging.h" 25 | 26 | DEFINE_string(data_set, "", 27 | "Name of data set. Required: One of breastcancer, ionosphere, " 28 | "ocr17, ocr49, ocr17-mnist, ocr49-mnist, diabetes, german."); 29 | DEFINE_string(data_filename, "", 30 | "Filename containing data. Required: data_filename not empty."); 31 | DEFINE_int32(num_folds, -1, 32 | "(num_folds - 2)/num_folds of data used for training, 1/num_folds " 33 | "of data used for cross-validation, 1/num_folds of data used for " 34 | "testing. Required: num_folds >= 3."); 35 | DEFINE_int32(fold_to_cv, -1, 36 | "Zero-indexed fold used for cross-validation. Required: " 37 | "0 <= fold_to_cv <= num_folds - 1."); 38 | DEFINE_int32(fold_to_test, -1, 39 | "Zero-indexed fold used for testing. Required: 0 <= fold_to_test " 40 | "<= num_folds - 1."); 41 | DEFINE_double(noise_prob, 0, 42 | "Noise probability. Required: 0 <= noise_prob <= 1."); 43 | 44 | static std::mt19937 rng; 45 | 46 | void SetSeed(uint_fast32_t seed) { rng.seed(seed); } 47 | 48 | void SplitString(const string &text, char sep, vector* tokens) { 49 | int start = 0, end = 0; 50 | string token; 51 | while ((end = text.find(sep, start)) != string::npos) { 52 | token = text.substr(start, end - start); 53 | if (!token.empty()) { 54 | tokens->push_back(token); 55 | } 56 | start = end + 1; 57 | } 58 | token = text.substr(start); 59 | if (!token.empty()) { 60 | tokens->push_back(token); 61 | } 62 | } 63 | 64 | bool ParseLineBreastCancer(const string& line, Example* example) { 65 | example->values.clear(); 66 | vector values; 67 | SplitString(line, ',', &values); 68 | for (int i = 0; i < values.size(); ++i) { 69 | if (i == 0) { 70 | continue; // Skip ID 71 | } else if (i == values.size() - 1) { 72 | if (values[i] == "2") { // Benign 73 | example->label = -1; 74 | } else if (values[i] == "4") { // Malignant 75 | example->label = +1; 76 | } else { 77 | LOG(FATAL) << "Unexpected label: " << values[i]; 78 | } 79 | } else if (values[i] == "?") { 80 | return false; 81 | } else { 82 | float value = atof(values[i].c_str()); 83 | example->values.push_back(value); 84 | } 85 | } 86 | return true; 87 | } 88 | 89 | bool ParseLineIon(const string& line, Example* example) { 90 | example->values.clear(); 91 | vector values; 92 | SplitString(line, ',', &values); 93 | for (int i = 0; i < values.size(); ++i) { 94 | if (i == values.size() - 1) { 95 | if (values[i] == "b") { // Bad 96 | example->label = -1; 97 | } else if (values[i] == "g") { // Good 98 | example->label = +1; 99 | } else { 100 | LOG(FATAL) << "Unexpected label: " << values[i]; 101 | } 102 | } else { 103 | float value = atof(values[i].c_str()); 104 | example->values.push_back(value); 105 | } 106 | } 107 | return true; 108 | } 109 | 110 | bool ParseLineGerman(const string& line, Example* example) { 111 | example->values.clear(); 112 | vector values; 113 | SplitString(line, ' ', &values); 114 | for (int i = 0; i < values.size(); ++i) { 115 | if (i == values.size() - 1) { 116 | if (values[i] == "1") { // Good 117 | example->label = -1; 118 | } else if (values[i] == "2") { // Bad 119 | example->label = +1; 120 | } else { 121 | LOG(FATAL) << "Unexpected label: " << values[i]; 122 | } 123 | } else { 124 | float value = atof(values[i].c_str()); 125 | example->values.push_back(value); 126 | } 127 | } 128 | return true; 129 | } 130 | 131 | bool ParseLineOcr17(const string& line, Example* example) { 132 | example->values.clear(); 133 | vector values; 134 | SplitString(line, ',', &values); 135 | for (int i = 0; i < values.size(); ++i) { 136 | if (i == values.size() - 1) { 137 | if (values[i] == "1") { // Digit 1 138 | example->label = -1; 139 | } else if (values[i] == "7") { // Digit 7 140 | example->label = +1; 141 | } else { 142 | return false; 143 | } 144 | } else { 145 | float value = atof(values[i].c_str()); 146 | example->values.push_back(value); 147 | } 148 | } 149 | return true; 150 | } 151 | 152 | bool ParseLineOcr49(const string& line, Example* example) { 153 | example->values.clear(); 154 | vector values; 155 | SplitString(line, ',', &values); 156 | for (int i = 0; i < values.size(); ++i) { 157 | if (i == values.size() - 1) { 158 | if (values[i] == "4") { // Digit 4 159 | example->label = -1; 160 | } else if (values[i] == "9") { // Digit 9 161 | example->label = +1; 162 | } else { 163 | return false; 164 | } 165 | } else { 166 | float value = atof(values[i].c_str()); 167 | example->values.push_back(value); 168 | } 169 | } 170 | return true; 171 | } 172 | 173 | bool ParseLineOcr17Princeton(const string& line, Example* example) { 174 | example->values.clear(); 175 | vector values; 176 | SplitString(line, ' ', &values); 177 | for (int i = 0; i < values.size(); ++i) { 178 | if (i == values.size() - 1) { 179 | if (values[i] == "1") { // Digit 1 180 | example->label = -1; 181 | } else if (values[i] == "7") { // Digit 7 182 | example->label = +1; 183 | } else { 184 | return false; 185 | } 186 | } else { 187 | float value = atof(values[i].c_str()); 188 | example->values.push_back(value); 189 | } 190 | } 191 | return true; 192 | } 193 | 194 | bool ParseLineOcr49Princeton(const string& line, Example* example) { 195 | example->values.clear(); 196 | vector values; 197 | SplitString(line, ' ', &values); 198 | for (int i = 0; i < values.size(); ++i) { 199 | if (i == values.size() - 1) { 200 | if (values[i] == "4") { // Digit 4 201 | example->label = -1; 202 | } else if (values[i] == "9") { // Digit 9 203 | example->label = +1; 204 | } else { 205 | return false; 206 | } 207 | } else { 208 | float value = atof(values[i].c_str()); 209 | example->values.push_back(value); 210 | } 211 | } 212 | return true; 213 | } 214 | 215 | bool ParseLinePima(const string& line, Example* example) { 216 | example->values.clear(); 217 | vector values; 218 | SplitString(line, ',', &values); 219 | for (int i = 0; i < values.size(); ++i) { 220 | if (i == values.size() - 1) { 221 | if (values[i] == "0") { 222 | example->label = -1; 223 | } else if (values[i] == "1") { 224 | example->label = +1; 225 | } else { 226 | LOG(FATAL) << "Unexpected label: " << values[i]; 227 | } 228 | } else { 229 | float value = atof(values[i].c_str()); 230 | example->values.push_back(value); 231 | } 232 | } 233 | return true; 234 | } 235 | 236 | 237 | void ReadData(vector* train_examples, 238 | vector* cv_examples, 239 | vector* test_examples) { 240 | train_examples->clear(); 241 | cv_examples->clear(); 242 | test_examples->clear(); 243 | vector examples; 244 | std::ifstream file(FLAGS_data_filename); 245 | CHECK(file.is_open()); 246 | string line; 247 | while (!std::getline(file, line).eof()) { 248 | Example example; 249 | bool keep_example; 250 | if (FLAGS_data_set == "breastcancer") { 251 | keep_example = ParseLineBreastCancer(line, &example); 252 | } else if (FLAGS_data_set == "ionosphere") { 253 | keep_example = ParseLineIon(line, &example); 254 | } else if (FLAGS_data_set == "german") { 255 | keep_example = ParseLineGerman(line, &example); 256 | } else if (FLAGS_data_set == "ocr17-mnist") { 257 | keep_example = ParseLineOcr17(line, &example); 258 | } else if (FLAGS_data_set == "ocr49-mnist") { 259 | keep_example = ParseLineOcr49(line, &example); 260 | } else if (FLAGS_data_set == "ocr17") { 261 | keep_example = ParseLineOcr17Princeton(line, &example); 262 | } else if (FLAGS_data_set == "ocr49") { 263 | keep_example = ParseLineOcr49Princeton(line, &example); 264 | } else if (FLAGS_data_set == "diabetes") { 265 | keep_example = ParseLinePima(line, &example); 266 | } else { 267 | LOG(FATAL) << "Unknown data set: " << FLAGS_data_set; 268 | } 269 | if (keep_example) examples.push_back(example); 270 | } 271 | std::shuffle(examples.begin(), examples.end(), rng); 272 | std::uniform_real_distribution dist; 273 | int fold = 0; 274 | // TODO(usyed): Two loops is inefficient 275 | for (Example& example : examples) { 276 | double r = dist(rng); 277 | if (r < FLAGS_noise_prob) { 278 | example.label = -example.label; 279 | } 280 | if (fold == FLAGS_fold_to_test) { 281 | test_examples->push_back(example); 282 | } else if (fold == FLAGS_fold_to_cv) { 283 | cv_examples->push_back(example); 284 | } else { 285 | train_examples->push_back(example); 286 | } 287 | ++fold; 288 | if (fold == FLAGS_num_folds) fold = 0; 289 | } 290 | const float initial_wgt = 1.0 / train_examples->size(); 291 | // TODO(usyed): Three loops is _really_ inefficient 292 | for (Example& example : *train_examples) { 293 | example.weight = initial_wgt; 294 | } 295 | return; 296 | } 297 | -------------------------------------------------------------------------------- /io.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #ifndef IO_H_ 18 | #define IO_H_ 19 | 20 | #include 21 | 22 | #include "types.h" 23 | 24 | using std::string; 25 | 26 | // Split text into the tokens vector, using sep as a delimiter. Consecutive 27 | // delimiters are ignored. 28 | void SplitString(const string &text, char sep, vector* tokens); 29 | 30 | void SetSeed(uint_fast32_t seed); 31 | 32 | // The following functions each parse one line of a data set. 33 | 34 | bool ParseLineBreastCancer(const string& line, Example* example); 35 | 36 | bool ParseLineIon(const string& line, Example* example); 37 | 38 | bool ParseLineGerman(const string& line, Example* example); 39 | 40 | bool ParseLineOcr49(const string& line, Example* example); 41 | 42 | bool ParseLineOcr17(const string& line, Example* example); 43 | 44 | bool ParseLineOcr49Princeton(const string& line, Example* example); 45 | 46 | bool ParseLineOcr17Princeton(const string& line, Example* example); 47 | 48 | bool ParseLinePima(const string& line, Example* example); 49 | 50 | // Read data set into training set, cross-validation set and test set. 51 | void ReadData(vector* train_examples, 52 | vector* cv_examples, 53 | vector* test_examples); 54 | 55 | #endif // IO_H_ 56 | -------------------------------------------------------------------------------- /io_test.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include "srm_test.h" 18 | #include "io.h" 19 | 20 | #include "gflags/gflags.h" 21 | #include "gtest/gtest.h" 22 | 23 | DECLARE_string(data_set); 24 | DECLARE_string(data_filename); 25 | DECLARE_int32(fold_to_cv); 26 | DECLARE_int32(fold_to_test); 27 | DECLARE_int32(num_folds); 28 | DECLARE_double(noise_prob); 29 | 30 | class IoTest : public SrmTest {}; 31 | 32 | TEST_F(IoTest, SplitStringTest) { 33 | string text = ",,2,3,,14,,"; 34 | vector tokens; 35 | SplitString(text, ',', &tokens); 36 | EXPECT_EQ(3, tokens.size()); 37 | EXPECT_EQ("2", tokens[0]); 38 | EXPECT_EQ("3", tokens[1]); 39 | EXPECT_EQ("14", tokens[2]); 40 | 41 | text = " 55 71 90 1 "; 42 | tokens.clear(); 43 | SplitString(text, ' ', &tokens); 44 | EXPECT_EQ(4, tokens.size()); 45 | EXPECT_EQ("55", tokens[0]); 46 | EXPECT_EQ("71", tokens[1]); 47 | EXPECT_EQ("90", tokens[2]); 48 | EXPECT_EQ("1", tokens[3]); 49 | } 50 | 51 | TEST_F(IoTest, ParseLineBreastCancerTest) { 52 | Example example; 53 | string line = "1000025,5,1,1,1,2,1,3,1,1,2"; 54 | EXPECT_TRUE(ParseLineBreastCancer(line, &example)); 55 | EXPECT_EQ(-1, example.label); 56 | EXPECT_EQ(9, example.values.size()); 57 | // Spot check features 58 | EXPECT_NEAR(5, example.values[0], kTolerance); 59 | EXPECT_NEAR(3, example.values[6], kTolerance); 60 | EXPECT_NEAR(1, example.values[8], kTolerance); 61 | line = "1017122,8,10,10,8,7,10,9,7,1,4"; 62 | EXPECT_TRUE(ParseLineBreastCancer(line, &example)); 63 | EXPECT_EQ(1, example.label); 64 | line = "1057013,8,4,5,1,2,?,7,3,1,4"; 65 | EXPECT_FALSE(ParseLineBreastCancer(line, &example)); 66 | } 67 | 68 | TEST_F(IoTest, ParseLineIonTest) { 69 | Example example; 70 | string line = 71 | "1,0,1,-0.15899,0.72314,0.27686,0.83443,-0.58388,1,-0.28207,1,-0.49863,0." 72 | "79962,-0.12527,0.76837,0.14638,1,0.39337,1,0.26590,0.96354,-0.01891,0." 73 | "92599,-0.91338,1,0.14803,1,-0.11582,1,-0.11129,1,0.53372,1,-0.57758,g"; 74 | EXPECT_TRUE(ParseLineIon(line, &example)); 75 | EXPECT_EQ(1, example.label); 76 | EXPECT_EQ(34, example.values.size()); 77 | // Spot check features 78 | EXPECT_NEAR(1, example.values[0], kTolerance); 79 | EXPECT_NEAR(0.83443, example.values[6], kTolerance); 80 | EXPECT_NEAR(-0.11129, example.values[29], kTolerance); 81 | line = 82 | "1,0,1,-0.18829,0.93035,-0.36156,-0.10868,-0.93597,1,-0.04549,0.50874,-0." 83 | "67743,0.34432,-0.69707,-0.51685,-0.97515,0.05499,-0.62237,0.33109,-1,-0." 84 | "13151,-0.45300,-0.18056,-0.35734,-0.20332,-0.26569,-0.20468,-0.18401,-0." 85 | "19040,-0.11593,-0.16626,-0.06288,-0.13738,-0.02447,b"; 86 | EXPECT_TRUE(ParseLineIon(line, &example)); 87 | EXPECT_EQ(-1, example.label); 88 | } 89 | 90 | TEST_F(IoTest, ParseLineGermanTest) { 91 | Example example; 92 | string line = 93 | " 2 48 2 60 1 3 2 2 1 22 3 1 1 1 1 0 0 " 94 | "1 0 0 1 0 0 1 2 "; 95 | EXPECT_TRUE(ParseLineGerman(line, &example)); 96 | EXPECT_EQ(1, example.label); 97 | EXPECT_EQ(24, example.values.size()); 98 | line = 99 | " 1 6 4 12 5 5 3 4 1 67 3 2 1 2 1 0 0 " 100 | "1 0 0 1 0 0 1 1 "; 101 | EXPECT_TRUE(ParseLineGerman(line, &example)); 102 | EXPECT_EQ(-1, example.label); 103 | // Spot check features 104 | EXPECT_NEAR(1, example.values[0], kTolerance); 105 | EXPECT_NEAR(3, example.values[6], kTolerance); 106 | EXPECT_NEAR(1, example.values[8], kTolerance); 107 | EXPECT_NEAR(67, example.values[9], kTolerance); 108 | } 109 | 110 | TEST_F(IoTest, ParseLineOcr17Test) { 111 | Example example; 112 | string line = 113 | "0,0,0,3,16,11,1,0,0,0,0,8,16,16,1,0,0,0,0,9,16,14,0,0,0,1,7,16,16,11,0," 114 | "0,0,9,16,16,16,8,0,0,0,1,8,6,16,7,0,0,0,0,0,5,16,9,0,0,0,0,0,2,14,14,1," 115 | "0,1"; 116 | EXPECT_TRUE(ParseLineOcr17(line, &example)); 117 | EXPECT_EQ(-1, example.label); 118 | EXPECT_EQ(64, example.values.size()); 119 | line = 120 | "0,0,8,15,16,13,0,0,0,1,11,9,11,16,1,0,0,0,0,0,7,14,0,0,0,0,3,4,14,12,2," 121 | "0,0,1,16,16,16,16,10,0,0,2,12,16,10,0,0,0,0,0,2,16,4,0,0,0,0,0,9,14,0,0," 122 | "0,0,7"; 123 | EXPECT_TRUE(ParseLineOcr17(line, &example)); 124 | EXPECT_EQ(1, example.label); 125 | // Spot check features 126 | EXPECT_NEAR(0, example.values[0], kTolerance); 127 | EXPECT_NEAR(0, example.values[6], kTolerance); 128 | EXPECT_NEAR(0, example.values[8], kTolerance); 129 | EXPECT_NEAR(9, example.values[58], kTolerance); 130 | line = 131 | "0,0,0,3,11,16,0,0,0,0,5,16,11,13,7,0,0,3,15,8,1,15,6,0,0,11,16,16,16,16," 132 | "10,0,0,1,4,4,13,10,2,0,0,0,0,0,15,4,0,0,0,0,0,3,16,0,0,0,0,0,0,1,15,2,0," 133 | "0,4"; 134 | EXPECT_FALSE(ParseLineOcr17(line, &example)); 135 | } 136 | 137 | TEST_F(IoTest, ParseLineOcr49Test) { 138 | Example example; 139 | string line = 140 | "0,0,0,3,11,16,0,0,0,0,5,16,11,13,7,0,0,3,15,8,1,15,6,0,0,11,16,16,16,16," 141 | "10,0,0,1,4,4,13,10,2,0,0,0,0,0,15,4,0,0,0,0,0,3,16,0,0,0,0,0,0,1,15,2,0," 142 | "0,4"; 143 | EXPECT_TRUE(ParseLineOcr49(line, &example)); 144 | EXPECT_EQ(-1, example.label); 145 | // Spot check features 146 | EXPECT_NEAR(0, example.values[0], kTolerance); 147 | EXPECT_NEAR(0, example.values[6], kTolerance); 148 | EXPECT_NEAR(0, example.values[8], kTolerance); 149 | EXPECT_NEAR(0, example.values[58], kTolerance); 150 | EXPECT_NEAR(15, example.values[60], kTolerance); 151 | EXPECT_EQ(64, example.values.size()); 152 | line = 153 | "0,0,0,4,13,16,16,3,0,0,8,16,9,12,16,4,0,7,16,3,3,15,13,0,0,9,15,14,16," 154 | "16,6,0,0,1,8,7,12,15,0,0,0,0,0,0,13,10,0,0,0,0,0,3,15,6,0,0,0,0,0,5,15," 155 | "4,0,0,9"; 156 | EXPECT_TRUE(ParseLineOcr49(line, &example)); 157 | EXPECT_EQ(1, example.label); 158 | line = 159 | "0,0,8,15,16,13,0,0,0,1,11,9,11,16,1,0,0,0,0,0,7,14,0,0,0,0,3,4,14,12,2," 160 | "0,0,1,16,16,16,16,10,0,0,2,12,16,10,0,0,0,0,0,2,16,4,0,0,0,0,0,9,14,0,0," 161 | "0,0,7"; 162 | EXPECT_FALSE(ParseLineOcr49(line, &example)); 163 | } 164 | 165 | TEST_F(IoTest, ParseLineOcr17PrincetonTest) { 166 | Example example; 167 | string line = 168 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 " 169 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1 3 3 3 3 3 3 0 0 0 0 0 " 170 | "0 0 3 2 0 0 0 3 1 0 0 0 0 0 0 0 0 0 0 0 3 2 0 0 0 0 0 0 0 0 0 0 0 2 3 0 " 171 | "0 0 0 0 0 0 0 0 0 0 2 3 0 0 0 0 0 0 0 0 0 0 0 1 3 0 0 0 0 0 0 0 0 0 0 0 " 172 | "0 3 1 0 0 0 0 0 0 0 0 0 0 0 2 2 0 0 0 0 0 0 0 0 0 0 0 1 3 0 0 0 0 0 0 0 " 173 | "0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 7"; 174 | EXPECT_TRUE(ParseLineOcr17Princeton(line, &example)); 175 | EXPECT_EQ(1, example.label); 176 | EXPECT_EQ(196, example.values.size()); 177 | line = 178 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 " 179 | "0 0 0 0 0 0 0 0 0 0 0 0 1 3 0 0 0 0 0 0 0 0 0 0 0 0 1 3 0 0 0 0 0 0 0 0 " 180 | "0 0 0 0 2 3 0 0 0 0 0 0 0 0 0 0 0 0 2 3 0 0 0 0 0 0 0 0 0 0 0 0 2 3 0 0 " 181 | "0 0 0 0 0 0 0 0 0 0 3 3 0 0 0 0 0 0 0 0 0 0 0 0 3 2 0 0 0 0 0 0 0 0 0 0 " 182 | "0 0 3 2 0 0 0 0 0 0 0 0 0 0 0 0 3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 " 183 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1"; 184 | EXPECT_TRUE(ParseLineOcr17Princeton(line, &example)); 185 | EXPECT_EQ(-1, example.label); 186 | // Spot check features 187 | EXPECT_NEAR(0, example.values[0], kTolerance); 188 | EXPECT_NEAR(0, example.values[6], kTolerance); 189 | EXPECT_NEAR(0, example.values[8], kTolerance); 190 | EXPECT_NEAR(3, example.values[35], kTolerance); 191 | line = 192 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 " 193 | "0 0 0 0 0 0 0 0 0 0 0 0 1 3 0 0 0 0 0 0 0 0 0 0 0 0 1 3 0 0 0 0 0 0 0 0 " 194 | "0 0 0 0 2 3 0 0 0 0 0 0 0 0 0 0 0 0 2 3 0 0 0 0 0 0 0 0 0 0 0 0 2 3 0 0 " 195 | "0 0 0 0 0 0 0 0 0 0 3 3 0 0 0 0 0 0 0 0 0 0 0 0 3 2 0 0 0 0 0 0 0 0 0 0 " 196 | "0 0 3 2 0 0 0 0 0 0 0 0 0 0 0 0 3 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 " 197 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4"; 198 | EXPECT_FALSE(ParseLineOcr17Princeton(line, &example)); 199 | } 200 | 201 | TEST_F(IoTest, ParseLineOcr49PrincetonTest) { 202 | Example example; 203 | string line = 204 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 " 205 | "0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 2 3 3 3 1 0 0 0 0 0 0 " 206 | "0 0 2 3 1 0 2 3 0 0 0 0 0 0 0 0 2 1 0 1 3 3 0 0 0 0 0 0 0 0 0 2 3 3 3 2 " 207 | "0 0 0 0 0 0 0 0 0 0 1 3 3 0 0 0 0 0 0 0 0 0 0 0 1 3 1 0 0 0 0 0 0 0 0 0 " 208 | "0 0 2 3 0 0 0 0 0 0 0 0 0 0 0 0 3 1 0 0 0 0 0 0 0 0 0 0 0 2 3 0 0 0 0 0 " 209 | "0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 9"; 210 | EXPECT_TRUE(ParseLineOcr49Princeton(line, &example)); 211 | EXPECT_EQ(1, example.label); 212 | EXPECT_EQ(196, example.values.size()); 213 | line = 214 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 " 215 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 1 0 0 0 0 0 0 0 2 0 0 1 3 2 0 0 0 0 " 216 | "0 0 1 3 1 0 2 3 2 0 0 0 0 0 0 2 3 2 0 2 3 2 1 1 0 0 0 0 1 3 3 3 3 3 3 3 " 217 | "3 2 0 0 0 0 0 2 2 3 3 3 1 0 0 0 0 0 0 0 0 0 0 3 2 0 0 0 0 0 0 0 0 0 0 0 " 218 | "0 3 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 " 219 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4"; 220 | EXPECT_TRUE(ParseLineOcr49Princeton(line, &example)); 221 | EXPECT_EQ(-1, example.label); 222 | // Spot check features 223 | EXPECT_NEAR(0, example.values[0], kTolerance); 224 | EXPECT_NEAR(0, example.values[6], kTolerance); 225 | EXPECT_NEAR(0, example.values[8], kTolerance); 226 | EXPECT_NEAR(1, example.values[54], kTolerance); 227 | EXPECT_NEAR(0, example.values[58], kTolerance); 228 | line = 229 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 " 230 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 1 0 0 0 0 0 0 0 2 0 0 1 3 2 0 0 0 0 " 231 | "0 0 1 3 1 0 2 3 2 0 0 0 0 0 0 2 3 2 0 2 3 2 1 1 0 0 0 0 1 3 3 3 3 3 3 3 " 232 | "3 2 0 0 0 0 0 2 2 3 3 3 1 0 0 0 0 0 0 0 0 0 0 3 2 0 0 0 0 0 0 0 0 0 0 0 " 233 | "0 3 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 " 234 | "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7"; 235 | EXPECT_FALSE(ParseLineOcr49Princeton(line, &example)); 236 | } 237 | 238 | TEST_F(IoTest, ParseLinePimaTest) { 239 | Example example; 240 | string line = "6,148,72,35,0,33.6,0.627,50,1"; 241 | EXPECT_TRUE(ParseLinePima(line, &example)); 242 | EXPECT_EQ(1, example.label); 243 | EXPECT_EQ(8, example.values.size()); 244 | // Spot check features 245 | EXPECT_NEAR(6, example.values[0], kTolerance); 246 | EXPECT_NEAR(0.627, example.values[6], kTolerance); 247 | line = "1,85,66,29,0,26.6,0.351,31,0"; 248 | EXPECT_TRUE(ParseLinePima(line, &example)); 249 | EXPECT_EQ(-1, example.label); 250 | } 251 | 252 | TEST_F(IoTest, ReadDataTest) { 253 | FLAGS_data_set = "breastcancer"; 254 | FLAGS_data_filename = "./testdata/breast-cancer-wisconsin.data"; 255 | FLAGS_num_folds = 4; 256 | FLAGS_fold_to_cv = 1; 257 | FLAGS_fold_to_test = 0; 258 | SetSeed(123456); 259 | 260 | vector train_examples, cv_examples, test_examples; 261 | ReadData(&train_examples, &cv_examples, &test_examples); 262 | EXPECT_EQ(2, train_examples.size()); 263 | EXPECT_EQ(1, cv_examples.size()); 264 | EXPECT_EQ(1, test_examples.size()); 265 | EXPECT_NEAR(0.5, train_examples[0].weight, kTolerance); 266 | EXPECT_NEAR(0.5, train_examples[1].weight, kTolerance); 267 | } 268 | 269 | TEST_F(IoTest, ReadDataTestWithNoise) { 270 | FLAGS_data_set = "breastcancer"; 271 | FLAGS_data_filename = "./testdata/breast-cancer-wisconsin.data"; 272 | FLAGS_num_folds = 4; 273 | FLAGS_fold_to_cv = 1; 274 | FLAGS_fold_to_test = 0; 275 | FLAGS_noise_prob = 0; 276 | SetSeed(123456); 277 | 278 | vector train_examples, cv_examples, test_examples; 279 | ReadData(&train_examples, &cv_examples, &test_examples); 280 | Label train_label_0 = train_examples[0].label; 281 | Label train_label_1 = train_examples[1].label; 282 | Label cv_label_0 = cv_examples[0].label; 283 | Label test_label_0 = test_examples[0].label; 284 | 285 | FLAGS_noise_prob = 1; 286 | ReadData(&train_examples, &cv_examples, &test_examples); 287 | EXPECT_EQ(-train_label_0, train_examples[0].label); 288 | EXPECT_EQ(-train_label_1, train_examples[1].label); 289 | EXPECT_EQ(-cv_label_0, cv_examples[0].label); 290 | EXPECT_EQ(-test_label_0, test_examples[0].label); 291 | 292 | FLAGS_noise_prob = 0.5; 293 | const int kIterations = 100; 294 | double sum_labels = 0.0; 295 | for (int i = 0; i < kIterations; ++i) { 296 | ReadData(&train_examples, &cv_examples, &test_examples); 297 | sum_labels += train_examples[0].label; 298 | sum_labels += train_examples[1].label; 299 | sum_labels += cv_examples[0].label; 300 | sum_labels += test_examples[0].label; 301 | } 302 | // The average of uniformly random +1/-1 labels should be about 0 303 | EXPECT_NEAR(0, sum_labels / (4 * kIterations), 1e-2); 304 | } 305 | -------------------------------------------------------------------------------- /srm_test.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #ifndef SRM_TEST_H_ 18 | #define SRM_TEST_H_ 19 | 20 | #include "types.h" 21 | #include "gtest/gtest.h" 22 | 23 | class SrmTest : public ::testing::Test { 24 | protected: 25 | virtual void SetUp() { 26 | // Three positive examples, two negative examples, three features. Best 27 | // split for feature 0 is useless, and for the other features divides 28 | // examples into the ratios +3/-1 (left) and +0/-1 (right). Two splits (on 29 | // features 1 and 2) perfectly classify all examples. 30 | // TODO(usyed): For more interesting tests, weights should be different from 31 | // uniform. 32 | Example examples_arr[5]; 33 | examples_arr[0].values = {1.0, 0.1, 11.0}; 34 | examples_arr[0].label = 1; 35 | examples_arr[0].weight = 0.2; 36 | examples_arr[1].values = {3.0, 0.3, 11.0}; 37 | examples_arr[1].label = 1; 38 | examples_arr[1].weight = 0.2; 39 | examples_arr[2].values = {5.0, 0.4, 11.0}; 40 | examples_arr[2].label = 1; 41 | examples_arr[2].weight = 0.2; 42 | examples_arr[3].values = {2.0, 0.2, 22.0}; 43 | examples_arr[3].label = -1; 44 | examples_arr[3].weight = 0.2; 45 | examples_arr[4].values = {4.0, 0.5, 11.0}; 46 | examples_arr[4].label = -1; 47 | examples_arr[4].weight = 0.2; 48 | examples_.assign(examples_arr, examples_arr + 5); 49 | } 50 | 51 | vector examples_; 52 | }; 53 | 54 | #endif // SRM_TEST_H_ 55 | -------------------------------------------------------------------------------- /testdata/breast-cancer-wisconsin.data: -------------------------------------------------------------------------------- 1 | 1000025,5,1,1,1,2,1,3,1,1,2 2 | 1057013,8,4,5,1,2,?,7,3,1,4 3 | 1002945,5,4,4,5,7,10,3,2,1,2 4 | 1015425,3,1,1,1,2,2,3,1,1,2 5 | 1033078,4,2,1,1,2,1,2,1,1,2 6 | -------------------------------------------------------------------------------- /tree.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include 18 | 19 | #include "tree.h" 20 | 21 | #include "gflags/gflags.h" 22 | #include "glog/logging.h" 23 | 24 | DEFINE_double(beta, -1.0, "beta parameter for gradient."); 25 | DEFINE_double(lambda, -1.0, "lambda parameter for gradient."); 26 | DEFINE_int32(tree_depth, -1, 27 | "Maximum depth of each decision tree. The root node has depth 0. " 28 | "Required: tree_depth >= 0."); 29 | 30 | // TODO(usyed): Global variables are bad style. 31 | static int num_features; 32 | static int num_examples; 33 | static float the_normalizer; 34 | static bool is_initialized = false; 35 | 36 | void InitializeTreeData(const vector& examples, float normalizer) { 37 | CHECK_GE(examples.size(), 1); 38 | num_examples = examples.size(); 39 | num_features = examples[0].values.size(); 40 | the_normalizer = normalizer; 41 | is_initialized = true; 42 | } 43 | 44 | Node MakeRootNode(const vector& examples) { 45 | Node root; 46 | root.examples = examples; 47 | root.positive_weight = root.negative_weight = 0; 48 | for (const Example& example : examples) { 49 | if (example.label == 1) { 50 | root.positive_weight += example.weight; 51 | } else { // label == -1 52 | root.negative_weight += example.weight; 53 | } 54 | } 55 | root.leaf = true; 56 | root.depth = 0; 57 | return root; 58 | } 59 | 60 | map> MakeValueToWeightsMap(const Node& node, 61 | Feature feature) { 62 | map> value_to_weights; 63 | for (const Example& example : node.examples) { 64 | if (example.label == 1) { 65 | value_to_weights[example.values[feature]].first += example.weight; 66 | } else { // label = -1 67 | value_to_weights[example.values[feature]].second += example.weight; 68 | } 69 | } 70 | return value_to_weights; 71 | } 72 | 73 | void BestSplitValue(const map>& value_to_weights, 74 | const Node& node, int tree_size, Value* split_value, 75 | float* delta_gradient) { 76 | *delta_gradient = 0; 77 | Weight left_positive_weight = 0, left_negative_weight = 0, 78 | right_positive_weight = node.positive_weight, 79 | right_negative_weight = node.negative_weight; 80 | float old_error = fmin(left_positive_weight + right_positive_weight, 81 | left_negative_weight + right_negative_weight); 82 | float old_gradient = Gradient(old_error, tree_size, 0, -1); 83 | for (const pair>& elem : value_to_weights) { 84 | left_positive_weight += elem.second.first; 85 | right_positive_weight -= elem.second.first; 86 | left_negative_weight += elem.second.second; 87 | right_negative_weight -= elem.second.second; 88 | float new_error = fmin(left_positive_weight, left_negative_weight) + 89 | fmin(right_positive_weight, right_negative_weight); 90 | float new_gradient = Gradient(new_error, tree_size + 2, 0, -1); 91 | if (fabs(new_gradient) - fabs(old_gradient) > 92 | *delta_gradient + kTolerance) { 93 | *delta_gradient = fabs(new_gradient) - fabs(old_gradient); 94 | *split_value = elem.first; 95 | } 96 | } 97 | } 98 | 99 | void MakeChildNodes(Feature split_feature, Value split_value, Node* parent, 100 | Tree* tree) { 101 | parent->split_feature = split_feature; 102 | parent->split_value = split_value; 103 | parent->leaf = false; 104 | Node left_child, right_child; 105 | left_child.depth = right_child.depth = parent->depth + 1; 106 | left_child.leaf = right_child.leaf = true; 107 | left_child.positive_weight = left_child.negative_weight = 108 | right_child.positive_weight = right_child.negative_weight = 0; 109 | for (const Example& example : parent->examples) { 110 | Node* child; 111 | if (example.values[split_feature] <= split_value) { 112 | child = &left_child; 113 | } else { 114 | child = &right_child; 115 | } 116 | // TODO(usyed): Moving examples around is inefficient. 117 | child->examples.push_back(example); 118 | if (example.label == 1) { 119 | child->positive_weight += example.weight; 120 | } else { // label == -1 121 | child->negative_weight += example.weight; 122 | } 123 | } 124 | parent->left_child_id = tree->size(); 125 | parent->right_child_id = tree->size() + 1; 126 | tree->push_back(left_child); 127 | tree->push_back(right_child); 128 | } 129 | 130 | Tree TrainTree(const vector& examples) { 131 | CHECK(is_initialized); 132 | Tree tree; 133 | tree.push_back(MakeRootNode(examples)); 134 | NodeId node_id = 0; 135 | while (node_id < tree.size()) { 136 | Node& node = tree[node_id]; // TODO(usyed): Too bad this can't be const. 137 | Feature best_split_feature; 138 | Value best_split_value; 139 | float best_delta_gradient = 0; 140 | for (Feature split_feature = 0; split_feature < num_features; 141 | ++split_feature) { 142 | const map> value_to_weights = 143 | MakeValueToWeightsMap(node, split_feature); 144 | Value split_value; 145 | float delta_gradient; 146 | BestSplitValue(value_to_weights, node, tree.size(), &split_value, 147 | &delta_gradient); 148 | if (delta_gradient > best_delta_gradient + kTolerance) { 149 | best_delta_gradient = delta_gradient; 150 | best_split_feature = split_feature; 151 | best_split_value = split_value; 152 | } 153 | } 154 | if (node.depth < FLAGS_tree_depth && best_delta_gradient > kTolerance) { 155 | MakeChildNodes(best_split_feature, best_split_value, &node, &tree); 156 | } 157 | ++node_id; 158 | } 159 | return tree; 160 | } 161 | 162 | Label ClassifyExample(const Example& example, const Tree& tree) { 163 | CHECK_GE(tree.size(), 1); 164 | const Node* node = &tree[0]; 165 | while (node->leaf == false) { 166 | if (example.values[node->split_feature] <= node->split_value) { 167 | node = &tree[node->left_child_id]; 168 | } else { 169 | node = &tree[node->right_child_id]; 170 | } 171 | } 172 | if (node->positive_weight >= node->negative_weight) { 173 | return 1; 174 | } else { 175 | return -1; 176 | } 177 | } 178 | 179 | float Gradient(float wgtd_error, int tree_size, float alpha, int sign_edge) { 180 | // TODO(usyed): Can we make some mild assumptions and get rid of sign_edge? 181 | const float complexity_penalty = ComplexityPenalty(tree_size); 182 | const float edge = wgtd_error - 0.5; 183 | const int sign_alpha = (alpha >= 0) ? 1 : -1; 184 | if (fabs(alpha) > kTolerance) { 185 | return edge + sign_alpha * complexity_penalty; 186 | } else if (fabs(edge) <= complexity_penalty) { 187 | return 0; 188 | } else { 189 | return edge - sign_edge * complexity_penalty; 190 | } 191 | } 192 | 193 | float EvaluateTreeWgtd(const vector& examples, const Tree& tree) { 194 | float wgtd_error = 0; 195 | for (const Example& example : examples) { 196 | if (ClassifyExample(example, tree) != example.label) { 197 | wgtd_error += example.weight; 198 | } 199 | } 200 | return wgtd_error; 201 | } 202 | 203 | float ComplexityPenalty(int tree_size) { 204 | CHECK(is_initialized); 205 | float rademacher = 206 | sqrt(((2 * tree_size + 1) * (log(num_features + 2) / log(2)) * 207 | log(num_examples)) / 208 | num_examples); 209 | return ((FLAGS_lambda * rademacher + FLAGS_beta) * num_examples) / 210 | (2 * the_normalizer); 211 | } 212 | -------------------------------------------------------------------------------- /tree.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #ifndef TREE_H_ 18 | #define TREE_H_ 19 | 20 | #include "types.h" 21 | 22 | // Initialize some global variables. 23 | void InitializeTreeData(const vector& examples, float normalizer); 24 | 25 | // Return root node for a tree. 26 | Node MakeRootNode(const vector& examples); 27 | 28 | // Return a tree trained on examples. 29 | Tree TrainTree(const vector& examples); 30 | 31 | // Make child nodes using split feature/value and add them to the tree. Also 32 | // update info in the parent node, like child pointers. 33 | void MakeChildNodes(Feature split_feature, Value split_value, Node* parent, 34 | Tree* tree); 35 | 36 | // Return a map from each value of feature to a pair of weights. The first 37 | // weight in the pair is the total weight of positive examples at node that have 38 | // that value for feature, and the second weight in the pair is the total weight 39 | // of negative examples at node that have that value for feature. This map is 40 | // used to determine the best split feature/value. 41 | map> MakeValueToWeightsMap(const Node& node, 42 | Feature feature); 43 | 44 | // Given a value-to-weights map for a feature (constructed by 45 | // MakeValueToWeightsMap()), determine the best split value for the feature and 46 | // the improvement in the gradient of the objective if we split on that value. 47 | // Note that delta_gradient <= 0 indicates that we should not split on this 48 | // feature. 49 | void BestSplitValue(const map>& value_to_weights, 50 | const Node& node, int tree_size, Value* split_value, 51 | float* delta_gradient); 52 | 53 | // Given an example and a tree, classify the example with the tree. 54 | // NB: This function assumes that if an example has a feature value that is 55 | // _less than or equal to_ a node's split value then the example should be sent 56 | // to the left child, and otherwise sent to the right child. 57 | Label ClassifyExample(const Example& example, const Tree& tree); 58 | 59 | // Return the (sub)gradient of the objective with respect to a tree. 60 | float Gradient(float wgtd_error, int tree_size, float alpha, int sign_edge); 61 | 62 | // Given a set of examples and a tree, return the weighted error of tree on 63 | // the examples. 64 | float EvaluateTreeWgtd(const vector& examples, const Tree& tree); 65 | 66 | // Return complexity penalty. 67 | float ComplexityPenalty(int tree_size); 68 | 69 | #endif // TREE_H_ 70 | -------------------------------------------------------------------------------- /tree_test.cc: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include "srm_test.h" 18 | #include "tree.h" 19 | 20 | #include "gflags/gflags.h" 21 | #include "gtest/gtest.h" 22 | 23 | DECLARE_int32(tree_depth); 24 | DECLARE_double(beta); 25 | DECLARE_double(lambda); 26 | 27 | class TreeTest : public SrmTest { 28 | protected: 29 | virtual void SetUp() { 30 | SrmTest::SetUp(); 31 | InitializeTreeData(examples_, examples_.size()); 32 | } 33 | }; 34 | 35 | TEST_F(TreeTest, TestMakeRootNode) { 36 | Node root = MakeRootNode(examples_); 37 | EXPECT_EQ(5, root.examples.size()); 38 | EXPECT_NEAR(0.6, root.positive_weight, kTolerance); 39 | EXPECT_NEAR(0.4, root.negative_weight, kTolerance); 40 | EXPECT_TRUE(root.leaf); 41 | EXPECT_EQ(0, root.depth); 42 | } 43 | 44 | TEST_F(TreeTest, TestMakeValueToWeightsMap) { 45 | Node root = MakeRootNode(examples_); 46 | map> value_to_weights; 47 | 48 | // Sort by first feature 49 | value_to_weights = MakeValueToWeightsMap(root, 0); 50 | vector values_for_0 = {1.0, 2.0, 3.0, 4.0, 5.0}; 51 | vector positive_weights_for_0 = {0.2, 0.0, 0.2, 0.0, 0.2}; 52 | vector negative_weights_for_0 = {0.0, 0.2, 0.0, 0.2, 0.0}; 53 | int i = 0; 54 | for (const pair>& elem : value_to_weights) { 55 | EXPECT_NEAR(values_for_0[i], elem.first, kTolerance); 56 | EXPECT_NEAR(positive_weights_for_0[i], elem.second.first, kTolerance); 57 | EXPECT_NEAR(negative_weights_for_0[i], elem.second.second, kTolerance); 58 | ++i; 59 | } 60 | 61 | // Sort by second feature 62 | value_to_weights = MakeValueToWeightsMap(root, 1); 63 | vector values_for_1 = {0.1, 0.2, 0.3, 0.4, 0.5}; 64 | vector positive_weights_for_1 = {0.2, 0.0, 0.2, 0.2, 0.0}; 65 | vector negative_weights_for_1 = {0.0, 0.2, 0.0, 0.0, 0.2}; 66 | i = 0; 67 | for (const pair>& elem : value_to_weights) { 68 | EXPECT_NEAR(values_for_1[i], elem.first, kTolerance); 69 | EXPECT_NEAR(positive_weights_for_1[i], elem.second.first, kTolerance); 70 | EXPECT_NEAR(negative_weights_for_1[i], elem.second.second, kTolerance); 71 | ++i; 72 | } 73 | } 74 | 75 | TEST_F(TreeTest, TestBestSplitValue) { 76 | Node root = MakeRootNode(examples_); 77 | map> value_to_weights; 78 | Value split_value; 79 | float delta_gradient; 80 | 81 | FLAGS_tree_depth = 1; 82 | FLAGS_lambda = 0; 83 | FLAGS_beta = 0; 84 | 85 | // Split on first feature, which is useless. 86 | value_to_weights = MakeValueToWeightsMap(root, 0); 87 | BestSplitValue(value_to_weights, root, 1, &split_value, &delta_gradient); 88 | EXPECT_NEAR(0, delta_gradient, kTolerance); 89 | 90 | // Split on second feature, which is useful. 91 | value_to_weights = MakeValueToWeightsMap(root, 1); 92 | BestSplitValue(value_to_weights, root, 1, &split_value, &delta_gradient); 93 | EXPECT_NEAR(0.2, delta_gradient, kTolerance); 94 | EXPECT_NEAR(0.4, split_value, kTolerance); 95 | 96 | // Don't split on second feature if complexity penalty is very high. 97 | FLAGS_lambda = 100; 98 | value_to_weights = MakeValueToWeightsMap(root, 1); 99 | BestSplitValue(value_to_weights, root, 1, &split_value, &delta_gradient); 100 | EXPECT_NEAR(delta_gradient, 0, kTolerance); 101 | } 102 | 103 | TEST_F(TreeTest, TestMakeChildNodes) { 104 | Node root = MakeRootNode(examples_); 105 | Tree tree; 106 | 107 | tree.push_back(root); 108 | MakeChildNodes(0, 3.0, &tree[0], &tree); 109 | EXPECT_EQ(3, tree.size()); 110 | // Check root node 111 | EXPECT_EQ(5, tree[0].examples.size()); 112 | EXPECT_EQ(0, tree[0].split_feature); 113 | EXPECT_EQ(1, tree[0].left_child_id); 114 | EXPECT_EQ(2, tree[0].right_child_id); 115 | EXPECT_NEAR(3.0, tree[0].split_value, kTolerance); 116 | EXPECT_NEAR(0.6, tree[0].positive_weight, kTolerance); 117 | EXPECT_NEAR(0.4, tree[0].negative_weight, kTolerance); 118 | EXPECT_FALSE(tree[0].leaf); 119 | EXPECT_EQ(0, tree[0].depth); 120 | // Check left child node 121 | EXPECT_EQ(3, tree[1].examples.size()); 122 | EXPECT_NEAR(0.4, tree[1].positive_weight, kTolerance); 123 | EXPECT_NEAR(0.2, tree[1].negative_weight, kTolerance); 124 | EXPECT_TRUE(tree[1].leaf); 125 | EXPECT_EQ(1, tree[1].depth); 126 | // Check right child node 127 | EXPECT_EQ(2, tree[2].examples.size()); 128 | EXPECT_NEAR(0.2, tree[2].positive_weight, kTolerance); 129 | EXPECT_NEAR(0.2, tree[2].negative_weight, kTolerance); 130 | EXPECT_TRUE(tree[2].leaf); 131 | EXPECT_EQ(1, tree[2].depth); 132 | 133 | tree.clear(); 134 | tree.push_back(root); 135 | MakeChildNodes(1, 0.4, &tree[0], &tree); 136 | EXPECT_EQ(3, tree.size()); 137 | // Check root node 138 | EXPECT_EQ(5, tree[0].examples.size()); 139 | EXPECT_EQ(1, tree[0].split_feature); 140 | EXPECT_NEAR(0.4, tree[0].split_value, kTolerance); 141 | EXPECT_EQ(1, tree[0].left_child_id); 142 | EXPECT_EQ(2, tree[0].right_child_id); 143 | EXPECT_NEAR(0.6, tree[0].positive_weight, kTolerance); 144 | EXPECT_NEAR(0.4, tree[0].negative_weight, kTolerance); 145 | EXPECT_FALSE(tree[0].leaf); 146 | EXPECT_EQ(0, tree[0].depth); 147 | // Check left child node 148 | EXPECT_EQ(4, tree[1].examples.size()); 149 | EXPECT_NEAR(0.6, tree[1].positive_weight, kTolerance); 150 | EXPECT_NEAR(0.2, tree[1].negative_weight, kTolerance); 151 | EXPECT_TRUE(tree[1].leaf); 152 | EXPECT_EQ(1, tree[1].depth); 153 | // Check right child node 154 | EXPECT_EQ(1, tree[2].examples.size()); 155 | EXPECT_NEAR(0.0, tree[2].positive_weight, kTolerance); 156 | EXPECT_NEAR(0.2, tree[2].negative_weight, kTolerance); 157 | EXPECT_TRUE(tree[2].leaf); 158 | EXPECT_EQ(1, tree[2].depth); 159 | } 160 | 161 | TEST_F(TreeTest, TestTrainTree) { 162 | FLAGS_beta = 0; 163 | FLAGS_lambda = 0; 164 | 165 | FLAGS_tree_depth = 1; 166 | Tree tree = TrainTree(examples_); 167 | EXPECT_EQ(3, tree.size()); 168 | 169 | FLAGS_tree_depth = 2; 170 | tree = TrainTree(examples_); 171 | EXPECT_EQ(5, tree.size()); 172 | 173 | // Check all the nodes 174 | // Node 0 175 | EXPECT_EQ(1, tree[0].split_feature); 176 | EXPECT_NEAR(0.4, tree[0].split_value, kTolerance); 177 | EXPECT_EQ(1, tree[0].left_child_id); 178 | EXPECT_EQ(2, tree[0].right_child_id); 179 | EXPECT_NEAR(0.6, tree[0].positive_weight, kTolerance); 180 | EXPECT_NEAR(0.4, tree[0].negative_weight, kTolerance); 181 | EXPECT_FALSE(tree[0].leaf); 182 | EXPECT_EQ(0, tree[0].depth); 183 | // Node 1 184 | EXPECT_EQ(2, tree[1].split_feature); 185 | EXPECT_NEAR(11.0, tree[1].split_value, kTolerance); 186 | EXPECT_EQ(3, tree[1].left_child_id); 187 | EXPECT_EQ(4, tree[1].right_child_id); 188 | EXPECT_NEAR(0.6, tree[1].positive_weight, kTolerance); 189 | EXPECT_NEAR(0.2, tree[1].negative_weight, kTolerance); 190 | EXPECT_FALSE(tree[1].leaf); 191 | EXPECT_EQ(1, tree[1].depth); 192 | // Node 2 193 | EXPECT_NEAR(0.0, tree[2].positive_weight, kTolerance); 194 | EXPECT_NEAR(0.2, tree[2].negative_weight, kTolerance); 195 | EXPECT_TRUE(tree[2].leaf); 196 | EXPECT_EQ(1, tree[2].depth); 197 | // Node 3 198 | EXPECT_NEAR(0.6, tree[3].positive_weight, kTolerance); 199 | EXPECT_NEAR(0.0, tree[3].negative_weight, kTolerance); 200 | EXPECT_TRUE(tree[3].leaf); 201 | EXPECT_EQ(2, tree[3].depth); 202 | // Node 4 203 | EXPECT_NEAR(0.0, tree[4].positive_weight, kTolerance); 204 | EXPECT_NEAR(0.2, tree[4].negative_weight, kTolerance); 205 | EXPECT_TRUE(tree[4].leaf); 206 | EXPECT_EQ(2, tree[4].depth); 207 | 208 | // Very high complexity penalty causes tree to never split 209 | FLAGS_lambda = 100; 210 | tree = TrainTree(examples_); 211 | EXPECT_EQ(1, tree.size()); 212 | } 213 | 214 | TEST_F(TreeTest, TestComplexityPenalty) { 215 | FLAGS_beta = 1; 216 | FLAGS_lambda = 1; 217 | 218 | float complexity_penalty = ComplexityPenalty(10); 219 | EXPECT_NEAR(2.48087078356, complexity_penalty, kTolerance); 220 | } 221 | 222 | TEST_F(TreeTest, GradientTest) { 223 | FLAGS_beta = 0; 224 | FLAGS_lambda = 0; 225 | 226 | float gradient = Gradient(0.25, 100, 4, -1); 227 | EXPECT_NEAR(0.25 - 0.5, gradient, kTolerance); 228 | 229 | FLAGS_beta = 1; 230 | FLAGS_lambda = 1; 231 | 232 | gradient = Gradient(0.25, 10, 1, 1); 233 | EXPECT_NEAR(0.25 - 0.5 + ComplexityPenalty(10), gradient, kTolerance); 234 | 235 | gradient = Gradient(0.25, 10, -1, 1); 236 | EXPECT_NEAR(0.25 - 0.5 - ComplexityPenalty(10), gradient, kTolerance); 237 | 238 | gradient = Gradient(0.25, 10, 0, 1); 239 | EXPECT_NEAR(0, gradient, kTolerance); 240 | 241 | FLAGS_beta = 0; 242 | FLAGS_lambda = 0.1; 243 | 244 | gradient = Gradient(0.2, 10, 0, 1); 245 | EXPECT_NEAR(0.2 - 0.5 - ComplexityPenalty(10), gradient, kTolerance); 246 | 247 | gradient = Gradient(0.2, 10, 0, -1); 248 | EXPECT_NEAR(0.2 - 0.5 + ComplexityPenalty(10), gradient, kTolerance); 249 | } 250 | 251 | TEST_F(TreeTest, TestClassifyExample) { 252 | FLAGS_beta = 0; 253 | FLAGS_lambda = 0; 254 | FLAGS_tree_depth = 2; 255 | Tree tree = TrainTree(examples_); 256 | 257 | EXPECT_EQ(1, ClassifyExample(examples_[0], tree)); 258 | EXPECT_EQ(1, ClassifyExample(examples_[1], tree)); 259 | EXPECT_EQ(1, ClassifyExample(examples_[2], tree)); 260 | EXPECT_EQ(-1, ClassifyExample(examples_[3], tree)); 261 | EXPECT_EQ(-1, ClassifyExample(examples_[4], tree)); 262 | } 263 | 264 | TEST_F(TreeTest, TestEvaluateTreeWgtd) { 265 | FLAGS_beta = 0; 266 | FLAGS_lambda = 0; 267 | FLAGS_tree_depth = 1; 268 | Tree tree = TrainTree(examples_); 269 | EXPECT_NEAR(0.2, EvaluateTreeWgtd(examples_, tree), kTolerance); 270 | } 271 | -------------------------------------------------------------------------------- /types.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2015 Google Inc. All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #ifndef TYPES_H_ 18 | #define TYPES_H_ 19 | 20 | #include 21 | #include 22 | 23 | using std::map; 24 | using std::pair; 25 | using std::vector; 26 | 27 | // Used in many places as the minimum possible difference between two distinct 28 | // numbers. Helps make code stable, tests predictable, etc. 29 | static const float kTolerance = 1e-7; 30 | 31 | typedef int Feature; 32 | typedef int Label; 33 | typedef int NodeId; 34 | typedef float Value; 35 | typedef float Weight; 36 | 37 | // An example consists of a vector of feature values, a label and a weight. 38 | // Note that this is a dense feature representation; the value of every 39 | // feature is contained in the vector, listed in a canonical order. 40 | typedef struct Example { 41 | vector values; 42 | Label label; 43 | Weight weight; 44 | } Example; 45 | 46 | // A tree node. 47 | typedef struct Node { 48 | vector examples; // Examples at this node. 49 | Feature split_feature; // Split feature. 50 | Value split_value; // Split value. 51 | NodeId left_child_id; // Pointer to left child, if any. 52 | NodeId right_child_id; // Pointer to right child, if any. 53 | Weight positive_weight; // Total weight of positive examples at this node. 54 | Weight negative_weight; // Total weight of negative examples at this node. 55 | bool leaf; // Is this node is a leaf? 56 | int depth; // Depth of the node in the tree. Root node has depth 0. 57 | } Node; 58 | 59 | // A tree is a vector of nodes. 60 | typedef vector Tree; 61 | 62 | // A model is a vector of (weight, tree) pairs, i.e., a weighted combination of 63 | // trees. 64 | typedef vector> Model; 65 | 66 | #endif // TYPES_H_ 67 | --------------------------------------------------------------------------------