├── .github └── workflows │ └── main.yml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── README_PyPI.md ├── build.py ├── config.json ├── doc ├── build.md └── faqs.md ├── experiments └── datasets │ ├── car_evaluation │ └── data.csv │ ├── chudi │ └── data.csv │ ├── compas │ ├── binned.csv │ ├── compas-2016.csv │ ├── data.csv │ ├── original.csv │ └── processed.csv │ ├── coupon │ ├── .DS_Store │ ├── bar-7.csv │ ├── data.csv │ └── original.csv │ ├── fico │ ├── data.csv │ ├── fico-binary.csv │ ├── fico.names │ └── original.csv │ ├── gaussian │ ├── hundred.csv │ ├── ten.csv │ ├── ten_thousand.csv │ └── thousand.csv │ ├── iris │ ├── data.csv │ ├── iris.names │ ├── setosa.csv │ ├── versicolor.csv │ └── virginica.csv │ ├── monk_1 │ └── data.csv │ ├── monk_2 │ └── data.csv │ ├── monk_3 │ └── data.csv │ ├── sine │ ├── hundred.csv │ ├── ten.csv │ ├── ten_thousand.csv │ └── thousand.csv │ ├── summary.csv │ └── tic-tac-toe │ ├── Index.txt │ ├── data.csv │ ├── tic-tac-toe.csv │ └── tic-tac-toe.names ├── include ├── csv │ └── csv.h └── json │ └── json.hpp ├── log └── .gitignore ├── output.json ├── pyproject.toml ├── setup.py ├── src ├── .dirstamp ├── additive_metrics.cpp ├── additive_metrics.hpp ├── bitmask.cpp ├── bitmask.hpp ├── cart_it.hpp ├── configuration.cpp ├── configuration.hpp ├── dataset.cpp ├── dataset.hpp ├── encoder.cpp ├── encoder.hpp ├── gosdt.cpp ├── gosdt.hpp ├── graph.cpp ├── graph.hpp ├── index.cpp ├── index.hpp ├── integrity_violation.hpp ├── local_state.cpp ├── local_state.hpp ├── main.cpp ├── main.hpp ├── memusage.h ├── message.cpp ├── message.hpp ├── model.cpp ├── model.hpp ├── model_set.cpp ├── model_set.hpp ├── optimizer.cpp ├── optimizer.hpp ├── optimizer │ ├── diagnosis │ │ ├── false_convergence.hpp │ │ ├── non_convergence.hpp │ │ ├── trace.hpp │ │ └── tree.hpp │ ├── dispatch │ │ └── dispatch.hpp │ └── extraction │ │ ├── models.hpp │ │ └── rash_models.hpp ├── python_extension.cpp ├── python_extension.hpp ├── queue.cpp ├── queue.hpp ├── state.cpp ├── state.hpp ├── task.cpp ├── task.hpp ├── tile.cpp ├── tile.hpp ├── trie.cpp ├── trie.hpp ├── types.hpp └── version.hpp ├── test ├── .dirstamp ├── fixtures │ ├── binary_sepal.csv │ ├── binary_sepal.json │ ├── dataset.csv │ ├── dataset.json │ ├── sepal.csv │ ├── sepal.json │ ├── sequences.csv │ ├── sequences.json │ ├── small.csv │ ├── tree.csv │ └── tree.json ├── test.cpp ├── test.hpp ├── test_bitmask.hpp ├── test_consistency.hpp ├── test_index.hpp ├── test_queue.hpp └── test_trie.hpp └── treefarms ├── __init__.py ├── example.py ├── model ├── __init__.py ├── encoder.py ├── imbalance │ ├── __init__.py │ ├── osdt_imb_v9.py │ ├── osdt_sup.py │ └── rule.py ├── model_set.py ├── threshold_guess.py ├── tree_classifier.py └── treefarms.py └── tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Python Ignores 2 | notebooks/.ipynb_checkpoints/* 3 | *.pyc 4 | *.pyo 5 | 6 | # Build Objects 7 | *.o 8 | *.tmp 9 | gosdt 10 | gosdt_test 11 | 12 | # Autobuild Ignores 13 | MANIFEST 14 | Makefile 15 | config.log 16 | config.status 17 | autom4te.cache/* 18 | aclocal.m4 19 | m4 20 | libtool 21 | .deps 22 | gosdt 23 | gosdt_test 24 | 25 | .vscode -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | # Project `Name` and `Language` 4 | project(TREEFARMS) 5 | 6 | # Set the language standard to `c++11` 7 | set(CMAKE_CXX_STANDARD 11) 8 | 9 | # Set the compiler flags 10 | if (MSVC) 11 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 12 | set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /DDEBUG") 13 | set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE}") 14 | 15 | # `#define NOMINMAX` prevents expansion of min and max macros on Windows, 16 | # otherwise `std::numeric_limits::max()/min()` leads to MSVC compiler errors. 17 | # Reference: https://stackoverflow.com/questions/27442885/syntax-error-with-stdnumeric-limitsmax 18 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj /w /DNOMINMAX") 19 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /DDEBUG") 20 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") 21 | else() 22 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra") 23 | set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O0 -DDEBUG") 24 | set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE}") 25 | 26 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") 27 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -DDEBUG") 28 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") 29 | endif() 30 | 31 | # 32 | # MARK: - Targets 33 | # 34 | 35 | # Target Definition 36 | set(TARGET_CLI "treefarms") 37 | set(TARGET_TESTS "treefarms_tests") 38 | set(TARGET_PYTHON "libgosdt") 39 | 40 | # Target: TREEFARMS CLI 41 | file(GLOB_RECURSE SOURCE_FILES_CLI src/*.cpp) 42 | list(REMOVE_ITEM SOURCE_FILES_CLI "${CMAKE_SOURCE_DIR}/src/python_extension.cpp") 43 | add_executable(${TARGET_CLI} ${SOURCE_FILES_CLI}) 44 | target_include_directories(${TARGET_CLI} PUBLIC ${CMAKE_SOURCE_DIR}/include) 45 | 46 | # Target: TREEFARMS Tests 47 | file(GLOB_RECURSE SOURCE_FILES_TESTS src/*.cpp) 48 | list(REMOVE_ITEM SOURCE_FILES_TESTS "${CMAKE_SOURCE_DIR}/src/python_extension.cpp") 49 | list(REMOVE_ITEM SOURCE_FILES_TESTS "${CMAKE_SOURCE_DIR}/src/main.cpp") 50 | list(INSERT SOURCE_FILES_TESTS 0 "${CMAKE_SOURCE_DIR}/test/test.cpp") 51 | add_executable(${TARGET_TESTS} ${SOURCE_FILES_TESTS}) 52 | target_include_directories(${TARGET_TESTS} PUBLIC ${CMAKE_SOURCE_DIR}/include) 53 | 54 | # 55 | # MARK: - Dependencies 56 | # 57 | 58 | # Dependencies: Intel TBB 59 | find_package(TBB REQUIRED) 60 | target_link_libraries(${TARGET_CLI} PRIVATE TBB::tbb) 61 | target_link_libraries(${TARGET_CLI} PRIVATE TBB::tbbmalloc) 62 | target_link_libraries(${TARGET_TESTS} PRIVATE TBB::tbb) 63 | target_link_libraries(${TARGET_TESTS} PRIVATE TBB::tbbmalloc) 64 | 65 | # Dependencies: GMP 66 | find_package(PkgConfig REQUIRED) 67 | pkg_check_modules(GMP REQUIRED IMPORTED_TARGET gmp) 68 | target_link_libraries(${TARGET_CLI} PRIVATE PkgConfig::GMP) 69 | target_include_directories(${TARGET_CLI} PRIVATE ${GMP_INCLUDE_DIRS}) 70 | target_link_libraries(${TARGET_TESTS} PRIVATE PkgConfig::GMP) 71 | target_include_directories(${TARGET_TESTS} PRIVATE ${GMP_INCLUDE_DIRS}) 72 | 73 | # Dependencies: Threads (pthread on macOS and Ubuntu, win32 thread on Windows) 74 | # This is needed because the CentOS docker provided by manylinux reports linker errors 75 | set(THREADS_PREFER_PTHREAD_FLAG ON) 76 | find_package(Threads REQUIRED) 77 | target_link_libraries(${TARGET_CLI} PRIVATE Threads::Threads) 78 | target_link_libraries(${TARGET_TESTS} PRIVATE Threads::Threads) 79 | 80 | # Target: TREEFARMS Python Module 81 | if (SKBUILD) 82 | message(STATUS "TREEFARMS is built using scikit-build. Will build the Python module.") 83 | # Find the Python 3 development environment 84 | if (NOT DEFINED Python3_INCLUDE_DIR) 85 | message(FATAL_ERROR "The CMake variable Python3_INCLUDE_DIR should have been defined by scikit-build.") 86 | endif() 87 | # Create the list of source files needed to build the Python extension 88 | file(GLOB_RECURSE SOURCE_FILES_PY src/*.cpp) 89 | list(REMOVE_ITEM SOURCE_FILES_PY "${CMAKE_SOURCE_DIR}/src/main.cpp") 90 | # Define the CMake target for the Python extension 91 | add_library(${TARGET_PYTHON} MODULE ${SOURCE_FILES_PY}) 92 | target_include_directories(${TARGET_PYTHON} PRIVATE ${CMAKE_SOURCE_DIR}/include ${Python3_INCLUDE_DIR} ${GMP_INCLUDE_DIRS}) 93 | target_link_libraries(${TARGET_PYTHON} TBB::tbb TBB::tbbmalloc PkgConfig::GMP Threads::Threads) 94 | # Set up the Python extension 95 | find_package(PythonExtensions REQUIRED) 96 | ## Use the suffix `.abi3.so` or `.pyd` so that Python 3 on other platforms can find the dylib and import it properly 97 | message(STATUS "The current Python extension suffix is \"${PYTHON_EXTENSION_MODULE_SUFFIX}\".") 98 | if (WIN32) 99 | set(PYTHON_EXTENSION_MODULE_SUFFIX ".pyd") 100 | else() 101 | set(PYTHON_EXTENSION_MODULE_SUFFIX ".abi3.so") 102 | endif() 103 | message(STATUS "The new Python extension suffix is \"${PYTHON_EXTENSION_MODULE_SUFFIX}\".") 104 | ## Define the Python extension module target 105 | python_extension_module(${TARGET_PYTHON}) 106 | # Install `libgosdt` to the root directory of the Python extension package 107 | install(TARGETS ${TARGET_PYTHON} LIBRARY DESTINATION .) 108 | endif() 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, 2022 Chudi Zhong, Hayden McTavish, Jimmy Lin, Margo Seltzer, 4 | Cynthia Rudin, The University of British Columbia, Duke University 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /build.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import platform 3 | import distro 4 | import subprocess 5 | import sys 6 | import os 7 | import shutil 8 | 9 | 10 | def setup(args) -> None: 11 | """ 12 | Interface of the cross-platform scikit-build script 13 | :param args: A list of arguments passed to `setup.py` 14 | """ 15 | subprocess.run([sys.executable, "setup.py"] + args).check_returncode() 16 | 17 | 18 | def delocate_wheel(args) -> None: 19 | """ 20 | Interface of `delocate-wheel` that repairs wheel files on macOS 21 | :param args: A list of arguments passed to `delocate-wheel` 22 | """ 23 | subprocess.run(["delocate-wheel"] + args).check_returncode() 24 | 25 | 26 | def auditwheel(args) -> None: 27 | """ 28 | Interface of `auditwheel` that repairs wheel files on Ubuntu 29 | :param args: A list of arguments passed to "auditwheel" 30 | """ 31 | subprocess.run(["auditwheel"] + args).check_returncode() 32 | 33 | 34 | def delvewheel(args) -> None: 35 | """ 36 | Interface of `delvewheel` that repairs wheel files on Windows 37 | :param args: A list of arguments passed to `delvewheel` 38 | """ 39 | subprocess.run([sys.executable, "-m", "delvewheel"] + args).check_returncode() 40 | 41 | 42 | def repair_wheel(wheel) -> None: 43 | """ 44 | Repair the generated wheel file by copying all required dynamic libraries to it 45 | :param wheel: Path to the generated wheel file 46 | """ 47 | # Fetch the current operating system 48 | system = platform.system() 49 | # Repair the wheel 50 | if system == "Darwin": 51 | delocate_wheel(["-w", "dist", "-v", wheel]) 52 | elif system == "Linux": 53 | distribution = distro.id() 54 | if distribution == "ubuntu": 55 | auditwheel(["repair", "-w", "dist", "--plat", "linux_x86_64", wheel]) 56 | elif distribution == "centos": 57 | auditwheel(["repair", "-w", "dist", "--plat", "manylinux_2_17_x86_64", wheel]) 58 | # Remove the original wheel 59 | # The fixed wheel has a difference file name 60 | os.remove(wheel) 61 | else: 62 | print("Linux distribution {} is not supported by this script.".format(distribution)) 63 | raise EnvironmentError 64 | elif system == "Windows": 65 | search_path = str(pathlib.Path(os.getenv("VCPKG_INSTALLATION_ROOT")) / "installed\\x64-windows\\bin") 66 | delvewheel(["repair", "--no-mangle-all", "--add-path", search_path, wheel, "-w", "dist"]) 67 | else: 68 | print("{} is not supported.".format(system)) 69 | raise EnvironmentError 70 | 71 | 72 | def remove_dir_if_exists(str) -> None: 73 | if os.path.exists(str): 74 | shutil.rmtree(str) 75 | 76 | 77 | if __name__ == '__main__': 78 | try: 79 | print(">> Cleaning the garbage...") 80 | remove_dir_if_exists("dist") 81 | remove_dir_if_exists("treefarms.egg-info") 82 | setup(["clean"]) 83 | 84 | print(">> Rebuilding the project from scratch...") 85 | # `--py-limited-api=cp37` is needed otherwise installing the wheel file produced by the CentOS 7 docker 86 | # on other Linux distributions, such as Ubuntu, leads to an error `Not a supported wheel on this platform`. 87 | # On Windows, the value must match the version of the Python that builds the wheel file, 88 | # otherwise, wheel files will have the same name, despite being generated by different Python installations. 89 | api_version = "cp{}{}".format(sys.version_info.major, sys.version_info.minor) 90 | print("Using the CPython API {}.".format(api_version)) 91 | setup(["bdist_wheel", "--py-limited-api={}".format(api_version), "--build-type=Release", 92 | "-G", "Ninja", "--", "--", "-j{}".format(os.cpu_count())]) 93 | 94 | print(">> Adding required dynamic libraries to the wheel file...") 95 | wheels = os.listdir("dist") 96 | assert len(wheels) == 1, "The number of generated wheels is not 1. All wheels: {}.".format(wheels) 97 | wheel = "dist/{}".format(wheels[0]) 98 | print("Wheel file to be fixed: {}.".format(wheel)) 99 | repair_wheel(wheel) 100 | 101 | print("All done.") 102 | exit(0) 103 | except (EnvironmentError, subprocess.CalledProcessError): 104 | exit(1) 105 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "balance": false, 3 | "cancellation": true, 4 | "look_ahead": true, 5 | "similar_support": false, 6 | "feature_exchange": false, 7 | "continuous_feature_exchange": false, 8 | "rule_list": false, 9 | 10 | "diagnostics": false, 11 | "verbose": true, 12 | 13 | "regularization": 0.05, 14 | "uncertainty_tolerance": 0.0, 15 | "upperbound": 0.0, 16 | 17 | "model_limit": 10000, 18 | "precision_limit": 0, 19 | "stack_limit": 0, 20 | "tile_limit": 0, 21 | "time_limit": 0, 22 | "worker_limit": 1, 23 | 24 | "costs": "", 25 | "model": "", 26 | "rashomon_model": "", 27 | "rashomon_trie": "", 28 | "rashomon_model_set_suffix": "", 29 | "profile": "", 30 | "timing": "", 31 | "trace": "", 32 | "tree": "", 33 | "datatset_encoding": "", 34 | 35 | "depth_budget": 0, 36 | 37 | "minimum_captured_points": 0, 38 | 39 | "memory_checkpoints": [], 40 | 41 | "output_objective_model_set": false, 42 | "output_covered_sets": [], 43 | "covered_sets_thresholds": [], 44 | 45 | "rashomon": true, 46 | "rashomon_bound": 0, 47 | "rashomon_bound_multiplier": 0.3, 48 | "rashomon_bound_adder": 0, 49 | "rashomon_ignore_trivial_extensions": true 50 | } -------------------------------------------------------------------------------- /doc/faqs.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | - [Does Gosdt (implicitly) restrict the depth of the resulting tree?](##depth) 3 | - [Why does GOSDT run for a long time when the regularization parameter (lambda) is set to zero?](##long_run) 4 | - [Is there a way to limit the size of produced tree?](##limit_tree_size) 5 | - [In general, how does GOSDT set the regularization parameter?](##set_lambda) 6 | 7 | --- 8 | 9 | ## Does Gosdt (implicitly) restrict the depth of the resulting tree? 10 | 11 | New as of 2022, GOSDT With Guesses can now restrict the depth of the resulting tree both implicitly and explicitly. Our primary sparsity constraint is the regularization parameter (lambda) which is used to penalize the number of leaves. As lambda becomes smaller, the generated trees will have more leaves, but the number of leaves doesn't guarantee what depth a tree has since GOSDT generates trees of any shape. As of our 2022 AAAI paper, though, we've allowed users to restrict the depth of the tree. This provides more control over the tree's shape and reduces the runtime. However, the depth constraint is not a substitute for having a nonzero regularization parameter! Our algorithm achieves a better sparsity-accuracy tradeoff, and saves time, with a well-chosen lambda. 12 | 13 | ## Why does GOSDT run for a long time when the regularization parameter (lambda) is set to zero? 14 | 15 | The running time depends on the dataset itself and the regularization parameter (lambda). In general, setting lambda to 0 will make the running time longer. Setting lambda to 0 is kind of deactivating the branch-and-bound in GOSDT. In other words, we are kind of using brute force to search over the whole space without effective pruning, though dynamic programming can help for computational reuse. 16 | In GOSDT, we compare the difference between the upper and lower bound of a subproblem with lambda to determine whether this subproblem needs to be further split. If lambda=0, we can always split a subproblem. Therefore, it will take more time to run. It usually does not make sense to set lambda smaller than 1/n, where n is the number of samples. 17 | 18 | ## Is there a way to limit the size of the produced tree? 19 | 20 | Regularization parameter (lambda) is used to limit the size of the produced tree (specifically, in GOSDT, it limits the number of leaves of the produced tree). We usually set lambda to [0.1, 0.05, 0.01, 0.005, 0.001], but the value really depends on the dataset. One thing that might be helpful is considering how many samples should be captured by each leaf node. Suppose you want each leaf node to contain at least 10 samples. Then setting the regularization parameter to 10/n is reasonable. In general, the larger the value of lambda is, the sparser a tree you will get. 21 | 22 | 23 | ## In general, how does GOSDT set the regularization parameter? 24 | 25 | GOSDT aims to find an optimal tree that minimizes the training loss with a penalty on the number of leaves. The mathematical description is min loss+lambda*# of leaves. When we run GOSDT, we usually set lambda to different non-zero values and usually not smaller than 1/n. On page 31 Appendix I.6 in [our ICML paper](#https://arxiv.org/abs/2006.08690), we provide detailed information about the configuration we used to run accuracy vs. sparsity experiments. 26 | -------------------------------------------------------------------------------- /experiments/datasets/chudi/data.csv: -------------------------------------------------------------------------------- 1 | f1,f2,label 2 | 0.5,0.5,0 3 | 0.25,2.33,0 4 | 0.25,2.67,0 5 | 0.5,2.33,0 6 | 0.5,2.67,0 7 | 0.75,2.25,0 8 | 0.75,2.5,0 9 | 0.75,2.75,0 10 | 0.14,4.33,0 11 | 0.14,4.67,0 12 | 0.29,4.33,0 13 | 0.29,4.67,0 14 | 0.43,4.33,0 15 | 0.43,4.67,0 16 | 0.57,4.5,0 17 | 0.71,4.33,0 18 | 0.71,4.67,0 19 | 0.86,4.33,0 20 | 0.86,4.67,0 21 | 1.25,1.2,1 22 | 1.25,1.4,1 23 | 1.25,1.6,1 24 | 1.25,1.8,1 25 | 1.5,1.2,1 26 | 1.5,1.4,1 27 | 1.5,1.6,1 28 | 1.5,1.8,1 29 | 1.75,1.2,1 30 | 1.75,1.4,1 31 | 1.75,1.6,1 32 | 1.75,1.8,1 33 | 1.5,3.5,1 34 | 2.5,0.5,0 35 | 2.25,2.33,0 36 | 2.25,2.67,0 37 | 2.5,2.25,0 38 | 2.5,2.5,0 39 | 2.5,2.75,0 40 | 2.75,2.25,0 41 | 2.75,2.5,0 42 | 2.75,2.75,0 43 | 2.25,4.33,0 44 | 2.25,4.67,0 45 | 2.5,4.33,0 46 | 2.5,4.67,0 47 | 2.75,4.33,0 48 | 2.75,4.67,0 49 | 3.25,1.25,1 50 | 3.25,1.5,1 51 | 3.25,1.75,1 52 | 3.5,3.5,1 53 | 3.75,3.2,1 54 | 3.75,3.4,1 55 | 3.75,3.6,1 56 | 3.75,3.8,1 57 | 4.25,0.33,0 58 | 4.25,0.67,0 59 | 4.33,0.5,0 60 | 4.5,0.33,0 61 | 4.5,0.67,0 62 | 4.75,0.33,0 63 | 4.75,0.67,0 64 | 4.25,2.33,0 65 | 4.25,2.67,0 66 | 4.5,2.5,0 67 | 4.75,2.33,0 68 | 4.75,2.67,0 69 | 4.5,4.5,0 70 | 5.5,1.5,1 71 | 5.33,3.2,1 72 | 5.33,3.4,1 73 | 5.33,3.6,1 74 | 5.33,3.8,1 75 | 5.67,3.2,1 76 | 5.67,3.4,1 77 | 5.67,3.6,1 78 | 5.67,3.8,1 79 | 80 | -------------------------------------------------------------------------------- /experiments/datasets/coupon/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-systopia/treeFarms/5508f6063cbcac5124e164c72aafa9afc31f683d/experiments/datasets/coupon/.DS_Store -------------------------------------------------------------------------------- /experiments/datasets/fico/fico.names: -------------------------------------------------------------------------------- 1 | "Variable Names","Description","Monotonicity Constraint (with respect to probability of bad = 1)","Role","","","","","","","","" 2 | "RiskPerformance","Paid as negotiated flag (12-36 Months). String of Good and Bad","","target","","","","","","","","" 3 | "ExternalRiskEstimate","Consolidated version of risk markers","Monotonically Decreasing","predictor","","","","","","","","" 4 | "MSinceOldestTradeOpen","Months Since Oldest Trade Open","Monotonically Decreasing","predictor","","","","","","","","" 5 | "MSinceMostRecentTradeOpen","Months Since Most Recent Trade Open","Monotonically Decreasing","predictor","","","","","","","","" 6 | "AverageMInFile","Average Months in File","Monotonically Decreasing","predictor","","","","","","","","" 7 | "NumSatisfactoryTrades","Number Satisfactory Trades","Monotonically Decreasing","predictor","","","","","","","","" 8 | "NumTrades60Ever2DerogPubRec","Number Trades 60+ Ever","Monotonically Increasing","predictor","","","","","","","","" 9 | "NumTrades90Ever2DerogPubRec","Number Trades 90+ Ever","Monotonically Increasing","predictor","","","","","","","","" 10 | "PercentTradesNeverDelq","Percent Trades Never Delinquent","Monotonically Decreasing","predictor","","","","","","","","" 11 | "MSinceMostRecentDelq","Months Since Most Recent Delinquency","Monotonically Decreasing","predictor","","","","","","","","" 12 | "MaxDelq2PublicRecLast12M","Max Delq/Public Records Last 12 Months. See tab ""MaxDelq"" for each category","Values 0-7 are monotonically decreasing","predictor","","","","","","","","" 13 | "MaxDelqEver","Max Delinquency Ever. See tab ""MaxDelq"" for each category","Values 2-8 are monotonically decreasing","predictor","","","","","","","","" 14 | "NumTotalTrades","Number of Total Trades (total number of credit accounts)","No constraint","predictor","","","","","","","","" 15 | "NumTradesOpeninLast12M","Number of Trades Open in Last 12 Months","Monotonically Increasing","predictor","","","","","","","","" 16 | "PercentInstallTrades","Percent Installment Trades","No constraint","predictor","","","","","","","","" 17 | "MSinceMostRecentInqexcl7days","Months Since Most Recent Inq excl 7days","Monotonically Decreasing","predictor","","","","","","","","" 18 | "NumInqLast6M","Number of Inq Last 6 Months","Monotonically Increasing","predictor","","","","","","","","" 19 | "NumInqLast6Mexcl7days","Number of Inq Last 6 Months excl 7days. Excluding the last 7 days removes inquiries that are likely due to price comparision shopping.","Monotonically Increasing","predictor","","","","","","","","" 20 | "NetFractionRevolvingBurden","Net Fraction Revolving Burden. This is revolving balance divided by credit limit","Monotonically Increasing","predictor","","","","","","","","" 21 | "NetFractionInstallBurden","Net Fraction Installment Burden. This is installment balance divided by original loan amount","Monotonically Increasing","predictor","","","","","","","","" 22 | "NumRevolvingTradesWBalance","Number Revolving Trades with Balance","No constraint","predictor","","","","","","","","" 23 | "NumInstallTradesWBalance","Number Installment Trades with Balance","No constraint","predictor","","","","","","","","" 24 | "NumBank2NatlTradesWHighUtilization","Number Bank/Natl Trades w high utilization ratio","Monotonically Increasing","predictor","","","","","","","","" 25 | "PercentTradesWBalance","Percent Trades with Balance","No constraint","predictor","","","","","","","","" 26 | "","","","","","","","","","","","" 27 | "*For a more detailed example of the monotonicity contraint, please see the ""challenge rules"" page","","","","","","","","","","","" 28 | -------------------------------------------------------------------------------- /experiments/datasets/gaussian/hundred.csv: -------------------------------------------------------------------------------- 1 | x,label 2 | -4.0,0 3 | -3.919191919191919,0 4 | -3.8383838383838382,0 5 | -3.757575757575758,0 6 | -3.676767676767677,0 7 | -3.595959595959596,0 8 | -3.515151515151515,0 9 | -3.4343434343434343,0 10 | -3.3535353535353534,0 11 | -3.2727272727272725,0 12 | -3.191919191919192,0 13 | -3.111111111111111,0 14 | -3.0303030303030303,0 15 | -2.9494949494949494,0 16 | -2.8686868686868685,0 17 | -2.787878787878788,0 18 | -2.7070707070707067,0 19 | -2.6262626262626263,0 20 | -2.5454545454545454,0 21 | -2.4646464646464645,0 22 | -2.3838383838383836,0 23 | -2.3030303030303028,0 24 | -2.2222222222222223,0 25 | -2.141414141414141,0 26 | -2.0606060606060606,0 27 | -1.9797979797979797,0 28 | -1.8989898989898988,0 29 | -1.818181818181818,0 30 | -1.737373737373737,0 31 | -1.6565656565656566,0 32 | -1.5757575757575757,1 33 | -1.4949494949494948,0 34 | -1.414141414141414,0 35 | -1.333333333333333,1 36 | -1.2525252525252522,1 37 | -1.1717171717171713,1 38 | -1.0909090909090908,1 39 | -1.01010101010101,1 40 | -0.9292929292929291,0 41 | -0.8484848484848482,0 42 | -0.7676767676767673,0 43 | -0.6868686868686864,1 44 | -0.606060606060606,1 45 | -0.5252525252525251,1 46 | -0.4444444444444442,1 47 | -0.3636363636363633,1 48 | -0.28282828282828243,0 49 | -0.20202020202020154,1 50 | -0.1212121212121211,1 51 | -0.04040404040404022,1 52 | 0.040404040404040664,1 53 | 0.12121212121212199,1 54 | 0.20202020202020243,1 55 | 0.2828282828282829,1 56 | 0.3636363636363642,1 57 | 0.44444444444444464,0 58 | 0.525252525252526,1 59 | 0.6060606060606064,0 60 | 0.6868686868686869,1 61 | 0.7676767676767682,1 62 | 0.8484848484848486,0 63 | 0.92929292929293,1 64 | 1.0101010101010104,1 65 | 1.0909090909090917,0 66 | 1.1717171717171722,1 67 | 1.2525252525252526,1 68 | 1.333333333333334,1 69 | 1.4141414141414144,0 70 | 1.4949494949494957,0 71 | 1.5757575757575761,1 72 | 1.6565656565656575,0 73 | 1.737373737373738,0 74 | 1.8181818181818183,1 75 | 1.8989898989898997,1 76 | 1.9797979797979801,0 77 | 2.0606060606060614,0 78 | 2.141414141414142,0 79 | 2.2222222222222223,0 80 | 2.3030303030303036,0 81 | 2.383838383838384,0 82 | 2.4646464646464654,0 83 | 2.545454545454546,0 84 | 2.626262626262627,0 85 | 2.7070707070707076,0 86 | 2.787878787878788,0 87 | 2.8686868686868694,0 88 | 2.94949494949495,0 89 | 3.030303030303031,0 90 | 3.1111111111111116,0 91 | 3.191919191919193,0 92 | 3.2727272727272734,0 93 | 3.353535353535354,0 94 | 3.434343434343435,0 95 | 3.5151515151515156,0 96 | 3.595959595959597,0 97 | 3.6767676767676774,0 98 | 3.757575757575758,0 99 | 3.838383838383839,0 100 | 3.9191919191919196,0 101 | 4.0,0 102 | -------------------------------------------------------------------------------- /experiments/datasets/gaussian/ten.csv: -------------------------------------------------------------------------------- 1 | x,label 2 | -3.0,0 3 | -2.3333333333333335,0 4 | -1.6666666666666667,0 5 | -1.0,1 6 | -0.3333333333333335,1 7 | 0.33333333333333304,1 8 | 1.0,0 9 | 1.666666666666666,0 10 | 2.333333333333333,0 11 | 3.0,0 12 | -------------------------------------------------------------------------------- /experiments/datasets/iris/data.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,class 2 | 5.1,3.5,1.4,0.2,Iris-setosa 3 | 4.9,3.0,1.4,0.2,Iris-setosa 4 | 4.7,3.2,1.3,0.2,Iris-setosa 5 | 4.6,3.1,1.5,0.2,Iris-setosa 6 | 5.0,3.6,1.4,0.2,Iris-setosa 7 | 5.4,3.9,1.7,0.4,Iris-setosa 8 | 4.6,3.4,1.4,0.3,Iris-setosa 9 | 5.0,3.4,1.5,0.2,Iris-setosa 10 | 4.4,2.9,1.4,0.2,Iris-setosa 11 | 4.9,3.1,1.5,0.1,Iris-setosa 12 | 5.4,3.7,1.5,0.2,Iris-setosa 13 | 4.8,3.4,1.6,0.2,Iris-setosa 14 | 4.8,3.0,1.4,0.1,Iris-setosa 15 | 4.3,3.0,1.1,0.1,Iris-setosa 16 | 5.8,4.0,1.2,0.2,Iris-setosa 17 | 5.7,4.4,1.5,0.4,Iris-setosa 18 | 5.4,3.9,1.3,0.4,Iris-setosa 19 | 5.1,3.5,1.4,0.3,Iris-setosa 20 | 5.7,3.8,1.7,0.3,Iris-setosa 21 | 5.1,3.8,1.5,0.3,Iris-setosa 22 | 5.4,3.4,1.7,0.2,Iris-setosa 23 | 5.1,3.7,1.5,0.4,Iris-setosa 24 | 4.6,3.6,1.0,0.2,Iris-setosa 25 | 5.1,3.3,1.7,0.5,Iris-setosa 26 | 4.8,3.4,1.9,0.2,Iris-setosa 27 | 5.0,3.0,1.6,0.2,Iris-setosa 28 | 5.0,3.4,1.6,0.4,Iris-setosa 29 | 5.2,3.5,1.5,0.2,Iris-setosa 30 | 5.2,3.4,1.4,0.2,Iris-setosa 31 | 4.7,3.2,1.6,0.2,Iris-setosa 32 | 4.8,3.1,1.6,0.2,Iris-setosa 33 | 5.4,3.4,1.5,0.4,Iris-setosa 34 | 5.2,4.1,1.5,0.1,Iris-setosa 35 | 5.5,4.2,1.4,0.2,Iris-setosa 36 | 4.9,3.1,1.5,0.1,Iris-setosa 37 | 5.0,3.2,1.2,0.2,Iris-setosa 38 | 5.5,3.5,1.3,0.2,Iris-setosa 39 | 4.9,3.1,1.5,0.1,Iris-setosa 40 | 4.4,3.0,1.3,0.2,Iris-setosa 41 | 5.1,3.4,1.5,0.2,Iris-setosa 42 | 5.0,3.5,1.3,0.3,Iris-setosa 43 | 4.5,2.3,1.3,0.3,Iris-setosa 44 | 4.4,3.2,1.3,0.2,Iris-setosa 45 | 5.0,3.5,1.6,0.6,Iris-setosa 46 | 5.1,3.8,1.9,0.4,Iris-setosa 47 | 4.8,3.0,1.4,0.3,Iris-setosa 48 | 5.1,3.8,1.6,0.2,Iris-setosa 49 | 4.6,3.2,1.4,0.2,Iris-setosa 50 | 5.3,3.7,1.5,0.2,Iris-setosa 51 | 5.0,3.3,1.4,0.2,Iris-setosa 52 | 7.0,3.2,4.7,1.4,Iris-versicolor 53 | 6.4,3.2,4.5,1.5,Iris-versicolor 54 | 6.9,3.1,4.9,1.5,Iris-versicolor 55 | 5.5,2.3,4.0,1.3,Iris-versicolor 56 | 6.5,2.8,4.6,1.5,Iris-versicolor 57 | 5.7,2.8,4.5,1.3,Iris-versicolor 58 | 6.3,3.3,4.7,1.6,Iris-versicolor 59 | 4.9,2.4,3.3,1.0,Iris-versicolor 60 | 6.6,2.9,4.6,1.3,Iris-versicolor 61 | 5.2,2.7,3.9,1.4,Iris-versicolor 62 | 5.0,2.0,3.5,1.0,Iris-versicolor 63 | 5.9,3.0,4.2,1.5,Iris-versicolor 64 | 6.0,2.2,4.0,1.0,Iris-versicolor 65 | 6.1,2.9,4.7,1.4,Iris-versicolor 66 | 5.6,2.9,3.6,1.3,Iris-versicolor 67 | 6.7,3.1,4.4,1.4,Iris-versicolor 68 | 5.6,3.0,4.5,1.5,Iris-versicolor 69 | 5.8,2.7,4.1,1.0,Iris-versicolor 70 | 6.2,2.2,4.5,1.5,Iris-versicolor 71 | 5.6,2.5,3.9,1.1,Iris-versicolor 72 | 5.9,3.2,4.8,1.8,Iris-versicolor 73 | 6.1,2.8,4.0,1.3,Iris-versicolor 74 | 6.3,2.5,4.9,1.5,Iris-versicolor 75 | 6.1,2.8,4.7,1.2,Iris-versicolor 76 | 6.4,2.9,4.3,1.3,Iris-versicolor 77 | 6.6,3.0,4.4,1.4,Iris-versicolor 78 | 6.8,2.8,4.8,1.4,Iris-versicolor 79 | 6.7,3.0,5.0,1.7,Iris-versicolor 80 | 6.0,2.9,4.5,1.5,Iris-versicolor 81 | 5.7,2.6,3.5,1.0,Iris-versicolor 82 | 5.5,2.4,3.8,1.1,Iris-versicolor 83 | 5.5,2.4,3.7,1.0,Iris-versicolor 84 | 5.8,2.7,3.9,1.2,Iris-versicolor 85 | 6.0,2.7,5.1,1.6,Iris-versicolor 86 | 5.4,3.0,4.5,1.5,Iris-versicolor 87 | 6.0,3.4,4.5,1.6,Iris-versicolor 88 | 6.7,3.1,4.7,1.5,Iris-versicolor 89 | 6.3,2.3,4.4,1.3,Iris-versicolor 90 | 5.6,3.0,4.1,1.3,Iris-versicolor 91 | 5.5,2.5,4.0,1.3,Iris-versicolor 92 | 5.5,2.6,4.4,1.2,Iris-versicolor 93 | 6.1,3.0,4.6,1.4,Iris-versicolor 94 | 5.8,2.6,4.0,1.2,Iris-versicolor 95 | 5.0,2.3,3.3,1.0,Iris-versicolor 96 | 5.6,2.7,4.2,1.3,Iris-versicolor 97 | 5.7,3.0,4.2,1.2,Iris-versicolor 98 | 5.7,2.9,4.2,1.3,Iris-versicolor 99 | 6.2,2.9,4.3,1.3,Iris-versicolor 100 | 5.1,2.5,3.0,1.1,Iris-versicolor 101 | 5.7,2.8,4.1,1.3,Iris-versicolor 102 | 6.3,3.3,6.0,2.5,Iris-virginica 103 | 5.8,2.7,5.1,1.9,Iris-virginica 104 | 7.1,3.0,5.9,2.1,Iris-virginica 105 | 6.3,2.9,5.6,1.8,Iris-virginica 106 | 6.5,3.0,5.8,2.2,Iris-virginica 107 | 7.6,3.0,6.6,2.1,Iris-virginica 108 | 4.9,2.5,4.5,1.7,Iris-virginica 109 | 7.3,2.9,6.3,1.8,Iris-virginica 110 | 6.7,2.5,5.8,1.8,Iris-virginica 111 | 7.2,3.6,6.1,2.5,Iris-virginica 112 | 6.5,3.2,5.1,2.0,Iris-virginica 113 | 6.4,2.7,5.3,1.9,Iris-virginica 114 | 6.8,3.0,5.5,2.1,Iris-virginica 115 | 5.7,2.5,5.0,2.0,Iris-virginica 116 | 5.8,2.8,5.1,2.4,Iris-virginica 117 | 6.4,3.2,5.3,2.3,Iris-virginica 118 | 6.5,3.0,5.5,1.8,Iris-virginica 119 | 7.7,3.8,6.7,2.2,Iris-virginica 120 | 7.7,2.6,6.9,2.3,Iris-virginica 121 | 6.0,2.2,5.0,1.5,Iris-virginica 122 | 6.9,3.2,5.7,2.3,Iris-virginica 123 | 5.6,2.8,4.9,2.0,Iris-virginica 124 | 7.7,2.8,6.7,2.0,Iris-virginica 125 | 6.3,2.7,4.9,1.8,Iris-virginica 126 | 6.7,3.3,5.7,2.1,Iris-virginica 127 | 7.2,3.2,6.0,1.8,Iris-virginica 128 | 6.2,2.8,4.8,1.8,Iris-virginica 129 | 6.1,3.0,4.9,1.8,Iris-virginica 130 | 6.4,2.8,5.6,2.1,Iris-virginica 131 | 7.2,3.0,5.8,1.6,Iris-virginica 132 | 7.4,2.8,6.1,1.9,Iris-virginica 133 | 7.9,3.8,6.4,2.0,Iris-virginica 134 | 6.4,2.8,5.6,2.2,Iris-virginica 135 | 6.3,2.8,5.1,1.5,Iris-virginica 136 | 6.1,2.6,5.6,1.4,Iris-virginica 137 | 7.7,3.0,6.1,2.3,Iris-virginica 138 | 6.3,3.4,5.6,2.4,Iris-virginica 139 | 6.4,3.1,5.5,1.8,Iris-virginica 140 | 6.0,3.0,4.8,1.8,Iris-virginica 141 | 6.9,3.1,5.4,2.1,Iris-virginica 142 | 6.7,3.1,5.6,2.4,Iris-virginica 143 | 6.9,3.1,5.1,2.3,Iris-virginica 144 | 5.8,2.7,5.1,1.9,Iris-virginica 145 | 6.8,3.2,5.9,2.3,Iris-virginica 146 | 6.7,3.3,5.7,2.5,Iris-virginica 147 | 6.7,3.0,5.2,2.3,Iris-virginica 148 | 6.3,2.5,5.0,1.9,Iris-virginica 149 | 6.5,3.0,5.2,2.0,Iris-virginica 150 | 6.2,3.4,5.4,2.3,Iris-virginica 151 | 5.9,3.0,5.1,1.8,Iris-virginica -------------------------------------------------------------------------------- /experiments/datasets/iris/iris.names: -------------------------------------------------------------------------------- 1 | 1. Title: Iris Plants Database 2 | Updated Sept 21 by C.Blake - Added discrepency information 3 | 4 | 2. Sources: 5 | (a) Creator: R.A. Fisher 6 | (b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov) 7 | (c) Date: July, 1988 8 | 9 | 3. Past Usage: 10 | - Publications: too many to mention!!! Here are a few. 11 | 1. Fisher,R.A. "The use of multiple measurements in taxonomic problems" 12 | Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions 13 | to Mathematical Statistics" (John Wiley, NY, 1950). 14 | 2. Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis. 15 | (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218. 16 | 3. Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System 17 | Structure and Classification Rule for Recognition in Partially Exposed 18 | Environments". IEEE Transactions on Pattern Analysis and Machine 19 | Intelligence, Vol. PAMI-2, No. 1, 67-71. 20 | -- Results: 21 | -- very low misclassification rates (0% for the setosa class) 22 | 4. Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE 23 | Transactions on Information Theory, May 1972, 431-433. 24 | -- Results: 25 | -- very low misclassification rates again 26 | 5. See also: 1988 MLC Proceedings, 54-64. Cheeseman et al's AUTOCLASS II 27 | conceptual clustering system finds 3 classes in the data. 28 | 29 | 4. Relevant Information: 30 | --- This is perhaps the best known database to be found in the pattern 31 | recognition literature. Fisher's paper is a classic in the field 32 | and is referenced frequently to this day. (See Duda & Hart, for 33 | example.) The data set contains 3 classes of 50 instances each, 34 | where each class refers to a type of iris plant. One class is 35 | linearly separable from the other 2; the latter are NOT linearly 36 | separable from each other. 37 | --- Predicted attribute: class of iris plant. 38 | --- This is an exceedingly simple domain. 39 | --- This data differs from the data presented in Fishers article 40 | (identified by Steve Chadwick, spchadwick@espeedaz.net ) 41 | The 35th sample should be: 4.9,3.1,1.5,0.2,"Iris-setosa" 42 | where the error is in the fourth feature. 43 | The 38th sample: 4.9,3.6,1.4,0.1,"Iris-setosa" 44 | where the errors are in the second and third features. 45 | 46 | 5. Number of Instances: 150 (50 in each of three classes) 47 | 48 | 6. Number of Attributes: 4 numeric, predictive attributes and the class 49 | 50 | 7. Attribute Information: 51 | 1. sepal length in cm 52 | 2. sepal width in cm 53 | 3. petal length in cm 54 | 4. petal width in cm 55 | 5. class: 56 | -- Iris Setosa 57 | -- Iris Versicolour 58 | -- Iris Virginica 59 | 60 | 8. Missing Attribute Values: None 61 | 62 | Summary Statistics: 63 | Min Max Mean SD Class Correlation 64 | sepal length: 4.3 7.9 5.84 0.83 0.7826 65 | sepal width: 2.0 4.4 3.05 0.43 -0.4194 66 | petal length: 1.0 6.9 3.76 1.76 0.9490 (high!) 67 | petal width: 0.1 2.5 1.20 0.76 0.9565 (high!) 68 | 69 | 9. Class Distribution: 33.3% for each of 3 classes. 70 | -------------------------------------------------------------------------------- /experiments/datasets/iris/setosa.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,class 2 | 5.1,3.5,1.4,0.2,Iris-setosa 3 | 4.9,3.0,1.4,0.2,Iris-setosa 4 | 4.7,3.2,1.3,0.2,Iris-setosa 5 | 4.6,3.1,1.5,0.2,Iris-setosa 6 | 5.0,3.6,1.4,0.2,Iris-setosa 7 | 5.4,3.9,1.7,0.4,Iris-setosa 8 | 4.6,3.4,1.4,0.3,Iris-setosa 9 | 5.0,3.4,1.5,0.2,Iris-setosa 10 | 4.4,2.9,1.4,0.2,Iris-setosa 11 | 4.9,3.1,1.5,0.1,Iris-setosa 12 | 5.4,3.7,1.5,0.2,Iris-setosa 13 | 4.8,3.4,1.6,0.2,Iris-setosa 14 | 4.8,3.0,1.4,0.1,Iris-setosa 15 | 4.3,3.0,1.1,0.1,Iris-setosa 16 | 5.8,4.0,1.2,0.2,Iris-setosa 17 | 5.7,4.4,1.5,0.4,Iris-setosa 18 | 5.4,3.9,1.3,0.4,Iris-setosa 19 | 5.1,3.5,1.4,0.3,Iris-setosa 20 | 5.7,3.8,1.7,0.3,Iris-setosa 21 | 5.1,3.8,1.5,0.3,Iris-setosa 22 | 5.4,3.4,1.7,0.2,Iris-setosa 23 | 5.1,3.7,1.5,0.4,Iris-setosa 24 | 4.6,3.6,1.0,0.2,Iris-setosa 25 | 5.1,3.3,1.7,0.5,Iris-setosa 26 | 4.8,3.4,1.9,0.2,Iris-setosa 27 | 5.0,3.0,1.6,0.2,Iris-setosa 28 | 5.0,3.4,1.6,0.4,Iris-setosa 29 | 5.2,3.5,1.5,0.2,Iris-setosa 30 | 5.2,3.4,1.4,0.2,Iris-setosa 31 | 4.7,3.2,1.6,0.2,Iris-setosa 32 | 4.8,3.1,1.6,0.2,Iris-setosa 33 | 5.4,3.4,1.5,0.4,Iris-setosa 34 | 5.2,4.1,1.5,0.1,Iris-setosa 35 | 5.5,4.2,1.4,0.2,Iris-setosa 36 | 4.9,3.1,1.5,0.1,Iris-setosa 37 | 5.0,3.2,1.2,0.2,Iris-setosa 38 | 5.5,3.5,1.3,0.2,Iris-setosa 39 | 4.9,3.1,1.5,0.1,Iris-setosa 40 | 4.4,3.0,1.3,0.2,Iris-setosa 41 | 5.1,3.4,1.5,0.2,Iris-setosa 42 | 5.0,3.5,1.3,0.3,Iris-setosa 43 | 4.5,2.3,1.3,0.3,Iris-setosa 44 | 4.4,3.2,1.3,0.2,Iris-setosa 45 | 5.0,3.5,1.6,0.6,Iris-setosa 46 | 5.1,3.8,1.9,0.4,Iris-setosa 47 | 4.8,3.0,1.4,0.3,Iris-setosa 48 | 5.1,3.8,1.6,0.2,Iris-setosa 49 | 4.6,3.2,1.4,0.2,Iris-setosa 50 | 5.3,3.7,1.5,0.2,Iris-setosa 51 | 5.0,3.3,1.4,0.2,Iris-setosa 52 | 7.0,3.2,4.7,1.4,Iris-other 53 | 6.4,3.2,4.5,1.5,Iris-other 54 | 6.9,3.1,4.9,1.5,Iris-other 55 | 5.5,2.3,4.0,1.3,Iris-other 56 | 6.5,2.8,4.6,1.5,Iris-other 57 | 5.7,2.8,4.5,1.3,Iris-other 58 | 6.3,3.3,4.7,1.6,Iris-other 59 | 4.9,2.4,3.3,1.0,Iris-other 60 | 6.6,2.9,4.6,1.3,Iris-other 61 | 5.2,2.7,3.9,1.4,Iris-other 62 | 5.0,2.0,3.5,1.0,Iris-other 63 | 5.9,3.0,4.2,1.5,Iris-other 64 | 6.0,2.2,4.0,1.0,Iris-other 65 | 6.1,2.9,4.7,1.4,Iris-other 66 | 5.6,2.9,3.6,1.3,Iris-other 67 | 6.7,3.1,4.4,1.4,Iris-other 68 | 5.6,3.0,4.5,1.5,Iris-other 69 | 5.8,2.7,4.1,1.0,Iris-other 70 | 6.2,2.2,4.5,1.5,Iris-other 71 | 5.6,2.5,3.9,1.1,Iris-other 72 | 5.9,3.2,4.8,1.8,Iris-other 73 | 6.1,2.8,4.0,1.3,Iris-other 74 | 6.3,2.5,4.9,1.5,Iris-other 75 | 6.1,2.8,4.7,1.2,Iris-other 76 | 6.4,2.9,4.3,1.3,Iris-other 77 | 6.6,3.0,4.4,1.4,Iris-other 78 | 6.8,2.8,4.8,1.4,Iris-other 79 | 6.7,3.0,5.0,1.7,Iris-other 80 | 6.0,2.9,4.5,1.5,Iris-other 81 | 5.7,2.6,3.5,1.0,Iris-other 82 | 5.5,2.4,3.8,1.1,Iris-other 83 | 5.5,2.4,3.7,1.0,Iris-other 84 | 5.8,2.7,3.9,1.2,Iris-other 85 | 6.0,2.7,5.1,1.6,Iris-other 86 | 5.4,3.0,4.5,1.5,Iris-other 87 | 6.0,3.4,4.5,1.6,Iris-other 88 | 6.7,3.1,4.7,1.5,Iris-other 89 | 6.3,2.3,4.4,1.3,Iris-other 90 | 5.6,3.0,4.1,1.3,Iris-other 91 | 5.5,2.5,4.0,1.3,Iris-other 92 | 5.5,2.6,4.4,1.2,Iris-other 93 | 6.1,3.0,4.6,1.4,Iris-other 94 | 5.8,2.6,4.0,1.2,Iris-other 95 | 5.0,2.3,3.3,1.0,Iris-other 96 | 5.6,2.7,4.2,1.3,Iris-other 97 | 5.7,3.0,4.2,1.2,Iris-other 98 | 5.7,2.9,4.2,1.3,Iris-other 99 | 6.2,2.9,4.3,1.3,Iris-other 100 | 5.1,2.5,3.0,1.1,Iris-other 101 | 5.7,2.8,4.1,1.3,Iris-other 102 | 6.3,3.3,6.0,2.5,Iris-other 103 | 5.8,2.7,5.1,1.9,Iris-other 104 | 7.1,3.0,5.9,2.1,Iris-other 105 | 6.3,2.9,5.6,1.8,Iris-other 106 | 6.5,3.0,5.8,2.2,Iris-other 107 | 7.6,3.0,6.6,2.1,Iris-other 108 | 4.9,2.5,4.5,1.7,Iris-other 109 | 7.3,2.9,6.3,1.8,Iris-other 110 | 6.7,2.5,5.8,1.8,Iris-other 111 | 7.2,3.6,6.1,2.5,Iris-other 112 | 6.5,3.2,5.1,2.0,Iris-other 113 | 6.4,2.7,5.3,1.9,Iris-other 114 | 6.8,3.0,5.5,2.1,Iris-other 115 | 5.7,2.5,5.0,2.0,Iris-other 116 | 5.8,2.8,5.1,2.4,Iris-other 117 | 6.4,3.2,5.3,2.3,Iris-other 118 | 6.5,3.0,5.5,1.8,Iris-other 119 | 7.7,3.8,6.7,2.2,Iris-other 120 | 7.7,2.6,6.9,2.3,Iris-other 121 | 6.0,2.2,5.0,1.5,Iris-other 122 | 6.9,3.2,5.7,2.3,Iris-other 123 | 5.6,2.8,4.9,2.0,Iris-other 124 | 7.7,2.8,6.7,2.0,Iris-other 125 | 6.3,2.7,4.9,1.8,Iris-other 126 | 6.7,3.3,5.7,2.1,Iris-other 127 | 7.2,3.2,6.0,1.8,Iris-other 128 | 6.2,2.8,4.8,1.8,Iris-other 129 | 6.1,3.0,4.9,1.8,Iris-other 130 | 6.4,2.8,5.6,2.1,Iris-other 131 | 7.2,3.0,5.8,1.6,Iris-other 132 | 7.4,2.8,6.1,1.9,Iris-other 133 | 7.9,3.8,6.4,2.0,Iris-other 134 | 6.4,2.8,5.6,2.2,Iris-other 135 | 6.3,2.8,5.1,1.5,Iris-other 136 | 6.1,2.6,5.6,1.4,Iris-other 137 | 7.7,3.0,6.1,2.3,Iris-other 138 | 6.3,3.4,5.6,2.4,Iris-other 139 | 6.4,3.1,5.5,1.8,Iris-other 140 | 6.0,3.0,4.8,1.8,Iris-other 141 | 6.9,3.1,5.4,2.1,Iris-other 142 | 6.7,3.1,5.6,2.4,Iris-other 143 | 6.9,3.1,5.1,2.3,Iris-other 144 | 5.8,2.7,5.1,1.9,Iris-other 145 | 6.8,3.2,5.9,2.3,Iris-other 146 | 6.7,3.3,5.7,2.5,Iris-other 147 | 6.7,3.0,5.2,2.3,Iris-other 148 | 6.3,2.5,5.0,1.9,Iris-other 149 | 6.5,3.0,5.2,2.0,Iris-other 150 | 6.2,3.4,5.4,2.3,Iris-other 151 | 5.9,3.0,5.1,1.8,Iris-other -------------------------------------------------------------------------------- /experiments/datasets/iris/versicolor.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,class 2 | 5.1,3.5,1.4,0.2,Iris-other 3 | 4.9,3.0,1.4,0.2,Iris-other 4 | 4.7,3.2,1.3,0.2,Iris-other 5 | 4.6,3.1,1.5,0.2,Iris-other 6 | 5.0,3.6,1.4,0.2,Iris-other 7 | 5.4,3.9,1.7,0.4,Iris-other 8 | 4.6,3.4,1.4,0.3,Iris-other 9 | 5.0,3.4,1.5,0.2,Iris-other 10 | 4.4,2.9,1.4,0.2,Iris-other 11 | 4.9,3.1,1.5,0.1,Iris-other 12 | 5.4,3.7,1.5,0.2,Iris-other 13 | 4.8,3.4,1.6,0.2,Iris-other 14 | 4.8,3.0,1.4,0.1,Iris-other 15 | 4.3,3.0,1.1,0.1,Iris-other 16 | 5.8,4.0,1.2,0.2,Iris-other 17 | 5.7,4.4,1.5,0.4,Iris-other 18 | 5.4,3.9,1.3,0.4,Iris-other 19 | 5.1,3.5,1.4,0.3,Iris-other 20 | 5.7,3.8,1.7,0.3,Iris-other 21 | 5.1,3.8,1.5,0.3,Iris-other 22 | 5.4,3.4,1.7,0.2,Iris-other 23 | 5.1,3.7,1.5,0.4,Iris-other 24 | 4.6,3.6,1.0,0.2,Iris-other 25 | 5.1,3.3,1.7,0.5,Iris-other 26 | 4.8,3.4,1.9,0.2,Iris-other 27 | 5.0,3.0,1.6,0.2,Iris-other 28 | 5.0,3.4,1.6,0.4,Iris-other 29 | 5.2,3.5,1.5,0.2,Iris-other 30 | 5.2,3.4,1.4,0.2,Iris-other 31 | 4.7,3.2,1.6,0.2,Iris-other 32 | 4.8,3.1,1.6,0.2,Iris-other 33 | 5.4,3.4,1.5,0.4,Iris-other 34 | 5.2,4.1,1.5,0.1,Iris-other 35 | 5.5,4.2,1.4,0.2,Iris-other 36 | 4.9,3.1,1.5,0.1,Iris-other 37 | 5.0,3.2,1.2,0.2,Iris-other 38 | 5.5,3.5,1.3,0.2,Iris-other 39 | 4.9,3.1,1.5,0.1,Iris-other 40 | 4.4,3.0,1.3,0.2,Iris-other 41 | 5.1,3.4,1.5,0.2,Iris-other 42 | 5.0,3.5,1.3,0.3,Iris-other 43 | 4.5,2.3,1.3,0.3,Iris-other 44 | 4.4,3.2,1.3,0.2,Iris-other 45 | 5.0,3.5,1.6,0.6,Iris-other 46 | 5.1,3.8,1.9,0.4,Iris-other 47 | 4.8,3.0,1.4,0.3,Iris-other 48 | 5.1,3.8,1.6,0.2,Iris-other 49 | 4.6,3.2,1.4,0.2,Iris-other 50 | 5.3,3.7,1.5,0.2,Iris-other 51 | 5.0,3.3,1.4,0.2,Iris-other 52 | 7.0,3.2,4.7,1.4,Iris-versicolor 53 | 6.4,3.2,4.5,1.5,Iris-versicolor 54 | 6.9,3.1,4.9,1.5,Iris-versicolor 55 | 5.5,2.3,4.0,1.3,Iris-versicolor 56 | 6.5,2.8,4.6,1.5,Iris-versicolor 57 | 5.7,2.8,4.5,1.3,Iris-versicolor 58 | 6.3,3.3,4.7,1.6,Iris-versicolor 59 | 4.9,2.4,3.3,1.0,Iris-versicolor 60 | 6.6,2.9,4.6,1.3,Iris-versicolor 61 | 5.2,2.7,3.9,1.4,Iris-versicolor 62 | 5.0,2.0,3.5,1.0,Iris-versicolor 63 | 5.9,3.0,4.2,1.5,Iris-versicolor 64 | 6.0,2.2,4.0,1.0,Iris-versicolor 65 | 6.1,2.9,4.7,1.4,Iris-versicolor 66 | 5.6,2.9,3.6,1.3,Iris-versicolor 67 | 6.7,3.1,4.4,1.4,Iris-versicolor 68 | 5.6,3.0,4.5,1.5,Iris-versicolor 69 | 5.8,2.7,4.1,1.0,Iris-versicolor 70 | 6.2,2.2,4.5,1.5,Iris-versicolor 71 | 5.6,2.5,3.9,1.1,Iris-versicolor 72 | 5.9,3.2,4.8,1.8,Iris-versicolor 73 | 6.1,2.8,4.0,1.3,Iris-versicolor 74 | 6.3,2.5,4.9,1.5,Iris-versicolor 75 | 6.1,2.8,4.7,1.2,Iris-versicolor 76 | 6.4,2.9,4.3,1.3,Iris-versicolor 77 | 6.6,3.0,4.4,1.4,Iris-versicolor 78 | 6.8,2.8,4.8,1.4,Iris-versicolor 79 | 6.7,3.0,5.0,1.7,Iris-versicolor 80 | 6.0,2.9,4.5,1.5,Iris-versicolor 81 | 5.7,2.6,3.5,1.0,Iris-versicolor 82 | 5.5,2.4,3.8,1.1,Iris-versicolor 83 | 5.5,2.4,3.7,1.0,Iris-versicolor 84 | 5.8,2.7,3.9,1.2,Iris-versicolor 85 | 6.0,2.7,5.1,1.6,Iris-versicolor 86 | 5.4,3.0,4.5,1.5,Iris-versicolor 87 | 6.0,3.4,4.5,1.6,Iris-versicolor 88 | 6.7,3.1,4.7,1.5,Iris-versicolor 89 | 6.3,2.3,4.4,1.3,Iris-versicolor 90 | 5.6,3.0,4.1,1.3,Iris-versicolor 91 | 5.5,2.5,4.0,1.3,Iris-versicolor 92 | 5.5,2.6,4.4,1.2,Iris-versicolor 93 | 6.1,3.0,4.6,1.4,Iris-versicolor 94 | 5.8,2.6,4.0,1.2,Iris-versicolor 95 | 5.0,2.3,3.3,1.0,Iris-versicolor 96 | 5.6,2.7,4.2,1.3,Iris-versicolor 97 | 5.7,3.0,4.2,1.2,Iris-versicolor 98 | 5.7,2.9,4.2,1.3,Iris-versicolor 99 | 6.2,2.9,4.3,1.3,Iris-versicolor 100 | 5.1,2.5,3.0,1.1,Iris-versicolor 101 | 5.7,2.8,4.1,1.3,Iris-versicolor 102 | 6.3,3.3,6.0,2.5,Iris-other 103 | 5.8,2.7,5.1,1.9,Iris-other 104 | 7.1,3.0,5.9,2.1,Iris-other 105 | 6.3,2.9,5.6,1.8,Iris-other 106 | 6.5,3.0,5.8,2.2,Iris-other 107 | 7.6,3.0,6.6,2.1,Iris-other 108 | 4.9,2.5,4.5,1.7,Iris-other 109 | 7.3,2.9,6.3,1.8,Iris-other 110 | 6.7,2.5,5.8,1.8,Iris-other 111 | 7.2,3.6,6.1,2.5,Iris-other 112 | 6.5,3.2,5.1,2.0,Iris-other 113 | 6.4,2.7,5.3,1.9,Iris-other 114 | 6.8,3.0,5.5,2.1,Iris-other 115 | 5.7,2.5,5.0,2.0,Iris-other 116 | 5.8,2.8,5.1,2.4,Iris-other 117 | 6.4,3.2,5.3,2.3,Iris-other 118 | 6.5,3.0,5.5,1.8,Iris-other 119 | 7.7,3.8,6.7,2.2,Iris-other 120 | 7.7,2.6,6.9,2.3,Iris-other 121 | 6.0,2.2,5.0,1.5,Iris-other 122 | 6.9,3.2,5.7,2.3,Iris-other 123 | 5.6,2.8,4.9,2.0,Iris-other 124 | 7.7,2.8,6.7,2.0,Iris-other 125 | 6.3,2.7,4.9,1.8,Iris-other 126 | 6.7,3.3,5.7,2.1,Iris-other 127 | 7.2,3.2,6.0,1.8,Iris-other 128 | 6.2,2.8,4.8,1.8,Iris-other 129 | 6.1,3.0,4.9,1.8,Iris-other 130 | 6.4,2.8,5.6,2.1,Iris-other 131 | 7.2,3.0,5.8,1.6,Iris-other 132 | 7.4,2.8,6.1,1.9,Iris-other 133 | 7.9,3.8,6.4,2.0,Iris-other 134 | 6.4,2.8,5.6,2.2,Iris-other 135 | 6.3,2.8,5.1,1.5,Iris-other 136 | 6.1,2.6,5.6,1.4,Iris-other 137 | 7.7,3.0,6.1,2.3,Iris-other 138 | 6.3,3.4,5.6,2.4,Iris-other 139 | 6.4,3.1,5.5,1.8,Iris-other 140 | 6.0,3.0,4.8,1.8,Iris-other 141 | 6.9,3.1,5.4,2.1,Iris-other 142 | 6.7,3.1,5.6,2.4,Iris-other 143 | 6.9,3.1,5.1,2.3,Iris-other 144 | 5.8,2.7,5.1,1.9,Iris-other 145 | 6.8,3.2,5.9,2.3,Iris-other 146 | 6.7,3.3,5.7,2.5,Iris-other 147 | 6.7,3.0,5.2,2.3,Iris-other 148 | 6.3,2.5,5.0,1.9,Iris-other 149 | 6.5,3.0,5.2,2.0,Iris-other 150 | 6.2,3.4,5.4,2.3,Iris-other 151 | 5.9,3.0,5.1,1.8,Iris-other -------------------------------------------------------------------------------- /experiments/datasets/iris/virginica.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,class 2 | 5.1,3.5,1.4,0.2,Iris-other 3 | 4.9,3.0,1.4,0.2,Iris-other 4 | 4.7,3.2,1.3,0.2,Iris-other 5 | 4.6,3.1,1.5,0.2,Iris-other 6 | 5.0,3.6,1.4,0.2,Iris-other 7 | 5.4,3.9,1.7,0.4,Iris-other 8 | 4.6,3.4,1.4,0.3,Iris-other 9 | 5.0,3.4,1.5,0.2,Iris-other 10 | 4.4,2.9,1.4,0.2,Iris-other 11 | 4.9,3.1,1.5,0.1,Iris-other 12 | 5.4,3.7,1.5,0.2,Iris-other 13 | 4.8,3.4,1.6,0.2,Iris-other 14 | 4.8,3.0,1.4,0.1,Iris-other 15 | 4.3,3.0,1.1,0.1,Iris-other 16 | 5.8,4.0,1.2,0.2,Iris-other 17 | 5.7,4.4,1.5,0.4,Iris-other 18 | 5.4,3.9,1.3,0.4,Iris-other 19 | 5.1,3.5,1.4,0.3,Iris-other 20 | 5.7,3.8,1.7,0.3,Iris-other 21 | 5.1,3.8,1.5,0.3,Iris-other 22 | 5.4,3.4,1.7,0.2,Iris-other 23 | 5.1,3.7,1.5,0.4,Iris-other 24 | 4.6,3.6,1.0,0.2,Iris-other 25 | 5.1,3.3,1.7,0.5,Iris-other 26 | 4.8,3.4,1.9,0.2,Iris-other 27 | 5.0,3.0,1.6,0.2,Iris-other 28 | 5.0,3.4,1.6,0.4,Iris-other 29 | 5.2,3.5,1.5,0.2,Iris-other 30 | 5.2,3.4,1.4,0.2,Iris-other 31 | 4.7,3.2,1.6,0.2,Iris-other 32 | 4.8,3.1,1.6,0.2,Iris-other 33 | 5.4,3.4,1.5,0.4,Iris-other 34 | 5.2,4.1,1.5,0.1,Iris-other 35 | 5.5,4.2,1.4,0.2,Iris-other 36 | 4.9,3.1,1.5,0.1,Iris-other 37 | 5.0,3.2,1.2,0.2,Iris-other 38 | 5.5,3.5,1.3,0.2,Iris-other 39 | 4.9,3.1,1.5,0.1,Iris-other 40 | 4.4,3.0,1.3,0.2,Iris-other 41 | 5.1,3.4,1.5,0.2,Iris-other 42 | 5.0,3.5,1.3,0.3,Iris-other 43 | 4.5,2.3,1.3,0.3,Iris-other 44 | 4.4,3.2,1.3,0.2,Iris-other 45 | 5.0,3.5,1.6,0.6,Iris-other 46 | 5.1,3.8,1.9,0.4,Iris-other 47 | 4.8,3.0,1.4,0.3,Iris-other 48 | 5.1,3.8,1.6,0.2,Iris-other 49 | 4.6,3.2,1.4,0.2,Iris-other 50 | 5.3,3.7,1.5,0.2,Iris-other 51 | 5.0,3.3,1.4,0.2,Iris-other 52 | 7.0,3.2,4.7,1.4,Iris-other 53 | 6.4,3.2,4.5,1.5,Iris-other 54 | 6.9,3.1,4.9,1.5,Iris-other 55 | 5.5,2.3,4.0,1.3,Iris-other 56 | 6.5,2.8,4.6,1.5,Iris-other 57 | 5.7,2.8,4.5,1.3,Iris-other 58 | 6.3,3.3,4.7,1.6,Iris-other 59 | 4.9,2.4,3.3,1.0,Iris-other 60 | 6.6,2.9,4.6,1.3,Iris-other 61 | 5.2,2.7,3.9,1.4,Iris-other 62 | 5.0,2.0,3.5,1.0,Iris-other 63 | 5.9,3.0,4.2,1.5,Iris-other 64 | 6.0,2.2,4.0,1.0,Iris-other 65 | 6.1,2.9,4.7,1.4,Iris-other 66 | 5.6,2.9,3.6,1.3,Iris-other 67 | 6.7,3.1,4.4,1.4,Iris-other 68 | 5.6,3.0,4.5,1.5,Iris-other 69 | 5.8,2.7,4.1,1.0,Iris-other 70 | 6.2,2.2,4.5,1.5,Iris-other 71 | 5.6,2.5,3.9,1.1,Iris-other 72 | 5.9,3.2,4.8,1.8,Iris-other 73 | 6.1,2.8,4.0,1.3,Iris-other 74 | 6.3,2.5,4.9,1.5,Iris-other 75 | 6.1,2.8,4.7,1.2,Iris-other 76 | 6.4,2.9,4.3,1.3,Iris-other 77 | 6.6,3.0,4.4,1.4,Iris-other 78 | 6.8,2.8,4.8,1.4,Iris-other 79 | 6.7,3.0,5.0,1.7,Iris-other 80 | 6.0,2.9,4.5,1.5,Iris-other 81 | 5.7,2.6,3.5,1.0,Iris-other 82 | 5.5,2.4,3.8,1.1,Iris-other 83 | 5.5,2.4,3.7,1.0,Iris-other 84 | 5.8,2.7,3.9,1.2,Iris-other 85 | 6.0,2.7,5.1,1.6,Iris-other 86 | 5.4,3.0,4.5,1.5,Iris-other 87 | 6.0,3.4,4.5,1.6,Iris-other 88 | 6.7,3.1,4.7,1.5,Iris-other 89 | 6.3,2.3,4.4,1.3,Iris-other 90 | 5.6,3.0,4.1,1.3,Iris-other 91 | 5.5,2.5,4.0,1.3,Iris-other 92 | 5.5,2.6,4.4,1.2,Iris-other 93 | 6.1,3.0,4.6,1.4,Iris-other 94 | 5.8,2.6,4.0,1.2,Iris-other 95 | 5.0,2.3,3.3,1.0,Iris-other 96 | 5.6,2.7,4.2,1.3,Iris-other 97 | 5.7,3.0,4.2,1.2,Iris-other 98 | 5.7,2.9,4.2,1.3,Iris-other 99 | 6.2,2.9,4.3,1.3,Iris-other 100 | 5.1,2.5,3.0,1.1,Iris-other 101 | 5.7,2.8,4.1,1.3,Iris-other 102 | 6.3,3.3,6.0,2.5,Iris-virginica 103 | 5.8,2.7,5.1,1.9,Iris-virginica 104 | 7.1,3.0,5.9,2.1,Iris-virginica 105 | 6.3,2.9,5.6,1.8,Iris-virginica 106 | 6.5,3.0,5.8,2.2,Iris-virginica 107 | 7.6,3.0,6.6,2.1,Iris-virginica 108 | 4.9,2.5,4.5,1.7,Iris-virginica 109 | 7.3,2.9,6.3,1.8,Iris-virginica 110 | 6.7,2.5,5.8,1.8,Iris-virginica 111 | 7.2,3.6,6.1,2.5,Iris-virginica 112 | 6.5,3.2,5.1,2.0,Iris-virginica 113 | 6.4,2.7,5.3,1.9,Iris-virginica 114 | 6.8,3.0,5.5,2.1,Iris-virginica 115 | 5.7,2.5,5.0,2.0,Iris-virginica 116 | 5.8,2.8,5.1,2.4,Iris-virginica 117 | 6.4,3.2,5.3,2.3,Iris-virginica 118 | 6.5,3.0,5.5,1.8,Iris-virginica 119 | 7.7,3.8,6.7,2.2,Iris-virginica 120 | 7.7,2.6,6.9,2.3,Iris-virginica 121 | 6.0,2.2,5.0,1.5,Iris-virginica 122 | 6.9,3.2,5.7,2.3,Iris-virginica 123 | 5.6,2.8,4.9,2.0,Iris-virginica 124 | 7.7,2.8,6.7,2.0,Iris-virginica 125 | 6.3,2.7,4.9,1.8,Iris-virginica 126 | 6.7,3.3,5.7,2.1,Iris-virginica 127 | 7.2,3.2,6.0,1.8,Iris-virginica 128 | 6.2,2.8,4.8,1.8,Iris-virginica 129 | 6.1,3.0,4.9,1.8,Iris-virginica 130 | 6.4,2.8,5.6,2.1,Iris-virginica 131 | 7.2,3.0,5.8,1.6,Iris-virginica 132 | 7.4,2.8,6.1,1.9,Iris-virginica 133 | 7.9,3.8,6.4,2.0,Iris-virginica 134 | 6.4,2.8,5.6,2.2,Iris-virginica 135 | 6.3,2.8,5.1,1.5,Iris-virginica 136 | 6.1,2.6,5.6,1.4,Iris-virginica 137 | 7.7,3.0,6.1,2.3,Iris-virginica 138 | 6.3,3.4,5.6,2.4,Iris-virginica 139 | 6.4,3.1,5.5,1.8,Iris-virginica 140 | 6.0,3.0,4.8,1.8,Iris-virginica 141 | 6.9,3.1,5.4,2.1,Iris-virginica 142 | 6.7,3.1,5.6,2.4,Iris-virginica 143 | 6.9,3.1,5.1,2.3,Iris-virginica 144 | 5.8,2.7,5.1,1.9,Iris-virginica 145 | 6.8,3.2,5.9,2.3,Iris-virginica 146 | 6.7,3.3,5.7,2.5,Iris-virginica 147 | 6.7,3.0,5.2,2.3,Iris-virginica 148 | 6.3,2.5,5.0,1.9,Iris-virginica 149 | 6.5,3.0,5.2,2.0,Iris-virginica 150 | 6.2,3.4,5.4,2.3,Iris-virginica 151 | 5.9,3.0,5.1,1.8,Iris-virginica -------------------------------------------------------------------------------- /experiments/datasets/monk_1/data.csv: -------------------------------------------------------------------------------- 1 | a1_1,a1_2,a2_1,a2_2,a3_1,a4_1,a4_2,a5_1,a5_2,a5_3,a6_1,class_1 2 | 1,0,1,0,1,1,0,0,0,1,1,1 3 | 1,0,1,0,1,1,0,0,0,1,0,1 4 | 1,0,1,0,1,0,0,0,1,0,1,1 5 | 1,0,1,0,1,0,0,0,0,1,0,1 6 | 1,0,1,0,0,1,0,0,1,0,1,1 7 | 1,0,1,0,0,1,0,0,1,0,0,1 8 | 1,0,1,0,0,0,1,0,0,1,1,1 9 | 1,0,1,0,0,0,1,0,0,0,1,1 10 | 1,0,1,0,0,0,0,1,0,0,0,1 11 | 1,0,0,1,1,1,0,1,0,0,0,1 12 | 1,0,0,1,1,1,0,0,1,0,1,0 13 | 1,0,0,1,1,1,0,0,0,1,1,0 14 | 1,0,0,1,1,1,0,0,0,0,0,0 15 | 1,0,0,1,1,0,1,1,0,0,1,1 16 | 1,0,0,1,1,0,1,0,0,1,1,0 17 | 1,0,0,1,1,0,1,0,0,1,0,0 18 | 1,0,0,1,1,0,1,0,0,0,0,0 19 | 1,0,0,1,1,0,0,0,1,0,1,0 20 | 1,0,0,1,1,0,0,0,0,0,0,0 21 | 1,0,0,1,0,1,0,0,1,0,0,0 22 | 1,0,0,1,0,0,1,0,0,1,0,0 23 | 1,0,0,1,0,0,1,0,0,0,1,0 24 | 1,0,0,1,0,0,1,0,0,0,0,0 25 | 1,0,0,1,0,0,0,0,1,0,0,0 26 | 1,0,0,1,0,0,0,0,0,1,1,0 27 | 1,0,0,1,0,0,0,0,0,1,0,0 28 | 1,0,0,0,1,1,0,0,1,0,1,0 29 | 1,0,0,0,1,1,0,0,0,0,1,0 30 | 1,0,0,0,1,0,1,0,1,0,1,0 31 | 1,0,0,0,1,0,1,0,0,0,1,0 32 | 1,0,0,0,1,0,0,1,0,0,0,1 33 | 1,0,0,0,1,0,0,0,1,0,0,0 34 | 1,0,0,0,1,0,0,0,0,1,1,0 35 | 1,0,0,0,1,0,0,0,0,0,1,0 36 | 1,0,0,0,1,0,0,0,0,0,0,0 37 | 1,0,0,0,0,1,0,0,1,0,0,0 38 | 1,0,0,0,0,0,1,1,0,0,0,1 39 | 1,0,0,0,0,0,1,0,1,0,0,0 40 | 1,0,0,0,0,0,1,0,0,1,0,0 41 | 1,0,0,0,0,0,1,0,0,0,1,0 42 | 1,0,0,0,0,0,1,0,0,0,0,0 43 | 1,0,0,0,0,0,0,1,0,0,1,1 44 | 1,0,0,0,0,0,0,0,1,0,1,0 45 | 1,0,0,0,0,0,0,0,0,0,1,0 46 | 1,0,0,0,0,0,0,0,0,0,0,0 47 | 0,1,1,0,1,1,0,0,0,1,1,0 48 | 0,1,1,0,1,1,0,0,0,1,0,0 49 | 0,1,1,0,1,0,1,1,0,0,1,1 50 | 0,1,1,0,1,0,1,1,0,0,0,1 51 | 0,1,1,0,1,0,1,0,1,0,0,0 52 | 0,1,1,0,1,0,1,0,0,1,1,0 53 | 0,1,1,0,1,0,1,0,0,0,1,0 54 | 0,1,1,0,1,0,1,0,0,0,0,0 55 | 0,1,1,0,1,0,0,0,0,0,1,0 56 | 0,1,1,0,0,1,0,0,1,0,0,0 57 | 0,1,1,0,0,1,0,0,0,1,1,0 58 | 0,1,1,0,0,1,0,0,0,0,0,0 59 | 0,1,1,0,0,0,1,0,0,1,1,0 60 | 0,1,1,0,0,0,1,0,0,0,0,0 61 | 0,1,1,0,0,0,0,0,1,0,0,0 62 | 0,1,1,0,0,0,0,0,0,0,1,0 63 | 0,1,0,1,1,1,0,0,1,0,1,1 64 | 0,1,0,1,1,1,0,0,1,0,0,1 65 | 0,1,0,1,1,1,0,0,0,1,1,1 66 | 0,1,0,1,1,0,1,0,0,1,0,1 67 | 0,1,0,1,1,0,0,1,0,0,1,1 68 | 0,1,0,1,1,0,0,1,0,0,0,1 69 | 0,1,0,1,1,0,0,0,1,0,0,1 70 | 0,1,0,1,1,0,0,0,0,1,0,1 71 | 0,1,0,1,1,0,0,0,0,0,0,1 72 | 0,1,0,1,0,1,0,1,0,0,1,1 73 | 0,1,0,1,0,1,0,0,0,1,0,1 74 | 0,1,0,1,0,1,0,0,0,0,1,1 75 | 0,1,0,1,0,1,0,0,0,0,0,1 76 | 0,1,0,1,0,0,1,0,1,0,1,1 77 | 0,1,0,1,0,0,0,0,0,0,1,1 78 | 0,1,0,0,1,1,0,1,0,0,1,1 79 | 0,1,0,0,1,0,1,1,0,0,1,1 80 | 0,1,0,0,1,0,1,0,0,1,1,0 81 | 0,1,0,0,1,0,0,1,0,0,0,1 82 | 0,1,0,0,1,0,0,0,0,1,1,0 83 | 0,1,0,0,1,0,0,0,0,0,0,0 84 | 0,1,0,0,0,1,0,0,0,1,0,0 85 | 0,1,0,0,0,0,1,1,0,0,1,1 86 | 0,1,0,0,0,0,1,1,0,0,0,1 87 | 0,1,0,0,0,0,1,0,1,0,1,0 88 | 0,1,0,0,0,0,0,0,0,1,0,0 89 | 0,0,1,0,1,1,0,1,0,0,1,1 90 | 0,0,1,0,1,1,0,1,0,0,0,1 91 | 0,0,1,0,1,0,1,1,0,0,1,1 92 | 0,0,1,0,1,0,1,0,1,0,0,0 93 | 0,0,1,0,1,0,0,0,1,0,0,0 94 | 0,0,1,0,0,1,0,1,0,0,1,1 95 | 0,0,1,0,0,1,0,0,1,0,0,0 96 | 0,0,1,0,0,0,1,0,1,0,0,0 97 | 0,0,1,0,0,0,1,0,0,1,0,0 98 | 0,0,1,0,0,0,0,0,1,0,0,0 99 | 0,0,0,1,1,1,0,1,0,0,1,1 100 | 0,0,0,1,1,1,0,0,0,0,0,0 101 | 0,0,0,1,1,0,1,1,0,0,0,1 102 | 0,0,0,1,1,0,1,0,0,0,0,0 103 | 0,0,0,1,0,1,0,1,0,0,1,1 104 | 0,0,0,1,0,1,0,1,0,0,0,1 105 | 0,0,0,1,0,1,0,0,0,1,0,0 106 | 0,0,0,1,0,0,0,1,0,0,1,1 107 | 0,0,0,1,0,0,0,0,1,0,1,0 108 | 0,0,0,1,0,0,0,0,0,0,1,0 109 | 0,0,0,0,1,1,0,1,0,0,1,1 110 | 0,0,0,0,1,1,0,0,1,0,1,1 111 | 0,0,0,0,1,1,0,0,0,0,0,1 112 | 0,0,0,0,1,0,1,0,0,1,0,1 113 | 0,0,0,0,1,0,1,0,0,0,0,1 114 | 0,0,0,0,1,0,0,1,0,0,0,1 115 | 0,0,0,0,1,0,0,0,1,0,1,1 116 | 0,0,0,0,1,0,0,0,1,0,0,1 117 | 0,0,0,0,1,0,0,0,0,0,0,1 118 | 0,0,0,0,0,1,0,1,0,0,1,1 119 | 0,0,0,0,0,1,0,0,0,1,0,1 120 | 0,0,0,0,0,1,0,0,0,0,1,1 121 | 0,0,0,0,0,1,0,0,0,0,0,1 122 | 0,0,0,0,0,0,0,1,0,0,0,1 123 | 0,0,0,0,0,0,0,0,1,0,0,1 124 | 0,0,0,0,0,0,0,0,0,1,0,1 125 | 0,0,0,0,0,0,0,0,0,0,0,1 -------------------------------------------------------------------------------- /experiments/datasets/monk_2/data.csv: -------------------------------------------------------------------------------- 1 | a1_1,a1_2,a2_1,a2_2,a3_1,a4_1,a4_2,a5_1,a5_2,a5_3,a6_1,class_1 2 | 1,0,1,0,1,1,0,0,1,0,0,0 3 | 1,0,1,0,1,1,0,0,0,0,1,0 4 | 1,0,1,0,1,0,1,1,0,0,1,0 5 | 1,0,1,0,1,0,1,1,0,0,0,0 6 | 1,0,1,0,1,0,1,0,1,0,1,0 7 | 1,0,1,0,1,0,1,0,0,1,1,0 8 | 1,0,1,0,1,0,1,0,0,0,1,0 9 | 1,0,1,0,1,0,0,0,1,0,1,0 10 | 1,0,1,0,1,0,0,0,0,0,1,0 11 | 1,0,1,0,0,1,0,1,0,0,1,0 12 | 1,0,1,0,0,1,0,1,0,0,0,0 13 | 1,0,1,0,0,0,1,0,0,1,1,0 14 | 1,0,1,0,0,0,1,0,0,0,1,0 15 | 1,0,1,0,0,0,1,0,0,0,0,1 16 | 1,0,1,0,0,0,0,1,0,0,0,0 17 | 1,0,1,0,0,0,0,0,1,0,0,1 18 | 1,0,0,1,1,1,0,1,0,0,0,0 19 | 1,0,0,1,1,0,1,1,0,0,0,0 20 | 1,0,0,1,1,0,1,0,1,0,0,1 21 | 1,0,0,1,1,0,1,0,0,1,1,0 22 | 1,0,0,1,1,0,1,0,0,1,0,1 23 | 1,0,0,1,1,0,1,0,0,0,1,0 24 | 1,0,0,1,1,0,0,1,0,0,1,0 25 | 1,0,0,1,1,0,0,1,0,0,0,0 26 | 1,0,0,1,1,0,0,0,1,0,0,1 27 | 1,0,0,1,1,0,0,0,0,1,1,0 28 | 1,0,0,1,1,0,0,0,0,1,0,1 29 | 1,0,0,1,1,0,0,0,0,0,1,0 30 | 1,0,0,1,1,0,0,0,0,0,0,1 31 | 1,0,0,1,0,1,0,0,1,0,1,0 32 | 1,0,0,1,0,1,0,0,0,0,1,0 33 | 1,0,0,1,0,0,1,0,0,1,1,1 34 | 1,0,0,1,0,0,1,0,0,0,1,1 35 | 1,0,0,1,0,0,0,1,0,0,1,0 36 | 1,0,0,1,0,0,0,1,0,0,0,1 37 | 1,0,0,1,0,0,0,0,0,1,1,1 38 | 1,0,0,1,0,0,0,0,0,1,0,0 39 | 1,0,0,1,0,0,0,0,0,0,1,1 40 | 1,0,0,1,0,0,0,0,0,0,0,0 41 | 1,0,0,0,1,1,0,1,0,0,0,0 42 | 1,0,0,0,1,1,0,0,1,0,0,0 43 | 1,0,0,0,1,1,0,0,0,1,1,0 44 | 1,0,0,0,1,1,0,0,0,1,0,0 45 | 1,0,0,0,1,0,1,0,1,0,1,0 46 | 1,0,0,0,1,0,1,0,1,0,0,1 47 | 1,0,0,0,1,0,1,0,0,1,0,1 48 | 1,0,0,0,1,0,1,0,0,0,1,0 49 | 1,0,0,0,1,0,0,0,1,0,0,1 50 | 1,0,0,0,1,0,0,0,0,1,1,0 51 | 1,0,0,0,1,0,0,0,0,0,0,1 52 | 1,0,0,0,0,1,0,0,0,1,1,0 53 | 1,0,0,0,0,1,0,0,0,1,0,1 54 | 1,0,0,0,0,1,0,0,0,0,1,0 55 | 1,0,0,0,0,0,1,1,0,0,0,1 56 | 1,0,0,0,0,0,1,0,0,1,0,0 57 | 1,0,0,0,0,0,1,0,0,0,0,0 58 | 1,0,0,0,0,0,0,0,1,0,1,1 59 | 0,1,1,0,1,1,0,1,0,0,1,0 60 | 0,1,1,0,1,1,0,0,1,0,0,0 61 | 0,1,1,0,1,1,0,0,0,1,1,0 62 | 0,1,1,0,1,0,1,0,1,0,0,1 63 | 0,1,1,0,1,0,0,1,0,0,0,0 64 | 0,1,1,0,1,0,0,0,1,0,0,1 65 | 0,1,1,0,1,0,0,0,0,1,0,1 66 | 0,1,1,0,1,0,0,0,0,0,1,0 67 | 0,1,1,0,0,1,0,1,0,0,1,0 68 | 0,1,1,0,0,1,0,0,1,0,0,1 69 | 0,1,1,0,0,1,0,0,0,0,1,0 70 | 0,1,1,0,0,0,1,0,1,0,1,1 71 | 0,1,1,0,0,0,1,0,0,0,0,0 72 | 0,1,1,0,0,0,0,1,0,0,1,0 73 | 0,1,1,0,0,0,0,1,0,0,0,1 74 | 0,1,1,0,0,0,0,0,1,0,0,0 75 | 0,1,1,0,0,0,0,0,0,1,0,0 76 | 0,1,1,0,0,0,0,0,0,0,0,0 77 | 0,1,0,1,1,1,0,0,0,1,1,0 78 | 0,1,0,1,1,1,0,0,0,0,0,1 79 | 0,1,0,1,1,0,1,1,0,0,1,0 80 | 0,1,0,1,1,0,1,0,0,1,1,1 81 | 0,1,0,1,1,0,0,0,0,1,1,1 82 | 0,1,0,1,1,0,0,0,0,1,0,0 83 | 0,1,0,1,1,0,0,0,0,0,1,1 84 | 0,1,0,1,0,1,0,1,0,0,1,0 85 | 0,1,0,1,0,1,0,0,1,0,0,0 86 | 0,1,0,1,0,1,0,0,0,1,0,0 87 | 0,1,0,1,0,1,0,0,0,0,1,1 88 | 0,1,0,1,0,1,0,0,0,0,0,0 89 | 0,1,0,1,0,0,1,1,0,0,1,1 90 | 0,1,0,1,0,0,1,0,1,0,0,0 91 | 0,1,0,1,0,0,1,0,0,1,1,0 92 | 0,1,0,1,0,0,0,1,0,0,1,1 93 | 0,1,0,1,0,0,0,0,1,0,1,0 94 | 0,1,0,1,0,0,0,0,1,0,0,0 95 | 0,1,0,1,0,0,0,0,0,0,0,0 96 | 0,1,0,0,1,1,0,1,0,0,1,0 97 | 0,1,0,0,1,1,0,1,0,0,0,0 98 | 0,1,0,0,1,1,0,0,0,1,0,1 99 | 0,1,0,0,1,0,1,1,0,0,1,0 100 | 0,1,0,0,1,0,1,0,0,1,1,1 101 | 0,1,0,0,1,0,1,0,0,1,0,0 102 | 0,1,0,0,1,0,1,0,0,0,0,0 103 | 0,1,0,0,1,0,0,1,0,0,0,1 104 | 0,1,0,0,1,0,0,0,1,0,1,1 105 | 0,1,0,0,1,0,0,0,0,0,1,1 106 | 0,1,0,0,0,1,0,1,0,0,0,1 107 | 0,1,0,0,0,1,0,0,1,0,1,1 108 | 0,1,0,0,0,1,0,0,0,1,1,1 109 | 0,1,0,0,0,1,0,0,0,0,0,0 110 | 0,1,0,0,0,0,1,1,0,0,1,1 111 | 0,1,0,0,0,0,1,0,1,0,1,0 112 | 0,1,0,0,0,0,1,0,0,1,0,0 113 | 0,1,0,0,0,0,0,0,0,1,1,0 114 | 0,1,0,0,0,0,0,0,0,1,0,0 115 | 0,1,0,0,0,0,0,0,0,0,0,0 116 | 0,0,1,0,1,1,0,0,0,0,1,0 117 | 0,0,1,0,1,0,1,1,0,0,0,0 118 | 0,0,1,0,1,0,1,0,1,0,0,1 119 | 0,0,1,0,1,0,1,0,0,1,0,1 120 | 0,0,1,0,1,0,1,0,0,0,1,0 121 | 0,0,1,0,1,0,1,0,0,0,0,1 122 | 0,0,1,0,1,0,0,1,0,0,1,0 123 | 0,0,1,0,1,0,0,1,0,0,0,0 124 | 0,0,1,0,1,0,0,0,1,0,0,1 125 | 0,0,1,0,1,0,0,0,0,1,0,1 126 | 0,0,1,0,0,1,0,1,0,0,1,0 127 | 0,0,1,0,0,1,0,0,1,0,0,1 128 | 0,0,1,0,0,1,0,0,0,1,1,0 129 | 0,0,1,0,0,1,0,0,0,1,0,1 130 | 0,0,1,0,0,1,0,0,0,0,1,0 131 | 0,0,1,0,0,1,0,0,0,0,0,1 132 | 0,0,1,0,0,0,1,0,1,0,1,1 133 | 0,0,1,0,0,0,0,1,0,0,0,1 134 | 0,0,1,0,0,0,0,0,1,0,1,1 135 | 0,0,1,0,0,0,0,0,1,0,0,0 136 | 0,0,1,0,0,0,0,0,0,0,0,0 137 | 0,0,0,1,1,1,0,1,0,0,0,0 138 | 0,0,0,1,1,1,0,0,1,0,0,1 139 | 0,0,0,1,1,1,0,0,0,1,1,0 140 | 0,0,0,1,1,1,0,0,0,1,0,1 141 | 0,0,0,1,1,0,1,1,0,0,0,1 142 | 0,0,0,1,1,0,1,0,1,0,1,1 143 | 0,0,0,1,1,0,0,1,0,0,1,0 144 | 0,0,0,1,1,0,0,0,1,0,1,1 145 | 0,0,0,1,1,0,0,0,0,1,1,1 146 | 0,0,0,1,1,0,0,0,0,1,0,0 147 | 0,0,0,1,0,1,0,1,0,0,1,0 148 | 0,0,0,1,0,1,0,0,1,0,0,0 149 | 0,0,0,1,0,1,0,0,0,1,1,1 150 | 0,0,0,1,0,1,0,0,0,1,0,0 151 | 0,0,0,1,0,0,1,1,0,0,1,1 152 | 0,0,0,1,0,0,1,0,1,0,1,0 153 | 0,0,0,1,0,0,1,0,1,0,0,0 154 | 0,0,0,1,0,0,1,0,0,1,0,0 155 | 0,0,0,1,0,0,0,1,0,0,1,1 156 | 0,0,0,1,0,0,0,0,0,1,0,0 157 | 0,0,0,1,0,0,0,0,0,0,0,0 158 | 0,0,0,0,1,1,0,1,0,0,1,0 159 | 0,0,0,0,1,1,0,0,1,0,1,0 160 | 0,0,0,0,1,1,0,0,0,1,1,0 161 | 0,0,0,0,1,1,0,0,0,1,0,1 162 | 0,0,0,0,1,0,1,0,0,1,0,0 163 | 0,0,0,0,0,1,0,1,0,0,1,0 164 | 0,0,0,0,0,0,1,1,0,0,1,1 165 | 0,0,0,0,0,0,1,0,1,0,1,0 166 | 0,0,0,0,0,0,1,0,0,1,1,0 167 | 0,0,0,0,0,0,1,0,0,1,0,0 168 | 0,0,0,0,0,0,0,1,0,0,1,1 169 | 0,0,0,0,0,0,0,0,1,0,1,0 170 | 0,0,0,0,0,0,0,0,0,0,0,0 -------------------------------------------------------------------------------- /experiments/datasets/monk_3/data.csv: -------------------------------------------------------------------------------- 1 | a1_1,a1_2,a2_1,a2_2,a3_1,a4_1,a4_2,a5_1,a5_2,a5_3,a6_1,class_1 2 | 1,0,1,0,1,1,0,1,0,0,0,1 3 | 1,0,1,0,1,1,0,0,1,0,1,1 4 | 1,0,1,0,1,1,0,0,1,0,0,1 5 | 1,0,1,0,1,1,0,0,0,1,1,0 6 | 1,0,1,0,1,1,0,0,0,0,1,0 7 | 1,0,1,0,1,0,1,1,0,0,1,1 8 | 1,0,1,0,1,0,1,0,1,0,0,1 9 | 1,0,1,0,1,0,1,0,0,0,0,0 10 | 1,0,1,0,0,1,0,0,1,0,0,1 11 | 1,0,1,0,0,1,0,0,0,0,0,0 12 | 1,0,1,0,0,0,1,0,1,0,0,1 13 | 1,0,1,0,0,0,1,0,0,0,1,0 14 | 1,0,1,0,0,0,1,0,0,0,0,0 15 | 1,0,1,0,0,0,0,1,0,0,1,1 16 | 1,0,1,0,0,0,0,1,0,0,0,1 17 | 1,0,1,0,0,0,0,0,0,1,1,1 18 | 1,0,1,0,0,0,0,0,0,1,0,1 19 | 1,0,0,1,1,1,0,0,0,1,1,1 20 | 1,0,0,1,1,0,1,0,1,0,1,1 21 | 1,0,0,1,1,0,1,0,1,0,0,1 22 | 1,0,0,1,1,0,1,0,0,1,1,0 23 | 1,0,0,1,1,0,0,1,0,0,1,1 24 | 1,0,0,1,1,0,0,1,0,0,0,1 25 | 1,0,0,1,1,0,0,0,1,0,1,1 26 | 1,0,0,1,1,0,0,0,1,0,0,1 27 | 1,0,0,1,1,0,0,0,0,1,0,1 28 | 1,0,0,1,1,0,0,0,0,0,1,0 29 | 1,0,0,1,0,1,0,0,0,1,1,1 30 | 1,0,0,1,0,1,0,0,0,0,0,0 31 | 1,0,0,1,0,0,1,1,0,0,1,1 32 | 1,0,0,1,0,0,1,0,1,0,1,1 33 | 1,0,0,1,0,0,1,0,1,0,0,1 34 | 1,0,0,1,0,0,0,1,0,0,1,1 35 | 1,0,0,1,0,0,0,0,1,0,1,1 36 | 1,0,0,1,0,0,0,0,1,0,0,1 37 | 1,0,0,0,1,1,0,0,1,0,1,0 38 | 1,0,0,0,1,1,0,0,0,0,1,0 39 | 1,0,0,0,1,0,1,0,0,1,0,0 40 | 1,0,0,0,1,0,1,0,0,0,1,0 41 | 1,0,0,0,1,0,0,1,0,0,1,0 42 | 1,0,0,0,1,0,0,0,0,1,1,0 43 | 1,0,0,0,0,1,0,1,0,0,1,0 44 | 1,0,0,0,0,1,0,1,0,0,0,0 45 | 1,0,0,0,0,1,0,0,1,0,1,0 46 | 1,0,0,0,0,1,0,0,0,0,0,0 47 | 1,0,0,0,0,0,1,0,0,1,0,0 48 | 1,0,0,0,0,0,1,0,0,0,0,0 49 | 1,0,0,0,0,0,0,0,0,0,1,0 50 | 0,1,1,0,1,1,0,1,0,0,1,1 51 | 0,1,1,0,1,1,0,1,0,0,0,1 52 | 0,1,1,0,1,1,0,0,0,0,1,0 53 | 0,1,1,0,1,1,0,0,0,0,0,0 54 | 0,1,1,0,1,0,1,1,0,0,1,1 55 | 0,1,1,0,1,0,1,1,0,0,0,1 56 | 0,1,1,0,1,0,0,0,1,0,0,1 57 | 0,1,1,0,1,0,0,0,0,1,0,1 58 | 0,1,1,0,1,0,0,0,0,0,1,0 59 | 0,1,1,0,0,1,0,0,1,0,0,1 60 | 0,1,1,0,0,0,1,0,0,0,1,0 61 | 0,1,1,0,0,0,0,1,0,0,0,1 62 | 0,1,0,1,1,1,0,0,0,1,0,1 63 | 0,1,0,1,1,1,0,0,0,0,0,0 64 | 0,1,0,1,1,0,1,1,0,0,0,1 65 | 0,1,0,1,1,0,1,0,1,0,1,0 66 | 0,1,0,1,1,0,0,1,0,0,1,1 67 | 0,1,0,1,1,0,0,0,1,0,0,1 68 | 0,1,0,1,1,0,0,0,0,1,1,0 69 | 0,1,0,1,1,0,0,0,0,1,0,0 70 | 0,1,0,1,1,0,0,0,0,0,0,0 71 | 0,1,0,1,0,1,0,0,1,0,0,1 72 | 0,1,0,1,0,0,1,1,0,0,0,1 73 | 0,1,0,1,0,0,1,0,0,1,1,1 74 | 0,1,0,1,0,0,1,0,0,1,0,1 75 | 0,1,0,1,0,0,0,0,0,0,1,0 76 | 0,1,0,0,1,1,0,0,0,1,1,1 77 | 0,1,0,0,1,0,1,1,0,0,1,0 78 | 0,1,0,0,1,0,1,0,1,0,1,0 79 | 0,1,0,0,1,0,1,0,1,0,0,0 80 | 0,1,0,0,1,0,1,0,0,1,0,0 81 | 0,1,0,0,1,0,0,0,0,1,1,0 82 | 0,1,0,0,0,1,0,1,0,0,0,0 83 | 0,1,0,0,0,1,0,0,1,0,0,0 84 | 0,1,0,0,0,1,0,0,0,0,1,0 85 | 0,1,0,0,0,0,1,0,0,1,1,0 86 | 0,1,0,0,0,0,1,0,0,0,0,0 87 | 0,1,0,0,0,0,0,1,0,0,1,0 88 | 0,1,0,0,0,0,0,0,1,0,1,0 89 | 0,1,0,0,0,0,0,0,0,0,0,0 90 | 0,0,1,0,1,1,0,1,0,0,1,1 91 | 0,0,1,0,1,1,0,0,1,0,1,1 92 | 0,0,1,0,1,1,0,0,0,1,1,1 93 | 0,0,1,0,1,0,1,0,0,0,0,0 94 | 0,0,1,0,1,0,0,1,0,0,0,1 95 | 0,0,1,0,1,0,0,0,0,0,0,0 96 | 0,0,1,0,0,1,0,0,1,0,1,1 97 | 0,0,1,0,0,0,1,0,0,1,0,1 98 | 0,0,1,0,0,0,1,0,0,0,0,0 99 | 0,0,1,0,0,0,0,1,0,0,1,1 100 | 0,0,0,1,1,1,0,0,1,0,0,1 101 | 0,0,0,1,1,1,0,0,0,0,1,0 102 | 0,0,0,1,1,0,1,0,0,1,1,1 103 | 0,0,0,1,1,0,0,1,0,0,0,1 104 | 0,0,0,1,0,1,0,0,1,0,0,1 105 | 0,0,0,1,0,1,0,0,0,1,0,1 106 | 0,0,0,1,0,0,1,1,0,0,0,1 107 | 0,0,0,1,0,0,0,1,0,0,1,1 108 | 0,0,0,1,0,0,0,0,0,1,0,1 109 | 0,0,0,1,0,0,0,0,0,0,1,0 110 | 0,0,0,0,1,1,0,0,0,1,0,1 111 | 0,0,0,0,1,1,0,0,0,0,1,1 112 | 0,0,0,0,1,0,1,0,0,0,0,0 113 | 0,0,0,0,1,0,0,1,0,0,1,0 114 | 0,0,0,0,1,0,0,0,1,0,1,0 115 | 0,0,0,0,1,0,0,0,1,0,0,0 116 | 0,0,0,0,1,0,0,0,0,0,1,0 117 | 0,0,0,0,0,1,0,1,0,0,1,0 118 | 0,0,0,0,0,1,0,1,0,0,0,0 119 | 0,0,0,0,0,0,1,0,1,0,0,0 120 | 0,0,0,0,0,0,1,0,0,1,0,0 121 | 0,0,0,0,0,0,0,1,0,0,1,0 122 | 0,0,0,0,0,0,0,0,0,1,0,0 123 | 0,0,0,0,0,0,0,0,0,0,0,0 -------------------------------------------------------------------------------- /experiments/datasets/sine/hundred.csv: -------------------------------------------------------------------------------- 1 | x,label 2 | 0.0,0 3 | 0.10101010101010101,1 4 | 0.20202020202020202,0 5 | 0.30303030303030304,0 6 | 0.40404040404040403,1 7 | 0.5050505050505051,1 8 | 0.6060606060606061,1 9 | 0.7070707070707071,1 10 | 0.8080808080808081,1 11 | 0.9090909090909091,1 12 | 1.0101010101010102,1 13 | 1.1111111111111112,1 14 | 1.2121212121212122,1 15 | 1.3131313131313131,1 16 | 1.4141414141414141,1 17 | 1.5151515151515151,1 18 | 1.6161616161616161,1 19 | 1.7171717171717171,1 20 | 1.8181818181818181,1 21 | 1.9191919191919191,1 22 | 2.0202020202020203,1 23 | 2.121212121212121,1 24 | 2.2222222222222223,1 25 | 2.323232323232323,1 26 | 2.4242424242424243,1 27 | 2.525252525252525,0 28 | 2.6262626262626263,1 29 | 2.727272727272727,1 30 | 2.8282828282828283,1 31 | 2.929292929292929,1 32 | 3.0303030303030303,0 33 | 3.131313131313131,0 34 | 3.2323232323232323,1 35 | 3.3333333333333335,0 36 | 3.4343434343434343,0 37 | 3.5353535353535355,0 38 | 3.6363636363636362,0 39 | 3.7373737373737375,0 40 | 3.8383838383838382,1 41 | 3.9393939393939394,0 42 | 4.040404040404041,0 43 | 4.141414141414141,0 44 | 4.242424242424242,0 45 | 4.343434343434343,0 46 | 4.444444444444445,0 47 | 4.545454545454545,0 48 | 4.646464646464646,0 49 | 4.747474747474747,0 50 | 4.848484848484849,0 51 | 4.94949494949495,0 52 | 5.05050505050505,0 53 | 5.151515151515151,0 54 | 5.252525252525253,0 55 | 5.353535353535354,0 56 | 5.454545454545454,0 57 | 5.555555555555555,0 58 | 5.656565656565657,0 59 | 5.757575757575758,1 60 | 5.858585858585858,0 61 | 5.959595959595959,1 62 | 6.0606060606060606,0 63 | 6.161616161616162,0 64 | 6.262626262626262,0 65 | 6.363636363636363,1 66 | 6.4646464646464645,1 67 | 6.565656565656566,0 68 | 6.666666666666667,1 69 | 6.767676767676767,1 70 | 6.8686868686868685,1 71 | 6.96969696969697,1 72 | 7.070707070707071,1 73 | 7.171717171717171,0 74 | 7.2727272727272725,1 75 | 7.373737373737374,1 76 | 7.474747474747475,1 77 | 7.575757575757575,1 78 | 7.6767676767676765,1 79 | 7.777777777777778,1 80 | 7.878787878787879,1 81 | 7.979797979797979,1 82 | 8.080808080808081,1 83 | 8.181818181818182,1 84 | 8.282828282828282,1 85 | 8.383838383838384,1 86 | 8.484848484848484,1 87 | 8.585858585858587,1 88 | 8.686868686868687,1 89 | 8.787878787878787,1 90 | 8.88888888888889,1 91 | 8.98989898989899,1 92 | 9.09090909090909,1 93 | 9.191919191919192,1 94 | 9.292929292929292,1 95 | 9.393939393939394,0 96 | 9.494949494949495,1 97 | 9.595959595959595,1 98 | 9.696969696969697,0 99 | 9.797979797979798,1 100 | 9.8989898989899,0 101 | 10.0,0 102 | -------------------------------------------------------------------------------- /experiments/datasets/sine/ten.csv: -------------------------------------------------------------------------------- 1 | x,label 2 | 0.0,0 3 | 0.1111111111111111,0 4 | 0.2222222222222222,1 5 | 0.3333333333333333,1 6 | 0.4444444444444444,1 7 | 0.5555555555555556,1 8 | 0.6666666666666666,1 9 | 0.7777777777777777,0 10 | 0.8888888888888888,1 11 | 1.0,1 12 | -------------------------------------------------------------------------------- /experiments/datasets/summary.csv: -------------------------------------------------------------------------------- 1 | dataset,number_of_samples,number_of_features,number_of_binary_features,number_of_classes,type 2 | iris,150,4,118,3,classification 3 | compas,12381,22,628,2,classification 4 | fico,10459,23,1917,2,classification 5 | 6 | hcv,1385,28,10548,4,classification 7 | nyclu,11008,78,41420,undetermined,classification 8 | broward,93643,18,undetermined,undetermined,regression 9 | adult,32560,14,22130,2,classification 10 | sleep,37,113,undetermined,undetermined,undetermined 11 | -------------------------------------------------------------------------------- /experiments/datasets/tic-tac-toe/Index.txt: -------------------------------------------------------------------------------- 1 | Index of tic-tac-toe 2 | 3 | 02 Dec 1996 126 Index 4 | 19 Aug 1991 25866 tic-tac-toe.data 5 | 19 Aug 1991 3244 tic-tac-toe.names 6 | -------------------------------------------------------------------------------- /experiments/datasets/tic-tac-toe/tic-tac-toe.names: -------------------------------------------------------------------------------- 1 | 1. Title: Tic-Tac-Toe Endgame database 2 | 3 | 2. Source Information 4 | -- Creator: David W. Aha (aha@cs.jhu.edu) 5 | -- Donor: David W. Aha (aha@cs.jhu.edu) 6 | -- Date: 19 August 1991 7 | 8 | 3. Known Past Usage: 9 | 1. Matheus,~C.~J., \& Rendell,~L.~A. (1989). Constructive 10 | induction on decision trees. In {\it Proceedings of the 11 | Eleventh International Joint Conference on Artificial Intelligence} 12 | (pp. 645--650). Detroit, MI: Morgan Kaufmann. 13 | -- CITRE was applied to 100-instance training and 200-instance test 14 | sets. In a study using various amounts of domain-specific 15 | knowledge, its highest average accuracy was 76.7% (using the 16 | final decision tree created for testing). 17 | 18 | 2. Matheus,~C.~J. (1990). Adding domain knowledge to SBL through 19 | feature construction. In {\it Proceedings of the Eighth National 20 | Conference on Artificial Intelligence} (pp. 803--808). 21 | Boston, MA: AAAI Press. 22 | -- Similar experiments with CITRE, includes learning curves up 23 | to 500-instance training sets but used _all_ instances in the 24 | database for testing. Accuracies reached above 90%, but specific 25 | values are not given (see Chris's dissertation for more details). 26 | 27 | 3. Aha,~D.~W. (1991). Incremental constructive induction: An instance-based 28 | approach. In {\it Proceedings of the Eighth International Workshop 29 | on Machine Learning} (pp. 117--121). Evanston, ILL: Morgan Kaufmann. 30 | -- Used 70% for training, 30% of the instances for testing, evaluated 31 | over 10 trials. Results reported for six algorithms: 32 | -- NewID: 84.0% 33 | -- CN2: 98.1% 34 | -- MBRtalk: 88.4% 35 | -- IB1: 98.1% 36 | -- IB3: 82.0% 37 | -- IB3-CI: 99.1% 38 | -- Results also reported when adding an additional 10 irrelevant 39 | ternary-valued attributes; similar _relative_ results except that 40 | IB1's performance degraded more quickly than the others. 41 | 42 | 4. Relevant Information: 43 | 44 | This database encodes the complete set of possible board configurations 45 | at the end of tic-tac-toe games, where "x" is assumed to have played 46 | first. The target concept is "win for x" (i.e., true when "x" has one 47 | of 8 possible ways to create a "three-in-a-row"). 48 | 49 | Interestingly, this raw database gives a stripped-down decision tree 50 | algorithm (e.g., ID3) fits. However, the rule-based CN2 algorithm, the 51 | simple IB1 instance-based learning algorithm, and the CITRE 52 | feature-constructing decision tree algorithm perform well on it. 53 | 54 | 5. Number of Instances: 958 (legal tic-tac-toe endgame boards) 55 | 56 | 6. Number of Attributes: 9, each corresponding to one tic-tac-toe square 57 | 58 | 7. Attribute Information: (x=player x has taken, o=player o has taken, b=blank) 59 | 60 | 1. top-left-square: {x,o,b} 61 | 2. top-middle-square: {x,o,b} 62 | 3. top-right-square: {x,o,b} 63 | 4. middle-left-square: {x,o,b} 64 | 5. middle-middle-square: {x,o,b} 65 | 6. middle-right-square: {x,o,b} 66 | 7. bottom-left-square: {x,o,b} 67 | 8. bottom-middle-square: {x,o,b} 68 | 9. bottom-right-square: {x,o,b} 69 | 10. Class: {positive,negative} 70 | 71 | 8. Missing Attribute Values: None 72 | 73 | 9. Class Distribution: About 65.3% are positive (i.e., wins for "x") 74 | -------------------------------------------------------------------------------- /log/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /output.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-systopia/treeFarms/5508f6063cbcac5124e164c72aafa9afc31f683d/output.json -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "scikit-build", 5 | "setuptools", 6 | "wheel", 7 | "distro", 8 | "attrs", 9 | "packaging>=20.9", 10 | "editables==0.2", 11 | "pandas", 12 | "scikit-learn", 13 | "sortedcontainers", 14 | "gmpy2", 15 | "matplotlib", 16 | "tqdm", 17 | "timbertrek" 18 | ] 19 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import os 3 | import pathlib 4 | import distro 5 | 6 | from setuptools import find_packages 7 | from skbuild import setup 8 | 9 | cmake_args = [] 10 | 11 | if platform.system() == "Windows" or (platform.system() == "Linux" and distro.id() == "centos"): 12 | assert "VCPKG_INSTALLATION_ROOT" in os.environ, \ 13 | "The environment variable \"VCPKG_INSTALLATION_ROOT\" must be set before running this script." 14 | toolchain_path = pathlib.Path(os.getenv("VCPKG_INSTALLATION_ROOT")) / "scripts/buildsystems/vcpkg.cmake" 15 | cmake_args.append("-DCMAKE_TOOLCHAIN_FILE={}".format(toolchain_path)) 16 | 17 | print("Additional CMake Arguments = {}".format(cmake_args)) 18 | 19 | setup( 20 | name="treefarms", 21 | version="0.2.4", 22 | description="Implementation of Trees FAst RashoMon Sets", 23 | author="UBC Systopia Research Lab", 24 | url="https://github.com/ubc-systopia/treeFarms", 25 | license="BSD 3-Clause", 26 | packages=find_packages(where='.'), 27 | cmake_install_dir="treefarms", 28 | cmake_args=cmake_args, 29 | python_requires=">=3.7", 30 | long_description=pathlib.Path("README_PyPI.md").read_text(encoding="utf-8"), 31 | long_description_content_type="text/markdown", 32 | install_requires=["setuptools", 33 | "wheel", 34 | "attrs", 35 | "packaging>=20.9", 36 | "editables==0.2;python_version>'3.0'", 37 | "pandas", 38 | "scikit-learn", 39 | "sortedcontainers", 40 | "gmpy2", 41 | "matplotlib", 42 | "tqdm", 43 | "timbertrek"] 44 | ) 45 | -------------------------------------------------------------------------------- /src/.dirstamp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-systopia/treeFarms/5508f6063cbcac5124e164c72aafa9afc31f683d/src/.dirstamp -------------------------------------------------------------------------------- /src/additive_metrics.cpp: -------------------------------------------------------------------------------- 1 | #include "additive_metrics.hpp" 2 | #include "state.hpp" 3 | #include "configuration.hpp" 4 | 5 | Objective::Objective(const int falses, const int regularization) 6 | : falses(falses), regularization(regularization), 7 | objective(falses * State::dataset.get_mismatch_cost() + 8 | regularization * Configuration::regularization){}; -------------------------------------------------------------------------------- /src/additive_metrics.hpp: -------------------------------------------------------------------------------- 1 | #ifndef ADDITIVE_METRICS_H 2 | #define ADDITIVE_METRICS_H 3 | #include 4 | #include 5 | #include 6 | 7 | struct Objective { 8 | Objective(const int falses, const int regularization); 9 | Objective() = default; 10 | 11 | int falses; 12 | int regularization; 13 | float objective; 14 | 15 | Objective operator+(const Objective &other) const { 16 | return Objective(falses + other.falses, 17 | regularization + other.regularization); 18 | } 19 | 20 | std::tuple to_tuple() const { 21 | return std::make_tuple(objective, falses, regularization); 22 | } 23 | 24 | bool operator==(const Objective &other) const { 25 | return objective == other.objective; 26 | } 27 | 28 | bool operator<(const Objective &other) const { 29 | return objective < other.objective; 30 | } 31 | 32 | bool operator<(const float &other) const { return objective < other; } 33 | bool operator<=(const float &other) const { return objective <= other; } 34 | bool operator>(const float &other) const { return objective > other; } 35 | bool operator>=(const float &other) const { return objective >= other; } 36 | }; 37 | 38 | struct ValuesOfInterest { 39 | ValuesOfInterest() = default; 40 | ValuesOfInterest(const int TP, const int TN, const int regularization) 41 | : TP(TP), TN(TN), regularization(regularization){}; 42 | 43 | int TP; 44 | int TN; 45 | int regularization; 46 | 47 | ValuesOfInterest operator+(const ValuesOfInterest &other) const { 48 | return ValuesOfInterest(TP + other.TP, TN + other.TN, 49 | regularization + other.regularization); 50 | } 51 | 52 | bool operator==(const ValuesOfInterest &other) const { 53 | return TP == other.TP && TN == other.TN && 54 | regularization == other.regularization; 55 | } 56 | 57 | size_t hash() const { 58 | size_t seed = 0; 59 | // boost::hash_combine(result, TP); 60 | // boost::hash_combine(result, TN); 61 | // boost::hash_combine(result, regularization); 62 | seed ^= TP + 0x9e3779b9 + (seed << 6) + (seed >> 2); 63 | seed ^= TN + 0x9e3779b9 + (seed << 6) + (seed >> 2); 64 | seed ^= regularization + 0x9e3779b9 + (seed << 6) + (seed >> 2); 65 | return seed; 66 | } 67 | 68 | std::tuple to_tuple() const { 69 | return std::make_tuple(TP, TN, regularization); 70 | } 71 | }; 72 | 73 | struct ObjectiveHash { 74 | std::size_t operator()(const Objective &k) const { 75 | return std::hash{}(k.objective); 76 | } 77 | }; 78 | 79 | struct ObjectiveLess { 80 | bool operator()(const Objective &left, const Objective &right) const { 81 | return left.objective < right.objective; 82 | } 83 | }; 84 | #endif -------------------------------------------------------------------------------- /src/cart_it.hpp: -------------------------------------------------------------------------------- 1 | 2 | #ifndef CART_IT 3 | #define CART_IT 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using namespace std; 14 | 15 | // typedef int T; 16 | 17 | // Compute the cartesian product of input 18 | // {S_1, S_2, ..., S_n} --> S_1 x S_2 x ... S_n 19 | // For a {L_1, L_2, ..., L_n} vector where L_i is the legnth of the ith inner 20 | // vector, generate L_1 * ... * L_n vectors of length n 21 | template class CartIt { 22 | public: 23 | typedef vector value_type; 24 | typedef vector &reference; 25 | typedef vector *pointer; 26 | 27 | // Should really do a const iterator but whatever lol 28 | class iterator { 29 | public: 30 | typedef iterator self_type; 31 | typedef std::forward_iterator_tag iterator_category; 32 | typedef int difference_type; 33 | iterator(CartIt *base) : base(base) { idx = 0; } 34 | iterator(CartIt *base, long long idx) : base(base), idx(idx) {} 35 | self_type operator++() { 36 | self_type i = *this; 37 | idx++; 38 | return i; 39 | } 40 | self_type operator++(int junk) { 41 | idx++; 42 | return *this; 43 | } 44 | value_type operator*() { return base->access_idx(idx); } 45 | // pointer operator->() { return ptr_; } 46 | bool operator==(const self_type &rhs) { return idx == rhs.idx; } 47 | bool operator!=(const self_type &rhs) { return idx != rhs.idx; } 48 | 49 | private: 50 | CartIt *base; 51 | long long int idx; 52 | }; 53 | 54 | CartIt(const vector> &item) : v(item) { 55 | auto product = [](long long int a, const vector &b) { 56 | return a * b.size(); 57 | }; 58 | N = accumulate(v.begin(), v.end(), 1LL, product); 59 | } 60 | 61 | ~CartIt() {} 62 | 63 | // size_type size() const { return size_; } 64 | 65 | // T& operator[](size_type index) 66 | // { 67 | // assert(index < size_); 68 | // return data_[index]; 69 | // } 70 | 71 | // const T& operator[](size_type index) const 72 | // { 73 | // assert(index < size_); 74 | // return data_[index]; 75 | // } 76 | 77 | iterator begin() { return iterator(this); } 78 | 79 | iterator end() { return iterator(this, N); } 80 | 81 | // int size() { 82 | // if (clipped_size < 0) { 83 | // return keys.size(); 84 | // } else { 85 | // return clipped_size; 86 | // } 87 | // } 88 | 89 | // float scope; 90 | 91 | private: 92 | const vector> &v; 93 | long long N; 94 | // From 95 | // https://stackoverflow.com/questions/5279051/how-can-i-create-cartesian-product-of-vector-of-vectors 96 | value_type access_idx(const long long n) { 97 | value_type u(v.size()); 98 | lldiv_t q{n, 0}; 99 | for (long long i = v.size() - 1; 0 <= i; --i) { 100 | q = div(q.quot, v[i].size()); 101 | u[i] = v[i][q.rem]; 102 | } 103 | return u; 104 | } 105 | }; 106 | #endif -------------------------------------------------------------------------------- /src/configuration.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CONFIGURATION_H 2 | #define CONFIGURATION_H 3 | 4 | #include 5 | #include 6 | 7 | using json = nlohmann::json; 8 | 9 | enum CoveredSetExtraction {F1, BACC, AUC}; 10 | 11 | // Static configuration object used to modifie the algorithm behaviour 12 | // By design, all running instances of the algorithm within the same process must share the same configuration 13 | class Configuration { 14 | public: 15 | static void configure(std::istream & configuration); 16 | static void configure(json source); 17 | static std::string to_string(unsigned int spacing = 0); 18 | 19 | static float uncertainty_tolerance; // The maximum allowed global optimality before the optimization can terminate 20 | static float regularization; // The penalty incurred for each leaf inthe model 21 | static float upperbound; // Upperbound on the root problem for pruning problems using a greedy model 22 | 23 | static unsigned int time_limit; // The maximum allowed runtime (seconds). 0 means unlimited. 24 | static unsigned int worker_limit; // The maximum allowed worker threads. 0 means match number of available cores 25 | static unsigned int stack_limit; // The maximum amount of stack space (bytes) allowed to use as buffers 26 | static unsigned int precision_limit; // The maximum number of significant figures considered for each ordinal feature 27 | static unsigned int model_limit; // The maximum number of models extracted 28 | 29 | static bool verbose; // Flag for printing status to standard output 30 | static bool diagnostics; // Flag for printing diagnosis to standard output if a bug is detected 31 | 32 | static unsigned char depth_budget; // The maximum tree depth for solutions, counting a tree with just the root node as depth 1. 0 means unlimited. 33 | 34 | static unsigned int minimum_captured_points; // The mimimum captured points for any leaf 35 | 36 | static std::vector memory_checkpoints; // Memory at which to dump trie generated by current adjency graph 37 | 38 | static bool output_accuracy_model_set; 39 | static std::vector output_covered_sets; 40 | static std::vector covered_sets_thresholds; 41 | static std::string covered_set_type_to_string(CoveredSetExtraction type) { 42 | switch (type) { 43 | case F1: 44 | return "f1"; 45 | case BACC: 46 | return "bacc"; 47 | case AUC: 48 | return "auc"; 49 | default: 50 | throw std::invalid_argument("Unknown type"); 51 | } 52 | }; 53 | static double computeScore(CoveredSetExtraction type, unsigned int P, unsigned int N, double TP, double TN) { 54 | switch (type) { 55 | case F1: 56 | return TP / (TP + 0.5 * (P - TP + N - TN)); 57 | case BACC: 58 | return (TP / P + TN / N) / 2; 59 | case AUC: 60 | return (TN * TP + 0.5 * (TP * (N - TN) + TN * (P - TP))) / (N * P); 61 | default: 62 | throw std::invalid_argument("Unknown type"); 63 | } 64 | }; 65 | 66 | static bool balance; // Flag for adjusting the importance of each row to equalize the total importance of each class (overrides weight) 67 | static bool look_ahead; // Flag for enabling the one-step look-ahead bound implemented via scopes 68 | static bool similar_support; // Flag for enabling the similar support bound imeplemented via the distance index 69 | static bool cancellation; // Flag for enabling upward propagation of cancelled subproblems 70 | static bool continuous_feature_exchange; // Flag for enabling the pruning of neighbouring thresholds using subset comparison 71 | static bool feature_exchange; // Flag for enabling the pruning of pairs of features using subset comparison 72 | static bool feature_transform; // Flag for enabling the equivalence discovery through simple feature transformations 73 | static bool rule_list; // Flag for enabling rule-list constraints on models 74 | static bool non_binary; // Flag for enabling non-binary encoding 75 | 76 | static std::string costs; // Path to file containing cost matrix 77 | static std::string model; // Path to file used to store the extracted models 78 | static std::string rashomon_model; // Path to directory used to store the Rashomon set 79 | static std::string rashomon_model_set_suffix; // Path to directory used to store the Rashomon model set 80 | static std::string rashomon_trie; // Path to directory used to store the Rashomon trie 81 | static std::string timing; // Path to file used to store training time 82 | static std::string trace; // Path to directory used to store traces 83 | static std::string tree; // Path to directory used to store tree-traces 84 | static std::string profile; // Path to file used to log runtime statistics 85 | static std::string datatset_encoding; // Path to file used store dataset encoding 86 | 87 | static bool rashomon; // Flag for extracting Rashomon set 88 | 89 | // The following three are mutually exclusive. Please only set one to be true 90 | static float rashomon_bound; // If the rashomon bound is known, setting this would skip finding the optimal tree 91 | static float rashomon_bound_multiplier; // Setting the Rashomon bound to be (Optimal Objective Value) * (1 + rashomon_bound_multiplier) 92 | static float rashomon_bound_adder; // Setting the Rashomon bound to be (Optimal Objective Value) + rashomon_bound_adder 93 | 94 | static bool rashomon_ignore_trivial_extensions; 95 | 96 | 97 | }; 98 | 99 | #endif -------------------------------------------------------------------------------- /src/dataset.hpp: -------------------------------------------------------------------------------- 1 | #ifndef DATASET_H 2 | #define DATASET_H 3 | 4 | #define CL_SILENCE_DEPRECATION 5 | #define __CL_ENABLE_EXCEPTIONS 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #ifdef INCLUDE_OPENCL 20 | #include 21 | #endif 22 | 23 | #include 24 | #include 25 | 26 | class Dataset; 27 | 28 | #include "bitmask.hpp" 29 | #include "configuration.hpp" 30 | #include "encoder.hpp" 31 | #include "index.hpp" 32 | //#include "state.hpp" // FIREWOLF: Circular References: Moved to cpp. 33 | #include "tile.hpp" 34 | 35 | using json = nlohmann::json; 36 | 37 | // Contain the dataset and any preprocessed values 38 | class Dataset { 39 | public: 40 | // The encoder used in converting between non-binary and binary 41 | Encoder encoder; 42 | 43 | Dataset(void); 44 | // @param data_source: byte stream of csv format which will be automatically encoded into a binary dataset 45 | // @note see encoder documentation for data source formatting preconditions and encoding semantics 46 | Dataset(std::istream & data_source); 47 | ~Dataset(void); 48 | 49 | // @modifies loads data from data stream 50 | void load(std::istream & data_source); 51 | 52 | // @modifies resets dataset to initial state 53 | void clear(void); 54 | 55 | // @returns the sample size of the data set 56 | unsigned int size(void) const; 57 | // @returns the physical number of rows needed to represent the data set 58 | unsigned int height(void) const; 59 | // @returns the number of binary non-target features used to represent the data set 60 | unsigned int width(void) const; 61 | // @returns the number of unique target values in the dataset 62 | unsigned int depth(void) const; 63 | 64 | // @param capture_set: The indicator for each equivalent groups are contained by this problem 65 | // @param id: Index of the local state entry used when a column buffer is needed 66 | // @modifies info: The alkaike information critierion of this set w.r.t the target distribution 67 | // @modifies potential: The maximum reduction in loss if all equivalent classes are relabelled (without considering complexity penalty) 68 | // @modifies min_loss: The minimal loss incurred if all equivalent classes are optimally labelled without considering complexity penalty 69 | // @modifies max_loss: The loss incurred if the capture set is left unsplit and the best single label is chosen 70 | // @modifies target_index: The label to choose if left unsplit 71 | void summary(Bitmask const & capture_set, float & info, float & potential, float & min_loss, float & max_loss, unsigned int & target_index, unsigned int id) const; 72 | 73 | void get_TP_TN(Bitmask const & capture_set, unsigned int id, unsigned int target_index, unsigned int & TP, unsigned int & TN); 74 | 75 | void get_total_P_N(unsigned int & P, unsigned int & N); 76 | 77 | // @param feature_index: the index of the binary feature to use bisect the set 78 | // @param positive: if true, modifies set to reflect the part of the bisection that responds positive to the binary feature 79 | // if false, the other part of the bisection is used 80 | // @param set: indicates the captured set of samples to be bisected 81 | // @modifies set: the captured set will be overwritten to reflect the subset extracted from the bisection 82 | // this can be either the positive or negative subset depending on the positive argument 83 | void subset(unsigned int feature_index, bool positive, Bitmask & set) const; 84 | // Convenient alias for performing both negative and positive tests 85 | void subset(unsigned int feature_index, Bitmask & negative, Bitmask & positive) const; 86 | 87 | // @param set: The indicator for each equivalent groups are contained by this problem 88 | // @param buffer: a buffer used for bitwise operations 89 | // @param i: feature index for pairwise comparison 90 | // @param j: other feature index for pairwise comparison 91 | // @return distance: The maximum change in objective value if feature i is swapped for j or vice versa 92 | float distance(Bitmask const & set, unsigned int i, unsigned int j, unsigned int id) const; 93 | 94 | void tile(Bitmask const & filter, Bitmask const & selector, Tile & tile_output, std::vector< int > & order, unsigned int id) const; 95 | 96 | float get_mismatch_cost() const; 97 | 98 | private: 99 | static bool index_comparator(const std::pair< unsigned int, unsigned int > & left, const std::pair< unsigned int, unsigned int > & right); 100 | 101 | // The dimensions of the dataset 102 | // Dim-0 = Number of samples 103 | // Dim-1 = Number of binary features 104 | // Dim-2 = Number of classes 105 | std::tuple< unsigned int, unsigned int, unsigned int > shape; 106 | unsigned int _size; // shortcut for number of samples 107 | 108 | // std::vector< Bitmask > columns; // Binary representation of columns 109 | // std::vector< std::vector< float > > distributions; // Class distributions for each row 110 | 111 | std::vector< Bitmask > features; // Binary representation of columns 112 | std::vector< Bitmask > targets; // Binary representation of columns 113 | std::vector< Bitmask > rows; // Binary representation of rows 114 | std::vector< Bitmask > feature_rows; // Binary representation of rows 115 | std::vector< Bitmask > target_rows; // Binary representation of rows 116 | Bitmask majority; // Binary representation of columns 117 | std::vector< std::vector< float > > costs; // Cost matrix for each type of misprediction 118 | std::vector< float > match_costs; // Cost matrix for each type of misprediction 119 | std::vector< float > mismatch_costs; // Cost matrix for each type of misprediction 120 | std::vector< float > max_costs; // Cost matrix for each type of misprediction 121 | std::vector< float > min_costs; // Cost matrix for each type of misprediction 122 | std::vector< float > diff_costs; // Cost matrix for each type of misprediction 123 | 124 | // Index index; // Index for calculating summaries 125 | // Index distance_index; // Index for calculating feature distances 126 | 127 | void construct_bitmasks(std::istream & input_stream); 128 | void construct_cost_matrix(void); 129 | void parse_cost_matrix(std::istream & input_stream); 130 | void aggregate_cost_matrix(void); 131 | void construct_majority(void); 132 | }; 133 | 134 | #endif -------------------------------------------------------------------------------- /src/gosdt.hpp: -------------------------------------------------------------------------------- 1 | #ifndef GOSDT_H 2 | #define GOSDT_H 3 | 4 | #include "graph.hpp" 5 | #define SIMDPP_ARCH_X86_SSE4_1 6 | 7 | #include 8 | 9 | #include 10 | // #include 11 | // #include 12 | // #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | // #include 19 | 20 | #include 21 | 22 | #include "encoder.hpp" 23 | #include "dataset.hpp" 24 | #include "integrity_violation.hpp" 25 | #include "model.hpp" 26 | #include "optimizer.hpp" 27 | 28 | using json = nlohmann::json; 29 | 30 | // The main interface of the library 31 | // Note that the algorithm behaviour is modified using the static configuration object using the Configuration class 32 | class GOSDT { 33 | public: 34 | GOSDT(void); 35 | ~GOSDT(void); 36 | 37 | static float time; 38 | static unsigned int size; 39 | static unsigned int iterations; 40 | static unsigned int status; 41 | 42 | // @param config_source: string stream containing a JSON object of configuration parameters 43 | // @note: See the Configuration class for details about each parameter 44 | static void configure(std::istream & config_source); 45 | 46 | // @require: The CSV must contain a header. 47 | // @require: Scientific notation is currently not supported by the parser, use long form decimal notation 48 | // @require: All rows must have the same number of entries 49 | // @require: all entries are comma-separated 50 | // @require: Wrapping quotations are not stripped 51 | // @param data_source: string containing a CSV of training_data 52 | void fit(std::istream & data_source); 53 | 54 | // @require: The CSV must contain a header. 55 | // @require: Scientific notation is currently not supported by the parser, use long form decimal notation 56 | // @require: All rows must have the same number of entries 57 | // @require: all entries are comma-separated 58 | // @require: Wrapping quotations are not stripped 59 | // @param data_source: string containing a CSV of training_data 60 | // @modifies result: Contains a JSON array of all optimal models extracted 61 | void fit(std::istream & data_source, std::string & result); 62 | 63 | // @require: The CSV must contain a header. 64 | // @require: Scientific notation is currently not supported by the parser, use long form decimal notation 65 | // @require: All rows must have the same number of entries 66 | // @require: all entries are comma-separated 67 | // @require: Wrapping quotations are not stripped 68 | // @param data_source: string containing a CSV of training_data 69 | // @modifies results: Set of models extracted from the optimization 70 | void fit(std::istream & data_source, results_t & results); 71 | 72 | // for Rashomon set construction 73 | // @param rashomon_bound: 74 | void fit_rashomon(Optimizer & optimizer, float rashomon_bound, results_t &results); 75 | void process_rashomon_result(results_t &results); 76 | 77 | // for finding the optimal tree 78 | void fit_gosdt(Optimizer & optimizer, std::unordered_set< Model > & models); 79 | 80 | private: 81 | // @param id: The worker ID of the current thread 82 | // @param optimizer: optimizer object which will assign work to the thread 83 | // @modifies return_reference: reference for returning values to the main thread 84 | static void work(int const id, Optimizer & optimizer, int & return_reference); 85 | }; 86 | 87 | #endif 88 | -------------------------------------------------------------------------------- /src/graph.hpp: -------------------------------------------------------------------------------- 1 | #ifndef GRAPH_H 2 | #define GRAPH_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | class Graph; 16 | 17 | #include "bitmask.hpp" 18 | #include "task.hpp" 19 | #include "tile.hpp" 20 | #include "additive_metrics.hpp" 21 | 22 | // #include "sorted_map.hpp" 23 | 24 | class ModelSet; 25 | typedef Tile key_type; 26 | typedef Task value_type; 27 | typedef std::vector translation_type; 28 | 29 | // Additional Hash Implementation for tbb::concurrent_hash_table 30 | // These delegate to the already implemented hash functions and equality operators 31 | class GraphVertexHashComparator { 32 | public: 33 | static size_t hash(key_type const & key) { 34 | return key.hash(); 35 | } 36 | static bool equal(key_type const & left, key_type const & right) { 37 | return left == right; 38 | } 39 | }; 40 | 41 | // class GraphTranslationHashComparator { 42 | // public: 43 | // static size_t hash(std::pair const & key) { 44 | // size_t seed = key.first.hash(); 45 | // seed ^= key.second.hash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); 46 | // return seed; 47 | // } 48 | // static bool equal(std::pair const & left, std::pair const & right) { 49 | // return left == right; 50 | // } 51 | // }; 52 | 53 | class GraphChildHashComparator { 54 | public: 55 | static size_t hash(std::pair const & key) { 56 | size_t seed = key.second; 57 | seed ^= key.first.hash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); 58 | return seed; 59 | } 60 | static bool equal(std::pair const & left, std::pair const & right) { 61 | return left == right; 62 | } 63 | }; 64 | 65 | 66 | typedef tbb::concurrent_hash_map< // Table for storing forward edges 67 | std::pair, key_type, GraphChildHashComparator, 68 | tbb::scalable_allocator const, key_type>>> child_table; 69 | 70 | typedef tbb::concurrent_hash_map< // Table for storing tile-orderings 71 | std::pair, translation_type, GraphChildHashComparator, 72 | tbb::scalable_allocator const, translation_type>>> translation_table; 73 | 74 | typedef tbb::concurrent_hash_map< // Table for storing vertices 75 | key_type, value_type, GraphVertexHashComparator, 76 | tbb::scalable_allocator>> vertex_table; 77 | 78 | typedef tbb::concurrent_unordered_map< // Set of parents for a single vertex 79 | key_type, std::pair, std::hash, std::equal_to, 80 | tbb::scalable_allocator>>> adjacency_set; 81 | 82 | typedef tbb::concurrent_hash_map< // Table of all adjacency sets 83 | key_type, adjacency_set, GraphVertexHashComparator, 84 | tbb::scalable_allocator>> adjacency_table; 85 | 86 | typedef tbb::concurrent_vector, tbb::scalable_allocator>> bound_list; // List of split-bounds for a single vertex 87 | 88 | // A collection of model sets, representing the result or solution of a 89 | // subproblem. The first entry is a *sorted* list of objective values whereas 90 | // the second entry represents a storage 91 | typedef std::pair, std::unordered_map, ObjectiveHash>> results_t; 92 | typedef std::tuple< float, results_t > scoped_result_t; // A collection of model sets with an associated scope 93 | 94 | typedef tbb::concurrent_hash_map< // Table of all bound lists 95 | key_type, bound_list, GraphVertexHashComparator, 96 | tbb::scalable_allocator>> bound_table; 97 | 98 | typedef tbb::concurrent_hash_map< // Table of all saved models 99 | key_type, scoped_result_t, GraphVertexHashComparator, 100 | tbb::scalable_allocator>> models_table; 101 | 102 | typedef vertex_table::const_accessor const_vertex_accessor; 103 | typedef vertex_table::accessor vertex_accessor; 104 | 105 | typedef translation_table::const_accessor const_translation_accessor; 106 | typedef translation_table::accessor translation_accessor; 107 | 108 | typedef child_table::const_accessor const_child_accessor; 109 | typedef child_table::accessor child_accessor; 110 | 111 | typedef adjacency_table::const_accessor const_adjacency_accessor; 112 | typedef adjacency_table::accessor adjacency_accessor; 113 | 114 | typedef adjacency_set::const_iterator const_adjacency_iterator; 115 | typedef adjacency_set::iterator adjacency_iterator; 116 | 117 | // typedef adjacency_set::const_indicator const_adjacency_indicator; 118 | // typedef adjacency_set::indicator adjacency_indicator; 119 | 120 | typedef bound_table::const_accessor const_bound_accessor; 121 | typedef bound_table::accessor bound_accessor; 122 | 123 | typedef bound_list::const_iterator const_bound_iterator; 124 | typedef bound_list::iterator bound_iterator; 125 | 126 | typedef models_table::const_accessor const_models_accessor; 127 | typedef models_table::accessor models_accessor; 128 | 129 | // Container for storing the dependency graph 130 | // The vertices of his graph act as a memoization table of subproblems 131 | // Entries in the table are not necessarily complete, some are still running, paused, or cancelled. 132 | class Graph { 133 | public: 134 | translation_table translations; 135 | child_table children; 136 | vertex_table vertices; 137 | adjacency_table edges; 138 | bound_table bounds; 139 | models_table models; 140 | 141 | Graph(void); 142 | ~Graph(void); 143 | 144 | // bool exists(key_type const & key) const; 145 | 146 | // bool insert(key_type const & key, value_type const & value); 147 | // bool insert(std::pair< key_type, value_type > const & pair); 148 | // bool connect(key_type const & parent, key_type const & child, float scope); 149 | 150 | // bool find(const_vertex_accessor & accessor, key_type const & key) const; 151 | // bool find(vertex_accessor & accessor, key_type const & key) const; 152 | 153 | // bool find(const_adjacency_accessor & accessor, key_type const & key, bool forward = true) const; 154 | // bool find(adjacency_accessor & accessor, key_type const & key, bool forward = true) const; 155 | 156 | // bool find_or_create(const_vertex_accessor & accessor, key_type const & key, 157 | // Bitmask & buffer_1, Bitmask & buffer_2, Bitmask & buffer_3, 158 | // Task const & task, unsigned int index, bool condition); 159 | 160 | // bool find_or_create(vertex_accessor & accessor, key_type const & key, 161 | // Bitmask & buffer_1, Bitmask & buffer_2, Bitmask & buffer_3, 162 | // Task const & task, unsigned int index, bool condition); 163 | 164 | bool erase(key_type const & key, bool disconnect = true); 165 | bool disconnect(key_type const & arent, key_type const & child); 166 | void clear(void); 167 | 168 | unsigned int size(void) const; 169 | }; 170 | 171 | #endif 172 | -------------------------------------------------------------------------------- /src/index.hpp: -------------------------------------------------------------------------------- 1 | #ifndef INDEX_H 2 | #define INDEX_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | //#include 9 | //#include 10 | //#include 11 | #include 12 | //#include 13 | 14 | #include "integrity_violation.hpp" 15 | #include "bitmask.hpp" 16 | 17 | // @note: vector type used for representing GPU floating point vector 18 | //typedef boost::numeric::ublas::vector< float > blasvector; 19 | // @note: vector type used for representing GPU mask 20 | //typedef boost::numeric::ublas::vector< bitblock > blasmask; 21 | 22 | // Container used to store prefix sums of vectors which help accelerate our calculations 23 | // multiple vectors are stored in this container so that ranges don't need to be recomputed for each vector 24 | class Index { 25 | public: 26 | static void precompute(void); 27 | void benchmark(void) const; 28 | 29 | Index(void); 30 | // @param source: vector of floating points to sum over (efficiently) 31 | Index(std::vector< std::vector< float > > const & source); 32 | ~Index(void); 33 | 34 | // @param indicator: mask of bits indicating which elements are relevant to the vector sum 35 | // @returns the total of all elements associated to bits that were set to 1 36 | void sum(Bitmask const & indicator, float * accumulator) const; 37 | 38 | // @returns string representation of original floating points (used for inspection) 39 | std::string to_string(void) const; 40 | 41 | private: 42 | // Copy of the original floating points 43 | std::vector< float > source; 44 | // precomputed representation of the floating point vector 45 | std::vector< std::vector< float > > prefixes; 46 | // Number of floating points represented in the source vector 47 | unsigned int size; 48 | unsigned int width; 49 | // Number of blocks expected in bitmask 50 | unsigned int num_blocks; 51 | 52 | // @param indicator: array of blocks of bits indicating which elements are relevant to the vector sum 53 | // @returns the total of all elements associated to bits that were set to 1 54 | // @note: This implementation uses look-up to precomputed values of the run-length-code for fast sums 55 | void block_sequential_sum(bitblock * blocks, float * accumulator) const; 56 | void block_sequential_sum(rangeblock block, unsigned int offset, float * accumulator) const; 57 | 58 | // @param indicator: array of blocks of bits indicating which elements are relevant to the vector sum 59 | // @returns the total of all elements associated to bits that were set to 1 60 | // @note: This implementation computes run-length-code for fast sums 61 | void bit_sequential_sum(Bitmask const &indicator, float *accumulator) const; 62 | 63 | // @param source: The original vector of floats used in computation 64 | // @modifies prefixes: writes the prefix sums into this vector 65 | void build_prefixes(std::vector< std::vector< float > > const & source, std::vector< std::vector< float > > & prefixes); 66 | }; 67 | #endif 68 | -------------------------------------------------------------------------------- /src/integrity_violation.hpp: -------------------------------------------------------------------------------- 1 | #ifndef INTEGRITY_VIOLATION_H 2 | #define INTEGRITY_VIOLATION_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | // Implementation of run time error for displaying any detected integrity violations during the algorithm 9 | // These exceptions indicate that a logical error in the code has caused the algorithm to reach an incorrect state 10 | // The correct response to an integrity violation is to report any diagnosis, then terminate the program. 11 | class IntegrityViolation : public std::runtime_error { 12 | public: 13 | IntegrityViolation(std::string error, std::string reason) : std::runtime_error(error), error(error), reason(reason) {} 14 | std::string error; 15 | std::string reason; 16 | std::string to_string(void) const { 17 | std::stringstream message; 18 | message << "\033[1;31mIntegrityViolation Detected during Optimization:\n" 19 | << " ErrorContext: " << this -> error << "\n" 20 | << " Reason: " << this -> reason << "\033[0m" << std::endl; 21 | return message.str(); 22 | } 23 | }; 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /src/local_state.cpp: -------------------------------------------------------------------------------- 1 | #include "local_state.hpp" 2 | 3 | LocalState::LocalState(void) {} 4 | 5 | void LocalState::initialize(unsigned int _samples, unsigned int _features, unsigned int _targets) { 6 | this -> samples = _samples; 7 | this -> features = _features; 8 | this -> targets = _targets; 9 | 10 | this -> inbound_message.initialize(_samples, _features, _targets); 11 | this -> outbound_message.initialize(_samples, _features, _targets); 12 | 13 | this -> neighbourhood.resize(2 * (this -> features)); 14 | 15 | unsigned int buffer_count = 4; 16 | unsigned int row_size = this -> features + this -> targets; 17 | unsigned int column_size = this -> samples; 18 | unsigned int max_tile_size = row_size * column_size; 19 | 20 | for (unsigned int i = 0; i < buffer_count; ++i) { 21 | this -> rows.emplace_back(row_size); 22 | this -> columns.emplace_back(column_size); 23 | } 24 | } 25 | 26 | LocalState & LocalState::operator=(LocalState const & source) { 27 | this -> neighbourhood = source.neighbourhood; 28 | this -> rows = source.rows; 29 | this -> columns = source.columns; 30 | return * this; 31 | } 32 | 33 | 34 | LocalState::~LocalState(void) { 35 | this -> neighbourhood.clear(); 36 | this -> rows.clear(); 37 | this -> columns.clear(); 38 | } 39 | -------------------------------------------------------------------------------- /src/local_state.hpp: -------------------------------------------------------------------------------- 1 | #ifndef LOCAL_STATE_H 2 | #define LOCAL_STATE_H 3 | 4 | class LocalState; 5 | 6 | #include "bitmask.hpp" 7 | #include "message.hpp" 8 | #include "task.hpp" 9 | #include "tile.hpp" 10 | 11 | // Container of all data structures the local state owned by each thread 12 | class LocalState { 13 | public: 14 | LocalState(void); 15 | ~LocalState(void); 16 | 17 | // @param samples: The number of samples a column bit mask has to represent 18 | // @param features: The number of features a row bit mask has to represent 19 | // @param targets: The number of targets a row bit mask has to represent 20 | // @note: bit masks are used to represent rows and columns of a data set 21 | // samples refer to the number of independent samples in a dataset 22 | // features refer to the number of binary features that will be available at prediction time 23 | // targets refer to the number of different classes that a sample can fall under 24 | void initialize(unsigned int samples, unsigned int features, unsigned int targets); 25 | 26 | LocalState & operator=(LocalState const & source); 27 | 28 | std::vector< Task > neighbourhood; // Memory buffer for storing children of a node 29 | Message inbound_message; // Memory buffer for storing a messages from the priority queue 30 | Message outbound_message; // Memory buffer for storing a messages from the priority queue 31 | // std::vector< Tile > keys; // Memory buffer for storing a tile representation of a node identifier 32 | std::vector< Bitmask > rows; // Memory buffer for storing a bitmask representation of a feature set + target set 33 | std::vector< Bitmask > columns; // Memory buffer for storing a bitmask representation of a capture set 34 | 35 | // Bitmask dirty; // Mask indicating which items in the neighbourhood needs to be written? 36 | 37 | unsigned int samples; 38 | unsigned int features; 39 | unsigned int targets; 40 | }; 41 | 42 | #endif -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include "main.hpp" 2 | #include "version.hpp" 3 | 4 | int main(int argc, char *argv[]) { 5 | 6 | // struct pollfd file_descriptors; 7 | // file_descriptors.fd = 0; /* this is STDIN */ 8 | // file_descriptors.events = POLLIN; 9 | // bool standard_input = poll(& file_descriptors, 1, 0) == 1; 10 | bool standard_input = false; 11 | 12 | std::cout << "gosdt-" << BUILD_GIT_REV << " (" << BUILD_DATE << " on " << BUILD_HOST << ")" << std::endl; 13 | 14 | // Check program input 15 | if ((standard_input && (argc < 1 || argc > 2)) || (!standard_input && (argc < 2 || argc > 3))) { 16 | std::cout << "Usage: gosdt [path to feature set] ?[path to config]" << std::endl; 17 | return 0; 18 | } 19 | if (argc >= 2 && !std::ifstream(argv[1]).good()) { 20 | std::cout << "File Not Found: " << argv[1] << std::endl; 21 | return 1; 22 | } 23 | if (argc >= 3 && !std::ifstream(argv[2]).good()) { 24 | std::cout << "File Not Found: " << argv[2] << std::endl; 25 | return 1; 26 | } 27 | 28 | if ((standard_input && argc == 2) || (!standard_input && argc == 3)) { 29 | // Use custom configuration if provided 30 | std::ifstream configuration(argv[argc - 1]); 31 | Configuration::configure(configuration); 32 | } 33 | 34 | // Print messages to help user ensure they've provided the correct inputs 35 | if (Configuration::verbose) { 36 | std::cout << "Generalized Optimal Sparse Decision Tree" << std::endl; 37 | std::cout << "Using data set: " << argv[1] << std::endl; 38 | } 39 | GOSDT model; 40 | if (standard_input) { 41 | model.fit(std::cin); 42 | } else { 43 | std::ifstream data(argv[1]); 44 | model.fit(data); 45 | } 46 | return 0; 47 | } 48 | -------------------------------------------------------------------------------- /src/main.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MAIN_H 2 | #define MAIN_H 3 | 4 | #include 5 | //#include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | #include "configuration.hpp" 16 | #include "gosdt.hpp" 17 | 18 | using json = nlohmann::json; 19 | 20 | // Main entry point for the CLI 21 | // The data set is entered either through the standard input stream or as a file path in the first argument 22 | // The configuration file is entered as a file path in the first argument if the data set is entered through standard input 23 | // otherwise the configuration file is entered as a file path in the second argument 24 | // Example: 25 | // cat data.csv | gosdt config.json 26 | // or 27 | // gosdt data.csv config.json 28 | int main(int argc, char *argv[]); 29 | 30 | #endif -------------------------------------------------------------------------------- /src/memusage.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Author: David Robert Nadeau 3 | * Site: http://NadeauSoftware.com/ 4 | * License: Creative Commons Attribution 3.0 Unported License 5 | * http://creativecommons.org/licenses/by/3.0/deed.en_US 6 | */ 7 | 8 | #if defined(_WIN32) 9 | #include 10 | #include 11 | 12 | #elif defined(__unix__) || defined(__unix) || defined(unix) || (defined(__APPLE__) && defined(__MACH__)) 13 | #include 14 | #include 15 | 16 | #if defined(__APPLE__) && defined(__MACH__) 17 | #include 18 | 19 | #elif (defined(_AIX) || defined(__TOS__AIX__)) || (defined(__sun__) || defined(__sun) || defined(sun) && (defined(__SVR4) || defined(__svr4__))) 20 | #include 21 | #include 22 | 23 | #elif defined(__linux__) || defined(__linux) || defined(linux) || defined(__gnu_linux__) 24 | #include 25 | 26 | #endif 27 | 28 | #else 29 | #error "Cannot define getPeakRSS( ) or getCurrentRSS( ) for an unknown OS." 30 | #endif 31 | 32 | 33 | 34 | 35 | 36 | /** 37 | * Returns the peak (maximum so far) resident set size (physical 38 | * memory use) measured in bytes, or zero if the value cannot be 39 | * determined on this OS. 40 | */ 41 | inline size_t getPeakRSS( ) 42 | { 43 | #if defined(_WIN32) 44 | /* Windows -------------------------------------------------- */ 45 | // PROCESS_MEMORY_COUNTERS info; 46 | // GetProcessMemoryInfo( GetCurrentProcess( ), &info, sizeof(info) ); 47 | // return (size_t)info.PeakWorkingSetSize; 48 | return (size_t)0L; /* Unsupported. */ 49 | 50 | #elif (defined(_AIX) || defined(__TOS__AIX__)) || (defined(__sun__) || defined(__sun) || defined(sun) && (defined(__SVR4) || defined(__svr4__))) 51 | /* AIX and Solaris ------------------------------------------ */ 52 | struct psinfo psinfo; 53 | int fd = -1; 54 | if ( (fd = open( "/proc/self/psinfo", O_RDONLY )) == -1 ) 55 | return (size_t)0L; /* Can't open? */ 56 | if ( read( fd, &psinfo, sizeof(psinfo) ) != sizeof(psinfo) ) 57 | { 58 | close( fd ); 59 | return (size_t)0L; /* Can't read? */ 60 | } 61 | close( fd ); 62 | return (size_t)(psinfo.pr_rssize * 1024L); 63 | 64 | #elif defined(__unix__) || defined(__unix) || defined(unix) || (defined(__APPLE__) && defined(__MACH__)) 65 | /* BSD, Linux, and OSX -------------------------------------- */ 66 | struct rusage rusage; 67 | getrusage( RUSAGE_SELF, &rusage ); 68 | #if defined(__APPLE__) && defined(__MACH__) 69 | return (size_t)rusage.ru_maxrss; 70 | #else 71 | return (size_t)(rusage.ru_maxrss * 1024L); 72 | #endif 73 | 74 | #else 75 | /* Unknown OS ----------------------------------------------- */ 76 | return (size_t)0L; /* Unsupported. */ 77 | #endif 78 | } 79 | 80 | 81 | 82 | 83 | 84 | /** 85 | * Returns the current resident set size (physical memory use) measured 86 | * in bytes, or zero if the value cannot be determined on this OS. 87 | */ 88 | inline size_t getCurrentRSS( ) 89 | { 90 | #if defined(_WIN32) 91 | /* Windows -------------------------------------------------- */ 92 | // PROCESS_MEMORY_COUNTERS info; 93 | // GetProcessMemoryInfo( GetCurrentProcess( ), &info, sizeof(info) ); 94 | // return (size_t)info.WorkingSetSize; 95 | return (size_t)0L; /* Unsupported. */ 96 | 97 | #elif defined(__APPLE__) && defined(__MACH__) 98 | /* OSX ------------------------------------------------------ */ 99 | struct mach_task_basic_info info; 100 | mach_msg_type_number_t infoCount = MACH_TASK_BASIC_INFO_COUNT; 101 | if ( task_info( mach_task_self( ), MACH_TASK_BASIC_INFO, 102 | (task_info_t)&info, &infoCount ) != KERN_SUCCESS ) 103 | return (size_t)0L; /* Can't access? */ 104 | return (size_t)info.resident_size; 105 | 106 | #elif defined(__linux__) || defined(__linux) || defined(linux) || defined(__gnu_linux__) 107 | /* Linux ---------------------------------------------------- */ 108 | long rss = 0L; 109 | FILE* fp = NULL; 110 | if ( (fp = fopen( "/proc/self/statm", "r" )) == NULL ) 111 | return (size_t)0L; /* Can't open? */ 112 | if ( fscanf( fp, "%*s%ld", &rss ) != 1 ) 113 | { 114 | fclose( fp ); 115 | return (size_t)0L; /* Can't read? */ 116 | } 117 | fclose( fp ); 118 | return (size_t)rss * (size_t)sysconf( _SC_PAGESIZE); 119 | 120 | #else 121 | /* AIX, BSD, Solaris, and Unknown OS ------------------------ */ 122 | return (size_t)0L; /* Unsupported. */ 123 | #endif 124 | } -------------------------------------------------------------------------------- /src/message.cpp: -------------------------------------------------------------------------------- 1 | #include "message.hpp" 2 | 3 | Message::Message(void){}; 4 | 5 | Message::~Message(void) {}; 6 | 7 | void Message::initialize(unsigned int height, unsigned int width, unsigned int depth) { 8 | this -> sender_tile.resize(height * (width + depth)); 9 | this -> recipient_tile.resize(height * (width + depth)); 10 | this -> recipient_capture.resize(height); 11 | this -> recipient_feature.resize(width); 12 | this -> features.resize(width); 13 | this -> signs.resize(width); 14 | } 15 | 16 | void Message::exploration(Tile const & sender, Bitmask const & recipient_capture, Bitmask const & recipient_feature, int feature, float scope, float primary, float secondary, float tertiary) { 17 | this -> sender_tile = sender; 18 | this -> recipient_capture = recipient_capture; 19 | this -> recipient_feature = recipient_feature; 20 | 21 | if (feature != 0) { 22 | this -> features.clear(); 23 | this -> features.set(std::abs(feature) - 1, true); 24 | this -> signs.clear(); 25 | this -> signs.set(std::abs(feature) - 1, feature > 0); 26 | } 27 | 28 | this -> scope = scope; 29 | 30 | this -> code = Message::exploration_message; 31 | 32 | this -> _primary = primary; 33 | this -> _secondary = secondary; 34 | this -> _tertiary = tertiary; 35 | } 36 | 37 | void Message::exploitation(Tile const & sender, Tile const & recipient, Bitmask const & features, float primary, float secondary, float tertiary) { 38 | this -> sender_tile = sender; 39 | this -> recipient_tile = recipient; 40 | 41 | this -> features = features; 42 | this -> code = Message::exploitation_message; 43 | 44 | this -> _primary = primary; 45 | this -> _secondary = secondary; 46 | this -> _tertiary = tertiary; 47 | } 48 | 49 | Message & Message::operator=(Message const & other) { 50 | this -> sender_tile = other.sender_tile; 51 | this -> recipient_tile = other.recipient_tile; 52 | this -> recipient_capture = other.recipient_capture; 53 | this -> recipient_feature = other.recipient_feature; 54 | this -> feature = other.feature; 55 | this -> features = other.features; 56 | this -> signs = other.signs; 57 | this -> scope = other.scope; 58 | this -> code = other.code; 59 | this -> _primary = other._primary; 60 | this -> _secondary = other._secondary; 61 | this -> _tertiary = other._tertiary; 62 | return * this; 63 | }; 64 | 65 | bool Message::operator<(Message const & other) const { 66 | if (this -> _primary != other._primary) { 67 | return this -> _primary < other._primary; 68 | } else if (this -> _secondary != other._secondary) { 69 | return this -> _secondary < other._secondary; 70 | } else if (this -> _tertiary != other._tertiary) { 71 | return this -> _tertiary < other._tertiary; 72 | } 73 | return false; 74 | } 75 | 76 | bool Message::operator>(Message const & other) const { 77 | if (this -> _primary != other._primary) { 78 | return this -> _primary > other._primary; 79 | } else if (this -> _secondary != other._secondary) { 80 | return this -> _secondary > other._secondary; 81 | } else if (this -> _tertiary != other._tertiary) { 82 | return this -> _tertiary > other._tertiary; 83 | } 84 | return false; 85 | } 86 | 87 | bool Message::operator<=(Message const & other) const { 88 | if (this -> _primary != other._primary) { 89 | return this -> _primary < other._primary; 90 | } else if (this -> _secondary != other._secondary) { 91 | return this -> _secondary < other._secondary; 92 | } else if (this -> _tertiary != other._tertiary) { 93 | return this -> _tertiary < other._tertiary; 94 | } 95 | return true; 96 | } 97 | 98 | bool Message::operator>=(Message const & other) const { 99 | if (this -> _primary != other._primary) { 100 | return this -> _primary > other._primary; 101 | } else if (this -> _secondary != other._secondary) { 102 | return this -> _secondary > other._secondary; 103 | } else if (this -> _tertiary != other._tertiary) { 104 | return this -> _tertiary > other._tertiary; 105 | } 106 | return true; 107 | } 108 | 109 | bool Message::operator==(Message const & other) const { 110 | if (this -> code != other.code) { return false; } 111 | switch (this -> code) { 112 | case Message::exploration_message: { 113 | return this -> sender_tile == other.sender_tile 114 | && this -> recipient_capture == other.recipient_capture; 115 | // && this -> features == other.features 116 | // && this -> signs == other.signs 117 | // && this -> scope == other.scope; 118 | break; 119 | } 120 | case Message::exploitation_message: { 121 | // return this -> features == other.features 122 | // && this -> recipient_tile == other.recipient_tile; 123 | 124 | return this -> recipient_tile == other.recipient_tile; 125 | break; 126 | } 127 | default: { 128 | return false; 129 | break; 130 | } 131 | } 132 | } 133 | 134 | size_t Message::hash(void) const { 135 | size_t seed = 0; 136 | switch (this -> code) { 137 | case Message::exploration_message: { 138 | seed ^= this -> sender_tile.hash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); 139 | seed ^= this -> recipient_capture.hash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); 140 | // seed ^= this -> feature + 0x9e3779b9 + (seed << 6) + (seed >> 2); 141 | break; 142 | } 143 | case Message::exploitation_message: { 144 | // seed ^= this -> features.hash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); 145 | seed ^= this -> recipient_tile.hash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); 146 | break; 147 | } 148 | default: { 149 | break; 150 | } 151 | } 152 | return seed; 153 | } -------------------------------------------------------------------------------- /src/message.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MESSAGE_H 2 | #define MESSAGE_H 3 | 4 | #include 5 | 6 | #include "bitmask.hpp" 7 | #include "tile.hpp" 8 | 9 | // Container for messages in the priority queue 10 | // Messages priority dictates which vertex in the dependency graph will be worked on next 11 | class Message { 12 | public: 13 | static const char exploration_message = 0b00000000; // 14 | static const char exploitation_message = 0b00000001; // 15 | 16 | Message(void); 17 | ~Message(void); 18 | 19 | void initialize(unsigned int height, unsigned int width, unsigned int depth); 20 | 21 | // @param sender: A tile used to identify the key to a parent vertex 22 | // @param recipient_capture: A bitmask indicating the captured points of a child vertex 23 | // @param recipient_feature: A bitmask indicating the initial features of a child vertex 24 | // @param feature: an integer indicating the feature used by the parent to produce the child 25 | // @param scope: a float used to specify the risk tolerance of the parent to the child 26 | // @param primary, secondar, tertiary: hierarchical priority values used to order messages 27 | void exploration( 28 | Tile const & sender, 29 | Bitmask const & recipient_capture, 30 | Bitmask const & recipient_feature, 31 | int feature, 32 | float scope, 33 | float primary = 0, float secondary = 0, float tertiary = 0); 34 | 35 | // @param sender: A tile used to identify the key to a child vertex 36 | // @param recipient: A tile used to identify the key to a parent vertex 37 | // @param feature: an integer indicating the feature used by the parent to produce the child 38 | // @param primary, secondar, tertiary: hierarchical priority values used to order messages 39 | void exploitation( 40 | Tile const & sender, 41 | Tile const & recipient, 42 | Bitmask const & features, 43 | float primary = 0, float secondary = 0, float tertiary = 0); 44 | 45 | // Assignment operator used to transfer ownership of message data 46 | Message & operator=(Message const & other); 47 | 48 | // Comparison operators used to order messages in the priority queue 49 | bool operator==(Message const & other) const; 50 | bool operator<(Message const & other) const; 51 | bool operator>(Message const & other) const; 52 | bool operator<=(Message const & other) const; 53 | bool operator>=(Message const & other) const; 54 | 55 | size_t hash(void) const; 56 | 57 | Tile sender_tile; 58 | Tile recipient_tile; 59 | Bitmask recipient_capture; 60 | Bitmask recipient_feature; 61 | 62 | int feature; 63 | Bitmask features; 64 | Bitmask signs; 65 | float scope; 66 | 67 | char code; 68 | 69 | private: 70 | 71 | float _primary; 72 | float _secondary; 73 | float _tertiary; 74 | }; 75 | 76 | #endif -------------------------------------------------------------------------------- /src/model.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MODEL_H 2 | #define MODEL_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include "configuration.hpp" 16 | #include "encoder.hpp" 17 | #include "dataset.hpp" 18 | #include "graph.hpp" 19 | #include "state.hpp" 20 | 21 | using json = nlohmann::json; 22 | 23 | // Container for holding classification model extracted from the dependency graph 24 | class Model { 25 | public: 26 | Model(void); 27 | // Constructor for terminal node in a model 28 | // @param set: shared pointer to a bitmask that identifies the captured set of data points 29 | Model(std::shared_ptr set); 30 | 31 | // Constructor for non-terminal node in a model 32 | // @param binary_feature_index: the index of the feature used for splitting (after encoding) 33 | // @param negative: shared pointer to the model acting as the left subtree 34 | // @param positive: shared pointer to the model acting as the right subtree 35 | Model(unsigned int binary_feature_index, std::shared_ptr negative, std::shared_ptr positive); 36 | 37 | ~Model(void); 38 | 39 | // Hash generated from the leaf set of model 40 | size_t hash(void) const; 41 | 42 | void identify(key_type const & indentifier); 43 | bool identified(void); 44 | 45 | void translate_self(translation_type const & translation); 46 | void translate_negatives(translation_type const & translation); 47 | void translate_positives(translation_type const & translation); 48 | 49 | // Equality operator implemented by comparing the set of addresses of the bitmask of each leaf 50 | // @param other: other model to compare against 51 | // @returns true if the two models are provably equivalent 52 | // @note the equality comparison assumes that leaf bitmasks are not duplicated 53 | // this assumes that identical bitmasks are only copy by reference, not by value 54 | bool operator==(Model const & other) const; 55 | 56 | // @param sample: bitmask of binary features (encoded) used to make the prediction 57 | // @modifies prediction: string representation of the class that is predicted 58 | void predict(Bitmask const & sample, std::string & prediction) const; 59 | 60 | // @returns: the training loss incurred by this model 61 | float loss(void) const; 62 | 63 | // @returns: the complexity penalty incurred by this model 64 | float complexity(void) const; 65 | 66 | // @modifies node: JSON object representation of this model 67 | void to_json(json &node) const; 68 | void _to_json(json &node) const; 69 | 70 | void decode_json(json & node) const; 71 | void translate_json(json & node, translation_type const & main, translation_type const & alternative) const; 72 | 73 | void summarize(json & node) const; 74 | void intersect(json & src, json & dest) const; 75 | 76 | // @param spacing: number of spaces to used in the indentation format 77 | // @modifies serialization: string representation of the JSON object representation of this model 78 | void serialize(std::string & serialization, int const spacing = 0) const; 79 | 80 | key_type identifier; // Identifier for association to graph vertex 81 | 82 | bool terminal = false; // Flag specifying whether the node is terminal 83 | 84 | std::shared_ptr get_negative() const; // left subtree 85 | std::shared_ptr get_positive() const; // right subtree 86 | unsigned int get_feature() const; // index of the encoded feature 87 | unsigned int get_binary_target() const; // index of the encoded prediction 88 | std::string get_prediction() const; // prediction of terminal node 89 | 90 | private: 91 | // Addresses of the bitmasks of the leaf set 92 | void _partitions(std::vector< Bitmask * > & addresses) const; 93 | void partitions(std::vector< Bitmask * > & addresses) const; 94 | 95 | mutable size_t _hash = 0; 96 | 97 | mutable float cached_loss = -1; 98 | mutable float cached_complexity = -1; 99 | 100 | // Common members for both Terminal and Non-terminal instances 101 | std::string name; // Name of the decoded feature or decoded target 102 | std::string type; // Type of the decoded feature or decoded target 103 | 104 | // Non-terminal members 105 | unsigned int feature; // index of the decoded feature 106 | unsigned int binary_feature; // index of the encoded feature 107 | std::string relation; // relational operator to apply to the decoded feature 108 | std::string reference; // reference value to compare with the decoded feature 109 | std::shared_ptr negative; // left subtree 110 | std::shared_ptr positive; // right subtree 111 | translation_type self_translator; // self feature reordering 112 | translation_type negative_translator; // left subtree feature reordering 113 | translation_type positive_translator; // right subtree feature reordering 114 | 115 | // Terminal members 116 | unsigned int binary_target; // index of the encoded prediction 117 | std::string prediction; // string representation of the predicted value 118 | float _loss; // loss incurred by this leaf 119 | float _complexity; // complexity penalty incurred by this leaf 120 | std::shared_ptr< Bitmask > capture_set; // indicator specifying the points captured by this leaf 121 | }; 122 | 123 | inline std::shared_ptr Model::get_negative() const { 124 | return negative; 125 | }; // left subtree 126 | inline std::shared_ptr Model::get_positive() const { 127 | return positive; 128 | }; // right subtree 129 | inline unsigned int Model::get_binary_target() const { 130 | return binary_target; 131 | }; // index of the encoded prediction 132 | inline std::string Model::get_prediction() const { 133 | return prediction; 134 | }; // prediction of terminal node 135 | inline unsigned int Model::get_feature() const { 136 | return binary_feature; 137 | }; // index of the encoded feature 138 | 139 | namespace std { 140 | template <> 141 | struct hash< Model > { 142 | std::size_t operator()(Model const & model) const { 143 | return model.hash(); 144 | } 145 | }; 146 | 147 | template <> 148 | struct hash< Model * > { 149 | std::size_t operator()(Model * const model) const { 150 | return model -> hash(); 151 | } 152 | }; 153 | 154 | template <> 155 | struct equal_to< Model * > { 156 | bool operator()(Model * const left, Model * const right) const { return (* left) == (* right); } 157 | }; 158 | } 159 | 160 | namespace std { 161 | 162 | } 163 | 164 | #endif -------------------------------------------------------------------------------- /src/model_set.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MODEL_SET_H 2 | #define MODEL_SET_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include "configuration.hpp" 17 | #include "graph.hpp" 18 | #include "state.hpp" 19 | 20 | #include "cart_it.hpp" 21 | #include "additive_metrics.hpp" 22 | 23 | using json = nlohmann::json; 24 | 25 | class ModelSet; 26 | 27 | typedef std::shared_ptr model_set_p; 28 | // typedef std::vector< std::pair< model_set_p, model_set_p> > children_set_t; 29 | 30 | 31 | 32 | 33 | // 34 | // typedef std::tuple values_of_interest_t; 35 | typedef ValuesOfInterest values_of_interest_t; 36 | 37 | struct key_hash { 38 | std::size_t operator()(const values_of_interest_t &k) const { 39 | return k.hash(); 40 | } 41 | }; 42 | 43 | // Count of each values of interest 44 | typedef std::unordered_map values_of_interest_count_t; 45 | typedef std::unordered_map values_of_interest_mapping_t; 46 | 47 | 48 | enum ModelSetType { 49 | CLUSTERED_BY_OBJ, 50 | CLUSTERED_BY_TUPLE, 51 | }; 52 | 53 | // Container for holding classification model extracted from the dependency graph 54 | class ModelSet { 55 | public: 56 | ModelSet(ModelSetType type=CLUSTERED_BY_OBJ); 57 | // Constructor for terminal node in a model 58 | // @param set: shared pointer to a bitmask that identifies the captured set of data points 59 | ModelSet(std::shared_ptr set); 60 | 61 | // Constructor for terminal node for switching node type 62 | ModelSet(ModelSet* source); 63 | 64 | ~ModelSet(void); 65 | 66 | // Hash generated from the leaf set of model 67 | size_t hash(void) const; 68 | 69 | // Equality operator implemented by comparing the set of addresses of the bitmask of each leaf 70 | // @param other: other model to compare against 71 | // @returns true if the two models are provably equivalent 72 | // @note the equality comparison assumes that leaf bitmasks are not duplicated 73 | // this assumes that identical bitmasks are only copy by reference, not by value 74 | bool operator==(ModelSet const & other) const; 75 | 76 | // @returns: the training loss incurred by this model 77 | float loss(void) const; 78 | 79 | // @returns: the complexity penalty incurred by this model 80 | float complexity(void) const; 81 | 82 | // @returns: the objective value of this model 83 | // float objective_value(void) const; 84 | 85 | void insert(int feature, model_set_p & positive, model_set_p & negative); 86 | 87 | void merge_with_leaf(ModelSet* other); 88 | void merge_with_leaf(model_set_p & other); 89 | 90 | void merge(model_set_p & other); 91 | 92 | unsigned int get_binary_target() { 93 | return binary_target; 94 | }; 95 | 96 | // boost::multiprecision::uint128_t get_stored_model_count(); 97 | long long unsigned int get_stored_model_count(); 98 | 99 | // bad coding practice but whatever lol 100 | values_of_interest_t merge_values_of_interest_with_self(values_of_interest_t other) { 101 | assert(terminal); 102 | return values_of_interest + other; 103 | }; 104 | 105 | values_of_interest_count_t& get_values_of_interest_count(); 106 | values_of_interest_mapping_t& get_values_of_interest_mapping(); 107 | 108 | void construct_values_of_interest_count(); 109 | void construct_values_of_interest_mapping(); 110 | 111 | static void serialize(results_t source, std::string & serialization, int const spacing); 112 | static void serialize(values_of_interest_mapping_t source, std::string & serialization, int const spacing); 113 | static void convert_ptr_and_to_json(model_set_p const source, json & storage_arr, std::unordered_map & pointer_dictionary); 114 | static json convert_values_of_interest_to_array(values_of_interest_t values_of_interest); 115 | 116 | bool terminal = false; 117 | const ModelSetType type; 118 | Objective objective; 119 | 120 | private: 121 | 122 | values_of_interest_count_t values_of_interest_count; 123 | values_of_interest_mapping_t values_of_interest_mapping; 124 | 125 | 126 | // Non-terminal members 127 | std::unordered_map >> mapping; 128 | 129 | 130 | // boost::multiprecision::uint128_t stored_model_count = 0; 131 | long long unsigned int stored_model_count = 0; 132 | 133 | // Terminal members 134 | unsigned int binary_target; // index of the encoded prediction 135 | float _loss; // loss incurred by this leaf 136 | float _complexity; // complexity penalty incurred by this leaf 137 | ValuesOfInterest values_of_interest; 138 | 139 | friend class Trie; 140 | }; 141 | 142 | namespace std { 143 | 144 | } 145 | 146 | #endif -------------------------------------------------------------------------------- /src/optimizer/diagnosis/false_convergence.hpp: -------------------------------------------------------------------------------- 1 | void Optimizer::diagnose_false_convergence(void) { 2 | // diagnose_false_convergence(this -> root_set); 3 | return; 4 | } 5 | bool Optimizer::diagnose_false_convergence(key_type const & key) { 6 | // if (Configuration::diagnostics == false) { return false; } 7 | // std::unordered_set< Model * > results; 8 | // models(key, results); 9 | // if (results.size() > 0) { return false; } 10 | 11 | // // unsigned int m = State::dataset.width(); 12 | 13 | // // float epsilon = std::numeric_limits::epsilon(); 14 | // vertex_accessor task; 15 | // State::graph.vertices.find(task, key); 16 | 17 | // std::cout 18 | // << "Task(" << task -> second.capture_set().to_string() << ") is falsely convergent." 19 | // << " Bounds = " << "[" << task -> second.lowerbound() << ", " << task -> second.upperbound() << "]" 20 | // << ", Base = " << task -> second.base_objective() << std::endl; 21 | 22 | // bound_accessor bounds; 23 | // State::graph.bounds.find(bounds, task -> second.identifier()); 24 | // for (bound_iterator iterator = bounds -> second.begin(); iterator != bounds -> second.end(); ++iterator) { 25 | // int feature = std::get<0>(* iterator); 26 | // bool ready; 27 | // float lower = 0.0, upper = 0.0; 28 | // for (int sign = -1; sign <= 1; sign += 2) { 29 | // vertex_accessor child; 30 | // child_accessor key; 31 | // ready = ready && State::graph.children.find(key, std::make_pair(task -> second.identifier(), sign * (feature + 1))) 32 | // && State::graph.vertices.find(child, key -> second); 33 | // if (ready) { 34 | // lower += child -> second.lowerbound(); 35 | // upper += child -> second.upperbound(); 36 | // } 37 | // } 38 | // if (ready) { 39 | // std::get<1>(* iterator) = lower; 40 | // std::get<2>(* iterator) = upper; 41 | // } 42 | 43 | // if (std::get<2>(* iterator) > task -> second.upperbound() + std::numeric_limits::epsilon()) { continue; } 44 | 45 | // std::cout << "Task(" << key.to_string() << ")'s upper bound points to Feature " << feature << std::endl; 46 | 47 | // { 48 | // vertex_accessor child; 49 | // child_accessor key; 50 | // if (State::graph.children.find(key, std::make_pair(task -> second.identifier(), (feature + 1)))) { 51 | // diagnose_false_convergence(key -> second); 52 | // } 53 | // } 54 | // { 55 | // vertex_accessor child; 56 | // child_accessor key; 57 | // if (State::graph.children.find(key, std::make_pair(task -> second.identifier(), -(feature + 1)))) { 58 | // diagnose_false_convergence(key -> second); 59 | // } 60 | // } 61 | // } 62 | return false; 63 | } -------------------------------------------------------------------------------- /src/optimizer/diagnosis/trace.hpp: -------------------------------------------------------------------------------- 1 | 2 | void Optimizer::diagnostic_trace(int iteration, key_type const & focal_point) { 3 | json tracer = json::object(); 4 | tracer["directed"] = true; 5 | tracer["multigraph"] = false; 6 | tracer["graph"] = json::object(); 7 | tracer["graph"]["name"] = "GOSDT Trace"; 8 | tracer["links"] = json::array(); 9 | tracer["nodes"] = json::array(); 10 | diagnostic_trace(this -> root, tracer, focal_point); 11 | 12 | int indentation = 2; 13 | 14 | std::stringstream trace_name; 15 | trace_name << Configuration::trace << "/" << iteration << ".gml"; 16 | std::string trace_result = tracer.dump(indentation); 17 | std::ofstream out(trace_name.str()); 18 | out << trace_result; 19 | out.close(); 20 | return; 21 | } 22 | bool Optimizer::diagnostic_trace(key_type const & identifier, json & tracer, key_type const & focal_point) { 23 | vertex_accessor task_accessor; 24 | if (State::graph.vertices.find(task_accessor, identifier) == false) { return false; } 25 | Task & task = task_accessor -> second; 26 | 27 | json node = json::object(); 28 | node["id"] = task.identifier().to_string(); 29 | node["name"] = task.capture_set().to_string(); 30 | node["lowerbound"] = task.lowerbound(); 31 | node["upperbound"] = task.upperbound(); 32 | node["explored"] = true; // new version no longer stores unexplored nodes 33 | node["resolved"] = task.lowerbound() == task.upperbound(); 34 | node["focused"] = (task.identifier() == focal_point); 35 | tracer["nodes"].push_back(node); 36 | 37 | bound_accessor bounds; 38 | State::graph.bounds.find(bounds, task.identifier()); 39 | for (bound_iterator iterator = bounds -> second.begin(); iterator != bounds -> second.end(); ++iterator) { 40 | int feature = std::get<0>(* iterator); 41 | 42 | child_accessor left_key, right_key; 43 | if (State::graph.children.find(left_key, std::make_pair(identifier, -(feature + 1)))) { 44 | json link = json::object(); 45 | link["source"] = identifier.to_string(); 46 | link["target"] = left_key -> second.to_string(); 47 | link["feature"] = feature; 48 | link["condition"] = false; 49 | tracer["links"].push_back(link); 50 | diagnostic_trace(left_key -> second, tracer, focal_point); 51 | left_key.release(); 52 | } 53 | if (State::graph.children.find(right_key, std::make_pair(identifier, feature + 1))) { 54 | json link = json::object(); 55 | link["source"] = identifier.to_string(); 56 | link["target"] = right_key -> second.to_string(); 57 | link["feature"] = feature; 58 | link["condition"] = true; 59 | tracer["links"].push_back(link); 60 | diagnostic_trace(right_key -> second, tracer, focal_point); 61 | right_key.release(); 62 | } 63 | } 64 | return true; 65 | } 66 | -------------------------------------------------------------------------------- /src/optimizer/diagnosis/tree.hpp: -------------------------------------------------------------------------------- 1 | 2 | void Optimizer::diagnostic_tree(int iteration) { 3 | json tracer = json::object(); 4 | tracer["directed"] = true; 5 | tracer["multigraph"] = false; 6 | tracer["graph"] = json::object(); 7 | tracer["graph"]["name"] = "GOSDT Trace"; 8 | tracer["links"] = json::array(); 9 | tracer["nodes"] = json::array(); 10 | diagnostic_tree(this -> root, tracer); 11 | 12 | int indentation = 2; 13 | 14 | std::stringstream trace_name; 15 | trace_name << Configuration::tree << "/" << iteration << ".gml"; 16 | std::string trace_result = tracer.dump(indentation); 17 | std::ofstream out(trace_name.str()); 18 | out << trace_result; 19 | out.close(); 20 | 21 | return; 22 | } 23 | bool Optimizer::diagnostic_tree(key_type const & identifier, json & tracer) { 24 | vertex_accessor task_accessor; 25 | if (State::graph.vertices.find(task_accessor, identifier) == false) { return false; } 26 | Task & task = task_accessor -> second; 27 | 28 | json node = json::object(); 29 | node["id"] = identifier.to_string(); 30 | node["capture"] = task.capture_set().to_string(); 31 | node["support"] = task.support(); 32 | node["terminal"] = task.lowerbound() == task.upperbound(); 33 | 34 | 35 | if (task.lowerbound() == task.base_objective()) { 36 | tracer["nodes"].push_back(node); 37 | return true; 38 | } 39 | 40 | json scores = json::object(); 41 | 42 | unsigned int m = State::dataset.width(); 43 | unsigned int k = 0; 44 | float score_k = std::numeric_limits::max(); 45 | 46 | bound_accessor bounds; 47 | State::graph.bounds.find(bounds, task.identifier()); 48 | for (bound_iterator iterator = bounds -> second.begin(); iterator != bounds -> second.end(); ++iterator) { 49 | int feature = std::get<0>(* iterator); 50 | 51 | std::string type, relation, reference; 52 | State::dataset.encoder.encoding(feature, type, relation, reference); 53 | float upper = std::get<2>(* iterator); 54 | scores[reference] = upper; 55 | if (upper < score_k) { 56 | score_k = upper; 57 | k = feature; 58 | } 59 | } 60 | unsigned int decoded_index; 61 | std::string type, relation, reference; 62 | State::dataset.encoder.decode(k, & decoded_index); 63 | State::dataset.encoder.encoding(k, type, relation, reference); 64 | node["threshold"] = reference; 65 | node["scores"] = scores; 66 | tracer["nodes"].push_back(node); 67 | if (score_k < std::numeric_limits::max()) { 68 | child_accessor left_key, right_key; 69 | if (State::graph.children.find(left_key, std::make_pair(identifier, -(k + 1)))) { 70 | diagnostic_tree(left_key -> second, tracer); 71 | left_key.release(); 72 | } 73 | if (State::graph.children.find(right_key, std::make_pair(identifier, k + 1))) { 74 | diagnostic_tree(right_key -> second, tracer); 75 | right_key.release(); 76 | } 77 | } 78 | 79 | return true; 80 | } 81 | -------------------------------------------------------------------------------- /src/python_extension.cpp: -------------------------------------------------------------------------------- 1 | #include "python_extension.hpp" 2 | 3 | // @param args: contains a single string object which is a JSON string containing the algorithm configuration 4 | static PyObject * configure(PyObject * self, PyObject * args) { 5 | const char * configuration; 6 | if (!PyArg_ParseTuple(args, "s", & configuration)) { return NULL; } 7 | 8 | std::istringstream config_stream(configuration); 9 | GOSDT::configure(config_stream); 10 | 11 | return Py_BuildValue(""); 12 | } 13 | 14 | // @param args: contains a single string object which contains the training data in CSV form 15 | // @returns a string object containing a JSON array of all resulting models 16 | static PyObject * fit(PyObject * self, PyObject * args) { 17 | const char * dataset; 18 | if (!PyArg_ParseTuple(args, "s", & dataset)) { return NULL; } 19 | 20 | std::istringstream data_stream(dataset); 21 | GOSDT model; 22 | std::string result; 23 | model.fit(data_stream, result); 24 | 25 | return Py_BuildValue("s", result.c_str()); 26 | } 27 | 28 | // @returns the number of seconds spent training 29 | static PyObject * time(PyObject * self, PyObject * args) { return Py_BuildValue("f", GOSDT::time); } 30 | 31 | // //@ returns the system time elapsed 32 | // static PyObject * stime(PyObject * self, PyObject * args) { return Py_BuildValue("f", GOSDT::ru_stime); } 33 | 34 | // //@ returns the user time elapsed 35 | // static PyObject * utime(PyObject * self, PyObject * args) { return Py_BuildValue("f", GOSDT::ru_utime); } 36 | 37 | // //@ returns the maximum memory usage 38 | // static PyObject * maxmem(PyObject * self, PyObject * args) { return Py_BuildValue("i", GOSDT::ru_maxrss); } 39 | 40 | // //@ retursn the number of swaps 41 | // static PyObject * numswap(PyObject * self, PyObject * args) { return Py_BuildValue("i", GOSDT::ru_nswap); } 42 | 43 | // //@ returns the number of context switches 44 | // static PyObject * numctxtswitch(PyObject * self, PyObject * args) { return Py_BuildValue("i", GOSDT::ru_nivcsw); } 45 | 46 | // @returns the number of iterations spent training 47 | static PyObject * iterations(PyObject * self, PyObject * args) { return Py_BuildValue("i", GOSDT::iterations); } 48 | 49 | // // @returns the global lower bound at the end of training 50 | // static PyObject * lower_bound(PyObject * self, PyObject * args) { return Py_BuildValue("d", GOSDT::lower_bound); } 51 | 52 | // // @returns the global upper bound at the end of training 53 | // static PyObject * upper_bound(PyObject * self, PyObject * args) { return Py_BuildValue("d", GOSDT::upper_bound); } 54 | 55 | // // @returns the loss of the tree found at the end of training (or trees - if more than one was found they should have the same loss) 56 | // static PyObject * model_loss(PyObject * self, PyObject * args) { return Py_BuildValue("f", GOSDT::model_loss); } 57 | 58 | // @returns the number of vertices in the depency graph 59 | static PyObject * size(PyObject * self, PyObject * args) { return Py_BuildValue("i", GOSDT::size); } 60 | 61 | // @returns the current status code 62 | static PyObject * status(PyObject * self, PyObject * args) { return Py_BuildValue("i", GOSDT::status); } 63 | 64 | // // @returns the current git revision of the build 65 | // static PyObject * build_version(PyObject *self, PyObject *args) { return Py_BuildValue("s", BUILD_GIT_REV);} 66 | // // @returns the date on which the modules has been built 67 | // static PyObject * build_date(PyObject *self, PyObject *args) { return Py_BuildValue("s", BUILD_DATE);} 68 | // // @returns the host name where the module have been build 69 | // static PyObject * build_host(PyObject *self, PyObject *args) { return Py_BuildValue("s", BUILD_HOST);} 70 | 71 | // Define the list of methods for a module 72 | static PyMethodDef libgosdt_methods[] = { 73 | // { method name, method pointer, method parameter format, method description } 74 | {"configure", configure, METH_VARARGS, "Configures the algorithm using an input JSON string"}, 75 | {"fit", fit, METH_VARARGS, "Trains the model using an input CSV string"}, 76 | {"time", time, METH_NOARGS, "Number of seconds spent training"}, 77 | {"iterations", iterations, METH_NOARGS, "Number of iterations spent training"}, 78 | {"size", size, METH_NOARGS, "Number of vertices in the depency graph"}, 79 | {"status", status, METH_NOARGS, "Check the status code of the algorithm"}, 80 | // {"stime", stime, METH_NOARGS, "System time (sec) spent training"}, 81 | // {"utime", utime, METH_NOARGS, "User-time (sec) spent training"}, 82 | // {"maxmem", maxmem, METH_NOARGS, "Maximum memory used during training"}, 83 | // {"numswap", numswap, METH_NOARGS, "number of swaps during training"}, 84 | // {"numctxtswitch", numctxtswitch, METH_NOARGS, "number of context switches in training"}, 85 | // {"lower_bound", lower_bound, METH_NOARGS, "Check the lower_bound code of the algorithm"}, 86 | // {"upper_bound", upper_bound, METH_NOARGS, "Check the upper_bound code of the algorithm"}, 87 | // {"model_loss", model_loss, METH_NOARGS, "Check the model_loss code of the algorithm"}, 88 | // {"build_version", build_version, METH_NOARGS, "the build git revision" }, 89 | // {"build_date", build_date, METH_NOARGS, "the build the build date" }, 90 | // {"build_host", build_host, METH_NOARGS, "the build host" }, 91 | 92 | {NULL, NULL, 0, NULL} 93 | }; 94 | 95 | // Define the module 96 | static struct PyModuleDef libgosdt = { 97 | PyModuleDef_HEAD_INIT, 98 | "libgosdt", // Module Name 99 | "Trees FAst RashoMon Sets", // Module Description 100 | -1, // Size of per-interpreter state 101 | libgosdt_methods // Module methods 102 | }; 103 | 104 | // Initialize the module 105 | PyMODINIT_FUNC PyInit_libgosdt(void) { 106 | return PyModule_Create(&libgosdt); 107 | } -------------------------------------------------------------------------------- /src/python_extension.hpp: -------------------------------------------------------------------------------- 1 | #define PY_SSIZE_T_CLEAN 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #include "gosdt.hpp" -------------------------------------------------------------------------------- /src/queue.cpp: -------------------------------------------------------------------------------- 1 | #include "queue.hpp" 2 | 3 | Queue::Queue(void) { 4 | return; 5 | } 6 | 7 | Queue::~Queue(void) { 8 | return; 9 | } 10 | 11 | bool Queue::push(Message const & message) { 12 | message_type * internal_message = new message_type(); 13 | * internal_message = message; 14 | 15 | // Attempt to copy content into membership set 16 | if (this -> membership.insert(std::make_pair(internal_message, true))) { 17 | this -> queue.push(internal_message); 18 | return true; 19 | } else { 20 | delete internal_message; 21 | return false; 22 | } 23 | } 24 | 25 | bool Queue::empty(void) const { return size() == 0; } 26 | 27 | unsigned int Queue::size(void) const { return this -> queue.size(); } 28 | 29 | 30 | bool Queue::pop(Message & message) { 31 | message_type * internal_message; 32 | if (this -> queue.try_pop(internal_message)) { 33 | this -> membership.erase(internal_message); // remove membership 34 | message = * internal_message; 35 | 36 | delete internal_message; 37 | return true; 38 | } else { 39 | return false; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/queue.hpp: -------------------------------------------------------------------------------- 1 | #ifndef QUEUE_H 2 | #define QUEUE_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "bitmask.hpp" 14 | #include "configuration.hpp" 15 | #include "message.hpp" 16 | 17 | typedef Message message_type; 18 | 19 | class PriorityKeyComparator { 20 | public: 21 | // Note that the tbb::concurrent_priority_queue is implemented to pop the item with the highest priority value 22 | bool operator()(message_type const * left, message_type const * right) { 23 | return (* left) < (* right); 24 | } 25 | }; 26 | 27 | struct MembershipKeyHashCompare { 28 | static size_t hash(message_type * message) { 29 | return message -> hash(); 30 | } 31 | static bool equal(message_type * left, message_type * right) { 32 | if ((* left) == (* right)) { 33 | left -> features.bit_or(right -> features); 34 | right -> features.bit_or(left -> features); 35 | left -> signs.bit_or(right -> signs); 36 | right -> signs.bit_or(left -> signs); 37 | left -> scope = std::max(left -> scope, right -> scope); 38 | right -> scope = std::max(left -> scope, right -> scope); 39 | return true; 40 | } else { 41 | return false; 42 | } 43 | } 44 | }; 45 | 46 | typedef tbb::concurrent_priority_queue< message_type *, PriorityKeyComparator, 47 | tbb::scalable_allocator< message_type * > > queue_type; 48 | 49 | typedef tbb::concurrent_hash_map< message_type *, bool, MembershipKeyHashCompare, 50 | tbb::scalable_allocator>> membership_table_type; // FIREWOLF: Fix the static assertion error 51 | 52 | class Queue { 53 | public: 54 | Queue(void); 55 | ~Queue(void); 56 | 57 | // @param message: a message to be sent from one vertex to another 58 | // @returns true if the message was successfully enqueued and not rejected by the membership filter 59 | // @note higher priority comes before lower priority 60 | bool push(Message const & message); 61 | 62 | // @returns whether queue is empty 63 | bool empty(void) const; 64 | 65 | // @returns the size of the queue 66 | unsigned int size(void) const; 67 | 68 | // @requires message: the 4th item of message must contain an instance of identifier_type with the correct amount of pre-allocated memory to copy assign into 69 | // @param message: a tuple containing the bitmask address, blocks 70 | // @param index: the particular channel (one of the queues) to pop from 71 | // @modifes message: message will be overwritten with a copy the content of the received message 72 | bool pop(Message & message); 73 | 74 | private: 75 | // map containing uniquely identified messages that are currently in queue 76 | membership_table_type membership; 77 | 78 | queue_type queue; // queue containing pending messages 79 | }; 80 | 81 | #endif -------------------------------------------------------------------------------- /src/state.cpp: -------------------------------------------------------------------------------- 1 | #include "state.hpp" 2 | 3 | Dataset State::dataset = Dataset(); 4 | Graph State::graph = Graph(); 5 | Queue State::queue = Queue(); 6 | std::vector< LocalState > State::locals = std::vector< LocalState >(); 7 | int State::status = 0; 8 | 9 | void State::initialize(std::istream & data_source, unsigned int workers) { 10 | State::dataset.load(data_source); 11 | State::graph = Graph(); 12 | State::queue = Queue(); 13 | State::locals.resize(workers); 14 | for (unsigned int i = 0; i < workers; ++i) { 15 | State::locals[i].initialize(dataset.height(), dataset.width(), dataset.depth()); 16 | } 17 | } 18 | 19 | 20 | void State::reset(void) { 21 | State::graph = Graph(); 22 | State::queue = Queue(); 23 | State::locals.clear(); 24 | State::dataset.clear(); 25 | } 26 | 27 | void State::reset_except_dataset(void) { 28 | State::graph = Graph(); 29 | State::queue = Queue(); 30 | for (unsigned int i = 0; i < locals.size(); ++i) { 31 | State::locals[i].initialize(dataset.height(), dataset.width(), dataset.depth()); 32 | } 33 | } -------------------------------------------------------------------------------- /src/state.hpp: -------------------------------------------------------------------------------- 1 | #ifndef STATE_H 2 | #define STATE_H 3 | 4 | class State; 5 | 6 | #include "dataset.hpp" 7 | #include "graph.hpp" 8 | #include "queue.hpp" 9 | #include "local_state.hpp" 10 | 11 | // Container of all data structures capturing the state of the optimization 12 | // Here we separate the memory used by the algorithm into two spaces: Global and Local 13 | // Global space is memory that all threads have access to, but is either read-only or is protected by locking mechanisms 14 | // Local space is memory that is partitioned components such each thread has unrestriced access but only to one component 15 | 16 | // Local space acts as an "extension" of the stack in a sense that the stack memory semantically belongs to a particular thread. 17 | // The actual location is stored on heap, although hypothetically we can store this on the stack 18 | 19 | class State { 20 | public: 21 | 22 | // Global state to which all thread shares access 23 | static Dataset dataset; 24 | static Graph graph; 25 | static Queue queue; 26 | static int status; 27 | 28 | // Local state to which each thread has exclusive access to a single entry 29 | static std::vector< LocalState > locals; 30 | 31 | static void initialize(std::istream & data_source, unsigned int workers = 1); 32 | static void reset(void); 33 | static void reset_except_dataset(void); 34 | 35 | }; 36 | 37 | #endif -------------------------------------------------------------------------------- /src/task.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TASK_H 2 | #define TASK_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | class Task; 14 | 15 | #include "bitmask.hpp" 16 | #include "configuration.hpp" 17 | #include "dataset.hpp" 18 | //#include "graph.hpp" // FIREWOLF: Circular references: Moved to cpp. 19 | #include "integrity_violation.hpp" 20 | #include "queue.hpp" 21 | //#include "state.hpp" // FIREWOLF: Circular references: Moved to cpp. 22 | #include "types.hpp" 23 | 24 | class Task { 25 | public: 26 | Task(void); 27 | 28 | // @param capture_set: indicator for which data points are captured 29 | // @param feature_set: indicator for which features are still active 30 | Task(Bitmask const & capture_set, Bitmask const & feature_set, unsigned int id, bool rashomon_flag = false); 31 | 32 | // @returns the support of the this task 33 | float support(void) const; 34 | 35 | // @returns the objective lowerbound of this task 36 | float lowerbound(void) const; 37 | 38 | // @return the objective upperbound of this task 39 | float upperbound(void) const; 40 | 41 | float lowerscope(void) const; 42 | float upperscope(void) const; 43 | void scope(float new_scope); 44 | 45 | float rashomon_bound(void) const; 46 | void set_rashomon_flag(void); 47 | void set_rashomon_bound(float); 48 | 49 | 50 | // @return the objective optimality gap of this task 51 | float uncertainty(void) const; 52 | 53 | // @return the objective risk of not splitting 54 | float base_objective(void) const; 55 | 56 | // @return the Alkaike information of the captured data 57 | float information(void) const; 58 | 59 | // @return a bitmask representing the points captured by this task 60 | Bitmask const & capture_set(void) const; 61 | 62 | // @return a bitmask representing the features that are not yet pruned 63 | Bitmask const & feature_set(void) const; 64 | 65 | Tile & identifier(void); 66 | Tile & parent(void); 67 | std::vector & order(void); 68 | 69 | // @modifies: prunes features 70 | void prune_feature(unsigned int id); 71 | 72 | // @modifies: inserts children into the cache based on the currently non-pruned features 73 | void create_children(unsigned int id); 74 | 75 | // @modifies: prunes features 76 | void prune_features(unsigned int id); 77 | 78 | // @modifies: prunes features based on the indifference bound within adjacent thresholds of ordinal features 79 | void continuous_feature_exchange(unsigned int id); 80 | 81 | // @modifes: prunes features based on the indifference bound for all feature pairs 82 | void feature_exchange(unsigned int id); 83 | 84 | void send_explorers(float scope, unsigned int id); 85 | 86 | void send_explorer(Task const & child, float scope, int feature, unsigned int id); 87 | 88 | bool update(float lower, float upper, int optimal_feature); 89 | 90 | // observer method used for debugging 91 | std::string inspect(void) const; 92 | float maximum_scope = 0; 93 | private: 94 | Tile _identifier; 95 | Bitmask _capture_set; 96 | Bitmask _feature_set; 97 | 98 | std::vector _order; 99 | 100 | float _support; 101 | float _base_objective; 102 | float _information; 103 | 104 | float _lowerbound = -std::numeric_limits::max(); 105 | float _upperbound = std::numeric_limits::max(); 106 | 107 | float _context_lowerbound = 0.0; 108 | float _context_upperbound = 0.0; 109 | 110 | float _lowerscope = -std::numeric_limits::max(); 111 | float _upperscope = std::numeric_limits::max(); 112 | float _coverage = -std::numeric_limits::max(); 113 | 114 | float _rashomon_bound = std::numeric_limits::max(); 115 | 116 | int _optimal_feature = -1; // Feature index set if part of the oracle model 117 | 118 | bool _rashomon_flag = false; // 119 | }; 120 | 121 | #endif -------------------------------------------------------------------------------- /src/tile.cpp: -------------------------------------------------------------------------------- 1 | #include "tile.hpp" 2 | 3 | Tile::Tile(Bitmask const & content, unsigned int width) : _content(content), _width(width) {} 4 | 5 | Tile::Tile(Bitmask const & samples, Bitmask const & features, unsigned int id) { 6 | 7 | } 8 | 9 | Tile::Tile(void) {}; 10 | Tile::~Tile(void) {}; 11 | 12 | Tile & Tile::operator=(Tile const & other) { 13 | this -> _content = other._content; 14 | this -> _width = other._width; 15 | return * this; 16 | }; 17 | 18 | bool Tile::operator==(Tile const & other) const { 19 | return (this -> _width == other._width) && (this -> _content == other._content); 20 | } 21 | 22 | bool Tile::operator!=(Tile const & other) const { 23 | return !(* this == other); 24 | } 25 | 26 | size_t Tile::hash(void) const { 27 | size_t seed = this -> _width; 28 | seed ^= this -> _content.hash() + 0x9e3779b9 + (seed << 6) + (seed >> 2); 29 | return seed; 30 | } 31 | 32 | unsigned int Tile::size(void) const { return this -> _content.size(); } 33 | 34 | void Tile::resize(unsigned int new_size) { this -> _content.resize(new_size); } 35 | 36 | Bitmask & Tile::content(void) { return this -> _content; } 37 | void Tile::content(Bitmask const & _new_content) { this -> _content = _new_content; } 38 | 39 | unsigned int Tile::width(void) const { return this -> _width; } 40 | void Tile::width(unsigned int _new_width) { this -> _width = _new_width; } 41 | 42 | std::string Tile::to_string(void) const { 43 | if (this -> _content.size() == 0) { return "Empty"; } 44 | 45 | std::stringstream stream; 46 | // for (unsigned int i = 0; i < this -> _content.size(); ++i) { 47 | // stream << (int)(this -> _content.get(i)); 48 | // if (((i + 1) % this -> _width) == 0 && i < this -> _content.size() - 1) { 49 | // stream << std::endl; 50 | // } 51 | // } 52 | stream << this -> _width; 53 | stream << " : "; 54 | stream << this -> _content.to_string(); 55 | return stream.str(); 56 | } -------------------------------------------------------------------------------- /src/tile.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TILE_H 2 | #define TILE_H 3 | 4 | #include 5 | #include 6 | 7 | #include "bitmask.hpp" 8 | 9 | // Container for tiles which represent an equivalence class of data sets 10 | class Tile { 11 | public: 12 | // @param content: A bitmask containing the bits of a binary matrix in linearized format 13 | // @param width: The width of the matrix, used for delinearization 14 | Tile(Bitmask const & content, unsigned int width); 15 | 16 | // @param samples: an indicator of samples that this tile must capture 17 | // @param features: an indicator of features that this tile must capture 18 | // @param id: The id of the local state used when a buffer is needed 19 | Tile(Bitmask const & samples, Bitmask const & features, unsigned int id); 20 | Tile(void); 21 | ~Tile(void); 22 | 23 | // Assignment operator used to transfer ownership of data 24 | Tile & operator=(Tile const & other); 25 | 26 | // Comparison operators used to match different tiles 27 | bool operator==(Tile const & other) const; 28 | bool operator!=(Tile const & other) const; 29 | 30 | // Accessors used to inspect/modify the content 31 | Bitmask & content(void); 32 | void content(Bitmask const & _new_content); 33 | unsigned int width(void) const; 34 | void width(unsigned int _new_width); 35 | 36 | size_t hash(void) const; 37 | 38 | unsigned int size(void) const; 39 | void resize(unsigned int new_size); 40 | 41 | std::string to_string(void) const; 42 | 43 | private: 44 | Bitmask _content; 45 | unsigned int _width; 46 | }; 47 | 48 | // Overrides for STD containers 49 | namespace std { 50 | template <> 51 | struct hash< Tile > { 52 | std::size_t operator()(Tile const & tile) const { return tile.hash(); } 53 | }; 54 | 55 | template <> 56 | struct equal_to< Tile > { 57 | bool operator()(Tile const & left, Tile const & right) const { return left == right; } 58 | }; 59 | } 60 | 61 | #endif -------------------------------------------------------------------------------- /src/trie.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "model.hpp" 13 | #include "model_set.hpp" 14 | 15 | #include "cart_it.hpp" 16 | 17 | // #include "utils.h" 18 | // #include "alloc.h" 19 | // #include "rule.h" 20 | 21 | enum class DataStruct { Tree, Queue, Pmap }; 22 | template 23 | using tracking_vector = std::vector>; 24 | typedef struct rule { 25 | char *features; /* Representation of the rule. */ 26 | int support; /* Number of 1's in truth table. */ 27 | int cardinality; 28 | int *ids; 29 | // VECTOR truthtable; /* Truth table; one bit per sample. */ 30 | } rule_t; 31 | 32 | class Node { 33 | public: 34 | Node(); 35 | Node(std::vector id, Node *parent); 36 | 37 | virtual ~Node(); 38 | 39 | inline std::vector id() const; 40 | 41 | // Returns pair of prefixes and predictions for the path from this node to 42 | // the root 43 | inline std::pair, DataStruct::Tree>, 44 | tracking_vector> 45 | get_prefix_and_predictions(); 46 | 47 | inline size_t depth() const; 48 | inline Node *child(std::vector idx); 49 | inline Node *parent() const; 50 | inline void delete_child(std::vector idx); 51 | inline size_t num_children() const; 52 | 53 | inline typename std::map, Node *>::iterator 54 | children_begin(); 55 | inline typename std::map, Node *>::iterator children_end(); 56 | virtual inline double get_curiosity() { return 0.0; } 57 | 58 | bool terminal = false; // Flag specifying whether the node is terminal 59 | 60 | // @modifies node: JSON object representation of this model 61 | void to_json(json &node) const; 62 | 63 | protected: 64 | std::map, Node *> children_; 65 | 66 | Node *parent_; 67 | size_t depth_; 68 | std::vector id_; 69 | 70 | // Terminal members 71 | float loss; 72 | float complexity; 73 | 74 | friend class Trie; 75 | }; 76 | 77 | class Trie { 78 | public: 79 | Trie(){}; 80 | Trie(bool calculate_size, char const *type); 81 | ~Trie(); 82 | 83 | Node *construct_node(std::vector new_rule, Node *parent); 84 | 85 | inline size_t num_nodes() const; 86 | inline size_t num_evaluated() const; 87 | inline Node *root() const; 88 | 89 | void insert_model(const Model *model); 90 | 91 | void insert_model_set(const model_set_p model); 92 | // @param models: a vector of children sets. Any element in the cartesian 93 | // product of the sets is in the rashomon set and (supposedly) has the same 94 | // objective value by design 95 | void insert_model_set_children( 96 | const std::vector>> 97 | &models, 98 | Node *currNode, values_of_interest_t values_of_interest); 99 | 100 | void finalize_leaf_node(Node *currNode, 101 | values_of_interest_t values_of_interest); 102 | 103 | inline void increment_num_evaluated(); 104 | inline void decrement_num_nodes(); 105 | inline int ablation() const; 106 | inline bool calculate_size() const; 107 | 108 | void insert_root(); 109 | void insert(Node *node); 110 | void insert_if_not_exist(std::vector feats, Node *currNode, 111 | Node *&child); 112 | void prune_up(Node *node); 113 | Node * 114 | check_prefix(tracking_vector, DataStruct::Tree> &prefix); 115 | 116 | // @param spacing: number of spaces to used in the indentation format 117 | // @modifies serialization: string representation of the JSON object 118 | // representation of this model 119 | void serialize(std::string &serialization, int const spacing = 0) const; 120 | 121 | protected: 122 | Node *root_; 123 | 124 | size_t num_nodes_; 125 | size_t num_evaluated_; 126 | bool calculate_size_; 127 | 128 | char const *type_; 129 | void gc_helper(Node *node); 130 | }; 131 | 132 | inline std::vector Node::id() const { return id_; } 133 | 134 | // A function for debugging. 135 | // inline std::pair, DataStruct::Tree>, 136 | // tracking_vector > 137 | // Node::get_prefix_and_predictions() { 138 | // tracking_vector, DataStruct::Tree> prefix; 139 | // tracking_vector predictions; 140 | // tracking_vector, DataStruct::Tree>::iterator it1 = 141 | // prefix.begin(); tracking_vector::iterator it2 = 142 | // predictions.begin(); Node* node = this; for(size_t i = depth_; i > 0; 143 | // --i) { 144 | // it1 = prefix.insert(it1, node->id()); 145 | // it2 = predictions.insert(it2, node->prediction()); 146 | // node = node->parent(); 147 | // } 148 | // return std::make_pair(prefix, predictions); 149 | // } 150 | 151 | inline size_t Node::depth() const { return depth_; } 152 | 153 | inline Node *Node::child(std::vector idx) { 154 | typename std::map, Node *>::iterator iter; 155 | iter = children_.find(idx); 156 | if (iter == children_.end()) 157 | return NULL; 158 | else 159 | return iter->second; 160 | } 161 | 162 | inline void Node::delete_child(std::vector idx) { children_.erase(idx); } 163 | 164 | inline size_t Node::num_children() const { return children_.size(); } 165 | 166 | inline typename std::map, Node *>::iterator 167 | Node::children_begin() { 168 | return children_.begin(); 169 | } 170 | 171 | inline typename std::map, Node *>::iterator 172 | Node::children_end() { 173 | return children_.end(); 174 | } 175 | 176 | inline Node *Node::parent() const { return parent_; } 177 | 178 | inline size_t Trie::num_nodes() const { return num_nodes_; } 179 | 180 | inline size_t Trie::num_evaluated() const { return num_evaluated_; } 181 | 182 | inline Node *Trie::root() const { return root_; } 183 | 184 | inline bool Trie::calculate_size() const { return calculate_size_; } 185 | 186 | /* 187 | * Increment number of nodes evaluated after performing incremental computation 188 | * in evaluate_children. 189 | */ 190 | inline void Trie::increment_num_evaluated() { ++num_evaluated_; } 191 | 192 | /* 193 | * Called whenever a node is deleted from the tree. 194 | */ 195 | inline void Trie::decrement_num_nodes() { --num_nodes_; } 196 | 197 | void delete_subtree(Trie *tree, Node *node, bool destructive, 198 | bool update_remaining_state_space); 199 | -------------------------------------------------------------------------------- /src/types.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TYPES_H 2 | #define TYPES_H 3 | 4 | #include "tile.hpp" 5 | 6 | #endif -------------------------------------------------------------------------------- /src/version.hpp: -------------------------------------------------------------------------------- 1 | 2 | #ifndef _VERSION_ 3 | #define _VERSION_ 1 4 | 5 | #ifndef BUILD_GIT_REV 6 | #define BUILD_GIT_REV "unknown" 7 | #endif 8 | 9 | #ifndef BUILD_HOST 10 | #define BUILD_HOST "unknown" 11 | #endif 12 | 13 | #ifndef BUILD_DATE 14 | #define BUILD_DATE "20210915-162422" 15 | #endif 16 | 17 | #endif 18 | -------------------------------------------------------------------------------- /test/.dirstamp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-systopia/treeFarms/5508f6063cbcac5124e164c72aafa9afc31f683d/test/.dirstamp -------------------------------------------------------------------------------- /test/fixtures/binary_sepal.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,class 2 | 5.1,3.5,1 3 | 4.9,3.0,1 4 | 4.7,3.2,1 5 | 4.6,3.1,1 6 | 5.0,3.6,1 7 | 5.4,3.9,1 8 | 4.6,3.4,1 9 | 5.0,3.4,1 10 | 4.4,2.9,1 11 | 4.9,3.1,1 12 | 5.4,3.7,1 13 | 4.8,3.4,1 14 | 4.8,3.0,1 15 | 4.3,3.0,1 16 | 5.8,4.0,1 17 | 5.7,4.4,1 18 | 5.4,3.9,1 19 | 5.1,3.5,1 20 | 5.7,3.8,1 21 | 5.1,3.8,1 22 | 5.4,3.4,1 23 | 5.1,3.7,1 24 | 4.6,3.6,1 25 | 5.1,3.3,1 26 | 4.8,3.4,1 27 | 5.0,3.0,1 28 | 5.0,3.4,1 29 | 5.2,3.5,1 30 | 5.2,3.4,1 31 | 4.7,3.2,1 32 | 4.8,3.1,1 33 | 5.4,3.4,1 34 | 5.2,4.1,1 35 | 5.5,4.2,1 36 | 4.9,3.1,1 37 | 5.0,3.2,1 38 | 5.5,3.5,1 39 | 4.9,3.1,1 40 | 4.4,3.0,1 41 | 5.1,3.4,1 42 | 5.0,3.5,1 43 | 4.5,2.3,1 44 | 4.4,3.2,1 45 | 5.0,3.5,1 46 | 5.1,3.8,1 47 | 4.8,3.0,1 48 | 5.1,3.8,1 49 | 4.6,3.2,1 50 | 5.3,3.7,1 51 | 5.0,3.3,1 52 | 7.0,3.2,0 53 | 6.4,3.2,0 54 | 6.9,3.1,0 55 | 5.5,2.3,0 56 | 6.5,2.8,0 57 | 5.7,2.8,0 58 | 6.3,3.3,0 59 | 4.9,2.4,0 60 | 6.6,2.9,0 61 | 5.2,2.7,0 62 | 5.0,2.0,0 63 | 5.9,3.0,0 64 | 6.0,2.2,0 65 | 6.1,2.9,0 66 | 5.6,2.9,0 67 | 6.7,3.1,0 68 | 5.6,3.0,0 69 | 5.8,2.7,0 70 | 6.2,2.2,0 71 | 5.6,2.5,0 72 | 5.9,3.2,0 73 | 6.1,2.8,0 74 | 6.3,2.5,0 75 | 6.1,2.8,0 76 | 6.4,2.9,0 77 | 6.6,3.0,0 78 | 6.8,2.8,0 79 | 6.7,3.0,0 80 | 6.0,2.9,0 81 | 5.7,2.6,0 82 | 5.5,2.4,0 83 | 5.5,2.4,0 84 | 5.8,2.7,0 85 | 6.0,2.7,0 86 | 5.4,3.0,0 87 | 6.0,3.4,0 88 | 6.7,3.1,0 89 | 6.3,2.3,0 90 | 5.6,3.0,0 91 | 5.5,2.5,0 92 | 5.5,2.6,0 93 | 6.1,3.0,0 94 | 5.8,2.6,0 95 | 5.0,2.3,0 96 | 5.6,2.7,0 97 | 5.7,3.0,0 98 | 5.7,2.9,0 99 | 6.2,2.9,0 100 | 5.1,2.5,0 101 | 5.7,2.8,0 102 | 6.3,3.3,0 103 | 5.8,2.7,0 104 | 7.1,3.0,0 105 | 6.3,2.9,0 106 | 6.5,3.0,0 107 | 7.6,3.0,0 108 | 4.9,2.5,0 109 | 7.3,2.9,0 110 | 6.7,2.5,0 111 | 7.2,3.6,0 112 | 6.5,3.2,0 113 | 6.4,2.7,0 114 | 6.8,3.0,0 115 | 5.7,2.5,0 116 | 5.8,2.8,0 117 | 6.4,3.2,0 118 | 6.5,3.0,0 119 | 7.7,3.8,0 120 | 7.7,2.6,0 121 | 6.0,2.2,0 122 | 6.9,3.2,0 123 | 5.6,2.8,0 124 | 7.7,2.8,0 125 | 6.3,2.7,0 126 | 6.7,3.3,0 127 | 7.2,3.2,0 128 | 6.2,2.8,0 129 | 6.1,3.0,0 130 | 6.4,2.8,0 131 | 7.2,3.0,0 132 | 7.4,2.8,0 133 | 7.9,3.8,0 134 | 6.4,2.8,0 135 | 6.3,2.8,0 136 | 6.1,2.6,0 137 | 7.7,3.0,0 138 | 6.3,3.4,0 139 | 6.4,3.1,0 140 | 6.0,3.0,0 141 | 6.9,3.1,0 142 | 6.7,3.1,0 143 | 6.9,3.1,0 144 | 5.8,2.7,0 145 | 6.8,3.2,0 146 | 6.7,3.3,0 147 | 6.7,3.0,0 148 | 6.3,2.5,0 149 | 6.5,3.0,0 150 | 6.2,3.4,0 151 | 5.9,3.0,0 -------------------------------------------------------------------------------- /test/fixtures/binary_sepal.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "false": { 4 | "complexity": 0.05000000074505806, 5 | "loss": 0.046666666865348816, 6 | "name": "class", 7 | "prediction": 1 8 | }, 9 | "feature": 0, 10 | "name": "sepal_length", 11 | "reference": 5.45, 12 | "relation": ">=", 13 | "true": { 14 | "complexity": 0.05000000074505806, 15 | "loss": 0.03333333507180214, 16 | "name": "class", 17 | "prediction": 0 18 | } 19 | } 20 | ] -------------------------------------------------------------------------------- /test/fixtures/dataset.csv: -------------------------------------------------------------------------------- 1 | f1,f2,f3,f4,f5,l 2 | 1,0,0,0,0,TFF 3 | 1,0,0,0,0,TFF 4 | 0,1,0,0,0,FTF 5 | 0,1,0,0,0,FTF 6 | 0,0,1,0,0,FFT 7 | 0,0,1,0,0,FFT 8 | 0,0,0,1,0,TFF 9 | 0,0,0,1,0,TFF 10 | 0,0,0,0,1,FTF 11 | 0,0,0,0,1,FTT -------------------------------------------------------------------------------- /test/fixtures/dataset.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "false": { 4 | "false": { 5 | "false": { 6 | "complexity": 0.05000000074505806, 7 | "loss": 0.10000000149011612, 8 | "name": "l", 9 | "prediction": "FTF" 10 | }, 11 | "feature": 3, 12 | "name": "f4", 13 | "reference": 1, 14 | "relation": "==", 15 | "true": { 16 | "complexity": 0.05000000074505806, 17 | "loss": 0.0, 18 | "name": "l", 19 | "prediction": "TFF" 20 | } 21 | }, 22 | "feature": 2, 23 | "name": "f3", 24 | "reference": 1, 25 | "relation": "==", 26 | "true": { 27 | "complexity": 0.05000000074505806, 28 | "loss": 0.0, 29 | "name": "l", 30 | "prediction": "FFT" 31 | } 32 | }, 33 | "feature": 0, 34 | "name": "f1", 35 | "reference": 1, 36 | "relation": "==", 37 | "true": { 38 | "complexity": 0.05000000074505806, 39 | "loss": 0.0, 40 | "name": "l", 41 | "prediction": "TFF" 42 | } 43 | } 44 | ] -------------------------------------------------------------------------------- /test/fixtures/sepal.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,class 2 | 5.1,3.5,Iris-setosa 3 | 4.9,3.0,Iris-setosa 4 | 4.7,3.2,Iris-setosa 5 | 4.6,3.1,Iris-setosa 6 | 5.0,3.6,Iris-setosa 7 | 5.4,3.9,Iris-setosa 8 | 4.6,3.4,Iris-setosa 9 | 5.0,3.4,Iris-setosa 10 | 4.4,2.9,Iris-setosa 11 | 4.9,3.1,Iris-setosa 12 | 5.4,3.7,Iris-setosa 13 | 4.8,3.4,Iris-setosa 14 | 4.8,3.0,Iris-setosa 15 | 4.3,3.0,Iris-setosa 16 | 5.8,4.0,Iris-setosa 17 | 5.7,4.4,Iris-setosa 18 | 5.4,3.9,Iris-setosa 19 | 5.1,3.5,Iris-setosa 20 | 5.7,3.8,Iris-setosa 21 | 5.1,3.8,Iris-setosa 22 | 5.4,3.4,Iris-setosa 23 | 5.1,3.7,Iris-setosa 24 | 4.6,3.6,Iris-setosa 25 | 5.1,3.3,Iris-setosa 26 | 4.8,3.4,Iris-setosa 27 | 5.0,3.0,Iris-setosa 28 | 5.0,3.4,Iris-setosa 29 | 5.2,3.5,Iris-setosa 30 | 5.2,3.4,Iris-setosa 31 | 4.7,3.2,Iris-setosa 32 | 4.8,3.1,Iris-setosa 33 | 5.4,3.4,Iris-setosa 34 | 5.2,4.1,Iris-setosa 35 | 5.5,4.2,Iris-setosa 36 | 4.9,3.1,Iris-setosa 37 | 5.0,3.2,Iris-setosa 38 | 5.5,3.5,Iris-setosa 39 | 4.9,3.1,Iris-setosa 40 | 4.4,3.0,Iris-setosa 41 | 5.1,3.4,Iris-setosa 42 | 5.0,3.5,Iris-setosa 43 | 4.5,2.3,Iris-setosa 44 | 4.4,3.2,Iris-setosa 45 | 5.0,3.5,Iris-setosa 46 | 5.1,3.8,Iris-setosa 47 | 4.8,3.0,Iris-setosa 48 | 5.1,3.8,Iris-setosa 49 | 4.6,3.2,Iris-setosa 50 | 5.3,3.7,Iris-setosa 51 | 5.0,3.3,Iris-setosa 52 | 7.0,3.2,Iris-versicolor 53 | 6.4,3.2,Iris-versicolor 54 | 6.9,3.1,Iris-versicolor 55 | 5.5,2.3,Iris-versicolor 56 | 6.5,2.8,Iris-versicolor 57 | 5.7,2.8,Iris-versicolor 58 | 6.3,3.3,Iris-versicolor 59 | 4.9,2.4,Iris-versicolor 60 | 6.6,2.9,Iris-versicolor 61 | 5.2,2.7,Iris-versicolor 62 | 5.0,2.0,Iris-versicolor 63 | 5.9,3.0,Iris-versicolor 64 | 6.0,2.2,Iris-versicolor 65 | 6.1,2.9,Iris-versicolor 66 | 5.6,2.9,Iris-versicolor 67 | 6.7,3.1,Iris-versicolor 68 | 5.6,3.0,Iris-versicolor 69 | 5.8,2.7,Iris-versicolor 70 | 6.2,2.2,Iris-versicolor 71 | 5.6,2.5,Iris-versicolor 72 | 5.9,3.2,Iris-versicolor 73 | 6.1,2.8,Iris-versicolor 74 | 6.3,2.5,Iris-versicolor 75 | 6.1,2.8,Iris-versicolor 76 | 6.4,2.9,Iris-versicolor 77 | 6.6,3.0,Iris-versicolor 78 | 6.8,2.8,Iris-versicolor 79 | 6.7,3.0,Iris-versicolor 80 | 6.0,2.9,Iris-versicolor 81 | 5.7,2.6,Iris-versicolor 82 | 5.5,2.4,Iris-versicolor 83 | 5.5,2.4,Iris-versicolor 84 | 5.8,2.7,Iris-versicolor 85 | 6.0,2.7,Iris-versicolor 86 | 5.4,3.0,Iris-versicolor 87 | 6.0,3.4,Iris-versicolor 88 | 6.7,3.1,Iris-versicolor 89 | 6.3,2.3,Iris-versicolor 90 | 5.6,3.0,Iris-versicolor 91 | 5.5,2.5,Iris-versicolor 92 | 5.5,2.6,Iris-versicolor 93 | 6.1,3.0,Iris-versicolor 94 | 5.8,2.6,Iris-versicolor 95 | 5.0,2.3,Iris-versicolor 96 | 5.6,2.7,Iris-versicolor 97 | 5.7,3.0,Iris-versicolor 98 | 5.7,2.9,Iris-versicolor 99 | 6.2,2.9,Iris-versicolor 100 | 5.1,2.5,Iris-versicolor 101 | 5.7,2.8,Iris-versicolor 102 | 6.3,3.3,Iris-virginica 103 | 5.8,2.7,Iris-virginica 104 | 7.1,3.0,Iris-virginica 105 | 6.3,2.9,Iris-virginica 106 | 6.5,3.0,Iris-virginica 107 | 7.6,3.0,Iris-virginica 108 | 4.9,2.5,Iris-virginica 109 | 7.3,2.9,Iris-virginica 110 | 6.7,2.5,Iris-virginica 111 | 7.2,3.6,Iris-virginica 112 | 6.5,3.2,Iris-virginica 113 | 6.4,2.7,Iris-virginica 114 | 6.8,3.0,Iris-virginica 115 | 5.7,2.5,Iris-virginica 116 | 5.8,2.8,Iris-virginica 117 | 6.4,3.2,Iris-virginica 118 | 6.5,3.0,Iris-virginica 119 | 7.7,3.8,Iris-virginica 120 | 7.7,2.6,Iris-virginica 121 | 6.0,2.2,Iris-virginica 122 | 6.9,3.2,Iris-virginica 123 | 5.6,2.8,Iris-virginica 124 | 7.7,2.8,Iris-virginica 125 | 6.3,2.7,Iris-virginica 126 | 6.7,3.3,Iris-virginica 127 | 7.2,3.2,Iris-virginica 128 | 6.2,2.8,Iris-virginica 129 | 6.1,3.0,Iris-virginica 130 | 6.4,2.8,Iris-virginica 131 | 7.2,3.0,Iris-virginica 132 | 7.4,2.8,Iris-virginica 133 | 7.9,3.8,Iris-virginica 134 | 6.4,2.8,Iris-virginica 135 | 6.3,2.8,Iris-virginica 136 | 6.1,2.6,Iris-virginica 137 | 7.7,3.0,Iris-virginica 138 | 6.3,3.4,Iris-virginica 139 | 6.4,3.1,Iris-virginica 140 | 6.0,3.0,Iris-virginica 141 | 6.9,3.1,Iris-virginica 142 | 6.7,3.1,Iris-virginica 143 | 6.9,3.1,Iris-virginica 144 | 5.8,2.7,Iris-virginica 145 | 6.8,3.2,Iris-virginica 146 | 6.7,3.3,Iris-virginica 147 | 6.7,3.0,Iris-virginica 148 | 6.3,2.5,Iris-virginica 149 | 6.5,3.0,Iris-virginica 150 | 6.2,3.4,Iris-virginica 151 | 5.9,3.0,Iris-virginica -------------------------------------------------------------------------------- /test/fixtures/sepal.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "false": { 4 | "false": { 5 | "complexity": 0.05000000074505806, 6 | "loss": 0.06666658818721771, 7 | "name": "class", 8 | "prediction": "Iris-versicolor" 9 | }, 10 | "feature": 1, 11 | "name": "sepal_width", 12 | "reference": 2.95, 13 | "relation": ">=", 14 | "true": { 15 | "complexity": 0.05000000074505806, 16 | "loss": 0.0733330249786377, 17 | "name": "class", 18 | "prediction": "Iris-setosa" 19 | } 20 | }, 21 | "feature": 0, 22 | "name": "sepal_length", 23 | "reference": 6.15, 24 | "relation": ">=", 25 | "true": { 26 | "complexity": 0.05000000074505806, 27 | "loss": 0.10666656494140625, 28 | "name": "class", 29 | "prediction": "Iris-virginica" 30 | } 31 | } 32 | ] -------------------------------------------------------------------------------- /test/fixtures/sequences.csv: -------------------------------------------------------------------------------- 1 | redundant,optional,bin,enum,str,int,float,path 2 | x,null,yes,1,alpha,0,0.0,xnullyes1alpha00.0 3 | x,notnull,no,2,beta,1,0.123456789,xnotnullno2beta10.123456789 4 | x,nontrivial,yes,3,gamma,-2,-999999.9999999,xnullyes3gamma-2-999999.9999999 5 | x,,no,4,delta,3,192875.1024875,xno4delta3192875.1024875 6 | x,NA,yes,1,epsilon,-4,0,xNAyes1epsilon-40 7 | x,Null,no,2,zeta,5,918273645.564738291,xNullno2zeta5918273645.564738291 8 | x,NaN,yes,3,eta,6,1121121121211.3,xNaNyes3eta61121121121211.3 -------------------------------------------------------------------------------- /test/fixtures/sequences.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "false": { 4 | "false": { 5 | "false": { 6 | "false": { 7 | "complexity": 0.05000000074505806, 8 | "loss": 0.0, 9 | "name": "path", 10 | "prediction": "xNullno2zeta5918273645.564738291" 11 | }, 12 | "feature": 3, 13 | "name": "enum", 14 | "reference": 4, 15 | "relation": ">=", 16 | "true": { 17 | "complexity": 0.05000000074505806, 18 | "loss": 0.0, 19 | "name": "path", 20 | "prediction": "xno4delta3192875.1024875" 21 | } 22 | }, 23 | "feature": 2, 24 | "name": "bin", 25 | "reference": "yes", 26 | "relation": "==", 27 | "true": { 28 | "false": { 29 | "false": { 30 | "complexity": 0.05000000074505806, 31 | "loss": 0.0, 32 | "name": "path", 33 | "prediction": "xNAyes1epsilon-40" 34 | }, 35 | "feature": 4, 36 | "name": "str", 37 | "reference": "alpha", 38 | "relation": "==", 39 | "true": { 40 | "complexity": 0.05000000074505806, 41 | "loss": 0.0, 42 | "name": "path", 43 | "prediction": "xnullyes1alpha00.0" 44 | } 45 | }, 46 | "feature": 3, 47 | "name": "enum", 48 | "reference": 2, 49 | "relation": ">=", 50 | "true": { 51 | "complexity": 0.05000000074505806, 52 | "loss": 0.0, 53 | "name": "path", 54 | "prediction": "xNaNyes3eta61121121121211.3" 55 | } 56 | } 57 | }, 58 | "feature": 1, 59 | "name": "optional", 60 | "reference": "notnull", 61 | "relation": "==", 62 | "true": { 63 | "complexity": 0.05000000074505806, 64 | "loss": 0.0, 65 | "name": "path", 66 | "prediction": "xnotnullno2beta10.123456789" 67 | } 68 | }, 69 | "feature": 1, 70 | "name": "optional", 71 | "reference": "nontrivial", 72 | "relation": "==", 73 | "true": { 74 | "complexity": 0.05000000074505806, 75 | "loss": 0.0, 76 | "name": "path", 77 | "prediction": "xnullyes3gamma-2-999999.9999999" 78 | } 79 | } 80 | ] -------------------------------------------------------------------------------- /test/fixtures/small.csv: -------------------------------------------------------------------------------- 1 | f0,f1,f2,f3,c 2 | 0,0,0,0,1 3 | 0,0,0,1,0 4 | 0,0,1,0,1 5 | 0,0,1,1,1 6 | 0,1,0,0,0 7 | 0,1,0,1,1 8 | 0,1,1,0,1 9 | 0,1,1,1,0 10 | 1,0,0,0,1 11 | 1,0,0,1,1 12 | 1,0,1,0,0 13 | 1,0,1,1,1 14 | 1,1,0,0,1 15 | 1,1,0,1,0 16 | 1,1,1,0,1 17 | 1,1,1,1,1 -------------------------------------------------------------------------------- /test/fixtures/tree.csv: -------------------------------------------------------------------------------- 1 | b0,b1,b2,b3,str 2 | 0,0,0,0,b0000 3 | 0,0,0,1,b0001 4 | 0,0,1,0,b0010 5 | 0,0,1,1,b0011 6 | 0,1,0,0,b0100 7 | 0,1,0,1,b0101 8 | 0,1,1,0,b0110 9 | 0,1,1,1,b0111 10 | 1,0,0,0,b1000 11 | 1,0,0,1,b1001 12 | 1,0,1,0,b1010 13 | 1,0,1,1,b1011 14 | 1,1,0,0,b1100 15 | 1,1,0,1,b1101 16 | 1,1,1,0,b1110 17 | 1,1,1,1,b1111 -------------------------------------------------------------------------------- /test/fixtures/tree.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "false": { 4 | "false": { 5 | "false": { 6 | "false": { 7 | "complexity": 0.05000000074505806, 8 | "loss": 0.0, 9 | "name": "str", 10 | "prediction": "b0000" 11 | }, 12 | "feature": 3, 13 | "name": "b3", 14 | "reference": 1, 15 | "relation": "==", 16 | "true": { 17 | "complexity": 0.05000000074505806, 18 | "loss": 0.0, 19 | "name": "str", 20 | "prediction": "b0001" 21 | } 22 | }, 23 | "feature": 2, 24 | "name": "b2", 25 | "reference": 1, 26 | "relation": "==", 27 | "true": { 28 | "false": { 29 | "complexity": 0.05000000074505806, 30 | "loss": 0.0, 31 | "name": "str", 32 | "prediction": "b0010" 33 | }, 34 | "feature": 3, 35 | "name": "b3", 36 | "reference": 1, 37 | "relation": "==", 38 | "true": { 39 | "complexity": 0.05000000074505806, 40 | "loss": 0.0, 41 | "name": "str", 42 | "prediction": "b0011" 43 | } 44 | } 45 | }, 46 | "feature": 1, 47 | "name": "b1", 48 | "reference": 1, 49 | "relation": "==", 50 | "true": { 51 | "false": { 52 | "false": { 53 | "complexity": 0.05000000074505806, 54 | "loss": 0.0, 55 | "name": "str", 56 | "prediction": "b0100" 57 | }, 58 | "feature": 3, 59 | "name": "b3", 60 | "reference": 1, 61 | "relation": "==", 62 | "true": { 63 | "complexity": 0.05000000074505806, 64 | "loss": 0.0, 65 | "name": "str", 66 | "prediction": "b0101" 67 | } 68 | }, 69 | "feature": 2, 70 | "name": "b2", 71 | "reference": 1, 72 | "relation": "==", 73 | "true": { 74 | "false": { 75 | "complexity": 0.05000000074505806, 76 | "loss": 0.0, 77 | "name": "str", 78 | "prediction": "b0110" 79 | }, 80 | "feature": 3, 81 | "name": "b3", 82 | "reference": 1, 83 | "relation": "==", 84 | "true": { 85 | "complexity": 0.05000000074505806, 86 | "loss": 0.0, 87 | "name": "str", 88 | "prediction": "b0111" 89 | } 90 | } 91 | } 92 | }, 93 | "feature": 0, 94 | "name": "b0", 95 | "reference": 1, 96 | "relation": "==", 97 | "true": { 98 | "false": { 99 | "false": { 100 | "false": { 101 | "complexity": 0.05000000074505806, 102 | "loss": 0.0, 103 | "name": "str", 104 | "prediction": "b1000" 105 | }, 106 | "feature": 3, 107 | "name": "b3", 108 | "reference": 1, 109 | "relation": "==", 110 | "true": { 111 | "complexity": 0.05000000074505806, 112 | "loss": 0.0, 113 | "name": "str", 114 | "prediction": "b1001" 115 | } 116 | }, 117 | "feature": 2, 118 | "name": "b2", 119 | "reference": 1, 120 | "relation": "==", 121 | "true": { 122 | "false": { 123 | "complexity": 0.05000000074505806, 124 | "loss": 0.0, 125 | "name": "str", 126 | "prediction": "b1010" 127 | }, 128 | "feature": 3, 129 | "name": "b3", 130 | "reference": 1, 131 | "relation": "==", 132 | "true": { 133 | "complexity": 0.05000000074505806, 134 | "loss": 0.0, 135 | "name": "str", 136 | "prediction": "b1011" 137 | } 138 | } 139 | }, 140 | "feature": 1, 141 | "name": "b1", 142 | "reference": 1, 143 | "relation": "==", 144 | "true": { 145 | "false": { 146 | "false": { 147 | "complexity": 0.05000000074505806, 148 | "loss": 0.0, 149 | "name": "str", 150 | "prediction": "b1100" 151 | }, 152 | "feature": 3, 153 | "name": "b3", 154 | "reference": 1, 155 | "relation": "==", 156 | "true": { 157 | "complexity": 0.05000000074505806, 158 | "loss": 0.0, 159 | "name": "str", 160 | "prediction": "b1101" 161 | } 162 | }, 163 | "feature": 2, 164 | "name": "b2", 165 | "reference": 1, 166 | "relation": "==", 167 | "true": { 168 | "false": { 169 | "complexity": 0.05000000074505806, 170 | "loss": 0.0, 171 | "name": "str", 172 | "prediction": "b1110" 173 | }, 174 | "feature": 3, 175 | "name": "b3", 176 | "reference": 1, 177 | "relation": "==", 178 | "true": { 179 | "complexity": 0.05000000074505806, 180 | "loss": 0.0, 181 | "name": "str", 182 | "prediction": "b1111" 183 | } 184 | } 185 | } 186 | } 187 | } 188 | ] -------------------------------------------------------------------------------- /test/test.cpp: -------------------------------------------------------------------------------- 1 | #include "test.hpp" 2 | #include "test_bitmask.hpp" 3 | #include "test_index.hpp" 4 | #include "test_queue.hpp" 5 | #include "test_consistency.hpp" 6 | 7 | int main() { 8 | int failures = 0; 9 | std::map< std::string, int (*)(void) > units; 10 | units["Bitmask"] = test_bitmask; 11 | units["Index"] = test_index; 12 | units["Queue"] = test_queue; 13 | units["Consistency"] = test_consistency; 14 | 15 | for (std::map< std::string, int (*)(void) >::iterator iterator = units.begin(); iterator != units.end(); ++iterator ) { 16 | try { 17 | failures += run_tests(iterator -> first, iterator -> second); 18 | } catch (char const * exception) { 19 | std::cout << "\033[1;31m" << "Uncaught Exception in " << iterator -> first << " Tests" << "\033[0m" << std::endl; 20 | std::cout << "\033[1;31m" << "Uncaught Exception: " << exception << "\033[0m" << std::endl; 21 | failures += 1; 22 | } 23 | } 24 | 25 | if (failures == 0) { 26 | std::cout << "\033[1;32m" << "All Tests Passed" << "\033[0m" << std::endl; 27 | return 0; 28 | } else { 29 | std::cout << "\033[1;31m" << failures << " Tests Failed" << "\033[0m" << std::endl; 30 | return 1; 31 | } 32 | } -------------------------------------------------------------------------------- /test/test.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TEST_H 2 | #define TEST_H 3 | 4 | #define TEST_VERBOSE false 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | //#include 15 | 16 | #include 17 | 18 | #include "../include/json/json.hpp" 19 | 20 | template 21 | std::string error_message(T expectation, T reality, std::string message, std::string context = "") { 22 | std::stringstream error_message; 23 | error_message << "\033[1;31m"; 24 | if (context != "") { 25 | error_message << context << "\n "; 26 | } 27 | error_message << message << " Expectation: " << expectation << " Reality: " << reality; 28 | error_message << "\033[0m"; 29 | return error_message.str(); 30 | } 31 | 32 | std::string error_message(std::string message, std::string context = "") { 33 | std::stringstream error_message; 34 | error_message << "\033[1;31m"; 35 | if (context != "") { 36 | error_message << context << " :: "; 37 | } 38 | error_message << message; 39 | error_message << "\033[0m"; 40 | return error_message.str(); 41 | } 42 | 43 | int expect(bool assertion, std::string message, std::string context = "") { 44 | if (!assertion) { 45 | std::cout << error_message(message, context) << std::endl; 46 | return 1; 47 | } else { 48 | return 0; 49 | } 50 | } 51 | 52 | int expect(int expectation, int reality, std::string message, std::string context = "") { 53 | if (expectation != reality) { 54 | std::cout << error_message(expectation, reality, message, context) << std::endl; 55 | return 1; 56 | } else { 57 | return 0; 58 | } 59 | } 60 | 61 | int expect(float expectation, float reality, std::string message, std::string context = "") { 62 | float const epsilon = std::numeric_limits::epsilon(); 63 | if (std::abs(expectation - reality) >= epsilon) { 64 | std::cout << std::setprecision(15) << error_message(expectation, reality, message, context) << std::endl; 65 | return 1; 66 | } else { 67 | return 0; 68 | } 69 | } 70 | 71 | int expect_fuzzy(float expectation, float reality, std::string message, std::string context = "", unsigned int fuzziness = 2) { 72 | float const epsilon = fuzziness * std::numeric_limits::epsilon(); 73 | if (std::abs(expectation - reality) >= epsilon) { 74 | std::cout << std::setprecision(15) << error_message(expectation, reality, message, context) << std::endl; 75 | return 1; 76 | } else { 77 | return 0; 78 | } 79 | } 80 | 81 | int expect(std::string expectation, std::string reality, std::string message, std::string context = "") { 82 | if (expectation != reality) { 83 | std::cout << error_message(expectation, reality, message, context) << std::endl; 84 | return 1; 85 | } else { 86 | return 0; 87 | } 88 | } 89 | 90 | template 91 | int expect(T const & expectation, T const & reality, std::string message, std::string context = "") { 92 | if (expectation != reality) { 93 | std::cout << error_message(expectation, reality, message, context) << std::endl; 94 | return 1; 95 | } else { 96 | return 0; 97 | } 98 | } 99 | 100 | void pass(std::string message) { 101 | std::cout << "\033[1;32m" << message << "\033[0m" << std::endl; 102 | } 103 | 104 | void fail(std::string message) { 105 | std::cout << "\033[1;31m" << message << "\033[0m" << std::endl; 106 | } 107 | 108 | int run_tests(std::string unit_name, int (*tests)(void)) { 109 | int failures = tests(); 110 | if (failures == 0) { 111 | std::cout << "\033[1;32m" << unit_name << " Tests Passed" << "\033[0m" << std::endl; 112 | } else { 113 | std::cout << "\033[1;31m" << failures << " " << unit_name << " Tests Failed" << "\033[0m" << std::endl; 114 | } 115 | return failures; 116 | } 117 | 118 | 119 | #endif 120 | -------------------------------------------------------------------------------- /test/test_consistency.hpp: -------------------------------------------------------------------------------- 1 | #include "../src/gosdt.hpp" 2 | 3 | int test_consistency(void) { 4 | int failures = 0; 5 | 6 | { 7 | std::string context = "Test Consistency test/fixtures/binary_sepal"; 8 | std::ifstream data("test/fixtures/binary_sepal.csv"); 9 | std::ifstream expectation("test/fixtures/binary_sepal.json"); 10 | std::stringstream buffer; 11 | buffer << expectation.rdbuf(); 12 | 13 | GOSDT model; 14 | std::string result; 15 | model.fit(data, result); 16 | 17 | failures += expect(buffer.str(), result, "Consistency Test test/fixtures/binary_sepal", context); 18 | } 19 | 20 | { 21 | std::string context = "Test Consistency test/fixtures/dataset"; 22 | std::ifstream data("test/fixtures/dataset.csv"); 23 | std::ifstream expectation("test/fixtures/dataset.json"); 24 | std::stringstream buffer; 25 | buffer << expectation.rdbuf(); 26 | 27 | GOSDT model; 28 | std::string result; 29 | model.fit(data, result); 30 | 31 | failures += expect(buffer.str(), result, "Consistency Test test/fixtures/dataset", context); 32 | } 33 | 34 | { 35 | std::string context = "Test Consistency test/fixtures/sequences"; 36 | std::ifstream data("test/fixtures/sequences.csv"); 37 | std::ifstream expectation("test/fixtures/sequences.json"); 38 | std::stringstream buffer; 39 | buffer << expectation.rdbuf(); 40 | 41 | GOSDT model; 42 | std::string result; 43 | model.fit(data, result); 44 | 45 | failures += expect(buffer.str(), result, "Consistency Test test/fixtures/sequences", context); 46 | } 47 | 48 | { 49 | std::string context = "Test Consistency test/fixtures/tree"; 50 | std::ifstream data("test/fixtures/tree.csv"); 51 | std::ifstream expectation("test/fixtures/tree.json"); 52 | std::stringstream buffer; 53 | buffer << expectation.rdbuf(); 54 | 55 | GOSDT model; 56 | std::string result; 57 | model.fit(data, result); 58 | 59 | failures += expect(buffer.str(), result, "Consistency Test test/fixtures/tree", context); 60 | } 61 | 62 | return failures; 63 | } -------------------------------------------------------------------------------- /test/test_index.hpp: -------------------------------------------------------------------------------- 1 | 2 | #include "../src/index.hpp" 3 | 4 | int test_index(void) { 5 | int failures = 0; 6 | 7 | std::vector< std::vector< float > > data; 8 | for (unsigned int i = 0; i < 10; ++i) { 9 | std::vector< float > row; 10 | for (unsigned int j = 0; j < 10; ++j) { 11 | row.emplace_back(i * j * 0.01); 12 | } 13 | data.emplace_back(row); 14 | } 15 | Index index(data); 16 | 17 | unsigned int size = 10; 18 | 19 | { 20 | Bitmask mask(size); // 10101010101...... 21 | for (unsigned int i = 0; i < size; i += 2) { mask.set(i, true); } 22 | 23 | std::stringstream context; 24 | context << "Test Bitmask Sum with Mask: " << mask.to_string(); 25 | 26 | std::vector accumulator(size, 0.0); 27 | std::vector expectation(size, 0.0); 28 | 29 | for (unsigned int i = 0; i < 10; ++i) { 30 | if (!mask.get(i)) { continue; } 31 | for (unsigned int j = 0; j < 10; ++j) { 32 | expectation[j] += data[i][j]; 33 | } 34 | } 35 | 36 | index.sum(mask, accumulator.data()); 37 | 38 | for (unsigned int j = 0; j < 10; ++j) { 39 | std::stringstream specifier; 40 | specifier << "Index::sum element " << j << " is incorrect"; 41 | failures += expect_fuzzy(expectation[j], accumulator[j], specifier.str(), context.str(), 10); 42 | } 43 | } 44 | 45 | { 46 | Bitmask mask(size); // 010101010101...... 47 | for (unsigned int i = 1; i < size; i += 2) { mask.set(i, true); } 48 | 49 | std::stringstream context; 50 | context << "Test Bitmask Sum with Mask: " << mask.to_string(); 51 | 52 | std::vector accumulator(size, 0.0); 53 | std::vector expectation(size, 0.0); 54 | 55 | for (unsigned int i = 0; i < 10; ++i) { 56 | if (!mask.get(i)) { continue; } 57 | for (unsigned int j = 0; j < 10; ++j) { 58 | expectation[j] += data[i][j]; 59 | } 60 | } 61 | 62 | index.sum(mask, accumulator.data()); 63 | 64 | for (unsigned int j = 0; j < 10; ++j) { 65 | std::stringstream specifier; 66 | specifier << "Index::sum element " << j << " is incorrect"; 67 | failures += expect_fuzzy(expectation[j], accumulator[j], specifier.str(), context.str(), 10); 68 | } 69 | } 70 | 71 | { 72 | Bitmask mask(size); // 111111000000...... 73 | for (unsigned int i = 0; i < size/2; i += 1) { mask.set(i, true); } 74 | 75 | std::stringstream context; 76 | context << "Test Bitmask Sum with Mask: " << mask.to_string(); 77 | 78 | std::vector accumulator(size, 0.0); 79 | std::vector expectation(size, 0.0); 80 | 81 | for (unsigned int i = 0; i < 10; ++i) { 82 | if (!mask.get(i)) { continue; } 83 | for (unsigned int j = 0; j < 10; ++j) { 84 | expectation[j] += data[i][j]; 85 | } 86 | } 87 | 88 | index.sum(mask, accumulator.data()); 89 | 90 | for (unsigned int j = 0; j < 10; ++j) { 91 | std::stringstream specifier; 92 | specifier << "Index::sum element " << j << " is incorrect"; 93 | failures += expect_fuzzy(expectation[j], accumulator[j], specifier.str(), context.str(), 10); 94 | } 95 | } 96 | 97 | { 98 | Bitmask mask(size); // 00000111111...... 99 | for (unsigned int i = size/2; i < size; i += 1) { mask.set(i, true); } 100 | 101 | std::stringstream context; 102 | context << "Test Bitmask Sum with Mask: " << mask.to_string(); 103 | 104 | std::vector accumulator(size, 0.0); 105 | std::vector expectation(size, 0.0); 106 | 107 | for (unsigned int i = 0; i < 10; ++i) { 108 | if (!mask.get(i)) { continue; } 109 | for (unsigned int j = 0; j < 10; ++j) { 110 | expectation[j] += data[i][j]; 111 | } 112 | } 113 | 114 | index.sum(mask, accumulator.data()); 115 | 116 | for (unsigned int j = 0; j < 10; ++j) { 117 | std::stringstream specifier; 118 | specifier << "Index::sum element " << j << " is incorrect"; 119 | failures += expect_fuzzy(expectation[j], accumulator[j], specifier.str(), context.str(), 10); 120 | } 121 | } 122 | 123 | return failures; 124 | } -------------------------------------------------------------------------------- /test/test_queue.hpp: -------------------------------------------------------------------------------- 1 | #include "../src/queue.hpp" 2 | 3 | int test_queue(void) { 4 | int failures = 0; 5 | 6 | { 7 | std::string context = "Test Queue Ordering"; 8 | // float equity_bias = 0.5; 9 | // Queue queue; 10 | // Message message; 11 | 12 | // Bitmask a(2); 13 | // a.set(0, false); 14 | // a.set(1, false); 15 | 16 | // Bitmask b(2); 17 | // b.set(0, true); 18 | // b.set(1, false); 19 | 20 | // Bitmask c(2); 21 | // c.set(0, false); 22 | // c.set(1, true); 23 | 24 | // Bitmask d(2); 25 | // d.set(0, true); 26 | // d.set(1, true); 27 | 28 | // // Messages 29 | // // a = (1, 00) 30 | // // b = (2, 10) 31 | // // c = (3, 01) 32 | // // d = (4, 11) 33 | 34 | // failures += expect(0, queue.size(), "Queue::size reports incorrect size", context); 35 | // queue.push(a, 1, 0.0, 0.5, 0.5, 0.7); // 4th 36 | // failures += expect(1, queue.size(), "Queue::size reports incorrect size", context); 37 | // queue.push(b, 2, 0.0, 0.5, 0.7, 0.5); // 3rd 38 | // failures += expect(2, queue.size(), "Queue::size reports incorrect size", context); 39 | // queue.push(c, 3, 0.0, 0.6, 0.3, 0.2); // 1st 40 | // failures += expect(3, queue.size(), "Queue::size reports incorrect size", context); 41 | // queue.push(d, 4, 0.0, 0.6, 0.2, 0.3); // 2nd 42 | // failures += expect(4, queue.size(), "Queue::size reports incorrect size", context); 43 | 44 | // // Expected Ordering: c, d, b, a 45 | // // Erroneous Ordering: d, a, b, c 46 | 47 | // failures += expect(false, queue.empty(), "Queue::empty reports extraneous message", context); 48 | 49 | // char code; 50 | // float weight; 51 | // Bitmask content(2); 52 | 53 | // queue.pop(content, & code, & weight); 54 | // failures += expect(3, code, "1st Message Context Mismatch", context); 55 | // failures += expect(c.to_string(), content.to_string(), "1st Message Content Mismatch", context); 56 | // failures += expect(3, queue.size(), "Queue::size reports incorrect size", context); 57 | 58 | // queue.pop(content, & code, & weight); 59 | // failures += expect(4, code, "2nd Message Context Mismatch", context); 60 | // failures += expect(d.to_string(), content.to_string(), "2nd Message Content Mismatch", context); 61 | // failures += expect(2, queue.size(), "Queue::size reports incorrect size", context); 62 | 63 | // queue.pop(content, & code, & weight); 64 | // failures += expect(2, code, "3rd Message Context Mismatch", context); 65 | // failures += expect(b.to_string(), content.to_string(), "3rd Message Content Mismatch", context); 66 | // failures += expect(1, queue.size(), "Queue::size reports incorrect size", context); 67 | 68 | // queue.pop(content, & code, & weight); 69 | // failures += expect(1, code, "4th Message Context Mismatch", context); 70 | // failures += expect(a.to_string(), content.to_string(), "4th Message Content Mismatch", context); 71 | // failures += expect(0, queue.size(), "Queue::size reports incorrect size", context); 72 | 73 | // failures += expect(queue.empty(), "Queue::empty reports extraneous message", context); 74 | } 75 | 76 | return failures; 77 | } -------------------------------------------------------------------------------- /test/test_trie.hpp: -------------------------------------------------------------------------------- 1 | #include "../src/trie.hpp" 2 | 3 | int test_trie(void) { 4 | int failures = 0; 5 | 6 | { 7 | std::string context = "Test Queue Ordering"; 8 | 9 | 10 | double c = 0.01; 11 | bool calculate_size = false; 12 | int ablation = 0; 13 | int nrules, nsamples, nlabels, nsamples_label; 14 | rule_t *rules, *labels; 15 | rule_t *meta; 16 | char const *type = "node"; 17 | Trie* tree = new Trie(calculate_size, type); 18 | tree->insert_root(); 19 | Node* root = tree->root(); 20 | 21 | std::vector id1 {1, 2}; 22 | Node* child1 = tree->construct_node(id1, root); 23 | tree->insert(child1); 24 | 25 | std::vector id2 {2, 3, 4, 5}; 26 | Node* child2 = tree->construct_node(id2, child1); 27 | tree->insert(child2); 28 | 29 | std::vector id3 {1, 3, 5, 6}; 30 | Node* child3 = tree->construct_node(id3, child1); 31 | tree->insert(child3); 32 | 33 | tracking_vector, DataStruct::Tree> prefix {{1, 2}, {2,3,4,5}}; 34 | failures += expect(child2 == tree->check_prefix(prefix), "test1", context); 35 | 36 | std::string serialization; 37 | tree->serialize(serialization, 2); 38 | std::cout << serialization << std::endl; 39 | 40 | // float equity_bias = 0.5; 41 | // Queue queue; 42 | // Message message; 43 | 44 | // Bitmask a(2); 45 | // a.set(0, false); 46 | // a.set(1, false); 47 | 48 | // Bitmask b(2); 49 | // b.set(0, true); 50 | // b.set(1, false); 51 | 52 | // Bitmask c(2); 53 | // c.set(0, false); 54 | // c.set(1, true); 55 | 56 | // Bitmask d(2); 57 | // d.set(0, true); 58 | // d.set(1, true); 59 | 60 | // // Messages 61 | // // a = (1, 00) 62 | // // b = (2, 10) 63 | // // c = (3, 01) 64 | // // d = (4, 11) 65 | 66 | // failures += expect(0, queue.size(), "Queue::size reports incorrect size", context); 67 | // queue.push(a, 1, 0.0, 0.5, 0.5, 0.7); // 4th 68 | // failures += expect(1, queue.size(), "Queue::size reports incorrect size", context); 69 | // queue.push(b, 2, 0.0, 0.5, 0.7, 0.5); // 3rd 70 | // failures += expect(2, queue.size(), "Queue::size reports incorrect size", context); 71 | // queue.push(c, 3, 0.0, 0.6, 0.3, 0.2); // 1st 72 | // failures += expect(3, queue.size(), "Queue::size reports incorrect size", context); 73 | // queue.push(d, 4, 0.0, 0.6, 0.2, 0.3); // 2nd 74 | // failures += expect(4, queue.size(), "Queue::size reports incorrect size", context); 75 | 76 | // // Expected Ordering: c, d, b, a 77 | // // Erroneous Ordering: d, a, b, c 78 | 79 | // failures += expect(false, queue.empty(), "Queue::empty reports extraneous message", context); 80 | 81 | // char code; 82 | // float weight; 83 | // Bitmask content(2); 84 | 85 | // queue.pop(content, & code, & weight); 86 | // failures += expect(3, code, "1st Message Context Mismatch", context); 87 | // failures += expect(c.to_string(), content.to_string(), "1st Message Content Mismatch", context); 88 | // failures += expect(3, queue.size(), "Queue::size reports incorrect size", context); 89 | 90 | // queue.pop(content, & code, & weight); 91 | // failures += expect(4, code, "2nd Message Context Mismatch", context); 92 | // failures += expect(d.to_string(), content.to_string(), "2nd Message Content Mismatch", context); 93 | // failures += expect(2, queue.size(), "Queue::size reports incorrect size", context); 94 | 95 | // queue.pop(content, & code, & weight); 96 | // failures += expect(2, code, "3rd Message Context Mismatch", context); 97 | // failures += expect(b.to_string(), content.to_string(), "3rd Message Content Mismatch", context); 98 | // failures += expect(1, queue.size(), "Queue::size reports incorrect size", context); 99 | 100 | // queue.pop(content, & code, & weight); 101 | // failures += expect(1, code, "4th Message Context Mismatch", context); 102 | // failures += expect(a.to_string(), content.to_string(), "4th Message Content Mismatch", context); 103 | // failures += expect(0, queue.size(), "Queue::size reports incorrect size", context); 104 | 105 | // failures += expect(queue.empty(), "Queue::empty reports extraneous message", context); 106 | } 107 | 108 | return failures; 109 | } -------------------------------------------------------------------------------- /treefarms/__init__.py: -------------------------------------------------------------------------------- 1 | # We're just going to bring these to the front 2 | # This is Tynan guessing what 3 | from treefarms.model.treefarms import TREEFARMS 4 | from treefarms.model.threshold_guess import get_thresholds -------------------------------------------------------------------------------- /treefarms/example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import time 4 | import pathlib 5 | from treefarms import TREEFARMS 6 | 7 | # read the dataset 8 | df = pd.read_csv("experiments/datasets/compas/binned.csv") 9 | X, y = df.iloc[:, :-1], df.iloc[:, -1] 10 | h = df.columns[:-1] 11 | 12 | config = { 13 | "regularization": 0.01, # regularization penalizes the tree with more leaves. We recommend to set it to relative high value to find a sparse tree. 14 | "rashomon_bound_multiplier": 0.05, # rashomon bound multiplier indicates how large of a Rashomon set would you like to get 15 | } 16 | 17 | model = TREEFARMS(config) 18 | 19 | model.fit(X, y) 20 | 21 | first_tree = model[0] 22 | 23 | print("evaluating the first model in the Rashomon set", flush=True) 24 | 25 | # get the results 26 | train_acc = first_tree.score(X, y) 27 | n_leaves = first_tree.leaves() 28 | n_nodes = first_tree.nodes() 29 | 30 | print("Training accuracy: {}".format(train_acc)) 31 | print("# of leaves: {}".format(n_leaves)) 32 | print(first_tree) 33 | -------------------------------------------------------------------------------- /treefarms/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-systopia/treeFarms/5508f6063cbcac5124e164c72aafa9afc31f683d/treefarms/model/__init__.py -------------------------------------------------------------------------------- /treefarms/model/imbalance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-systopia/treeFarms/5508f6063cbcac5124e164c72aafa9afc31f683d/treefarms/model/imbalance/__init__.py -------------------------------------------------------------------------------- /treefarms/model/threshold_guess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import json 4 | import time 5 | import random 6 | import sys 7 | import os 8 | from queue import Queue 9 | import pathlib 10 | 11 | from math import ceil 12 | from sklearn.model_selection import KFold, cross_val_score, train_test_split 13 | from sklearn.ensemble import GradientBoostingClassifier 14 | from sklearn import metrics 15 | 16 | 17 | # fit the tree using gradient boosted classifier 18 | def fit_boosted_tree(X, y, n_est=10, lr=0.1, d=1): 19 | clf = GradientBoostingClassifier(loss='log_loss', learning_rate=lr, n_estimators=n_est, max_depth=d, 20 | random_state=42) 21 | clf.fit(X, y) 22 | out = clf.score(X, y) 23 | return clf, out 24 | 25 | 26 | # perform cut on the dataset 27 | def cut(X, ts): 28 | df = X.copy() 29 | colnames = X.columns 30 | for j in range(len(ts)): 31 | for s in range(len(ts[j])): 32 | X[colnames[j]+'<='+str(ts[j][s])] = 1 33 | k = df[colnames[j]] > ts[j][s] 34 | X.loc[k, colnames[j]+'<='+str(ts[j][s])] = 0 35 | X = X.drop(colnames[j], axis=1) 36 | return X 37 | 38 | 39 | # compute the thresholds 40 | def get_thresholds(X, y, n_est, lr, d, backselect=True): 41 | # got a complaint here... 42 | y = np.ravel(y) 43 | # X is a dataframe 44 | clf, out = fit_boosted_tree(X, y, n_est, lr, d) 45 | #print('acc:', out, 'acc cv:', score.mean()) 46 | thresholds = [] 47 | for j in range(X.shape[1]): 48 | tj = np.array([]) 49 | for i in range(len(clf.estimators_)): 50 | f = clf.estimators_[i,0].tree_.feature 51 | t = clf.estimators_[i,0].tree_.threshold 52 | tj = np.append(tj, t[f==j]) 53 | tj = np.unique(tj) 54 | thresholds.append(tj.tolist()) 55 | 56 | X_new = cut(X, thresholds) 57 | clf1, out1 = fit_boosted_tree(X_new, y, n_est, lr, d) 58 | #print('acc','1:', out1, 'acc1 cv:', scorep.mean()) 59 | 60 | outp = 1 61 | Xp = X_new.copy() 62 | clfp = clf1 63 | itr=0 64 | if backselect: 65 | while outp >= out1 and itr < X_new.shape[1]-1: 66 | vi = clfp.feature_importances_ 67 | if vi.size > 0: 68 | c = Xp.columns 69 | i = np.argmin(vi) 70 | Xp = Xp.drop(c[i], axis=1) 71 | clfp, outp = fit_boosted_tree(Xp, y, n_est, lr, d) 72 | itr += 1 73 | else: 74 | break 75 | Xp[c[i]] = X_new[c[i]] 76 | #_, _ = fit_boosted_tree(Xp, y, n_est, lr, d) 77 | 78 | h = Xp.columns 79 | #print('features:', h) 80 | return Xp, thresholds, h 81 | 82 | # compute the thresholds 83 | def compute_thresholds(X, y, n_est, max_depth) : 84 | # n_est, max_depth: GBDT parameters 85 | # set LR to 0.1 86 | lr = 0.1 87 | start = time.perf_counter() 88 | X, thresholds, header = get_thresholds(X, y, n_est, lr, max_depth, backselect=True) 89 | guess_time = time.perf_counter()-start 90 | 91 | return X, thresholds, header, guess_time 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /treefarms/tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import numpy as np\n", 11 | "import pathlib\n", 12 | "from sklearn.ensemble import GradientBoostingClassifier\n", 13 | "from sklearn.model_selection import train_test_split\n", 14 | "from treefarms.model.threshold_guess import compute_thresholds, cut\n", 15 | "from treefarms import TREEFARMS\n", 16 | "from treefarms.model.model_set import ModelSetContainer" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "# Example\n", 24 | "\n", 25 | "In this example, we run TREEFARMS on COMPAS, a recidivism dataset. The COMPAS dataset contains 6907 samples and 7 continuous features. We visualize the Rashomon set using `timbertrek` package, as well as show the way to obtain individual trees from the Rashomon set.\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# read the dataset\n", 35 | "df = pd.read_csv(\"../experiments/datasets/compas/binned.csv\")\n", 36 | "X, y = df.iloc[:, :-1], df.iloc[:, -1]\n", 37 | "h = df.columns[:-1]\n", 38 | "df\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "We fit the Rashomon set on the COMPAS dataset.\n" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# train TREEFARMS model\n", 55 | "config = {\n", 56 | " \"regularization\": 0.01, # regularization penalizes the tree with more leaves. We recommend to set it to relative high value to find a sparse tree.\n", 57 | " \"rashomon_bound_multiplier\": 0.05, # rashomon bound multiplier indicates how large of a Rashomon set would you like to get\n", 58 | "}\n", 59 | "\n", 60 | "model = TREEFARMS(config)\n", 61 | "\n", 62 | "model.fit(X, y)\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "We then visualize the Rashomon set. " 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# # TREEFARMS will attempt to obtain feature names from the DataFrame columns.\n", 79 | "# # However, it is also possible to manually set this value, such as the\n", 80 | "# # commented code snippet below\n", 81 | "\n", 82 | "# feature_names = df.columns\n", 83 | "\n", 84 | "# feature_description = {\n", 85 | "# \"sex\": {\"info\": \"Sex\", \"type\": \"is\", \"short\": \"Sex\"},\n", 86 | "# \"age\": {\"info\": \"Age\", \"type\": \"count\", \"short\": \"Age\"},\n", 87 | "# \"juvenile-felonies\": {\n", 88 | "# \"info\": \"Number of juvenile felonies\",\n", 89 | "# \"type\": \"count\",\n", 90 | "# \"short\": \"Juv felony\",\n", 91 | "# },\n", 92 | "# \"juvenile-misdemeanors\": {\n", 93 | "# \"info\": \"Number of juvenile misdemeanors\",\n", 94 | "# \"type\": \"count\",\n", 95 | "# \"short\": \"Juv misdemeanor\",\n", 96 | "# },\n", 97 | "# \"juvenile-crimes\": {\n", 98 | "# \"info\": \"Number of juvenile crimes\",\n", 99 | "# \"type\": \"count\",\n", 100 | "# \"short\": \"Juv crime\",\n", 101 | "# },\n", 102 | "# \"priors\": {\n", 103 | "# \"info\": \"Number of prior crimes\",\n", 104 | "# \"type\": \"count\",\n", 105 | "# \"short\": \"Prior crime\",\n", 106 | "# },\n", 107 | "# \"recidivate-within-two-years\": {\n", 108 | "# \"info\": \"Has recidivated within two years\",\n", 109 | "# \"type\": \"yes\",\n", 110 | "# \"short\": \"Recidivated\",\n", 111 | "# },\n", 112 | "# }\n", 113 | "# model.visualize(feature_names, feature_description)\n", 114 | "\n", 115 | "model.visualize()\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "It is also possible to obtain individual trees from the Rashomon set. The following cell demonstrates getting the accuracy of the first tree in the Rashomon set as well as printing out its structure." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "first_tree = model[0]\n", 132 | "print(f'The accuracy of the first tree on the data is: {first_tree.score(X, y)}')\n", 133 | "print(model[0])" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "Thank you for reading our tutorial. Please do try out our methods with different parameters and datasets. Happy tree farming!\n" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3.8.13 ('gosdt')", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.8.13" 161 | }, 162 | "vscode": { 163 | "interpreter": { 164 | "hash": "4c3c3f64da95a59853a098320396a9255cc0464439ce0191cb810800c64ab010" 165 | } 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 2 170 | } 171 | --------------------------------------------------------------------------------