├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .travis.yml ├── README.md ├── appveyor.yml ├── data └── test │ ├── test_data_medium.txt │ └── test_data_small.txt ├── experiments ├── __init__.py ├── dataloader │ ├── __init__.py │ ├── cp4im.py │ └── synth.py ├── experiments │ ├── __init__.py │ ├── fig1 │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── plotter.py │ │ └── runner.py │ ├── fig2 │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── plotter.py │ │ └── runner.py │ └── fig3 │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── plotter.py │ │ └── runner.py ├── globals.py └── searchers │ ├── __init__.py │ ├── binary_classification_tree.py │ ├── cart.py │ ├── dl85.py │ ├── gosdt.py │ ├── maptree.py │ ├── mcmc.py │ ├── smc.py │ └── tree_smc │ ├── COPYING │ ├── README │ ├── __init__.py │ ├── process_data │ ├── __init__.py │ ├── madelon │ │ ├── README │ │ ├── __init__.py │ │ ├── download.sh │ │ ├── process_data_feat_challenge.py │ │ └── process_data_feat_challenge.py.bak │ ├── magic04 │ │ ├── README │ │ ├── __init__.py │ │ ├── download.sh │ │ ├── process_data.py │ │ └── process_data.py.bak │ └── pendigits │ │ ├── README │ │ ├── __init__.py │ │ ├── download.sh │ │ ├── process_data.py │ │ └── process_data.py.bak │ └── src │ ├── __init__.py │ ├── bdtmcmc.py │ ├── bdtmcmc.py.bak │ ├── bdtsmc.py │ ├── bdtsmc.py.bak │ ├── tree_utils.py │ ├── tree_utils.py.bak │ ├── utils.py │ └── utils.py.bak ├── maptree ├── CMakeLists.txt ├── Doxyfile.in ├── app │ └── main.cpp ├── cmake │ ├── CodeCoverage.cmake │ ├── Colors.cmake │ ├── ConfigSafeGuards.cmake │ ├── Doctest.cmake │ ├── Documentation.cmake │ ├── LTO.cmake │ └── Warnings.cmake ├── include │ ├── cache │ │ ├── approx_bitset_cache.h │ │ └── base_cache.h │ ├── constants.h │ ├── data │ │ ├── binary_data_loader.h │ │ ├── bitset.h │ │ ├── data_manager.h │ │ ├── fixed_bitset.h │ │ ├── rnumber.h │ │ └── split.h │ ├── posterior │ │ ├── tree_likelihood.h │ │ └── tree_prior.h │ ├── search │ │ ├── base_map_search.h │ │ └── befs_map_search.h │ ├── solution │ │ ├── decision_tree.h │ │ └── solution.h │ └── subproblem.h ├── src │ ├── cache │ │ └── approx_bitset_cache.cpp │ ├── data │ │ ├── binary_data_loader.cpp │ │ ├── bitset.cpp │ │ ├── data_manager.cpp │ │ ├── fixed_bitset.cpp │ │ └── rnumber.cpp │ ├── posterior │ │ ├── tree_likelihood.cpp │ │ └── tree_prior.cpp │ ├── python_bindings │ │ └── search_bindings.cpp │ ├── search │ │ ├── base_map_search.cpp │ │ └── befs_map_search.cpp │ ├── solution │ │ └── decision_tree.cpp │ └── subproblem.cpp └── tests │ ├── CMakeLists.txt │ ├── bcart │ └── test_bcart_utils.cpp │ ├── main.cpp │ ├── search │ └── test_search.cpp │ └── test_data │ ├── test_data_medium.txt │ └── test_data_small.txt ├── plot_results.py ├── pyproject.toml ├── requirements.txt ├── requirements_m1mac.txt ├── run_experiment.py ├── script.slurm ├── setup.py └── setup_data.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | build-and-test: 9 | 10 | name: ${{ matrix.toolchain }} 11 | runs-on: ${{ matrix.os }} 12 | 13 | strategy: 14 | matrix: 15 | toolchain: 16 | - linux-gcc 17 | - macos-clang 18 | - windows-msvc 19 | 20 | configuration: 21 | - Debug 22 | 23 | include: 24 | - toolchain: linux-gcc 25 | os: ubuntu-latest 26 | compiler: gcc 27 | 28 | - toolchain: macos-clang 29 | os: macos-latest 30 | compiler: clang 31 | 32 | - toolchain: windows-msvc 33 | os: windows-latest 34 | compiler: msvc 35 | 36 | steps: 37 | - name: Checkout Code 38 | uses: actions/checkout@v2 39 | 40 | - name: Configure (${{ matrix.configuration }}) 41 | run: cmake -S . -Bbuild -DCMAKE_BUILD_TYPE=${{ matrix.configuration }} 42 | 43 | - name: Build with ${{ matrix.compiler }} 44 | run: cmake --build build --config ${{ matrix.configuration }} 45 | 46 | - name: Test 47 | working-directory: build 48 | env: 49 | CTEST_OUTPUT_ON_FAILURE: 1 50 | run: ctest -C ${{ matrix.configuration }} 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | *.x 34 | 35 | # Build 36 | .idea/ 37 | .vscode/ 38 | build/ 39 | cmake-build-debug 40 | */cmake-build-debug 41 | */cmake-build-release 42 | */cmake-build-relwithdebinfo 43 | *.egg-info/ 44 | .eggs/ 45 | 46 | data/ 47 | !maptree/include/data 48 | !maptree/src/data 49 | experiments/results/ 50 | figures/ 51 | **/__pycache__/ 52 | 53 | pybind11 54 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: cpp 2 | dist: xenial 3 | notifications: 4 | email: false 5 | 6 | # Define builds on mulitple OS/compiler combinations. 7 | # Feel free to add/remove entries from this list. 8 | matrix: 9 | include: 10 | - os: linux 11 | addons: 12 | apt: 13 | sources: 14 | - ubuntu-toolchain-r-test 15 | packages: 16 | - lcov 17 | - g++-7 18 | env: 19 | - MATRIX_EVAL="CXX_COMPILER=g++-7; sudo update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-7 90" 20 | 21 | - os: osx 22 | osx_image: xcode10.1 23 | addons: 24 | homebrew: 25 | packages: 26 | - lcov 27 | update: true 28 | 29 | env: 30 | - MATRIX_EVAL="CXX_COMPILER=clang++" 31 | 32 | 33 | before_install: 34 | - eval "${MATRIX_EVAL}" 35 | - PARENTDIR=$(pwd) 36 | - mkdir $PARENTDIR/build 37 | 38 | install: 39 | - cd $PARENTDIR/build 40 | - cmake $PARENTDIR -DCMAKE_BUILD_TYPE=Coverage -DCMAKE_CXX_COMPILER=$CXX_COMPILER 41 | - make 42 | 43 | script: 44 | - make coverage 45 | 46 | after_success: 47 | - cd $PARENTDIR/build 48 | - lcov --list coverage_out.info.cleaned # Show test report in travis log. 49 | # Install coverals gem for uploading coverage to coveralls. 50 | - gem install coveralls-lcov 51 | - coveralls-lcov coverage_out.info.cleaned # uploads to coveralls 52 | - bash <(curl -s https://codecov.io/bash) -f coverage_out.info.cleaned || echo "Codecov did not collect coverage reports" 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAPTree 2 | 3 | This repository contains all of the code to reproduce the experiments from 4 | '_MAPTree: Beating "Optimal" Decision Trees with Bayesian Decision Trees_' 5 | by Colin Sullivan*, Mo Tiwari*, and Sebastian Thrun. 6 | 7 | Our main algorithm is written in C++ and called from Python via Python bindings. 8 | 9 | Below, we have instructions on how to reproduce all of our results. 10 | If you have a question about our code, please submit a Github issue. 11 | 12 | ## Set Up The Environment 13 | 14 | Our code only supports Python 3.10. 15 | Note that our dependencies require installing the old `sklearn` package 16 | (_not_ `scikit-learn`); you may need run the following command to allow this: 17 | 18 | ``` 19 | export SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True 20 | ``` 21 | 22 | ### M1 Macs 23 | 24 | If you're on an M1 Mac, you'll need to install `gosdt` from source; see the instructions 25 | [here](https://github.com/ubc-systopia/gosdt-guesses/blob/main/doc/build.md). Afterwards, install 26 | the other dependencies and the `maptree` Python package with 27 | 28 | ``` 29 | python -m pip install -r requirements_m1mac.txt 30 | python -m pip install . 31 | ``` 32 | 33 | ### Other platforms 34 | 35 | If you're not on an M1 Mac, you can install all dependencies and then build the `maptree` Python package directly: 36 | 37 | ``` 38 | python -m pip install -r requirements.txt 39 | python -m pip install . 40 | ``` 41 | 42 | ## Install Data 43 | 44 | Download all the necessary data with 45 | 46 | ``` 47 | python setup_data.py 48 | ``` 49 | 50 | ## Run Experiments on a Personal Computer 51 | 52 | We support running all experiments on a single machine (e.g., a personal laptop) with: 53 | 54 | ``` 55 | python run_experiment.py 56 | ``` 57 | 58 | Note that this may take a long time (1-2 weeks), depending on your hardware. 59 | 60 | ## Run Experiments on a Cluster 61 | 62 | We have also included a script to run the experiments on a cluster via SLURM, which can be invoked with: 63 | 64 | ``` 65 | sbatch script.slurm 66 | ``` 67 | 68 | Note that this may take some time (1-2 days) to complete, depending on your hardware, and that the 69 | parameters of the `script.slurm` file may need to be modified according to your cluster setup. 70 | 71 | ## Plot Results 72 | 73 | You can plot the results of all experiments with: 74 | 75 | ``` 76 | python plot_results.py 77 | ``` 78 | 79 | The plots can be found in `experiments/results/figures`. 80 | 81 | Hyperparameters for all the experiments can be found in the following files: 82 | ``` 83 | experiments/globals.py 84 | experiments/experiments/fig1/constants.py 85 | experiments/experiments/fig2/constants.py 86 | experiments/experiments/fig3/constants.py 87 | ``` -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | #---------------------------------# 2 | # environment configuration # 3 | #---------------------------------# 4 | 5 | # Build worker image (VM template) 6 | image: Visual Studio 2017 7 | 8 | clone_depth: 3 9 | 10 | platform: 11 | - Win32 12 | - x64 13 | 14 | configuration: 15 | - Debug 16 | - Release 17 | 18 | # environment: 19 | # matrix: 20 | # - TOOLSET: v140 21 | 22 | matrix: 23 | fast_finish: false 24 | 25 | # scripts that are called at very beginning, before repo cloning 26 | init: 27 | - cmd: cmake --version 28 | - cmd: msbuild /version 29 | 30 | before_build: 31 | - cmake . -Bbuild -A%PLATFORM% -DCMAKE_BUILD_TYPE=%configuration% 32 | 33 | build: 34 | project: build/CPP_BOILERPLATE.sln 35 | parallel: true 36 | verbosity: minimal 37 | 38 | test_script: 39 | - cd build 40 | - set CTEST_OUTPUT_ON_FAILURE=1 41 | - ctest -C %configuration% 42 | - cd .. 43 | -------------------------------------------------------------------------------- /data/test/test_data_small.txt: -------------------------------------------------------------------------------- 1 | 1 0 1 0 1 2 | 0 0 1 1 0 3 | 0 1 1 1 0 4 | 0 0 1 1 1 5 | 1 1 0 1 1 6 | 0 0 1 1 0 7 | 1 0 0 1 1 8 | 0 0 1 1 1 9 | 1 0 1 0 1 10 | 1 0 0 0 1 11 | 1 0 1 0 1 12 | 1 0 0 1 0 13 | 0 0 1 1 1 14 | 1 0 1 0 1 15 | 0 0 1 1 0 16 | 1 1 0 0 1 17 | 1 1 0 1 1 18 | 1 0 0 1 1 19 | 1 1 0 1 1 20 | 1 1 0 1 1 21 | 1 0 1 0 1 22 | 1 1 0 0 0 23 | 1 0 1 0 0 24 | 1 0 0 1 0 25 | 1 0 0 1 1 26 | 1 0 0 1 1 27 | 1 0 0 1 1 28 | 0 0 1 1 0 29 | 1 0 0 0 1 30 | 0 0 1 1 1 31 | 1 0 0 0 0 32 | 1 1 0 0 0 33 | 1 0 1 0 0 34 | 1 0 0 1 1 35 | 1 0 0 0 0 36 | 1 1 1 0 1 37 | 1 0 1 0 1 38 | 1 0 1 0 1 39 | 0 1 1 1 0 40 | 0 1 1 1 1 -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/dataloader/__init__.py -------------------------------------------------------------------------------- /experiments/dataloader/cp4im.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | import numpy as np 4 | from typing import List 5 | 6 | from ..globals import DIR_DATA_CP4IM, CP4IM_DATASET_NAMES, CP4IM_DATASET_URL 7 | 8 | 9 | def download_and_install(datasets: List[str] = None): 10 | if datasets is None: 11 | datasets = CP4IM_DATASET_NAMES 12 | 13 | for dataset in datasets: 14 | assert(dataset in CP4IM_DATASET_NAMES) 15 | print(f"Downloading CP4IM {dataset} dataset...") 16 | r = requests.get(CP4IM_DATASET_URL.format(dataset=dataset)) 17 | lines = r.text.split('\n') 18 | num_features = 0 19 | data_line = None 20 | for i, line in enumerate(lines): 21 | if line.startswith('@data'): 22 | data_line = i + 1 23 | elif line.startswith('@') \ 24 | and not line.startswith('@relation') \ 25 | and not line.startswith('@class'): 26 | line_tag = line[1:line.find(':')] 27 | num_features = max(num_features, int(line_tag) + 1) 28 | 29 | assert(num_features > 0) 30 | assert(data_line is not None) 31 | 32 | feats = [] 33 | labels = [] 34 | for line in lines[data_line:]: 35 | if not line: 36 | continue 37 | 38 | sample = line.split(' ') 39 | items = map(int, sample[:-1]) 40 | label = int(sample[-1]) 41 | feat = np.zeros(num_features, dtype=np.int64) 42 | for i in items: 43 | feat[i] = 1 44 | 45 | feats.append(feat) 46 | labels.append(label) 47 | 48 | X = np.row_stack(feats) 49 | y = np.array(labels) 50 | 51 | print(f'# Features: {X.shape[1]}') 52 | print(f'# Samples: {X.shape[0]}') 53 | 54 | if not os.path.exists(DIR_DATA_CP4IM): 55 | os.makedirs(DIR_DATA_CP4IM) 56 | 57 | data_path = os.path.join(DIR_DATA_CP4IM, f'{dataset}.txt') 58 | with open(data_path, 'w') as fp: 59 | A = np.column_stack((X, y)) 60 | np.savetxt(fp, A, fmt='%d') 61 | -------------------------------------------------------------------------------- /experiments/dataloader/synth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | import os 4 | 5 | from experiments.searchers.binary_classification_tree import BinaryClassificationTree 6 | from ..globals import ( 7 | SYNTH_NUM_TREES, 8 | SYNTH_TOTAL_SAMPLES_PER_TREE, 9 | SYNTH_NUM_FEATURES, 10 | SEED_SYNTH_TREE_GENERATOR, 11 | SEED_SYNTH_DATA_GENERATOR, 12 | SYNTH_TREE_NUM_INTERNAL_NODES, 13 | DIR_DATA_SYNTH 14 | ) 15 | 16 | 17 | def generate_random_tree( 18 | num_internal_nodes: int, 19 | available_features: List[int], 20 | rng: np.random.Generator): 21 | 22 | if num_internal_nodes == 0: 23 | return BinaryClassificationTree() 24 | 25 | # select feature from features not used by ancestors to avoid creating degenerate tree 26 | feature = rng.choice(available_features) 27 | available_features = [f for f in available_features if f != feature] 28 | 29 | # use 1 internal node for root, send uniform random amount left and the rest right 30 | num_internal_nodes_left = rng.integers(num_internal_nodes) 31 | left_subtree = generate_random_tree( 32 | num_internal_nodes_left, 33 | available_features, 34 | rng, 35 | ) 36 | right_subtree = generate_random_tree( 37 | num_internal_nodes - 1 - num_internal_nodes_left, 38 | available_features, 39 | rng, 40 | ) 41 | 42 | return BinaryClassificationTree(left_subtree, right_subtree, feature) 43 | 44 | 45 | def assign_random_labels(tree: BinaryClassificationTree, rng: np.random.Generator): 46 | assert not tree.is_leaf() 47 | leaves = tree.get_all_leaves() 48 | 49 | # assign alternating labels to leaf nodes 50 | label = rng.choice([False, True]) 51 | for leaf in leaves: 52 | leaf.label_counts = [0, 1] if label else [1, 0] 53 | label = not label 54 | 55 | 56 | def generate_synthetic_tree_data(): 57 | tree_rng = np.random.default_rng(SEED_SYNTH_TREE_GENERATOR) 58 | data_rng = np.random.default_rng(SEED_SYNTH_DATA_GENERATOR) 59 | 60 | X = data_rng.integers(2, size=(2, SYNTH_NUM_TREES, SYNTH_TOTAL_SAMPLES_PER_TREE, SYNTH_NUM_FEATURES)) 61 | for i in range(SYNTH_NUM_TREES): 62 | tree = generate_random_tree( 63 | SYNTH_TREE_NUM_INTERNAL_NODES, 64 | list(range(SYNTH_NUM_FEATURES)), 65 | tree_rng, 66 | ) 67 | assign_random_labels(tree, tree_rng) 68 | 69 | print(f"Generated Tree: {tree}") 70 | 71 | X_train = X[0][i] 72 | X_test = X[1][i] 73 | y_train = tree.predict(X_train) 74 | y_test = tree.predict(X_test) 75 | 76 | if not os.path.exists(DIR_DATA_SYNTH): 77 | os.makedirs(DIR_DATA_SYNTH) 78 | 79 | train_data_path = os.path.join(DIR_DATA_SYNTH, f'tree{i}-train.txt') 80 | with open(train_data_path, 'w') as fp: 81 | A = np.column_stack((X_train, y_train)) 82 | np.savetxt(fp, A, fmt='%d') 83 | 84 | test_data_path = os.path.join(DIR_DATA_SYNTH, f'tree{i}-test.txt') 85 | with open(test_data_path, 'w') as fp: 86 | A = np.column_stack((X_test, y_test)) 87 | np.savetxt(fp, A, fmt='%d') 88 | -------------------------------------------------------------------------------- /experiments/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/experiments/fig1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/experiments/fig1/__init__.py -------------------------------------------------------------------------------- /experiments/experiments/fig1/constants.py: -------------------------------------------------------------------------------- 1 | RHO = [2.5, 2.5] 2 | TIME_LIMIT = 60 3 | 4 | CART_PARAMS_LIST = [ 5 | {'max_depth': 2}, 6 | {'max_depth': 3}, 7 | {'max_depth': 4}, 8 | {'max_depth': 5}, 9 | {'max_depth': 6}, 10 | {'max_depth': 7}, 11 | {'max_depth': 8}, 12 | ] 13 | DL85_PARAMS_LIST = [ 14 | {'max_depth': 2, 'time_limit': TIME_LIMIT}, 15 | {'max_depth': 3, 'time_limit': TIME_LIMIT}, 16 | {'max_depth': 4, 'time_limit': TIME_LIMIT}, 17 | {'max_depth': 5, 'time_limit': TIME_LIMIT}, 18 | {'max_depth': 6, 'time_limit': TIME_LIMIT}, 19 | ] 20 | GOSDT_PARAMS_LIST = [ 21 | {'regularization': 0.03125, 'time_limit': TIME_LIMIT}, 22 | {'regularization': 0.3125, 'time_limit': TIME_LIMIT}, 23 | ] 24 | MAPTREE_PARAMS_LIST = [ 25 | {'alpha': 0.999, 'beta': 0.1, 'rho': RHO, 'time_limit': TIME_LIMIT}, 26 | {'alpha': 0.99, 'beta': 0.2, 'rho': RHO, 'time_limit': TIME_LIMIT}, 27 | {'alpha': 0.95, 'beta': 0.5, 'rho': RHO, 'time_limit': TIME_LIMIT}, 28 | {'alpha': 0.9, 'beta': 1.0, 'rho': RHO, 'time_limit': TIME_LIMIT}, 29 | {'alpha': 0.8, 'beta': 2.0, 'rho': RHO, 'time_limit': TIME_LIMIT}, 30 | {'alpha': 0.5, 'beta': 4.0, 'rho': RHO, 'time_limit': TIME_LIMIT}, 31 | {'alpha': 0.2, 'beta': 8.0, 'rho': RHO, 'time_limit': TIME_LIMIT}, 32 | ] 33 | 34 | SEARCHERS_AND_PARAMS_LISTS = [ 35 | ("MAPTree", MAPTREE_PARAMS_LIST), 36 | ("CART", CART_PARAMS_LIST), 37 | ("DL8.5", DL85_PARAMS_LIST), 38 | ("GOSDT", GOSDT_PARAMS_LIST), 39 | ] 40 | -------------------------------------------------------------------------------- /experiments/experiments/fig1/plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import os 4 | import pandas as pd 5 | 6 | from experiments.globals import get_latest_results, CP4IM_DATASET_NAMES, DIR_RESULTS_FIGS 7 | 8 | 9 | def boxplot(model_data, cart_data, file): 10 | data = pd.concat(model_data) 11 | 12 | cart_data = cart_data[["test_acc", "test_sll", "dataset"]].rename(columns={"test_acc": "base_acc", "test_sll": "base_sll"}) 13 | data = data.merge(cart_data, on="dataset") 14 | data["test_acc"] = data["test_acc"] - data["base_acc"] 15 | data["test_sll"] = data["test_sll"] - data["base_sll"] 16 | 17 | figure = plt.figure(layout="constrained", figsize=(8, 6)) 18 | ax = figure.subplots(1, 3) 19 | 20 | sns.boxplot( 21 | data[(data["searcher"] != "CART")], 22 | x="Model", 23 | y="test_acc", 24 | hue="searcher", 25 | ax=ax[0], 26 | ) 27 | 28 | sns.boxplot( 29 | data[(data["searcher"] != "CART")], 30 | x="Model", 31 | y="test_sll", 32 | hue="searcher", 33 | ax=ax[1], 34 | ) 35 | 36 | sns.boxplot( 37 | data, 38 | x="Model", 39 | y="size", 40 | hue="searcher", 41 | ax=ax[2] 42 | ) 43 | 44 | for axis in ax: 45 | axis.get_legend().remove() 46 | axis.set_xlabel("Model") 47 | axis.set_xticklabels(axis.get_xticklabels(), rotation=90) 48 | 49 | ax[0].set_ylabel("Relative Test Accuracy") 50 | ax[1].set_ylabel("Relative Per-Sample Test Log Likelihood") 51 | ax[2].set_ylabel("Tree Size") 52 | 53 | fig_file = os.path.join(DIR_RESULTS_FIGS, file) 54 | if not os.path.exists(DIR_RESULTS_FIGS): 55 | os.makedirs(DIR_RESULTS_FIGS) 56 | figure.savefig(fig_file, format='pdf', bbox_inches='tight') 57 | 58 | 59 | def run(): 60 | print(f"Plotting performance comparison of algorithms on CP4IM datasets...") 61 | 62 | all_results = [] 63 | for dataset in CP4IM_DATASET_NAMES: 64 | results = get_latest_results("fig1", dataset) 65 | results["dataset"] = dataset 66 | all_results.append(results) 67 | 68 | data = pd.concat(all_results) \ 69 | .groupby(['searcher', 'params_id', 'dataset'])[['test_acc', 'test_sll', 'size']] \ 70 | .agg('mean') \ 71 | .reset_index() 72 | 73 | dl85_depth_4_data = data[(data["searcher"] == "DL8.5") & (data["params_id"] == 2)] 74 | dl85_depth_5_data = data[(data["searcher"] == "DL8.5") & (data["params_id"] == 3)] 75 | dl85_depth_6_data = data[(data["searcher"] == "DL8.5") & (data["params_id"] == 4)] 76 | 77 | gosdt_slow_data = data[(data["searcher"] == "GOSDT") & (data["params_id"] == 0)] 78 | gosdt_fast_data = data[(data["searcher"] == "GOSDT") & (data["params_id"] == 1)] 79 | 80 | maptree_data_default = data[(data["searcher"] == "MAPTree") & (data["params_id"] == 2)] 81 | all_maptree_data = [] 82 | for params_id in range(7): 83 | all_maptree_data.append(data[(data["searcher"] == "MAPTree") & (data["params_id"] == params_id)]) 84 | all_maptree_data[-1]["Model"] = f"MAPTree (params={params_id})" 85 | 86 | cart_depth_4_data = data[(data["searcher"] == "CART") & (data["params_id"] == 2)] 87 | cart_depth_5_data = data[(data["searcher"] == "CART") & (data["params_id"] == 3)] 88 | cart_depth_6_data = data[(data["searcher"] == "CART") & (data["params_id"] == 4)] 89 | 90 | dl85_depth_4_data["Model"] = "DL8.5 (depth=4)" 91 | dl85_depth_5_data["Model"] = "DL8.5 (depth=5)" 92 | dl85_depth_6_data["Model"] = "DL8.5 (depth=6)" 93 | 94 | gosdt_slow_data["Model"] = "GOSDT (reg=0.03125)" 95 | gosdt_fast_data["Model"] = "GOSDT (reg=0.3125)" 96 | 97 | maptree_data_default["Model"] = "MAPTree" 98 | 99 | cart_depth_4_data["Model"] = "CART (depth=4)" 100 | cart_depth_5_data["Model"] = "CART (depth=5)" 101 | cart_depth_6_data["Model"] = "CART (depth=6)" 102 | 103 | all_model_data = [ 104 | maptree_data_default, 105 | dl85_depth_4_data, 106 | dl85_depth_5_data, 107 | gosdt_slow_data, 108 | gosdt_fast_data, 109 | cart_depth_4_data, 110 | ] 111 | 112 | boxplot(all_model_data, cart_depth_4_data, "fig1.pdf") 113 | boxplot(all_maptree_data, cart_depth_4_data, "fig1-appendix.pdf") 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /experiments/experiments/fig1/runner.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from experiments.globals import get_stratified_k_folds_cp4im_dataset, run_search, save_results 4 | from .constants import SEARCHERS_AND_PARAMS_LISTS, RHO 5 | 6 | 7 | def run(dataset: str): 8 | print(f"Performance comparison on CP4IM dataset: {dataset}") 9 | print("=====================================================") 10 | 11 | results = [] 12 | for i, fold in enumerate(get_stratified_k_folds_cp4im_dataset(dataset)): 13 | print(f"Fold: {i}") 14 | X_train, y_train, X_test, y_test = fold 15 | for searcher, params_list in SEARCHERS_AND_PARAMS_LISTS: 16 | print(f"Searcher: {searcher}") 17 | for j, params in enumerate(params_list): 18 | print(f"Params: {params}") 19 | 20 | result = run_search(searcher, X_train, y_train, **params) 21 | if result is None: 22 | print("Run Failed!!!") 23 | continue 24 | 25 | tree = result['tree'] 26 | time = result['time'] 27 | timeout = result['timeout'] 28 | 29 | tree.fit(X_train, y_train) 30 | 31 | # add results to results queue 32 | train_acc = (tree.predict(X_train) == y_train).sum() / len(y_train) 33 | test_acc = (tree.predict(X_test) == y_test).sum() / len(y_test) 34 | train_sll = tree.log_likelihood(X_train, y_train, rho=RHO) / len(y_train) 35 | test_sll = tree.log_likelihood(X_test, y_test, rho=RHO) / len(y_test) 36 | size = tree.size() 37 | 38 | print(f"Timed Out: {timeout}") 39 | print(f"Test Accuracy: {test_acc}") 40 | print(f"Test SLL: {test_sll}") 41 | 42 | results.append({ 43 | 'searcher': searcher, 44 | 'params_id': j, 45 | 'fold': i, 46 | 'tree': str(tree), 47 | 'time': time, 48 | 'train_acc': train_acc, 49 | 'test_acc': test_acc, 50 | 'train_sll': train_sll, 51 | 'test_sll': test_sll, 52 | 'size': size, 53 | 'timeout': timeout, 54 | }) 55 | 56 | save_results(pd.DataFrame(results), "fig1", dataset) 57 | -------------------------------------------------------------------------------- /experiments/experiments/fig2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/experiments/fig2/__init__.py -------------------------------------------------------------------------------- /experiments/experiments/fig2/constants.py: -------------------------------------------------------------------------------- 1 | POSTERIOR = { 2 | 'alpha': 0.95, 3 | 'beta': 0.5, 4 | 'rho': [2.5, 2.5], 5 | } 6 | 7 | FINAL_RUN_TIME_LIMIT = 180 8 | 9 | RANDOM_SEARCHER_SEEDS = list(range(42, 52)) 10 | RANDOM_SEARCHERS = ["MCMC", "SMC"] 11 | 12 | MCMC_PARAMS_LIST = [ 13 | {"num_iterations": 10, **POSTERIOR}, 14 | {"num_iterations": 30, **POSTERIOR}, 15 | {"num_iterations": 100, **POSTERIOR}, 16 | {"num_iterations": 300, **POSTERIOR}, 17 | {"num_iterations": 1000, **POSTERIOR}, 18 | ] 19 | SMC_PARAMS_LIST = [ 20 | {"num_particles": 10, **POSTERIOR}, 21 | {"num_particles": 30, **POSTERIOR}, 22 | {"num_particles": 100, **POSTERIOR}, 23 | {"num_particles": 300, **POSTERIOR}, 24 | {"num_particles": 1000, **POSTERIOR}, 25 | ] 26 | MAPTREE_PARAMS_LIST = [ 27 | {"num_expansions": 10, **POSTERIOR}, 28 | {"num_expansions": 30, **POSTERIOR}, 29 | {"num_expansions": 100, **POSTERIOR}, 30 | {"num_expansions": 300, **POSTERIOR}, 31 | {"num_expansions": 1000, **POSTERIOR}, 32 | {"num_expansions": 3000, **POSTERIOR}, 33 | {"num_expansions": 10000, **POSTERIOR}, 34 | {"num_expansions": 30000, **POSTERIOR}, 35 | {"num_expansions": 100000, **POSTERIOR}, 36 | {"time_limit": FINAL_RUN_TIME_LIMIT, **POSTERIOR}, 37 | ] 38 | 39 | SEARCHERS_AND_PARAMS_LISTS = [ 40 | ("MAPTree", MAPTREE_PARAMS_LIST), 41 | ("MCMC", MCMC_PARAMS_LIST), 42 | ("SMC", SMC_PARAMS_LIST), 43 | ] 44 | -------------------------------------------------------------------------------- /experiments/experiments/fig2/plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | import os 4 | 5 | from experiments.globals import get_latest_results, CP4IM_DATASET_NAMES, DIR_RESULTS_FIGS 6 | 7 | 8 | def run(): 9 | print(f"Plotting speed comparison of algorithms on CP4IM datasets...") 10 | 11 | figure = plt.figure(layout="constrained", figsize=(6, 4)) 12 | nrows = 2 13 | ncols = 2 14 | axes = figure.subplots(nrows, ncols) 15 | 16 | for i, dataset in enumerate(CP4IM_DATASET_NAMES[:4]): 17 | ax = axes[i // ncols, i % ncols] 18 | results = get_latest_results("fig2", dataset) 19 | 20 | # average times across runs 21 | avg_times = results \ 22 | .groupby(["searcher", "params_id"])['time'] \ 23 | .agg('mean') \ 24 | .reset_index() 25 | results = results[["searcher", "params_id", "post", "best_post"]] \ 26 | .merge(avg_times, on=["searcher", "params_id"]) 27 | 28 | sns.lineplot( 29 | results, 30 | x="time", 31 | y="post", 32 | hue="searcher", 33 | estimator="mean", 34 | errorbar=("ci", 95), 35 | ax=ax, 36 | ) 37 | 38 | ax.set_title(dataset) 39 | ax.set_xlabel("Time (s)") 40 | ax.set_xscale("log") 41 | ax.set_ylabel("Log Posterior") 42 | 43 | handles, labels = ax.get_legend_handles_labels() 44 | ax.get_legend().remove() 45 | 46 | legend = figure.legend( 47 | handles=handles, 48 | labels=labels, 49 | title='Model', 50 | loc='lower center', 51 | ncol=3, 52 | bbox_to_anchor=(0.5, -0.13), 53 | ) 54 | 55 | fig_file = os.path.join(DIR_RESULTS_FIGS, f"fig2.pdf") 56 | if not os.path.exists(DIR_RESULTS_FIGS): 57 | os.makedirs(DIR_RESULTS_FIGS) 58 | 59 | figure.savefig(fig_file, format='pdf', bbox_extra_artists=(legend,), bbox_inches='tight') 60 | 61 | -------------------------------------------------------------------------------- /experiments/experiments/fig2/runner.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from math import isclose 3 | 4 | from experiments.globals import get_full_cp4im_dataset, run_search, save_results 5 | from .constants import SEARCHERS_AND_PARAMS_LISTS, POSTERIOR, RANDOM_SEARCHERS, RANDOM_SEARCHER_SEEDS 6 | 7 | 8 | def run(dataset: str): 9 | print(f"Speed comparison on CP4IM dataset: {dataset}") 10 | print("=====================================================") 11 | 12 | data = get_full_cp4im_dataset(dataset) 13 | X, y = data 14 | 15 | results = [] 16 | 17 | for searcher, params_list in SEARCHERS_AND_PARAMS_LISTS: 18 | print(f"Searcher: {searcher}") 19 | for j, params in enumerate(params_list): 20 | print(f"Params: {params}") 21 | for k, seed in enumerate(RANDOM_SEARCHER_SEEDS): 22 | if searcher not in RANDOM_SEARCHERS and k > 0: 23 | continue 24 | 25 | if searcher in RANDOM_SEARCHERS: 26 | print(f"Seed: {seed}") 27 | params['seed'] = seed 28 | 29 | result = run_search(searcher, X, y, **params) 30 | if result is None: 31 | print("Run Failed!!!") 32 | continue 33 | 34 | tree = result['tree'] 35 | time = result['time'] 36 | tree.fit(X, y) 37 | size = tree.size() 38 | 39 | post = tree.log_posterior(X, y, **POSTERIOR) 40 | best_post = 0.0 41 | 42 | if "lower_bound" in result: 43 | best_post = -result["lower_bound"] 44 | if "upper_bound" in result: 45 | assert(isclose(post, -result["upper_bound"])) 46 | 47 | print(f"Time: {time}") 48 | print(f"Log Posterior: {post}") 49 | 50 | results.append({ 51 | 'searcher': searcher, 52 | 'params_id': j, 53 | 'tree': str(tree), 54 | 'time': time, 55 | 'post': post, 56 | 'best_post': best_post, 57 | 'size': size, 58 | 'seed': k, 59 | }) 60 | 61 | save_results(pd.DataFrame(results), "fig2", dataset) 62 | 63 | -------------------------------------------------------------------------------- /experiments/experiments/fig3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/experiments/fig3/__init__.py -------------------------------------------------------------------------------- /experiments/experiments/fig3/constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | NUM_SAMPLE_SIZE_VALUES = 10 4 | SAMPLE_SIZE_VALUES = list(np.round(np.linspace(100, 1000, NUM_SAMPLE_SIZE_VALUES)).astype(np.int32)) 5 | NOISE_VALUES = [0.0, 0.1, 0.25] 6 | 7 | TIME_LIMIT = 60 8 | RHO = [2.5, 2.5] 9 | 10 | CART_PARAMS_LIST = [ 11 | {'max_depth': 2}, 12 | {'max_depth': 3}, 13 | {'max_depth': 4}, 14 | {'max_depth': 5}, 15 | {'max_depth': 6}, 16 | {'max_depth': 7}, 17 | {'max_depth': 8}, 18 | ] 19 | DL85_PARAMS_LIST = [ 20 | {'max_depth': 2, 'time_limit': TIME_LIMIT}, 21 | {'max_depth': 3, 'time_limit': TIME_LIMIT}, 22 | {'max_depth': 4, 'time_limit': TIME_LIMIT}, 23 | {'max_depth': 5, 'time_limit': TIME_LIMIT}, 24 | {'max_depth': 6, 'time_limit': TIME_LIMIT}, 25 | ] 26 | GOSDT_PARAMS_LIST = [ 27 | {'regularization': 0.03125, 'time_limit': TIME_LIMIT}, 28 | {'regularization': 0.3125, 'time_limit': TIME_LIMIT}, 29 | ] 30 | MAPTREE_PARAMS_LIST = [ 31 | {'alpha': 0.95, 'beta': 0.5, 'rho': RHO, 'time_limit': TIME_LIMIT}, 32 | ] 33 | 34 | SEARCHERS_AND_PARAMS_LISTS = [ 35 | ("MAPTree", MAPTREE_PARAMS_LIST), 36 | ("CART", CART_PARAMS_LIST), 37 | ("DL8.5", DL85_PARAMS_LIST), 38 | ("GOSDT", GOSDT_PARAMS_LIST), 39 | ] 40 | -------------------------------------------------------------------------------- /experiments/experiments/fig3/plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.lines import Line2D 3 | import seaborn as sns 4 | import os 5 | import pandas as pd 6 | 7 | from experiments.globals import get_latest_results, SYNTH_NUM_TREES, DIR_RESULTS_FIGS 8 | from .constants import SAMPLE_SIZE_VALUES, NOISE_VALUES 9 | 10 | 11 | def run(): 12 | print(f"Plotting performance comparison of algorithms on synth datasets...") 13 | 14 | figure = plt.figure(layout="constrained", figsize=(6, 8)) 15 | axes = figure.subplots(len(NOISE_VALUES), 1) 16 | 17 | all_results = [] 18 | for i in range(SYNTH_NUM_TREES): 19 | results = get_latest_results("fig3", f"tree{i}") 20 | all_results.append(results) 21 | 22 | data = pd.concat(all_results) 23 | dl85_depth_4_data = data[(data["searcher"] == "DL8.5") & (data["params_id"] == 2)] 24 | dl85_depth_5_data = data[(data["searcher"] == "DL8.5") & (data["params_id"] == 3)] 25 | dl85_depth_6_data = data[(data["searcher"] == "DL8.5") & (data["params_id"] == 4)] 26 | gosdt_slow_data = data[(data["searcher"] == "GOSDT") & (data["params_id"] == 0)] 27 | gosdt_fast_data = data[(data["searcher"] == "GOSDT") & (data["params_id"] == 1)] 28 | maptree_data = data[(data["searcher"] == "MAPTree") & (data["params_id"] == 0)] 29 | cart_4_data = data[(data["searcher"] == "CART") & (data["params_id"] == 2)] 30 | cart_5_data = data[(data["searcher"] == "CART") & (data["params_id"] == 2)] 31 | cart_6_data = data[(data["searcher"] == "CART") & (data["params_id"] == 2)] 32 | 33 | maptree_data["model"] = "MAPTree" 34 | 35 | dl85_depth_4_data["model"] = "DL8.5 (depth=4)" 36 | dl85_depth_5_data["model"] = "DL8.5 (depth=5)" 37 | dl85_depth_6_data["model"] = "DL8.5 (depth=6)" 38 | 39 | gosdt_slow_data["model"] = "GOSDT (reg=0.03125)" 40 | gosdt_fast_data["model"] = "GOSDT (reg=0.3125)" 41 | 42 | cart_4_data["model"] = "CART (depth=4)" 43 | cart_5_data["model"] = "CART (depth=5)" 44 | cart_6_data["model"] = "CART (depth=6)" 45 | 46 | data = pd.concat([ 47 | maptree_data, 48 | dl85_depth_4_data, 49 | dl85_depth_5_data, 50 | # dl85_depth_6_data, 51 | gosdt_slow_data, 52 | gosdt_fast_data, 53 | cart_4_data, 54 | # cart_5_data, 55 | # cart_6_data, 56 | ]) 57 | 58 | for i, noise in enumerate(NOISE_VALUES): 59 | noise_data = data[data["noise_id"] == i] 60 | noise_data["num_samples"] = noise_data["sample_size_id"].map(lambda id: SAMPLE_SIZE_VALUES[int(id)]) 61 | ax = axes[i] 62 | 63 | sns.lineplot( 64 | noise_data, 65 | x="num_samples", 66 | y="test_acc", 67 | style="model", 68 | markers=True, 69 | hue="searcher", 70 | estimator="mean", 71 | errorbar=("ci", 95), 72 | ax=ax, 73 | ) 74 | 75 | ax.set_title(f"$\epsilon$ = {noise}") 76 | ax.set_xlabel("Numer of Training Samples") 77 | ax.set_ylabel("Test Accuracy") 78 | ax.set_ylim([0.5, 1.0]) 79 | ax.get_legend().remove() 80 | 81 | handles, labels = ax.get_legend_handles_labels() 82 | 83 | new_labels = [] 84 | new_handles = [] 85 | 86 | for i in range(6, len(labels)): 87 | model_line = handles[labels.index(labels[i].split(' ')[0])] 88 | new_labels.append(labels[i]) 89 | new_handles.append(Line2D([], [], color=model_line.get_color(), marker=handles[i].get_marker(), linestyle=handles[i].get_linestyle())) 90 | 91 | legend = figure.legend( 92 | labels=new_labels, 93 | handles=new_handles, 94 | title='Model', 95 | loc='lower center', 96 | ncol=3, 97 | bbox_to_anchor=(0.5, -0.1), 98 | ) 99 | 100 | fig_file = os.path.join(DIR_RESULTS_FIGS, f"fig3.pdf") 101 | if not os.path.exists(DIR_RESULTS_FIGS): 102 | os.makedirs(DIR_RESULTS_FIGS) 103 | figure.savefig(fig_file, format='pdf', bbox_extra_artists=(legend,), bbox_inches='tight') 104 | -------------------------------------------------------------------------------- /experiments/experiments/fig3/runner.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from experiments.globals import get_synth_data_samples, run_search, save_results 4 | from .constants import SEARCHERS_AND_PARAMS_LISTS, NOISE_VALUES, SAMPLE_SIZE_VALUES, RHO 5 | 6 | 7 | def run(tree_id: int): 8 | print(f"Performance comparison on synthetic tree-generated data for tree {tree_id}") 9 | print("=====================================================") 10 | 11 | results = [] 12 | for i, sample_size in enumerate(SAMPLE_SIZE_VALUES): 13 | print(f"Sample Size: {sample_size}") 14 | for j, noise in enumerate(NOISE_VALUES): 15 | print(f"Noise: {noise}") 16 | data = get_synth_data_samples(tree_id, sample_size, noise) 17 | X_train, y_train, X_test, y_test = data 18 | for searcher, params_list in SEARCHERS_AND_PARAMS_LISTS: 19 | print(f"Searcher: {searcher}") 20 | for k, params in enumerate(params_list): 21 | print(f"Params: {params}") 22 | 23 | result = run_search(searcher, X_train, y_train, **params) 24 | if result is None: 25 | print("Run Failed!!!") 26 | continue 27 | 28 | tree = result['tree'] 29 | time = result['time'] 30 | 31 | tree.fit(X_train, y_train) 32 | 33 | # add results to results queue 34 | train_acc = (tree.predict(X_train) == y_train).sum() / len(y_train) 35 | test_acc = (tree.predict(X_test) == y_test).sum() / len(y_test) 36 | train_sll = tree.log_likelihood(X_train, y_train, rho=RHO) / len(y_train) 37 | test_sll = tree.log_likelihood(X_test, y_test, rho=RHO) / len(y_test) 38 | size = tree.size() 39 | 40 | print(f"Test Acc: {test_acc}") 41 | 42 | results.append({ 43 | 'sample_size_id': i, 44 | 'noise_id': j, 45 | 'tree_data_id': tree_id, 46 | 'searcher': searcher, 47 | 'params_id': k, 48 | 'tree': str(tree), 49 | 'time': time, 50 | 'train_acc': train_acc, 51 | 'test_acc': test_acc, 52 | 'train_sll': train_sll, 53 | 'test_sll': test_sll, 54 | 'size': size, 55 | }) 56 | 57 | save_results(pd.DataFrame(results), "fig3", f"tree{tree_id}") 58 | -------------------------------------------------------------------------------- /experiments/globals.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from multiprocessing import Process, Queue, TimeoutError 5 | from queue import Empty 6 | from datetime import datetime 7 | from glob import glob 8 | from sklearn.model_selection import StratifiedKFold 9 | 10 | from experiments.searchers.maptree import run as maptree_search 11 | from experiments.searchers.mcmc import run as mcmc_search 12 | from experiments.searchers.smc import run as smc_search 13 | from experiments.searchers.cart import run as cart_search 14 | from experiments.searchers.dl85 import run as dl85_search 15 | from experiments.searchers.gosdt import run as gosdt_search 16 | 17 | DIR_DATA_CP4IM = os.path.join("data", "cp4im") 18 | DIR_DATA_SYNTH = os.path.join("data", "synth") 19 | 20 | CP4IM_DATASET_URL = 'https://dtai.cs.kuleuven.be/CP4IM/datasets/data/{dataset}.txt' 21 | CP4IM_DATASET_NAMES = sorted([ 22 | 'zoo-1', 23 | 'vote', 24 | 'tic-tac-toe', 25 | 'splice-1', 26 | 'soybean', 27 | 'primary-tumor', 28 | 'mushroom', 29 | 'lymph', 30 | 'kr-vs-kp', 31 | 'hypothyroid', 32 | 'hepatitis', 33 | 'heart-cleveland', 34 | 'german-credit', 35 | 'australian-credit', 36 | 'audiology', 37 | 'anneal', 38 | ]) 39 | 40 | DIR_RESULTS_DATA = os.path.join("experiments", "results", "data") 41 | DIR_RESULTS_FIGS = os.path.join("experiments", "results", "figures") 42 | TIMESTAMP_FORMAT = "%Y-%m-%d-%H:%M:%S" 43 | 44 | CP4IM_NUM_FOLDS = 10 45 | 46 | SYNTH_NUM_TREES = 20 47 | SYNTH_TREE_NUM_INTERNAL_NODES = 15 48 | SYNTH_TOTAL_SAMPLES_PER_TREE = 1000 49 | SYNTH_NUM_FEATURES = 40 50 | 51 | SEED_CP4IM_STRATIFIED_FOLD_CONSTRUCTOR = 84 52 | SEED_SYNTH_TREE_GENERATOR = 42 53 | SEED_SYNTH_DATA_GENERATOR = 21 54 | 55 | 56 | def search_target_decorator(search, queue: Queue): 57 | def search_target_func(*args, **kwargs): 58 | queue.put(search(*args, **kwargs)) 59 | return search_target_func 60 | 61 | 62 | # need to create actual functions in order to pickle :( 63 | def maptree_search_wrapper(queue: Queue, *args, **kwargs): 64 | return search_target_decorator(maptree_search, queue)(*args, **kwargs) 65 | 66 | 67 | def mcmc_search_wrapper(queue: Queue, *args, **kwargs): 68 | return search_target_decorator(mcmc_search, queue)(*args, **kwargs) 69 | 70 | 71 | def smc_search_wrapper(queue: Queue, *args, **kwargs): 72 | return search_target_decorator(smc_search, queue)(*args, **kwargs) 73 | 74 | 75 | def cart_search_wrapper(queue: Queue, *args, **kwargs): 76 | return search_target_decorator(cart_search, queue)(*args, **kwargs) 77 | 78 | 79 | def dl85_search_wrapper(queue: Queue, *args, **kwargs): 80 | return search_target_decorator(dl85_search, queue)(*args, **kwargs) 81 | 82 | 83 | def gosdt_search_wrapper(queue: Queue, *args, **kwargs): 84 | return search_target_decorator(gosdt_search, queue)(*args, **kwargs) 85 | 86 | 87 | ALL_SEARCHERS = { 88 | "CART": cart_search_wrapper, 89 | "DL8.5": dl85_search_wrapper, 90 | "GOSDT": gosdt_search_wrapper, 91 | "MAPTree": maptree_search_wrapper, 92 | "MCMC": mcmc_search_wrapper, 93 | "SMC": smc_search_wrapper, 94 | } 95 | 96 | 97 | def run_search(searcher: str, *args, **kwargs): 98 | if searcher not in ALL_SEARCHERS: 99 | raise ValueError(f"{searcher} is not a valid searcher") 100 | 101 | queue = Queue() 102 | p = Process(target=ALL_SEARCHERS[searcher], args=(queue,) + args, kwargs=kwargs) 103 | 104 | process_timeout = kwargs['time_limit'] * 2 if 'time_limit' in kwargs else None 105 | p.start() 106 | result = None 107 | try: 108 | result = queue.get(timeout=process_timeout) 109 | p.join(timeout=5) # give some time for a nice close 110 | except (TimeoutError, Empty): 111 | print("Process timed out") 112 | finally: 113 | p.close() 114 | return result 115 | 116 | 117 | def load_binary_data(path): 118 | assert os.path.exists(path) 119 | binary_data = np.loadtxt(path, delimiter=' ', dtype=np.int32) 120 | assert np.all((binary_data == 0) | (binary_data == 1)) 121 | X = binary_data[:, :-1] 122 | y = binary_data[:, -1] 123 | return X, y 124 | 125 | 126 | def get_cp4im_data_path(dataset: str): 127 | assert dataset in CP4IM_DATASET_NAMES 128 | return os.path.join(DIR_DATA_CP4IM, f"{dataset}.txt") 129 | 130 | 131 | def get_full_cp4im_dataset(dataset: str): 132 | file = get_cp4im_data_path(dataset) 133 | return load_binary_data(file) 134 | 135 | 136 | def get_stratified_k_folds_cp4im_dataset(dataset: str, k: int = CP4IM_NUM_FOLDS): 137 | file = get_cp4im_data_path(dataset) 138 | X, y = load_binary_data(file) 139 | skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=SEED_CP4IM_STRATIFIED_FOLD_CONSTRUCTOR) 140 | for train_idxs, test_idxs in skf.split(X, y): 141 | X_train = X[train_idxs] 142 | y_train = y[train_idxs] 143 | X_test = X[test_idxs] 144 | y_test = y[test_idxs] 145 | yield X_train, y_train, X_test, y_test 146 | 147 | 148 | def get_synth_data_samples(tree_id: int, sample_size: int, noise: float): 149 | assert sample_size <= SYNTH_TOTAL_SAMPLES_PER_TREE 150 | 151 | path_train = os.path.join(DIR_DATA_SYNTH, f"tree{tree_id}-train.txt") 152 | path_test = os.path.join(DIR_DATA_SYNTH, f"tree{tree_id}-test.txt") 153 | X_train, y_train = load_binary_data(path_train) 154 | X_test, y_test = load_binary_data(path_test) 155 | 156 | # apply noise to the training data 157 | rng = np.random.default_rng(SEED_SYNTH_DATA_GENERATOR) 158 | flip = rng.random((SYNTH_NUM_TREES, SYNTH_TOTAL_SAMPLES_PER_TREE)) < noise 159 | y_train ^= flip[tree_id] 160 | 161 | return X_train[:sample_size], y_train[:sample_size], X_test, y_test 162 | 163 | 164 | def save_results(results: pd.DataFrame, experiment: str, dataset: str): 165 | df = pd.DataFrame(results) 166 | timestamp = datetime.now().strftime(TIMESTAMP_FORMAT) 167 | dir_dataset_results = os.path.join(DIR_RESULTS_DATA, experiment, dataset) 168 | if not os.path.exists(dir_dataset_results): 169 | os.makedirs(dir_dataset_results) 170 | 171 | file = os.path.join(dir_dataset_results, f"results-{timestamp}.csv") 172 | df.to_csv(file) 173 | print(f"Saved results for experiment {experiment} on dataset {dataset} to file {file}") 174 | 175 | 176 | def get_latest_results(experiment: str, dataset: str) -> pd.DataFrame: 177 | results_dir = os.path.join(DIR_RESULTS_DATA, experiment, dataset) 178 | if not os.path.exists(results_dir): 179 | raise ValueError(f"Results directory does not exist: {results_dir}") 180 | 181 | results_files = [ 182 | os.path.basename(f) for f in 183 | glob(os.path.join(results_dir, 'results-*.csv')) 184 | ] 185 | timestamps = [ 186 | datetime.strptime(f[len('results-'):-len('.csv')], TIMESTAMP_FORMAT) 187 | for f in results_files 188 | ] 189 | most_recent_results_file = results_files[timestamps.index(max(timestamps))] 190 | return pd.read_csv(os.path.join(results_dir, most_recent_results_file)) 191 | 192 | 193 | -------------------------------------------------------------------------------- /experiments/searchers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/searchers/__init__.py -------------------------------------------------------------------------------- /experiments/searchers/binary_classification_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Tuple 3 | from math import lgamma 4 | 5 | 6 | def get_num_valid_splits(X): 7 | num_valid = 0 8 | for f in range(X.shape[1]): 9 | if 0 < np.count_nonzero(X[:, f]) < X.shape[0]: 10 | num_valid += 1 11 | return num_valid 12 | 13 | 14 | def log_prob_split(depth: int, alpha: float, beta: float) -> float: 15 | return np.log(alpha) - beta * np.log(depth + 1) 16 | 17 | 18 | def log_prob_stop(depth: int, alpha: float, beta: float) -> float: 19 | return np.log(1 - np.exp(log_prob_split(depth, alpha, beta))) 20 | 21 | 22 | def log_beta(count: Tuple[float, float]) -> float: 23 | return lgamma(count[0]) + lgamma(count[1]) - lgamma(count[0] + count[1]) 24 | 25 | 26 | def log_likelihood(y, rho: Tuple[float, float]) -> float: 27 | count = np.bincount(y, minlength=2) 28 | return log_beta((count[0] + rho[0], count[1] + rho[1])) - log_beta(rho) 29 | 30 | 31 | def split(X, feature: int) -> Tuple[np.ndarray, np.ndarray]: 32 | left = np.nonzero(X[:, feature] == False) 33 | right = np.nonzero(X[:, feature] == True) 34 | return left, right 35 | 36 | 37 | class BinaryClassificationTree: 38 | def __init__(self, 39 | left: 'BinaryClassificationTree' = None, 40 | right: 'BinaryClassificationTree' = None, 41 | feature: int = None): 42 | assert ((left is None) == (right is None) == (feature is None)) 43 | self.left = left 44 | self.right = right 45 | self.feature = feature 46 | self.label_counts = None 47 | 48 | def __str__(self): 49 | if self.is_leaf(): 50 | return "" 51 | return f"({self.left}{self.feature}{self.right})" 52 | 53 | @classmethod 54 | def parse(cls, tree: str) -> 'BinaryClassificationTree': 55 | if type(tree) == float: 56 | tree = str(tree) 57 | if tree in ['', 'nan']: 58 | return BinaryClassificationTree() 59 | 60 | def parse_feature(tree: str, i: int) -> Tuple[int, int]: 61 | j = i + 1 62 | while tree[j] not in ['(', ')']: 63 | j += 1 64 | return int(tree[i:j]), j 65 | 66 | def parse_node(tree: str, i: int = 0) -> Tuple[ 67 | BinaryClassificationTree, int]: 68 | if tree == '': 69 | return BinaryClassificationTree(), i 70 | if tree[i] == '(': 71 | left, i = parse_node(tree, i + 1) 72 | feature, i = parse_feature(tree, i) 73 | right, i = parse_node(tree, i) 74 | return BinaryClassificationTree(left, right, feature), i + 1 75 | else: 76 | return BinaryClassificationTree(), i 77 | 78 | return parse_node(tree)[0] 79 | 80 | def is_leaf(self) -> bool: 81 | return self.feature is None 82 | 83 | def size(self) -> int: 84 | if self.is_leaf(): 85 | return 1 86 | return 1 + self.left.size() + self.right.size() 87 | 88 | def depth(self) -> int: 89 | if self.is_leaf(): 90 | return 0 91 | return max(self.left.depth(), self.right.depth()) 92 | 93 | def fit(self, X, y): 94 | self.label_counts = np.bincount(y, minlength=2) 95 | if self.is_leaf(): 96 | return 97 | left, right = split(X, self.feature) 98 | self.left.fit(X[left], y[left]) 99 | self.right.fit(X[right], y[right]) 100 | 101 | def predict(self, X): 102 | if self.is_leaf(): 103 | assert (self.label_counts is not None) 104 | return np.argmax(self.label_counts) 105 | left, right = split(X, self.feature) 106 | y = np.zeros(X.shape[0], dtype=bool) 107 | y[left] = self.left.predict(X[left]) 108 | y[right] = self.right.predict(X[right]) 109 | return y 110 | 111 | def get_all_leaves(self): 112 | if self.is_leaf(): 113 | return [self] 114 | return self.left.get_all_leaves() + self.right.get_all_leaves() 115 | 116 | def log_prior(self, X, alpha, beta, depth=0): 117 | if X.shape[0] == 0: 118 | return -np.inf 119 | num_valid_splits = get_num_valid_splits(X) 120 | if self.is_leaf(): 121 | return log_prob_stop(depth, alpha, 122 | beta) if num_valid_splits else 0.0 123 | left, right = split(X, self.feature) 124 | return log_prob_split(depth, alpha, beta) - np.log(num_valid_splits) + \ 125 | self.left.log_prior(X[left], alpha, beta, depth + 1) + \ 126 | self.right.log_prior(X[right], alpha, beta, depth + 1) 127 | 128 | def log_likelihood(self, X, y, rho): 129 | if self.is_leaf(): 130 | return log_likelihood(y, rho) 131 | left, right = split(X, self.feature) 132 | return self.left.log_likelihood(X[left], y[left], rho) + \ 133 | self.right.log_likelihood(X[right], y[right], rho) 134 | 135 | def log_posterior(self, X, y, alpha, beta, rho): 136 | return self.log_prior(X, alpha, beta) + self.log_likelihood(X, y, rho) 137 | -------------------------------------------------------------------------------- /experiments/searchers/cart.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, Any 3 | from sklearn.tree import DecisionTreeClassifier 4 | 5 | from experiments.searchers.binary_classification_tree import BinaryClassificationTree 6 | 7 | 8 | def run( 9 | X_train, 10 | y_train, 11 | max_depth: int = None, 12 | max_leaf_nodes: int = None, 13 | ) -> Dict[str, Any]: 14 | assert(((X_train == 0) | (X_train == 1)).all()) 15 | assert(((y_train == 0) | (y_train == 1)).all()) 16 | 17 | start = time.perf_counter() 18 | clf = DecisionTreeClassifier(max_depth=max_depth, max_leaf_nodes=max_leaf_nodes) 19 | clf.fit(X_train, y_train) 20 | end = time.perf_counter() 21 | 22 | tree = parse(clf.tree_) 23 | tree.fit(X_train, y_train) 24 | 25 | return { 26 | 'tree': tree, 27 | 'time': end - start, 28 | 'timeout': False, 29 | } 30 | 31 | 32 | def parse(tree, idx: int=0) -> BinaryClassificationTree: 33 | if tree.children_left[idx] == -1: 34 | return BinaryClassificationTree() 35 | return BinaryClassificationTree( 36 | parse(tree, tree.children_left[idx]), 37 | parse(tree, tree.children_right[idx]), 38 | tree.feature[idx]) 39 | -------------------------------------------------------------------------------- /experiments/searchers/dl85.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, Any 3 | from pydl85 import DL85Classifier 4 | 5 | from experiments.searchers.binary_classification_tree import BinaryClassificationTree 6 | 7 | def run( 8 | X_train, 9 | y_train, 10 | max_depth: int = 3, 11 | time_limit: int = 0, 12 | ) -> Dict[str, Any]: 13 | assert(((X_train == 0) | (X_train == 1)).all()) 14 | assert(((y_train == 0) | (y_train == 1)).all()) 15 | 16 | start = time.perf_counter() 17 | clf = DL85Classifier(max_depth=max_depth, time_limit=time_limit) 18 | clf.fit(X_train, y_train) 19 | end = time.perf_counter() 20 | 21 | tree = parse(clf) 22 | tree.fit(X_train, y_train) 23 | 24 | return { 25 | 'tree': tree, 26 | 'time': end - start, 27 | 'timeout': clf.timeout_, 28 | } 29 | 30 | 31 | def parse(clf: DL85Classifier) -> BinaryClassificationTree: 32 | def parse_node(node: dict) -> BinaryClassificationTree: 33 | if "value" in node: 34 | return BinaryClassificationTree() 35 | return BinaryClassificationTree( 36 | parse_node(node["right"]), 37 | parse_node(node["left"]), 38 | int(node["feat"])) 39 | return parse_node(clf.base_tree_) 40 | -------------------------------------------------------------------------------- /experiments/searchers/gosdt.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, Any 3 | import pandas as pd 4 | from gosdt import GOSDT 5 | 6 | from experiments.searchers.binary_classification_tree import BinaryClassificationTree 7 | 8 | 9 | def run( 10 | X_train, 11 | y_train, 12 | max_depth: int = 0, 13 | regularization: float = 0.01, 14 | time_limit: int = 0, 15 | ) -> Dict[str, Any]: 16 | assert(((X_train == 0) | (X_train == 1)).all()) 17 | assert(((y_train == 0) | (y_train == 1)).all()) 18 | 19 | df_X_train = pd.DataFrame(X_train) 20 | df_y_train = pd.DataFrame(y_train) 21 | config = { 22 | 'depth_budget': max_depth, 23 | 'regularization': regularization, 24 | 'time_limit': time_limit, 25 | 'allow_small_reg': True, 26 | } 27 | 28 | start = time.perf_counter() 29 | clf = GOSDT(config) 30 | clf.fit(df_X_train, df_y_train) 31 | end = time.perf_counter() 32 | 33 | tree = parse(clf) 34 | tree.fit(X_train, y_train) 35 | 36 | return { 37 | 'tree': tree, 38 | 'time': end - start, 39 | 'timeout': clf.timeout 40 | } 41 | 42 | 43 | def parse(clf: GOSDT) -> BinaryClassificationTree: 44 | def parse_node(node: dict) -> BinaryClassificationTree: 45 | if "prediction" in node: 46 | return BinaryClassificationTree() 47 | assert(node["relation"] == "==") 48 | assert(node["reference"] == 1.0) 49 | return BinaryClassificationTree( 50 | parse_node(node["false"]), 51 | parse_node(node["true"]), 52 | int(node["feature"])) 53 | 54 | root = clf.tree.source 55 | return parse_node(root) -------------------------------------------------------------------------------- /experiments/searchers/maptree.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Tuple, Dict, Any 3 | from maptree import search as maptree_search 4 | from experiments.searchers.binary_classification_tree import BinaryClassificationTree 5 | 6 | 7 | def run( 8 | X_train, 9 | y_train, 10 | alpha: float = 0.95, 11 | beta: float = 0.5, 12 | rho: Tuple[float, float] = (2.5, 2.5), 13 | num_expansions: int = -1, 14 | time_limit: int = -1, 15 | ) -> Dict[str, Any]: 16 | assert(((X_train == 0) | (X_train == 1)).all()) 17 | assert(((y_train == 0) | (y_train == 1)).all()) 18 | 19 | start = time.perf_counter() 20 | sol = maptree_search(X_train, y_train, alpha, beta, rho, num_expansions, time_limit) 21 | end = time.perf_counter() 22 | 23 | tree = parse(sol.tree) 24 | tree.fit(X_train, y_train) 25 | 26 | return { 27 | 'tree': tree, 28 | 'time': end - start, 29 | 'timeout': sol.lb < sol.ub, 30 | 'lower_bound': sol.lb, 31 | 'upper_bound': sol.ub 32 | } 33 | 34 | 35 | def parse(tree: str) -> BinaryClassificationTree: 36 | return BinaryClassificationTree.parse(tree) 37 | -------------------------------------------------------------------------------- /experiments/searchers/mcmc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | from typing import Tuple, Dict, Any 4 | import numpy as np 5 | import random 6 | from experiments.searchers.binary_classification_tree import BinaryClassificationTree 7 | from .tree_smc.src.bdtmcmc import sample_tree, precompute, parser_add_common_options, parser_add_mcmc_options, parser_add_smc_options 8 | 9 | 10 | def run( 11 | X_train, 12 | y_train, 13 | alpha: float = 0.95, 14 | beta: float = 0.5, 15 | rho: Tuple[float, float] = [2.5, 2.5], 16 | num_iterations: int = 10, 17 | seed: int = 42, 18 | ) -> Dict[str, Any]: 19 | assert(((X_train == 0) | (X_train == 1)).all()) 20 | assert(((y_train == 0) | (y_train == 1)).all()) 21 | assert(rho[0] == rho[1]) 22 | 23 | parser = parser_add_common_options() 24 | parser = parser_add_smc_options(parser) 25 | parser = parser_add_mcmc_options(parser) 26 | settings = parser.parse_args([ 27 | '--alpha_split', str(alpha), 28 | '--beta_split', str(beta), 29 | '--alpha', str(rho[0] + rho[1]), 30 | '--verbose', '0', 31 | ])[0] 32 | 33 | data = { 34 | 'x_train': X_train, 35 | 'y_train': y_train, 36 | 'n_train': X_train.shape[0], 37 | 'n_dim': X_train.shape[1], 38 | 'n_class': 2, 39 | } 40 | 41 | np.random.seed(seed) 42 | random.seed(seed) 43 | 44 | start = time.perf_counter() 45 | param, cache, cache_tmp = precompute(data, settings) 46 | best_node_info_dump = None 47 | best_tree_leaves = None 48 | best_post = -np.inf 49 | p = sample_tree(data, settings, param, cache, cache_tmp) 50 | for _ in range(num_iterations): 51 | p.sample(data, settings, param, cache) 52 | post = p.compute_logprob() 53 | if post > best_post: 54 | best_node_info_dump = json.dumps(p.node_info) 55 | best_tree_leaves = list(p.leaf_nodes) 56 | best_post = post 57 | end = time.perf_counter() 58 | 59 | best_node_info = {int(k):v for k, v in json.loads(best_node_info_dump).items()} 60 | tree = parse((best_node_info, best_tree_leaves)) 61 | tree.fit(X_train, y_train) 62 | 63 | return { 64 | 'tree': tree, 65 | 'time': end - start, 66 | } 67 | 68 | 69 | def parse(tree: Tuple[dict, list], node_idx: int=0) -> BinaryClassificationTree: 70 | node_info, leaves = tree 71 | if node_idx in leaves or node_idx not in node_info: 72 | return BinaryClassificationTree() 73 | return BinaryClassificationTree( 74 | parse(tree, 2 * node_idx + 1), 75 | parse(tree, 2 * node_idx + 2), 76 | int(node_info[node_idx][0])) -------------------------------------------------------------------------------- /experiments/searchers/smc.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Tuple, Dict, Any 3 | import numpy as np 4 | import random 5 | from .mcmc import parse 6 | from .tree_smc.src.bdtsmc import init_smc, run_smc, parser_add_common_options, parser_add_smc_options 7 | 8 | 9 | def run( 10 | X_train, 11 | y_train, 12 | alpha: float = 0.95, 13 | beta: float = 0.5, 14 | rho: Tuple[float, float] = [2.5, 2.5], 15 | num_particles: int = 10, 16 | seed: int = 42, 17 | ) -> Dict[str, Any]: 18 | assert(((X_train == 0) | (X_train == 1)).all()) 19 | assert(((y_train == 0) | (y_train == 1)).all()) 20 | assert(rho[0] == rho[1]) 21 | 22 | parser = parser_add_common_options() 23 | parser = parser_add_smc_options(parser) 24 | settings = parser.parse_args([ 25 | '--alpha_split', str(alpha), 26 | '--beta_split', str(beta), 27 | '--alpha', str(rho[0] + rho[1]), 28 | '--n_particles', str(num_particles), 29 | '--n_islands', '1', 30 | '--verbose', '0', 31 | ])[0] 32 | 33 | data = { 34 | 'x_train': X_train, 35 | 'y_train': y_train, 36 | 'n_train': X_train.shape[0], 37 | 'n_dim': X_train.shape[1], 38 | 'n_class': 2, 39 | } 40 | 41 | np.random.seed(seed) 42 | random.seed(seed) 43 | 44 | start = time.perf_counter() 45 | (particles, param, log_weights, cache, cache_tmp) = init_smc(data, settings) 46 | (particles, ess_itr, log_weights_itr, log_pd, particle_stats_itr_d, particles_itr_d, log_pd_islands) = \ 47 | run_smc(particles, data, settings, param, log_weights, cache) 48 | best_particle = None 49 | best_post = -np.inf 50 | for p in particles: 51 | post = p.compute_logprob() 52 | if post > best_post: 53 | best_particle = p 54 | best_post = post 55 | end = time.perf_counter() 56 | 57 | tree = parse((best_particle.node_info, best_particle.leaf_nodes)) 58 | tree.fit(X_train, y_train) 59 | 60 | return { 61 | 'tree': tree, 62 | 'time': end - start, 63 | } 64 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/COPYING: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------- 2 | The standard MIT License for code in this archive written by Balaji Lakshminarayanan 3 | http://www.opensource.org/licenses/mit-license.php 4 | ------------------------------------------------------------------------------- 5 | Copyright (c) 2013 Balaji Lakshminarayanan 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to 9 | deal in the Software without restriction, including without limitation the 10 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 22 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 23 | IN THE SOFTWARE. 24 | ------------------------------------------------------------------------------- 25 | 26 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/README: -------------------------------------------------------------------------------- 1 | This folder contains the scripts used in the following paper: 2 | "Top-down particle filtering for Bayesian decision trees", 3 | Balaji Lakshminarayanan, Daniel M. Roy, Yee Whye Teh 4 | http://arxiv.org/abs/1303.0561 5 | 6 | I ran my experiments using Enthought python (which includes all the necessary python packages). 7 | If you are running a different version of python, you will need the following python packages 8 | (and possibly other packages) to run the scripts: 9 | - numpy 10 | - scipy 11 | - matplotlib 12 | - sklearn (for CART experiments) 13 | 14 | The datasets are not included here; you need to download them from the UCI repository. You can run 15 | experiments using toy data though. See instructions in README in process_data/madelon, 16 | process_data/magic04 and process_data/pendigits folders for automatically downloading and processing the datasets. 17 | 18 | If you have any questions/comments/suggestions, please contact me at 19 | balaji@gatsby.ucl.ac.uk. 20 | 21 | Code released under MIT license (see COPYING for more info). 22 | 23 | ---------------------------------------------------------------------------- 24 | 25 | List of scripts in the src folder: 26 | - bdtsmc.py 27 | - bdtmcmc.py 28 | - tree_utils.py 29 | - utils.py 30 | 31 | Help on usage can be obtained by typing the following commands on the terminal: 32 | ./bdtsmc.py -h 33 | ./bdtmcmc.py -h 34 | 35 | Example usage: 36 | ./bdtsmc.py --dataset toy --alpha 5.0 --alpha_split 0.95 --beta_split 0.5 --save 1 --n_particles 100 --proposal prior --grow next 37 | ./bdtmcmc.py --dataset toy --alpha 5.0 --alpha_split 0.95 --beta_split 0.5 --save 1 --n_iter 1000 -v 0 38 | 39 | Note that the results (predictions, accuracy, log predictive probability on training/test data, runtimes) are stored in the pickle files. 40 | You need to write additional scripts to aggregate the results from these pickle files and generate the plots in the PDF. 41 | 42 | I generated commands for parameter sweeps using 'build_cmds' script by Jan Gasthaus 43 | (available publicly at https://github.com/jgasthaus/Gitsby/tree/master/pbs/python). 44 | Some examples of parameter sweeps are: 45 | 46 | SMC design choice experiments: 47 | ./build_cmds ./bdtsmc.py "--op_dir={results}" "--init_id=1:1:11" "--resample={multinomial}" "--grow={next,layer}" "--dataset={pendigits,magic04}" "--proposal={posterior,prior,empirical}" "--n_particles={5,10,25,50,100,250,500,750,1000,1500,2000}" "--max_iterations=5000" "--ess_threshold={0.1}" "--save={1}" "--alpha={5.0}" "--alpha_split={0.95}" "--beta_split={0.5}" 48 | 49 | Effect of irrelevant features: madelon dataset 50 | ./build_cmds ./bdtsmc.py "--op_dir={results}" "--init_id=1:1:11" "--resample={multinomial}" "--grow={next}" "--dataset={madelon}" "--proposal={posterior,prior}" "--n_particles={10,50,100,250}" "--max_iterations=5000" "-- ess_threshold={0.1}" "--save={1}" "--alpha={5.0}" "--alpha_split={0.95}" "--beta_split={0.5}" 51 | 52 | Island-model SMC: 53 | /build_cmds ./bdtsmc.py "--op_dir={results}" "--init_id=1:1:11" "--resample={multinomial}" "--grow={next}" "--dataset={pendigits,magic04}" "--proposal={posterior,prior}" "--n_particles={100,250,500,750,1000,1500, 2000}" "--max_iterations=5000" "--ess_threshold={0.1}" "--save={1}" "--alpha={5.0}" "--alpha_split={0.95}" "--beta_split={0.5}" "--n_islands={5}" 54 | 55 | MCMC experiments: 56 | ./build_cmds ./bdtmcmc.py "--mcmc_type={chipman}" "--sample_y={0}" "--dataset={pendigits,magic04}" "--save={1}" "--n_iterations={100000}" "--init_id=1:1:11" "--alpha_split={0.95}" "--beta_split={0.5}" "--alpha={5.0}" 57 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/searchers/tree_smc/__init__.py -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/searchers/tree_smc/process_data/__init__.py -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/madelon/README: -------------------------------------------------------------------------------- 1 | Run the following commands to download and create the pickle files. 2 | 3 | ./download.sh 4 | ./process_data_feat_challenge.py 5 | 6 | Remember to use the right 'data_path' argument while loading files. 7 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/madelon/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/searchers/tree_smc/process_data/madelon/__init__.py -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/madelon/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/madelon/MADELON/madelon_train.data 3 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/madelon/MADELON/madelon_train.labels 4 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/madelon/MADELON/madelon_valid.data 5 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/madelon/madelon_valid.labels 6 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/madelon/process_data_feat_challenge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # script to create a pickle file with the dataset in a dictionary 3 | # the script uses *_train for training and *_valid as test 4 | # Example usage: ./process_data_feat_challenge madelon 5 | 6 | import numpy as np 7 | import pickle as pickle 8 | import sys 9 | 10 | 11 | name = sys.argv[1] 12 | DATA_PATH = name + '/' # tweak this if process_data_feat_challenge is in the same directory 13 | # DATA_PATH = '' 14 | 15 | data = {} 16 | 17 | def load_files(name, tag): 18 | x = np.loadtxt(DATA_PATH + name + '_' + tag + '.data', dtype='float') 19 | y = np.loadtxt(DATA_PATH + name + '_' + tag + '.labels', dtype='float') 20 | y = (y + 1) / 2.0 21 | y = y.astype('int') 22 | return (x, y) 23 | 24 | x_train, y_train = load_files(name, 'train') 25 | x_test, y_test = load_files(name, 'valid') # labels are not available for test data 26 | 27 | data['n_dim'] = x_train.shape[1] 28 | 29 | n_train = len(y_train) 30 | n_test = len(y_test) 31 | data['n_train'] = n_train 32 | data['n_test'] = n_test 33 | data['y_train'] = y_train 34 | data['y_test'] = y_test 35 | data['x_train'] = x_train 36 | data['x_test'] = x_test 37 | data['n_class'] = len(np.unique(y_train)) 38 | data['is_sparse'] = False 39 | 40 | print('dataset statistics:') 41 | print('n_train = %d, n_test = %d, n_dim = %d' % (n_train, n_test, data['n_dim'])) 42 | 43 | pickle.dump(data, open(DATA_PATH + name + ".p", "wb"), protocol=pickle.HIGHEST_PROTOCOL) 44 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/madelon/process_data_feat_challenge.py.bak: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # script to create a pickle file with the dataset in a dictionary 3 | # the script uses *_train for training and *_valid as test 4 | # Example usage: ./process_data_feat_challenge madelon 5 | 6 | import numpy as np 7 | import cPickle as pickle 8 | import sys 9 | 10 | 11 | name = sys.argv[1] 12 | DATA_PATH = name + '/' # tweak this if process_data_feat_challenge is in the same directory 13 | # DATA_PATH = '' 14 | 15 | data = {} 16 | 17 | def load_files(name, tag): 18 | x = np.loadtxt(DATA_PATH + name + '_' + tag + '.data', dtype='float') 19 | y = np.loadtxt(DATA_PATH + name + '_' + tag + '.labels', dtype='float') 20 | y = (y + 1) / 2.0 21 | y = y.astype('int') 22 | return (x, y) 23 | 24 | x_train, y_train = load_files(name, 'train') 25 | x_test, y_test = load_files(name, 'valid') # labels are not available for test data 26 | 27 | data['n_dim'] = x_train.shape[1] 28 | 29 | n_train = len(y_train) 30 | n_test = len(y_test) 31 | data['n_train'] = n_train 32 | data['n_test'] = n_test 33 | data['y_train'] = y_train 34 | data['y_test'] = y_test 35 | data['x_train'] = x_train 36 | data['x_test'] = x_test 37 | data['n_class'] = len(np.unique(y_train)) 38 | data['is_sparse'] = False 39 | 40 | print 'dataset statistics:' 41 | print 'n_train = %d, n_test = %d, n_dim = %d' % (n_train, n_test, data['n_dim']) 42 | 43 | pickle.dump(data, open(DATA_PATH + name + ".p", "wb"), protocol=pickle.HIGHEST_PROTOCOL) 44 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/magic04/README: -------------------------------------------------------------------------------- 1 | Run the following commands to download and create the pickle files. 2 | 3 | ./download.sh 4 | ./process_data.py 5 | 6 | Remember to use the right 'data_path' argument while loading files. 7 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/magic04/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/searchers/tree_smc/process_data/magic04/__init__.py -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/magic04/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/magic/magic04.names 3 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/magic/magic04.data 4 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/magic04/process_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # script to create a pickle file with the dataset in a dictionary 3 | # the script uses FRACTION_TRAINING for training and remaining as test data 4 | 5 | import numpy as np 6 | import pickle as pickle 7 | import random 8 | random.seed(123456789) # fix for reproducibility 9 | 10 | name = 'magic04' 11 | DATA_PATH = '/' 12 | FRACTION_TRAINING = 0.7 13 | 14 | data = {} 15 | filename = DATA_PATH + name + '.data' 16 | x = np.loadtxt(filename, dtype='float', delimiter=',', usecols = list(range(0, 10))) 17 | n = x.shape[0] 18 | print('ndim = %d' % x.shape[1]) 19 | y = np.zeros(n, dtype = 'int') 20 | for i, line in enumerate(open(filename, 'r')): 21 | a = line.rstrip('\n').split(',')[-1] 22 | if a == 'g': 23 | y[i] = 1 24 | elif a == 'h': 25 | y[i] = 0 26 | else: 27 | print('Unknown string %s' % a) 28 | print(line) 29 | raise Exception 30 | 31 | y = y.astype('int') 32 | data['n_dim'] = x.shape[1] 33 | idx = list(range(n)) 34 | random.shuffle(idx) 35 | 36 | n_train = int(FRACTION_TRAINING * n) 37 | idx_train = idx[:n_train] 38 | idx_test = idx[n_train:] 39 | data['n_train'] = n_train 40 | data['n_test'] = len(idx_test) 41 | data['y_train'] = y[idx_train] 42 | data['y_test'] = y[idx_test] 43 | data['x_train'] = x[idx_train, :] 44 | data['x_test'] = x[idx_test, :] 45 | data['n_class'] = len(np.unique(y)) 46 | data['is_sparse'] = False 47 | 48 | #pickle.dump(data, open(DATA_PATH + name + ".p", "wb")) 49 | pickle.dump(data, open(DATA_PATH + name + ".p", "wb"), protocol=pickle.HIGHEST_PROTOCOL) 50 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/magic04/process_data.py.bak: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # script to create a pickle file with the dataset in a dictionary 3 | # the script uses FRACTION_TRAINING for training and remaining as test data 4 | 5 | import numpy as np 6 | import cPickle as pickle 7 | import random 8 | random.seed(123456789) # fix for reproducibility 9 | 10 | name = 'magic04' 11 | DATA_PATH = '/' 12 | FRACTION_TRAINING = 0.7 13 | 14 | data = {} 15 | filename = DATA_PATH + name + '.data' 16 | x = np.loadtxt(filename, dtype='float', delimiter=',', usecols = range(0, 10)) 17 | n = x.shape[0] 18 | print 'ndim = %d' % x.shape[1] 19 | y = np.zeros(n, dtype = 'int') 20 | for i, line in enumerate(open(filename, 'r')): 21 | a = line.rstrip('\n').split(',')[-1] 22 | if a == 'g': 23 | y[i] = 1 24 | elif a == 'h': 25 | y[i] = 0 26 | else: 27 | print 'Unknown string %s' % a 28 | print line 29 | raise Exception 30 | 31 | y = y.astype('int') 32 | data['n_dim'] = x.shape[1] 33 | idx = range(n) 34 | random.shuffle(idx) 35 | 36 | n_train = int(FRACTION_TRAINING * n) 37 | idx_train = idx[:n_train] 38 | idx_test = idx[n_train:] 39 | data['n_train'] = n_train 40 | data['n_test'] = len(idx_test) 41 | data['y_train'] = y[idx_train] 42 | data['y_test'] = y[idx_test] 43 | data['x_train'] = x[idx_train, :] 44 | data['x_test'] = x[idx_test, :] 45 | data['n_class'] = len(np.unique(y)) 46 | data['is_sparse'] = False 47 | 48 | #pickle.dump(data, open(DATA_PATH + name + ".p", "wb")) 49 | pickle.dump(data, open(DATA_PATH + name + ".p", "wb"), protocol=pickle.HIGHEST_PROTOCOL) 50 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/pendigits/README: -------------------------------------------------------------------------------- 1 | Run the following commands to download and create the pickle files. 2 | 3 | ./download.sh 4 | ./process_data.py 5 | 6 | Remember to use the right 'data_path' argument while loading files. 7 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/pendigits/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/searchers/tree_smc/process_data/pendigits/__init__.py -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/pendigits/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.names 3 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tes 4 | wget http://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tra 5 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/pendigits/process_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # script to create a pickle file with the dataset in a dictionary 3 | 4 | import numpy as np 5 | import pickle as pickle 6 | 7 | name = 'pendigits' 8 | 9 | data = {} 10 | 11 | tmp = np.loadtxt(name + '.tra', dtype='float', delimiter=',') 12 | y = tmp[:, -1] 13 | x = tmp[:, :-1] 14 | y = y.astype('int') 15 | data['x_train'] = x 16 | data['n_train'] = x.shape[0] 17 | data['y_train'] = y 18 | tmp = np.loadtxt(name + '.tes', dtype='float', delimiter=',') 19 | y = tmp[:, -1] 20 | x = tmp[:, :-1] 21 | y = y.astype('int') 22 | data['x_test'] = x 23 | data['n_test'] = x.shape[0] 24 | data['y_test'] = y 25 | data['n_dim'] = tmp.shape[1] - 1 # last column contains labels 26 | data['n_class'] = len(np.unique(y)) 27 | data['is_sparse'] = False 28 | 29 | pickle.dump(data, open(name + ".p", "wb"), protocol=pickle.HIGHEST_PROTOCOL) 30 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/process_data/pendigits/process_data.py.bak: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # script to create a pickle file with the dataset in a dictionary 3 | 4 | import numpy as np 5 | import cPickle as pickle 6 | 7 | name = 'pendigits' 8 | 9 | data = {} 10 | 11 | tmp = np.loadtxt(name + '.tra', dtype='float', delimiter=',') 12 | y = tmp[:, -1] 13 | x = tmp[:, :-1] 14 | y = y.astype('int') 15 | data['x_train'] = x 16 | data['n_train'] = x.shape[0] 17 | data['y_train'] = y 18 | tmp = np.loadtxt(name + '.tes', dtype='float', delimiter=',') 19 | y = tmp[:, -1] 20 | x = tmp[:, :-1] 21 | y = y.astype('int') 22 | data['x_test'] = x 23 | data['n_test'] = x.shape[0] 24 | data['y_test'] = y 25 | data['n_dim'] = tmp.shape[1] - 1 # last column contains labels 26 | data['n_class'] = len(np.unique(y)) 27 | data['is_sparse'] = False 28 | 29 | pickle.dump(data, open(name + ".p", "wb"), protocol=pickle.HIGHEST_PROTOCOL) 30 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThrunGroup/maptree/0a5efe414a5d4fcd69e9c858eef9cd8c14bcca06/experiments/searchers/tree_smc/src/__init__.py -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def hist_count(x, basis): 4 | """ 5 | counts number of times each element in basis appears in x 6 | op is a vector of same size as basis 7 | assume no duplicates in basis 8 | """ 9 | op = np.zeros((len(basis)), dtype=int) 10 | map_basis = {} 11 | for n, k in enumerate(basis): 12 | map_basis[k] = n 13 | for t in x: 14 | op[map_basis[t]] += 1 15 | return op 16 | 17 | def logsumexp(x): 18 | tmp = x.copy() 19 | tmp_max = np.max(tmp) 20 | tmp -= tmp_max 21 | op = np.log(np.sum(np.exp(tmp))) + tmp_max 22 | return op 23 | 24 | 25 | def softmax(x): 26 | tmp = x.copy() 27 | tmp_max = np.max(tmp) 28 | tmp -= float(tmp_max) 29 | tmp = np.exp(tmp) 30 | op = tmp / np.sum(tmp) 31 | return op 32 | 33 | 34 | def assert_no_nan(mat, name='matrix'): 35 | try: 36 | assert(not any(np.isnan(mat))) 37 | except AssertionError: 38 | print('%s contains NaN' % name) 39 | print(mat) 40 | raise AssertionError 41 | 42 | def check_if_one(val): 43 | try: 44 | assert(np.abs(val - 1) < 1e-12) 45 | except AssertionError: 46 | print('val = %s (needs to be equal to 1)' % val) 47 | raise AssertionError 48 | 49 | def check_if_zero(val): 50 | try: 51 | assert(np.abs(val) < 1e-10) 52 | except AssertionError: 53 | print('val = %s (needs to be equal to 0)' % val) 54 | raise AssertionError 55 | 56 | 57 | def sample_multinomial(prob): 58 | try: 59 | k = int(np.where(np.random.multinomial(1, prob, size=1)[0]==1)[0]) 60 | except TypeError: 61 | print('problem in sample_multinomial: prob = ') 62 | print(prob) 63 | raise TypeError 64 | except: 65 | raise Exception 66 | return k 67 | 68 | 69 | def sample_multinomial_scores(scores): 70 | scores_cumsum = np.cumsum(scores) 71 | s = scores_cumsum[-1] * np.random.rand(1) 72 | k = 0 73 | while s > scores_cumsum[k]: 74 | k += 1 75 | return k 76 | 77 | 78 | def sample_polya(alpha_vec, n): 79 | """ alpha_vec is the parameter of the Dirichlet distribution, n is the #samples """ 80 | prob = np.random.dirichlet(alpha_vec) 81 | n_vec = np.random.multinomial(n, prob) 82 | return n_vec 83 | 84 | 85 | def get_kth_minimum(x, k=1): 86 | """ gets the k^th minimum element of the list x 87 | (note: k=1 is the minimum, k=2 is 2nd minimum) ... 88 | based on the incomplete selection sort pseudocode """ 89 | n = len(x) 90 | for i in range(n): 91 | minIndex = i 92 | minValue = x[i] 93 | for j in range(i+1, n): 94 | if x[j] < minValue: 95 | minIndex = j 96 | minValue = x[j] 97 | x[i], x[minIndex] = x[minIndex], x[i] 98 | return x[k-1] 99 | 100 | 101 | class empty(object): 102 | def __init__(self): 103 | pass 104 | -------------------------------------------------------------------------------- /experiments/searchers/tree_smc/src/utils.py.bak: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def hist_count(x, basis): 4 | """ 5 | counts number of times each element in basis appears in x 6 | op is a vector of same size as basis 7 | assume no duplicates in basis 8 | """ 9 | op = np.zeros((len(basis)), dtype=int) 10 | map_basis = {} 11 | for n, k in enumerate(basis): 12 | map_basis[k] = n 13 | for t in x: 14 | op[map_basis[t]] += 1 15 | return op 16 | 17 | def logsumexp(x): 18 | tmp = x.copy() 19 | tmp_max = np.max(tmp) 20 | tmp -= tmp_max 21 | op = np.log(np.sum(np.exp(tmp))) + tmp_max 22 | return op 23 | 24 | 25 | def softmax(x): 26 | tmp = x.copy() 27 | tmp_max = np.max(tmp) 28 | tmp -= float(tmp_max) 29 | tmp = np.exp(tmp) 30 | op = tmp / np.sum(tmp) 31 | return op 32 | 33 | 34 | def assert_no_nan(mat, name='matrix'): 35 | try: 36 | assert(not any(np.isnan(mat))) 37 | except AssertionError: 38 | print '%s contains NaN' % name 39 | print mat 40 | raise AssertionError 41 | 42 | def check_if_one(val): 43 | try: 44 | assert(np.abs(val - 1) < 1e-12) 45 | except AssertionError: 46 | print 'val = %s (needs to be equal to 1)' % val 47 | raise AssertionError 48 | 49 | def check_if_zero(val): 50 | try: 51 | assert(np.abs(val) < 1e-10) 52 | except AssertionError: 53 | print 'val = %s (needs to be equal to 0)' % val 54 | raise AssertionError 55 | 56 | 57 | def sample_multinomial(prob): 58 | try: 59 | k = int(np.where(np.random.multinomial(1, prob, size=1)[0]==1)[0]) 60 | except TypeError: 61 | print 'problem in sample_multinomial: prob = ' 62 | print prob 63 | raise TypeError 64 | except: 65 | raise Exception 66 | return k 67 | 68 | 69 | def sample_multinomial_scores(scores): 70 | scores_cumsum = np.cumsum(scores) 71 | s = scores_cumsum[-1] * np.random.rand(1) 72 | k = 0 73 | while s > scores_cumsum[k]: 74 | k += 1 75 | return k 76 | 77 | 78 | def sample_polya(alpha_vec, n): 79 | """ alpha_vec is the parameter of the Dirichlet distribution, n is the #samples """ 80 | prob = np.random.dirichlet(alpha_vec) 81 | n_vec = np.random.multinomial(n, prob) 82 | return n_vec 83 | 84 | 85 | def get_kth_minimum(x, k=1): 86 | """ gets the k^th minimum element of the list x 87 | (note: k=1 is the minimum, k=2 is 2nd minimum) ... 88 | based on the incomplete selection sort pseudocode """ 89 | n = len(x) 90 | for i in range(n): 91 | minIndex = i 92 | minValue = x[i] 93 | for j in range(i+1, n): 94 | if x[j] < minValue: 95 | minIndex = j 96 | minValue = x[j] 97 | x[i], x[minIndex] = x[minIndex], x[i] 98 | return x[k-1] 99 | 100 | 101 | class empty(object): 102 | def __init__(self): 103 | pass 104 | -------------------------------------------------------------------------------- /maptree/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # This file specifies how the project should be built, using CMake. 2 | # If you are unfamiliar with CMake, don't worry about all the details. 3 | # The sections you might want to edit are marked as such, and 4 | # the comments should hopefully make most of it clear. 5 | # 6 | # For many purposes, you may not need to change anything about this file. 7 | 8 | cmake_minimum_required(VERSION 3.14) 9 | 10 | # Set project name, version and laguages here. (change as needed) 11 | # Version numbers are available by including "exampleConfig.h" in 12 | # the source. See exampleConfig.h.in for some more details. 13 | project(BDT_MAP VERSION 0.0 LANGUAGES CXX) 14 | 15 | 16 | # Options: Things you can set via commandline options to cmake (e.g. -DENABLE_LTO=[ON|OFF]) 17 | option(ENABLE_WARNINGS_SETTINGS "Allow target_set_warnings to add flags and defines. 18 | Set this to OFF if you want to provide your own warning parameters." ON) 19 | option(ENABLE_LTO "Enable link time optimization" ON) 20 | option(ENABLE_DOCTESTS "Include tests in the library. Setting this to OFF will remove all doctest related code. 21 | Tests in tests/*.cpp will still be enabled." ON) 22 | 23 | # Include stuff. No change needed. 24 | set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/") 25 | include(ConfigSafeGuards) 26 | include(Colors) 27 | include(CTest) 28 | include(Doctest) 29 | include(Documentation) 30 | include(LTO) 31 | include(Warnings) 32 | 33 | # Check for LTO support. 34 | find_lto(CXX) 35 | 36 | 37 | # -------------------------------------------------------------------------------- 38 | # Locate files (change as needed). 39 | # -------------------------------------------------------------------------------- 40 | set(SOURCES # All .cpp files in src/ 41 | src/cache/approx_bitset_cache.cpp 42 | src/data/binary_data_loader.cpp 43 | src/data/bitset.cpp 44 | src/data/data_manager.cpp 45 | src/data/fixed_bitset.cpp 46 | src/data/rnumber.cpp 47 | src/posterior/tree_likelihood.cpp 48 | src/posterior/tree_prior.cpp 49 | src/search/befs_map_search.cpp 50 | src/search/base_map_search.cpp 51 | src/solution/decision_tree.cpp 52 | src/subproblem.cpp 53 | ) 54 | 55 | set(LIBRARY_NAME engine) # Default name for the library built from src/*.cpp (change if you wish) 56 | 57 | # -------------------------------------------------------------------------------- 58 | # Build! (Change as needed) 59 | # -------------------------------------------------------------------------------- 60 | # Compile all sources into a library. 61 | add_library(${LIBRARY_NAME} OBJECT ${SOURCES}) 62 | 63 | # Lib needs its header files, and users of the library must also see these (PUBLIC). (No change needed) 64 | target_include_directories(${LIBRARY_NAME} PUBLIC ${PROJECT_SOURCE_DIR}/include) 65 | 66 | # There's also (probably) doctests within the library, so we need to see this as well. 67 | target_link_libraries(${LIBRARY_NAME} PUBLIC doctest) 68 | 69 | # Set the compile options you want (change as needed). 70 | target_set_warnings(${LIBRARY_NAME} ENABLE ALL AS_ERROR ALL DISABLE Annoying) 71 | # target_compile_options(${LIBRARY_NAME} ... ) # For setting manually. 72 | 73 | # Add an executable for the file app/main.cpp. 74 | # If you add more executables, copy these lines accordingly. 75 | add_executable(main app/main.cpp) # Name of exec. and location of file. 76 | target_link_libraries(main PRIVATE ${LIBRARY_NAME}) # Link the executable to library (if it uses it). 77 | target_set_warnings(main ENABLE ALL AS_ERROR ALL DISABLE Annoying) # Set warnings (if needed). 78 | target_enable_lto(main optimized) # enable link-time-optimization if available for non-debug configurations 79 | 80 | # Set the properties you require, e.g. what C++ standard to use. Here applied to library and main (change as needed). 81 | set_target_properties( 82 | ${LIBRARY_NAME} main 83 | PROPERTIES 84 | CXX_STANDARD 17 85 | CXX_STANDARD_REQUIRED YES 86 | CXX_EXTENSIONS NO 87 | ) 88 | 89 | file(CREATE_LINK "${PROJECT_SOURCE_DIR}/tests/test_data" 90 | "${CMAKE_CURRENT_BINARY_DIR}/data" SYMBOLIC) 91 | 92 | set(CMAKE_C_FLAGS_RELEASE "-O3 -DNDEBUG") 93 | set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") 94 | set(CMAKE_C_FLAGS_DEBUG "-g -O0 -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer") 95 | set(CMAKE_CXX_FLAGS_DEBUG "-g -O0 -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer") 96 | 97 | add_subdirectory(tests) -------------------------------------------------------------------------------- /maptree/app/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "data/binary_data_loader.h" 9 | #include "search/befs_map_search.h" 10 | 11 | int main(int argc, char** argv) { 12 | std::string file; 13 | double alpha = 0.8; 14 | double beta = 1.0; 15 | std::array rho = {1.0, 1.0}; 16 | 17 | struct option longopts[] = { 18 | { "file", required_argument, NULL, 'f' }, 19 | { "alpha", optional_argument, NULL, 'a' }, 20 | { "beta", optional_argument, NULL, 'b' }, 21 | { "rho", optional_argument, NULL, 'r' }, 22 | { NULL, 0, NULL, 0 } 23 | }; 24 | 25 | while (true) { 26 | int opt = getopt_long(argc, argv, "f:a:b:r:", longopts, 0); 27 | if (opt == -1) break; 28 | switch (opt) { 29 | case 'f': { 30 | file = std::string(optarg); 31 | break; 32 | } 33 | case 'a': { 34 | alpha = std::atof(optarg); 35 | break; 36 | } 37 | case 'b': { 38 | beta = std::atof(optarg); 39 | break; 40 | } 41 | case 'r': { 42 | double total = std::atof(optarg); 43 | rho[0] = total / 2.0; 44 | rho[1] = total / 2.0; 45 | break; 46 | } 47 | case '?': { 48 | std::cout << "Usage: " << argv[0] << " -f [-a ] [-b ] [-r ]" << std::endl; 49 | return EXIT_FAILURE; 50 | break; 51 | } 52 | } 53 | } 54 | 55 | BinaryDataLoader bdl(file); 56 | std::vector> features = bdl.getFeatures(); 57 | std::vector labels = bdl.getLabels(); 58 | 59 | auto start = std::chrono::high_resolution_clock::now(); 60 | BestFirstSearchMAPSearch search(features, labels, alpha, beta, rho); 61 | Solution result = search.search(); 62 | auto stop = std::chrono::high_resolution_clock::now(); 63 | 64 | auto duration = std::chrono::duration_cast(stop - start); 65 | 66 | std::cout << "Tree: " << result.treeRepresentation << std::endl; 67 | std::cout << "Lower Bound: " << result.lowerBound << std::endl; 68 | std::cout << "Upper Bound: " << result.upperBound << std::endl; 69 | std::cout << "Training Time (ms): " << duration.count() << std::endl; 70 | 71 | return EXIT_SUCCESS; 72 | } -------------------------------------------------------------------------------- /maptree/cmake/CodeCoverage.cmake: -------------------------------------------------------------------------------- 1 | # 2012-01-31, Lars Bilke 2 | # - Enable Code Coverage 3 | # 4 | # 2013-09-17, Joakim Söderberg 5 | # - Added support for Clang. 6 | # - Some additional usage instructions. 7 | # 8 | # 2018-03-31, Bendik Samseth 9 | # - Relax debug output. 10 | # - Keep a copy of the coverage output for later use. 11 | # - Updated coverage exclude patterns. 12 | # 13 | # 2018-01-03, HenryRLee 14 | # - Allow for *Clang compiler names, not just Clang. 15 | # 16 | # 2018-01-03, Bendik Samseth 17 | # - Only check compiler compatibility if in a coverage build. 18 | 19 | # USAGE: 20 | # 21 | # 0. (Mac only) If you use Xcode 5.1 make sure to patch geninfo as described here: 22 | # http://stackoverflow.com/a/22404544/80480 23 | # 24 | # 1. Copy this file into your cmake modules path. 25 | # 26 | # 2. Add the following line to your CMakeLists.txt: 27 | # include(CodeCoverage) 28 | # 29 | # 3. Set compiler flags to turn off optimization and enable coverage: 30 | # set(CMAKE_CXX_FLAGS "-g -O0 -fprofile-arcs -ftest-coverage") 31 | # set(CMAKE_C_FLAGS "-g -O0 -fprofile-arcs -ftest-coverage") 32 | # 33 | # 3. Use the function setup_target_for_coverage to create a custom make target 34 | # which runs your test executable and produces a lcov code coverage report: 35 | # Example: 36 | # setup_target_for_coverage( 37 | # my_coverage_target # Name for custom target. 38 | # test_driver # Name of the test driver executable that runs the tests. 39 | # # NOTE! This should always have a ZERO as exit code 40 | # # otherwise the coverage generation will not complete. 41 | # coverage # Name of output directory. 42 | # ) 43 | # 44 | # 4. Build a Debug build: 45 | # cmake -DCMAKE_BUILD_TYPE=Debug .. 46 | # make 47 | # make my_coverage_target 48 | 49 | # Param _targetname The name of new the custom make target 50 | # Param _testrunner The name of the target which runs the tests. 51 | # MUST return ZERO always, even on errors. 52 | # If not, no coverage report will be created! 53 | # Param _outputname lcov output is generated as _outputname.info 54 | # HTML report is generated in _outputname/index.html 55 | # Optional fourth parameter is passed as arguments to _testrunner 56 | # Pass them in list form, e.g.: "-j;2" for -j 2 57 | function(setup_target_for_coverage _targetname _testrunner _outputname) 58 | 59 | if(NOT LCOV_PATH) 60 | message(FATAL_ERROR "lcov not found! Aborting...") 61 | endif() # NOT LCOV_PATH 62 | 63 | if(NOT GENHTML_PATH) 64 | message(FATAL_ERROR "genhtml not found! Aborting...") 65 | endif() # NOT GENHTML_PATH 66 | 67 | # Setup target 68 | add_custom_target(${_targetname} 69 | 70 | # Cleanup lcov 71 | ${LCOV_PATH} --directory . --zerocounters 72 | 73 | # Run tests 74 | COMMAND ${_testrunner} ${ARGV3} 75 | 76 | # Capturing lcov counters and generating report 77 | COMMAND ${LCOV_PATH} --directory . --capture --output-file ${_outputname}.info 78 | 79 | COMMAND ${LCOV_PATH} --remove ${_outputname}.info '*/tests/*' '/usr/*' '*/external/*' '*/3rd/*' '/Applications/*' --output-file ${_outputname}.info.cleaned 80 | COMMAND ${GENHTML_PATH} -o ${_outputname} ${_outputname}.info.cleaned 81 | COMMAND ${LCOV_PATH} --list ${_outputname}.info.cleaned 82 | 83 | WORKING_DIRECTORY ${CMAKE_BINARY_DIR} 84 | COMMENT "Resetting code coverage counters to zero.\nProcessing code coverage counters and generating report." 85 | ) 86 | 87 | # Show info where to find the report 88 | add_custom_command(TARGET ${_targetname} POST_BUILD 89 | COMMAND ; 90 | COMMENT "${BoldMagenta}Open ./${_outputname}/index.html in your browser to view the coverage report.${ColourReset}" 91 | ) 92 | 93 | endfunction() # setup_target_for_coverage 94 | 95 | 96 | string(TOLOWER "${CMAKE_BUILD_TYPE}" cmake_build_type_tolower) 97 | if (cmake_build_type_tolower STREQUAL "coverage") 98 | 99 | 100 | # Check prereqs 101 | find_program(GCOV_PATH gcov) 102 | find_program(LCOV_PATH lcov) 103 | find_program(GENHTML_PATH genhtml) 104 | find_program(GCOVR_PATH gcovr PATHS ${CMAKE_SOURCE_DIR}/tests) 105 | 106 | if(NOT GCOV_PATH) 107 | message(FATAL_ERROR "gcov not found! Aborting...") 108 | endif() # NOT GCOV_PATH 109 | 110 | if(NOT CMAKE_COMPILER_IS_GNUCXX) 111 | if(NOT "${CMAKE_CXX_COMPILER_ID}" MATCHES ".*Clang") 112 | message(FATAL_ERROR "Compiler is not GNU gcc or Clang! Aborting...") 113 | endif() 114 | endif() # NOT CMAKE_COMPILER_IS_GNUCXX 115 | 116 | SET(CMAKE_CXX_FLAGS_COVERAGE 117 | "-g -O0 -fprofile-arcs -ftest-coverage" 118 | CACHE STRING "Flags used by the C++ compiler during coverage builds." 119 | FORCE ) 120 | SET(CMAKE_C_FLAGS_COVERAGE 121 | "-g -O0 -fprofile-arcs -ftest-coverage" 122 | CACHE STRING "Flags used by the C compiler during coverage builds." 123 | FORCE ) 124 | SET(CMAKE_EXE_LINKER_FLAGS_COVERAGE 125 | "" 126 | CACHE STRING "Flags used for linking binaries during coverage builds." 127 | FORCE ) 128 | SET(CMAKE_SHARED_LINKER_FLAGS_COVERAGE 129 | "" 130 | CACHE STRING "Flags used by the shared libraries linker during coverage builds." 131 | FORCE ) 132 | mark_as_advanced( 133 | CMAKE_CXX_FLAGS_COVERAGE 134 | CMAKE_C_FLAGS_COVERAGE 135 | CMAKE_EXE_LINKER_FLAGS_COVERAGE 136 | CMAKE_SHARED_LINKER_FLAGS_COVERAGE ) 137 | 138 | 139 | # If unwanted files are included in the coverage reports, you can 140 | # adjust the exclude patterns on line 83. 141 | setup_target_for_coverage( 142 | coverage # Name for custom target. 143 | ${TEST_MAIN} # Name of the test driver executable that runs the tests. 144 | # NOTE! This should always have a ZERO as exit code 145 | # otherwise the coverage generation will not complete. 146 | coverage_out # Name of output directory. 147 | ) 148 | else() 149 | add_custom_target(coverage 150 | COMMAND echo "${Red}Code coverage only available in coverage builds." 151 | COMMAND echo "${Green}Make a new build directory and rerun cmake with -DCMAKE_BUILD_TYPE=Coverage to enable this target.${ColorReset}" 152 | ) 153 | endif() 154 | -------------------------------------------------------------------------------- /maptree/cmake/Colors.cmake: -------------------------------------------------------------------------------- 1 | IF(NOT WIN32) 2 | string(ASCII 27 Esc) 3 | set(ColorReset "${Esc}[m") 4 | set(ColorBold "${Esc}[1m") 5 | set(Red "${Esc}[31m") 6 | set(Green "${Esc}[32m") 7 | set(Yellow "${Esc}[33m") 8 | set(Blue "${Esc}[34m") 9 | set(Magenta "${Esc}[35m") 10 | set(Cyan "${Esc}[36m") 11 | set(White "${Esc}[37m") 12 | set(BoldRed "${Esc}[1;31m") 13 | set(BoldGreen "${Esc}[1;32m") 14 | set(BoldYellow "${Esc}[1;33m") 15 | set(BoldBlue "${Esc}[1;34m") 16 | set(BoldMagenta "${Esc}[1;35m") 17 | set(BoldCyan "${Esc}[1;36m") 18 | set(BoldWhite "${Esc}[1;37m") 19 | ENDIF() -------------------------------------------------------------------------------- /maptree/cmake/ConfigSafeGuards.cmake: -------------------------------------------------------------------------------- 1 | # guard against in-source builds 2 | if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR}) 3 | message(FATAL_ERROR "In-source builds not allowed. Please make a new directory (called a build directory) and run CMake from there.") 4 | endif() 5 | 6 | # guard against bad build-type strings 7 | if (NOT CMAKE_BUILD_TYPE) 8 | message(STATUS "No build type selected, default to Debug") 9 | set(CMAKE_BUILD_TYPE "Debug") 10 | endif() 11 | 12 | string(TOLOWER "${CMAKE_BUILD_TYPE}" cmake_build_type_tolower) 13 | string(TOUPPER "${CMAKE_BUILD_TYPE}" cmake_build_type_toupper) 14 | if( NOT cmake_build_type_tolower STREQUAL "debug" 15 | AND NOT cmake_build_type_tolower STREQUAL "release" 16 | AND NOT cmake_build_type_tolower STREQUAL "profile" 17 | AND NOT cmake_build_type_tolower STREQUAL "relwithdebinfo" 18 | AND NOT cmake_build_type_tolower STREQUAL "coverage") 19 | message(FATAL_ERROR "Unknown build type \"${CMAKE_BUILD_TYPE}\". Allowed values are Debug, Coverage, Release, Profile, RelWithDebInfo (case-insensitive).") 20 | endif() -------------------------------------------------------------------------------- /maptree/cmake/Doctest.cmake: -------------------------------------------------------------------------------- 1 | if(ENABLE_DOCTESTS) 2 | add_definitions(-DENABLE_DOCTEST_IN_LIBRARY) 3 | include(FetchContent) 4 | FetchContent_Declare( 5 | DocTest 6 | GIT_REPOSITORY "https://github.com/onqtam/doctest" 7 | GIT_TAG "b7c21ec5ceeadb4951b00396fc1e4642dd347e5f" 8 | ) 9 | 10 | FetchContent_MakeAvailable(DocTest) 11 | include_directories(${DOCTEST_INCLUDE_DIR}) 12 | endif() -------------------------------------------------------------------------------- /maptree/cmake/Documentation.cmake: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # Documentation (no change needed). 3 | # -------------------------------------------------------------------------------- 4 | # Add a make target 'doc' to generate API documentation with Doxygen. 5 | # You should set options to your liking in the file 'Doxyfile.in'. 6 | find_package(Doxygen) 7 | if(DOXYGEN_FOUND) 8 | configure_file(${CMAKE_CURRENT_SOURCE_DIR}/Doxyfile.in ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY) 9 | add_custom_target(doc 10 | ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile &> doxygen.log 11 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 12 | COMMENT "${BoldMagenta}Generating API documentation with Doxygen (open ./html/index.html to view).${ColourReset}" VERBATIM 13 | ) 14 | endif(DOXYGEN_FOUND) 15 | -------------------------------------------------------------------------------- /maptree/cmake/LTO.cmake: -------------------------------------------------------------------------------- 1 | # Usage : 2 | # 3 | # Variable : ENABLE_LTO | Enable or disable LTO support for this build 4 | # 5 | # find_lto(lang) 6 | # - lang is C or CXX (the language to test LTO for) 7 | # - call it after project() so that the compiler is already detected 8 | # 9 | # This will check for LTO support and create a target_enable_lto(target [debug,optimized,general]) macro. 10 | # The 2nd parameter has the same meaning as in target_link_libraries, and is used to enable LTO only for those build configurations 11 | # 'debug' is by default the Debug configuration, and 'optimized' all the other configurations 12 | # 13 | # if ENABLE_LTO is set to false, an empty macro will be generated 14 | # 15 | # Then to enable LTO for your target use 16 | # 17 | # target_enable_lto(mytarget general) 18 | # 19 | # It is however recommended to use it only for non debug builds the following way : 20 | # 21 | # target_enable_lto(mytarget optimized) 22 | # 23 | # Note : For CMake versions < 3.9, target_link_library is used in it's non plain version. 24 | # You will need to specify PUBLIC/PRIVATE/INTERFACE to all your other target_link_library calls for the target 25 | # 26 | # WARNING for cmake versions older than 3.9 : 27 | # This module will override CMAKE_AR CMAKE_RANLIB and CMAKE_NM by the gcc versions if found when building with gcc 28 | 29 | 30 | # License: 31 | # 32 | # Copyright (C) 2016 Lectem 33 | # 34 | # Permission is hereby granted, free of charge, to any person 35 | # obtaining a copy of this software and associated documentation files 36 | # (the 'Software') deal in the Software without restriction, 37 | # including without limitation the rights to use, copy, modify, merge, 38 | # publish, distribute, sublicense, and/or sell copies of the Software, 39 | # and to permit persons to whom the Software is furnished to do so, 40 | # subject to the following conditions: 41 | # 42 | # The above copyright notice and this permission notice shall be 43 | # included in all copies or substantial portions of the Software. 44 | # 45 | # THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 46 | # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 47 | # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 48 | # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 49 | # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 50 | # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 51 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 52 | # SOFTWARE. 53 | 54 | 55 | macro(find_lto lang) 56 | if(ENABLE_LTO AND NOT LTO_${lang}_CHECKED) 57 | 58 | #LTO support was added for clang/gcc in 3.9 59 | if(${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION} VERSION_LESS 3.9) 60 | cmake_policy(SET CMP0054 NEW) 61 | message(STATUS "Checking for LTO Compatibility") 62 | # Since GCC 4.9 we need to use gcc-ar / gcc-ranlib / gcc-nm 63 | if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") 64 | if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND NOT CMAKE_GCC_AR OR NOT CMAKE_GCC_RANLIB OR NOT CMAKE_GCC_NM) 65 | find_program(CMAKE_GCC_AR NAMES 66 | "${_CMAKE_TOOLCHAIN_PREFIX}gcc-ar" 67 | "${_CMAKE_TOOLCHAIN_PREFIX}gcc-ar-${_version}" 68 | DOC "gcc provided wrapper for ar which adds the --plugin option" 69 | ) 70 | find_program(CMAKE_GCC_RANLIB NAMES 71 | "${_CMAKE_TOOLCHAIN_PREFIX}gcc-ranlib" 72 | "${_CMAKE_TOOLCHAIN_PREFIX}gcc-ranlib-${_version}" 73 | DOC "gcc provided wrapper for ranlib which adds the --plugin option" 74 | ) 75 | # Not needed, but at least stay coherent 76 | find_program(CMAKE_GCC_NM NAMES 77 | "${_CMAKE_TOOLCHAIN_PREFIX}gcc-nm" 78 | "${_CMAKE_TOOLCHAIN_PREFIX}gcc-nm-${_version}" 79 | DOC "gcc provided wrapper for nm which adds the --plugin option" 80 | ) 81 | mark_as_advanced(CMAKE_GCC_AR CMAKE_GCC_RANLIB CMAKE_GCC_NM) 82 | set(CMAKE_LTO_AR ${CMAKE_GCC_AR}) 83 | set(CMAKE_LTO_RANLIB ${CMAKE_GCC_RANLIB}) 84 | set(CMAKE_LTO_NM ${CMAKE_GCC_NM}) 85 | endif() 86 | if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") 87 | set(CMAKE_LTO_AR ${CMAKE_AR}) 88 | set(CMAKE_LTO_RANLIB ${CMAKE_RANLIB}) 89 | set(CMAKE_LTO_NM ${CMAKE_NM}) 90 | endif() 91 | 92 | if(CMAKE_LTO_AR AND CMAKE_LTO_RANLIB) 93 | set(__lto_flags -flto) 94 | 95 | if(NOT CMAKE_${lang}_COMPILER_VERSION VERSION_LESS 4.7) 96 | list(APPEND __lto_flags -fno-fat-lto-objects) 97 | endif() 98 | 99 | if(NOT DEFINED CMAKE_${lang}_PASSED_LTO_TEST) 100 | set(__output_dir "${CMAKE_PLATFORM_INFO_DIR}/LtoTest1${lang}") 101 | file(MAKE_DIRECTORY "${__output_dir}") 102 | set(__output_base "${__output_dir}/lto-test-${lang}") 103 | 104 | execute_process( 105 | COMMAND ${CMAKE_COMMAND} -E echo "void foo() {}" 106 | COMMAND ${CMAKE_${lang}_COMPILER} ${__lto_flags} -c -xc - 107 | -o "${__output_base}.o" 108 | RESULT_VARIABLE __result 109 | ERROR_QUIET 110 | OUTPUT_QUIET 111 | ) 112 | 113 | if("${__result}" STREQUAL "0") 114 | execute_process( 115 | COMMAND ${CMAKE_LTO_AR} cr "${__output_base}.a" "${__output_base}.o" 116 | RESULT_VARIABLE __result 117 | ERROR_QUIET 118 | OUTPUT_QUIET 119 | ) 120 | endif() 121 | 122 | if("${__result}" STREQUAL "0") 123 | execute_process( 124 | COMMAND ${CMAKE_LTO_RANLIB} "${__output_base}.a" 125 | RESULT_VARIABLE __result 126 | ERROR_QUIET 127 | OUTPUT_QUIET 128 | ) 129 | endif() 130 | 131 | if("${__result}" STREQUAL "0") 132 | execute_process( 133 | COMMAND ${CMAKE_COMMAND} -E echo "void foo(); int main() {foo();}" 134 | COMMAND ${CMAKE_${lang}_COMPILER} ${__lto_flags} -xc - 135 | -x none "${__output_base}.a" -o "${__output_base}" 136 | RESULT_VARIABLE __result 137 | ERROR_QUIET 138 | OUTPUT_QUIET 139 | ) 140 | endif() 141 | 142 | if("${__result}" STREQUAL "0") 143 | set(__lto_found TRUE) 144 | endif() 145 | 146 | set(CMAKE_${lang}_PASSED_LTO_TEST 147 | ${__lto_found} CACHE INTERNAL 148 | "If the compiler passed a simple LTO test compile") 149 | endif() 150 | if(CMAKE_${lang}_PASSED_LTO_TEST) 151 | message(STATUS "Checking for LTO Compatibility - works") 152 | set(LTO_${lang}_SUPPORT TRUE CACHE BOOL "Do we have LTO support ?") 153 | set(LTO_COMPILE_FLAGS -flto CACHE STRING "Link Time Optimization compile flags") 154 | set(LTO_LINK_FLAGS -flto CACHE STRING "Link Time Optimization link flags") 155 | else() 156 | message(STATUS "Checking for LTO Compatibility - not working") 157 | endif() 158 | 159 | endif() 160 | elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang") 161 | message(STATUS "Checking for LTO Compatibility - works (assumed for clang)") 162 | set(LTO_${lang}_SUPPORT TRUE CACHE BOOL "Do we have LTO support ?") 163 | set(LTO_COMPILE_FLAGS -flto CACHE STRING "Link Time Optimization compile flags") 164 | set(LTO_LINK_FLAGS -flto CACHE STRING "Link Time Optimization link flags") 165 | elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") 166 | message(STATUS "Checking for LTO Compatibility - works") 167 | set(LTO_${lang}_SUPPORT TRUE CACHE BOOL "Do we have LTO support ?") 168 | set(LTO_COMPILE_FLAGS /GL CACHE STRING "Link Time Optimization compile flags") 169 | set(LTO_LINK_FLAGS -LTCG:INCREMENTAL CACHE STRING "Link Time Optimization link flags") 170 | else() 171 | message(STATUS "Checking for LTO Compatibility - compiler not handled by module") 172 | endif() 173 | mark_as_advanced(LTO_${lang}_SUPPORT LTO_COMPILE_FLAGS LTO_LINK_FLAGS) 174 | 175 | 176 | set(LTO_${lang}_CHECKED TRUE CACHE INTERNAL "" ) 177 | 178 | if(CMAKE_GCC_AR AND CMAKE_GCC_RANLIB AND CMAKE_GCC_NM) 179 | # THIS IS HACKY BUT THERE IS NO OTHER SOLUTION ATM 180 | set(CMAKE_AR ${CMAKE_GCC_AR} CACHE FILEPATH "Forcing gcc-ar instead of ar" FORCE) 181 | set(CMAKE_NM ${CMAKE_GCC_NM} CACHE FILEPATH "Forcing gcc-nm instead of nm" FORCE) 182 | set(CMAKE_RANLIB ${CMAKE_GCC_RANLIB} CACHE FILEPATH "Forcing gcc-ranlib instead of ranlib" FORCE) 183 | endif() 184 | endif(${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION} VERSION_LESS 3.9) 185 | endif(ENABLE_LTO AND NOT LTO_${lang}_CHECKED) 186 | 187 | 188 | if(ENABLE_LTO) 189 | #Special case for cmake older than 3.9, using a library for gcc/clang, but could dataloader the flags directly. 190 | #Taking advantage of the [debug,optimized] parameter of target_link_libraries 191 | if(${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION} VERSION_LESS 3.9) 192 | if(LTO_${lang}_SUPPORT) 193 | if(NOT TARGET __enable_lto_tgt) 194 | add_library(__enable_lto_tgt INTERFACE) 195 | endif() 196 | target_compile_options(__enable_lto_tgt INTERFACE ${LTO_COMPILE_FLAGS}) 197 | #this might not work for all platforms... in which case we'll have to set the link flags on the target directly 198 | target_link_libraries(__enable_lto_tgt INTERFACE ${LTO_LINK_FLAGS} ) 199 | macro(target_enable_lto _target _build_configuration) 200 | if(${_build_configuration} STREQUAL "optimized" OR ${_build_configuration} STREQUAL "debug" ) 201 | target_link_libraries(${_target} PRIVATE ${_build_configuration} __enable_lto_tgt) 202 | else() 203 | target_link_libraries(${_target} PRIVATE __enable_lto_tgt) 204 | endif() 205 | endmacro() 206 | else() 207 | #In old cmake versions, we can set INTERPROCEDURAL_OPTIMIZATION even if not supported by the compiler 208 | #So if we didn't detect it, let cmake give it a try 209 | set(__IPO_SUPPORTED TRUE) 210 | endif() 211 | else() 212 | cmake_policy(SET CMP0069 NEW) 213 | include(CheckIPOSupported) 214 | # Optional IPO. Do not use IPO if it's not supported by compiler. 215 | check_ipo_supported(RESULT __IPO_SUPPORTED OUTPUT output) 216 | if(NOT __IPO_SUPPORTED) 217 | message(STATUS "IPO is not supported or broken.") 218 | else() 219 | message(STATUS "IPO is supported") 220 | endif() 221 | endif() 222 | if(__IPO_SUPPORTED) 223 | macro(target_enable_lto _target _build_configuration) 224 | if(NOT ${_build_configuration} STREQUAL "debug" ) 225 | #enable for all configurations 226 | set_target_properties(${_target} PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE) 227 | endif() 228 | if(${_build_configuration} STREQUAL "optimized" ) 229 | #blacklist debug configurations 230 | set(__enable_debug_lto FALSE) 231 | else() 232 | #enable only for debug configurations 233 | set(__enable_debug_lto TRUE) 234 | endif() 235 | get_property(DEBUG_CONFIGURATIONS GLOBAL PROPERTY DEBUG_CONFIGURATIONS) 236 | if(NOT DEBUG_CONFIGURATIONS) 237 | set(DEBUG_CONFIGURATIONS DEBUG) # This is what is done by CMAKE internally... since DEBUG_CONFIGURATIONS is empty by default 238 | endif() 239 | foreach(config IN LISTS DEBUG_CONFIGURATIONS) 240 | set_target_properties(${_target} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_${config} ${__enable_debug_lto}) 241 | endforeach() 242 | endmacro() 243 | endif() 244 | endif() 245 | if(NOT COMMAND target_enable_lto) 246 | macro(target_enable_lto _target _build_configuration) 247 | endmacro() 248 | endif() 249 | endmacro() -------------------------------------------------------------------------------- /maptree/cmake/Warnings.cmake: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2017 Lectem 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | function(target_set_warnings) 25 | if(NOT ENABLE_WARNINGS_SETTINGS) 26 | return() 27 | endif() 28 | if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") 29 | set(WMSVC TRUE) 30 | elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") 31 | set(WGCC TRUE) 32 | elseif ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") 33 | set(WCLANG TRUE) 34 | endif() 35 | set(multiValueArgs ENABLE DISABLE AS_ERROR) 36 | cmake_parse_arguments(this "" "" "${multiValueArgs}" ${ARGN}) 37 | list(FIND this_ENABLE "ALL" enable_all) 38 | list(FIND this_DISABLE "ALL" disable_all) 39 | list(FIND this_AS_ERROR "ALL" as_error_all) 40 | if(NOT ${enable_all} EQUAL -1) 41 | if(WMSVC) 42 | # Not all the warnings, but WAll is unusable when using libraries 43 | # Unless you'd like to support MSVC in the code with pragmas, this is probably the best option 44 | list(APPEND WarningFlags "/W4") 45 | elseif(WGCC) 46 | list(APPEND WarningFlags "-Wall" "-Wextra" "-Wpedantic") 47 | elseif(WCLANG) 48 | list(APPEND WarningFlags "-Wall" "-Wextra" "-Wconversion" "-Wsign-conversion" "-Wpedantic") 49 | endif() 50 | elseif(NOT ${disable_all} EQUAL -1) 51 | set(SystemIncludes TRUE) # Treat includes as if coming from system 52 | if(WMSVC) 53 | list(APPEND WarningFlags "/w" "/W0") 54 | elseif(WGCC OR WCLANG) 55 | list(APPEND WarningFlags "-w") 56 | endif() 57 | endif() 58 | 59 | list(FIND this_DISABLE "Annoying" disable_annoying) 60 | if(NOT ${disable_annoying} EQUAL -1) 61 | if(WMSVC) 62 | # bounds-checked functions require to set __STDC_WANT_LIB_EXT1__ which we usually don't need/want 63 | list(APPEND WarningDefinitions -D_CRT_SECURE_NO_WARNINGS) 64 | # disable C4514 C4710 C4711... Those are useless to add most of the time 65 | #list(APPEND WarningFlags "/wd4514" "/wd4710" "/wd4711") 66 | #list(APPEND WarningFlags "/wd4365") #signed/unsigned mismatch 67 | #list(APPEND WarningFlags "/wd4668") # is not defined as a preprocessor macro, replacing with '0' for 68 | elseif(WGCC OR WCLANG) 69 | list(APPEND WarningFlags -Wno-switch-enum) 70 | if(WCLANG) 71 | list(APPEND WarningFlags -Wno-unknown-warning-option -Wno-padded -Wno-undef -Wno-reserved-id-macro -fcomment-block-commands=test,retval) 72 | if(NOT CMAKE_CXX_STANDARD EQUAL 98) 73 | list(APPEND WarningFlags -Wno-c++98-compat -Wno-c++98-compat-pedantic) 74 | endif() 75 | if ("${CMAKE_CXX_SIMULATE_ID}" STREQUAL "MSVC") # clang-cl has some VCC flags by default that it will not recognize... 76 | list(APPEND WarningFlags -Wno-unused-command-line-argument) 77 | endif() 78 | endif(WCLANG) 79 | endif() 80 | endif() 81 | 82 | if(NOT ${as_error_all} EQUAL -1) 83 | if(WMSVC) 84 | list(APPEND WarningFlags "/WX") 85 | elseif(WGCC OR WCLANG) 86 | list(APPEND WarningFlags "-Werror") 87 | endif() 88 | endif() 89 | foreach(target IN LISTS this_UNPARSED_ARGUMENTS) 90 | if(WarningFlags) 91 | target_compile_options(${target} PRIVATE ${WarningFlags}) 92 | endif() 93 | if(WarningDefinitions) 94 | target_compile_definitions(${target} PRIVATE ${WarningDefinitions}) 95 | endif() 96 | if(SystemIncludes) 97 | set_target_properties(${target} PROPERTIES 98 | INTERFACE_SYSTEM_INCLUDE_DIRECTORIES $) 99 | endif() 100 | endforeach() 101 | endfunction(target_set_warnings) -------------------------------------------------------------------------------- /maptree/include/cache/approx_bitset_cache.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file approx_bitset_cache.h 3 | * @brief Subproblem cache based on bitset hashes. 4 | * 5 | * This file contains a cache implementation which stores results for 6 | * subproblems based on several hash values of their bitset and depth. Note that 7 | * this cache may return false positives, but due to the limited number of 8 | * queries and the large hash values, this is unlikely. 9 | */ 10 | 11 | #ifndef APPROX_BITSET_CACHE_H 12 | #define APPROX_BITSET_CACHE_H 13 | 14 | #include 15 | #include 16 | #include 17 | 18 | #include "constants.h" 19 | #include "cache/base_cache.h" 20 | #include "data/bitset.h" 21 | 22 | //! Number of unsigned long longs in the ApproxBitsetCacheKey. 23 | #define APPROX_BITSET_CACHE_NUM_ULL_HASH_VALUES 2 24 | 25 | /** 26 | * @struct ApproxBitsetCacheKey 27 | * @brief Key for the ApproxBitsetCache. 28 | * 29 | * This struct contains the key for ApproxBitsetCache: the bitset hash values 30 | * and the depth. 31 | */ 32 | struct ApproxBitsetCacheKey { 33 | std::array hashedBitset; 34 | size_t depth = 0; 35 | friend bool operator==(const ApproxBitsetCacheKey &lhs, const ApproxBitsetCacheKey &rhs); 36 | }; 37 | 38 | /** 39 | * @struct ApproxBitsetCacheKeyHash 40 | * @brief Hash function for ApproxBitsetCacheKey. 41 | */ 42 | struct ApproxBitsetCacheKeyHash { 43 | size_t operator()(const ApproxBitsetCacheKey &key) const; 44 | }; 45 | 46 | /** 47 | * @class ApproxBitsetCache 48 | * @brief Subproblem cache based on hashed bitsets. 49 | * @implements BaseCache 50 | */ 51 | class ApproxBitsetCache : BaseCache { 52 | public: 53 | ApproxBitsetCache( 54 | size_t numBlocks 55 | ) { 56 | initBlockMults(numBlocks); 57 | }; 58 | 59 | /** 60 | * @implements BaseCache::put 61 | */ 62 | void put( 63 | Subproblem& subproblem, 64 | void *value 65 | ) override; 66 | 67 | /** 68 | * @implements BaseCache::get 69 | */ 70 | void *get( 71 | Subproblem& subproblem 72 | ) override; 73 | 74 | /** 75 | * @implements BaseCache::size 76 | */ 77 | size_t size() const override; 78 | 79 | static constexpr std::array BLOCK_MULT_BASE = { 80 | 377424577268497867ULL, 81 | 285989758769553131ULL, 82 | }; 83 | static constexpr BLOCK DEPTH_MULT = 234902547182092241ULL; 84 | 85 | private: 86 | std::unordered_map cache_; 87 | std::array, APPROX_BITSET_CACHE_NUM_ULL_HASH_VALUES> blockMults_; 88 | 89 | void initBlockMults(size_t numBlocks); 90 | ApproxBitsetCacheKey constructKey(Subproblem& subproblem) const; 91 | }; 92 | 93 | #endif 94 | -------------------------------------------------------------------------------- /maptree/include/cache/base_cache.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file base_cache.h 3 | * @brief Base cache interface for subproblem caches. 4 | * 5 | * This file contains a base interface for cache implementations that store 6 | * results for subproblems already explored. 7 | */ 8 | 9 | #ifndef BASE_CACHE_H 10 | #define BASE_CACHE_H 11 | 12 | #include "subproblem.h" 13 | 14 | /** 15 | * @class BaseCache 16 | * @brief Interface for subproblem caches. 17 | */ 18 | class BaseCache { 19 | public: 20 | virtual ~BaseCache() = default; 21 | 22 | /** 23 | * @brief Stores the provided value for the provided subproblem. 24 | * @param subproblem The subproblem to store the value for. 25 | * @param value The value to store. 26 | * @returns void 27 | */ 28 | virtual void put( 29 | Subproblem& subproblem, 30 | void *value 31 | ) = 0; 32 | 33 | /** 34 | * @brief Retrieves the value for the provided subproblem. 35 | * @param subproblem The subproblem to retrieve the value for. 36 | * @returns The value for the provided subproblem. 37 | */ 38 | virtual void *get( 39 | Subproblem& subproblem 40 | ) = 0; 41 | 42 | /** 43 | * @brief Returns the number of subproblems stored in the cache. 44 | * @returns The number of subproblems stored in the cache. 45 | */ 46 | virtual size_t size() const = 0; 47 | }; 48 | 49 | #endif 50 | -------------------------------------------------------------------------------- /maptree/include/constants.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file constants.h 3 | * @brief Constants used throughout the library. 4 | */ 5 | 6 | typedef unsigned long long BLOCK; 7 | #define BLOCK_BITS 64 8 | #define FULL_BLOCK 0xFFFFFFFFFFFFFFFFULL 9 | #define NUM_BLOCKS(numBits) ((numBits + BLOCK_BITS - 1) / BLOCK_BITS) -------------------------------------------------------------------------------- /maptree/include/data/binary_data_loader.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file binary_data_loader.h 3 | * @brief Loads binary data from a file. 4 | */ 5 | 6 | #ifndef BINARY_DATA_LOADER_H 7 | #define BINARY_DATA_LOADER_H 8 | 9 | #include 10 | #include 11 | 12 | /** 13 | * @class BinaryDataLoader 14 | * @brief Loads binary data from a file. 15 | * 16 | * Initialized with a filename, this class loads binary data from a file. The 17 | * file should contain only space delimited 0's and 1's, with the same number of 18 | * values on each line. Every line represents a single data point, and the first 19 | * value on each line is the label for that data point. 20 | */ 21 | class BinaryDataLoader { 22 | public: 23 | BinaryDataLoader( 24 | const std::string& filename 25 | ) 26 | : filename_(filename) 27 | { 28 | load(); 29 | } 30 | 31 | /** 32 | * @brief Returns the features loaded from the file. 33 | */ 34 | const std::vector>& getFeatures() const; 35 | 36 | /** 37 | * @brief Returns the labels loaded from the file. 38 | */ 39 | const std::vector& getLabels() const; 40 | private: 41 | std::string filename_; 42 | std::vector> features_; 43 | std::vector labels_; 44 | void load(); 45 | }; 46 | 47 | #endif 48 | -------------------------------------------------------------------------------- /maptree/include/data/bitset.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file bitset.h 3 | * @brief Reversible sparse bitset. 4 | * 5 | * This file contains an implementation of a reversible sparse bitset. This 6 | * bitset is implemented as a vector of reversible blocks (RNumber) and a vector 7 | * of indices of the blocks. The bitset is reversible in that we can apply masks 8 | * to it that remove bits or blocks from the bitset and then reverse the mask to 9 | * restore the bitset to its original state. 10 | * 11 | * @see https://arxiv.org/abs/1604.06641 12 | */ 13 | 14 | #ifndef BITSET_H 15 | #define BITSET_H 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "constants.h" 23 | #include "data/rnumber.h" 24 | #include "data/fixed_bitset.h" 25 | 26 | /** 27 | * @class Bitset 28 | * @brief Reversible sparse bitset implementation. 29 | */ 30 | class Bitset { 31 | public: 32 | Bitset( 33 | size_t numSamples, 34 | size_t maxLevel 35 | ) 36 | : numBlocks_(NUM_BLOCKS(numSamples)) 37 | , maxLevel_(maxLevel) 38 | , blocks_(numBlocks_, RNumber(maxLevel + 1, FULL_BLOCK)) 39 | , indices_(numBlocks_) 40 | , limit_(maxLevel + 1, numBlocks_) 41 | { 42 | for (size_t i = 0; i < numBlocks_; i++) indices_[i] = i; 43 | BLOCK lastBlock = (1ULL << (numSamples % BLOCK_BITS)) - 1; 44 | blocks_[numBlocks_ - 1].set(lastBlock); 45 | }; 46 | 47 | /** 48 | * @brief Returns the current level of the bitset, or the number of 49 | * masks that have been applied to it. 50 | * @returns The current level of the bitset. 51 | */ 52 | size_t level() const; 53 | 54 | /** 55 | * @brief Returns the number of blocks in the bitset. 56 | * @returns The number of blocks in the bitset. 57 | */ 58 | int count() const; 59 | 60 | /** 61 | * @brief Returns the number of bits in the intersection of this bitset 62 | * and the provided fixed bitset. 63 | * @param other The fixed bitset to intersect with. 64 | * @returns The number of bits in the intersection of this bitset and 65 | * the provided fixed bitset. 66 | */ 67 | int countIntersection(const FixedBitset& other) const; 68 | 69 | /** 70 | * @brief Checks if this bitset is a subset of the provided fixed 71 | * bitset. 72 | * @param other The fixed bitset to check if this bitset is a subset of. 73 | * @returns True if this bitset is a subset of the provided fixed 74 | * bitset, false otherwise. 75 | */ 76 | bool isSubset(const FixedBitset& other) const; 77 | 78 | /** 79 | * @brief Updates the bitset by applying the provided mask. 80 | * @param other The fixed bitset mask to apply to the bitset. 81 | * @returns void 82 | */ 83 | void intersect(const FixedBitset& other); 84 | 85 | /** 86 | * @brief Reverses the last mask applied to the bitset. 87 | * @returns void 88 | */ 89 | void reverse(); 90 | 91 | /** 92 | * @brief Resets the bitset to its original state. 93 | * @returns void 94 | */ 95 | void reset(); 96 | 97 | /** 98 | * @brief Returns a weighted sum across all blocks of the bitset. 99 | * @param blockWeights The weights to use for each block. 100 | * @returns The weighted sum across all blocks of the bitset. 101 | */ 102 | BLOCK sumOfBlocks(const std::vector& blockWeights) const; 103 | 104 | /** 105 | * @brief Outputs the bitset as a string of its blocks. 106 | */ 107 | friend std::ostream& operator<<(std::ostream& os, const Bitset& bitset); 108 | 109 | private: 110 | size_t level_ = 0; 111 | size_t numBlocks_; 112 | [[maybe_unused]] size_t maxLevel_; 113 | std::vector blocks_; 114 | std::vector indices_; 115 | RNumber limit_; 116 | }; 117 | 118 | #endif 119 | -------------------------------------------------------------------------------- /maptree/include/data/data_manager.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file data_manager.h 3 | * @brief Precomputes information about data for searchers. 4 | * 5 | * This file contains the DataManager class, which precomputes feature and label 6 | * masks for the searchers. 7 | */ 8 | 9 | #ifndef DATA_MANAGER_H 10 | #define DATA_MANAGER_H 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include "data/fixed_bitset.h" 17 | 18 | /** 19 | * @class DataManager 20 | * @brief Precomputes information about data for searchers. 21 | * 22 | * This class uses the provided features and labels to precompute feature and 23 | * label masks for the searchers. These masks are stored as vectors of blocks 24 | * (unsigned long longs) and used by searchers to quickly update subproblems. 1 25 | * bits in label masks indicate that the point at that bit's index contains the 26 | * mask's label or feature value and a 0 bit indicates that it does not. 27 | */ 28 | class DataManager { 29 | public: 30 | DataManager( 31 | const std::vector>& features, 32 | const std::vector& labels 33 | ) 34 | : numFeatures_(features[0].size()) 35 | , numSamples_(features.size()) 36 | , featureMasks_(numFeatures_ * 2, FixedBitset(numSamples_)) 37 | , labelMasks_(2, FixedBitset(numSamples_)) 38 | { 39 | buildFeatureMasks(features); 40 | buildLabelMasks(labels); 41 | } 42 | 43 | /** 44 | * @brief Returns the number of features in the data. 45 | */ 46 | size_t getNumFeatures() const; 47 | 48 | /** 49 | * @brief Returns the number of samples in the data. 50 | */ 51 | size_t getNumSamples() const; 52 | 53 | /** 54 | * @brief Returns the mask for the provided feature and feature value. 55 | * @param feature The feature to get the mask for. 56 | * @param value The feature value to get the mask for. 57 | * @returns The mask for the provided feature and feature value. 58 | */ 59 | const FixedBitset& getFeatureMask( 60 | size_t feature, 61 | bool value 62 | ) const; 63 | 64 | /** 65 | * @brief Returns the mask for the provided label value. 66 | * @param value The label value to get the mask for. 67 | * @returns The mask for the provided label value. 68 | */ 69 | const FixedBitset& getLabelMask( 70 | bool value 71 | ) const; 72 | 73 | private: 74 | size_t numFeatures_; 75 | size_t numSamples_; 76 | std::vector featureMasks_; 77 | std::vector labelMasks_; 78 | 79 | void buildFeatureMasks(const std::vector>& features); 80 | void buildLabelMasks(const std::vector& labels); 81 | }; 82 | 83 | #endif 84 | -------------------------------------------------------------------------------- /maptree/include/data/fixed_bitset.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file fixed_bitset.h 3 | * @brief Fixed size bitset. 4 | * 5 | * This file contains a fixed size bitset implementation. This bitset is 6 | * implemented as a vector of blocks (unsigned long longs) which are not 7 | * modified after construction. 8 | */ 9 | 10 | #ifndef FIXED_BITSET_H 11 | #define FIXED_BITSET_H 12 | 13 | #include 14 | #include 15 | 16 | #include "constants.h" 17 | 18 | /** 19 | * @class FixedBitset 20 | * @brief Fixed size bitset. 21 | */ 22 | class FixedBitset { 23 | public: 24 | FixedBitset( 25 | size_t numSamples 26 | ) 27 | : numSamples_(numSamples) 28 | , blocks_(NUM_BLOCKS(numSamples), 0) 29 | {}; 30 | 31 | /** 32 | * @brief Sets the blocks of this bitset using the provided bit vector. 33 | * @param bits The bit vector to set the blocks of this bitset with. 34 | * @pre The size of the provided bit vector must be equal to the number 35 | * of samples in this bitset. 36 | * @returns void 37 | */ 38 | void setBits( 39 | const std::vector& bits 40 | ); 41 | 42 | /** 43 | * @brief Returns the block at the provided index. 44 | * @param index The index of the block to access. 45 | * @returns The block at the provided index. 46 | */ 47 | BLOCK getBlock( 48 | size_t blockIdx 49 | ) const; 50 | 51 | private: 52 | [[maybe_unused]] size_t numSamples_; 53 | std::vector blocks_; 54 | }; 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /maptree/include/data/rnumber.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file rnumber.h 3 | * @brief This file contains the reversible number class used in reversible 4 | * sparse bitsets. 5 | * 6 | * This file contains the reversible number class used in reversible sparse 7 | * bitsets. Several actions can be performed on a reversible number, including 8 | * setting, updating, intersecting, reversing, resetting, and counting bits. The 9 | * reversible number is implemented as a vector of blocks (unsigned long longs) 10 | * with a head index. The head index is used to keep track of the current 11 | * position in the vector. Reversing can be done by simply decrementing the head 12 | * index in constant time. 13 | */ 14 | 15 | #ifndef RNUMBER_H 16 | #define RNUMBER_H 17 | 18 | #include 19 | #include 20 | 21 | #include "constants.h" 22 | 23 | class RNumber { 24 | public: 25 | RNumber( 26 | size_t capacity, 27 | BLOCK initValue = 0 28 | ) 29 | : values_(capacity, initValue) 30 | {}; 31 | 32 | /** 33 | * @brief Returns the current value of the reversible number. 34 | */ 35 | BLOCK get() const; 36 | 37 | /** 38 | * @brief Sets the current value to the provided value 39 | * @param value The value to set the reversible number to. 40 | * 41 | * This action is irreversible. 42 | */ 43 | void set( 44 | BLOCK value 45 | ); 46 | 47 | /** 48 | * @brief Updates the current value to the provided value. 49 | * @param value The value to update the reversible number with. 50 | * 51 | * This action is reversible. 52 | */ 53 | void update( 54 | BLOCK value 55 | ); 56 | 57 | /** 58 | * @brief Intersects the current value with the provided value. 59 | * @param value The value to intersect the reversible number with. 60 | * 61 | * This action is reversible. 62 | */ 63 | void intersect( 64 | BLOCK other 65 | ); 66 | 67 | /** 68 | * @brief Reverts previous action. 69 | */ 70 | void reverse(); 71 | 72 | /** 73 | * @brief Resets the reversible number to its initial state. 74 | */ 75 | void reset(); 76 | 77 | /** 78 | * @brief Counts the number of bits in the current value. 79 | * @returns The number of bits in the current value. 80 | */ 81 | int countBits() const; 82 | 83 | /** 84 | * @brief Counts the number of bits at the intersection of the current 85 | * value and the provided value. 86 | * @param other The value to intersect the reversible number with. 87 | * @returns The number of bits at the intersection of the current value 88 | * and the provided value. 89 | */ 90 | int countBitsAtIntersection( 91 | BLOCK other 92 | ) const; 93 | 94 | /** 95 | * @brief Checks if the current value is a subset of the provided value. 96 | * @param other The value to check if the current value is a subset of. 97 | * @returns True if the current value is a subset of the provided value, 98 | * false otherwise. 99 | */ 100 | bool isSubset( 101 | BLOCK other 102 | ) const; 103 | 104 | /** 105 | * @brief Checks if the current value is 0 106 | * @returns True if the current value is 0, false otherwise. 107 | */ 108 | bool empty() const; 109 | private: 110 | size_t head_ = 0; 111 | std::vector values_; 112 | }; 113 | 114 | #endif -------------------------------------------------------------------------------- /maptree/include/data/split.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file split.h 3 | * @brief Contains a struct which represents a feature-value split. 4 | */ 5 | 6 | #include 7 | 8 | #ifndef SPLIT_H 9 | #define SPLIT_H 10 | 11 | /** 12 | * @struct Split 13 | * @brief Represents a feature-value split. 14 | */ 15 | struct Split { 16 | size_t feature; 17 | bool value; 18 | }; 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /maptree/include/posterior/tree_likelihood.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file tree_likelihood.h 3 | * @brief This file contains the tree likelihoods used in MAP tree search. 4 | * 5 | * This file contains the functions used to compute the likelihood of a given 6 | * tree based on the BCART statistical model for use in MAP tree search. 7 | * 8 | * @see https://www.jstor.org/stable/2669832 9 | */ 10 | 11 | #ifndef TREE_LIKELIHOOD_H 12 | #define TREE_LIKELIHOOD_H 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | /** 20 | * @class TreeLikelihood 21 | * @brief Contains the likelihood function for the leaves of a BCART tree as 22 | * well as a couple other utility functions. 23 | */ 24 | class TreeLikelihood { 25 | public: 26 | TreeLikelihood( 27 | const std::array rho 28 | ) 29 | : rho_(rho) 30 | {}; 31 | 32 | /** 33 | * @brief Computes the natural log of the Beta function. 34 | * @param alpha The first parameter of the Beta function. 35 | * @param beta The second parameter of the Beta function. 36 | * @returns The natural log of the Beta function. 37 | */ 38 | static double logBeta( 39 | double alpha, 40 | double beta 41 | ) { 42 | return std::lgamma(alpha) + std::lgamma(beta) - std::lgamma(alpha + beta); 43 | } 44 | 45 | /** 46 | * @brief Computes the natural log likelihood of the provided binary label 47 | * counts in a particular leaf node. 48 | * @param labelCounts The number of points in the leaf node with each binary 49 | * label. 50 | * @returns The natural log likelihood of the provided binary label counts 51 | */ 52 | double logLikelihood( 53 | const std::array& labelCounts 54 | ) const; 55 | 56 | /** 57 | * @brief Computes the natural log likelihood of a perfect split of the 58 | * provided binary label counts. 59 | * @param labelCounts The count of points with each binary label. 60 | * @param rho The index of the Beta distribution prior for the Bernoulli 61 | * distribution in the node. 62 | * @returns The natural log likelihood of a perfect split of the. 63 | */ 64 | double logLikelihoodPerfectSplit( 65 | const std::array& labelCounts 66 | ) const; 67 | 68 | 69 | private: 70 | std::array rho_; 71 | }; 72 | 73 | #endif 74 | -------------------------------------------------------------------------------- /maptree/include/posterior/tree_prior.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file tree_prior.h 3 | * @brief This file contains the tree priors used in MAP tree search. 4 | * 5 | * Tree priors include: 6 | * - BCART: Constructive prior which assumes a probably of splitting the 7 | * tree that decreases exponentially as depth increases. This prior does not 8 | * support degenerate trees, or trees consisting of leaf nodes containing no 9 | * points in the training data. 10 | * - BCART-degen: Same as BCART prior but this prior supports degenerate 11 | * trees. 12 | * - Uniform: Uniform prior across all trees. Does not support degenerate 13 | * trees. 14 | * 15 | * @see https://www.jstor.org/stable/2669832 16 | */ 17 | 18 | #ifndef TREE_PRIOR_H 19 | #define TREE_PRIOR_H 20 | 21 | #include 22 | 23 | /** 24 | * @class TreePrior 25 | * @brief Interface for tree priors used in MAP tree search. 26 | */ 27 | class TreePrior { 28 | public: 29 | virtual ~TreePrior() = default; 30 | 31 | virtual double logSplitProb( 32 | size_t depth, 33 | size_t numValidSplits, 34 | size_t numFeatures 35 | ) const = 0; 36 | 37 | virtual double logStopProb( 38 | size_t depth, 39 | size_t numValidSplits, 40 | size_t numFeatures 41 | ) const = 0; 42 | }; 43 | 44 | class BCARTTreePrior : public TreePrior { 45 | public: 46 | BCARTTreePrior( 47 | double alpha, 48 | double beta 49 | ) 50 | : alpha_(alpha) 51 | , beta_(beta) 52 | {}; 53 | 54 | double logSplitProb(size_t depth, size_t numValidSplits, size_t numFeatures) const override; 55 | double logStopProb(size_t depth, size_t numValidSplits, size_t numFeatures) const override; 56 | private: 57 | double alpha_; 58 | double beta_; 59 | }; 60 | 61 | class BCARTDegenTreePrior : public TreePrior { 62 | public: 63 | BCARTDegenTreePrior( 64 | double alpha, 65 | double beta 66 | ) 67 | : alpha_(alpha) 68 | , beta_(beta) 69 | {}; 70 | 71 | double logSplitProb(size_t depth, size_t numValidSplits, size_t numFeatures) const override; 72 | double logStopProb(size_t depth, size_t numValidSplits, size_t numFeatures) const override; 73 | private: 74 | double alpha_; 75 | double beta_; 76 | }; 77 | 78 | class UniformTreePrior : public TreePrior { 79 | public: 80 | double logSplitProb(size_t depth, size_t numValidSplits, size_t numFeatures) const override; 81 | double logStopProb(size_t depth, size_t numValidSplits, size_t numFeatures) const override; 82 | }; 83 | 84 | #endif 85 | -------------------------------------------------------------------------------- /maptree/include/search/base_map_search.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file base_map_search.h 3 | * @brief Base class for MAP searchers. 4 | * 5 | * This file contains the base class for MAP searchers. MAP searchers are 6 | * searchers which search for the maximum a posteriori (MAP) decision tree. 7 | * The base class handles common functionality of MAP searchers, including 8 | * data preprocessing. 9 | */ 10 | 11 | #ifndef BASE_MAP_SEARCH_H 12 | #define BASE_MAP_SEARCH_H 13 | 14 | #include 15 | #include 16 | #include 17 | #include "data/data_manager.h" 18 | #include "solution/solution.h" 19 | #include "posterior/tree_prior.h" 20 | #include "posterior/tree_likelihood.h" 21 | /** 22 | * @class BaseMAPSearch 23 | * @brief Base class for MAP searchers. 24 | */ 25 | class BaseMAPSearch { 26 | public: 27 | BaseMAPSearch( 28 | const DataManager& dm, 29 | const TreeLikelihood& likelihood, 30 | const TreePrior& prior 31 | ) 32 | : dm_(dm) 33 | , likelihood_(likelihood) 34 | , prior_(prior) 35 | {}; 36 | virtual ~BaseMAPSearch() = default; 37 | 38 | /** 39 | * @brief Finds the MAP decision tree. 40 | * @returns a Solution object containing the MAP decision tree or the 41 | * best tree found thus far, as well as an upper and lower bound on the 42 | * unnormalized log posterior probability of the tree. 43 | */ 44 | virtual Solution search() = 0; 45 | 46 | /** 47 | * @brief Computes the lower bound for a subproblem based on its label 48 | * counts, depth, and number of valid splits. 49 | * @param labelCounts 50 | * @param depth 51 | * @param numValidSplits 52 | * @returns The lower bound. 53 | */ 54 | double getLowerBound( 55 | const std::array& labelCounts, 56 | size_t depth, 57 | size_t numValidSplits = UNKNOWN_VALID_SPLITS 58 | ) const; 59 | 60 | /** 61 | * @brief Computes the upper bound for a subproblem based on its label 62 | * counts, depth, and number of valid splits. 63 | * @param labelCounts 64 | * @param depth 65 | * @param numValidSplits 66 | * @returns The lower bound. 67 | */ 68 | double getUpperBound( 69 | const std::array& labelCounts, 70 | size_t depth, 71 | size_t numValidSplits = UNKNOWN_VALID_SPLITS 72 | ) const; 73 | 74 | protected: 75 | const DataManager& dm_; 76 | const TreeLikelihood& likelihood_; 77 | const TreePrior& prior_; 78 | 79 | private: 80 | static constexpr size_t UNKNOWN_VALID_SPLITS = 498126491684794917; 81 | }; 82 | 83 | #endif 84 | -------------------------------------------------------------------------------- /maptree/include/search/befs_map_search.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file befs_map_search.h 3 | * @brief Best-first search for MAP tree. 4 | * 5 | * This file contains the BestFirstSearchMAPSearch class, which implements a 6 | * best-first search for the MAP tree. The algorithm is adapted from AO* and 7 | * utilizes an admissible heuristic based on the "perfect split" lower bound. 8 | */ 9 | 10 | #ifndef BEFS_MAP_SEARCH_H 11 | #define BEFS_MAP_SEARCH_H 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include "constants.h" 18 | #include "subproblem.h" 19 | #include "cache/approx_bitset_cache.h" 20 | #include "search/base_map_search.h" 21 | #include "solution/decision_tree.h" 22 | 23 | //! Forward declaration of AndNode for use in OrNode. 24 | struct AndNode; 25 | 26 | /** 27 | * @struct OrNode 28 | * @brief Represents an OR node in the explicit AND/OR search graph or a 29 | * subproblem in the optimal decision tree search problem. 30 | * 31 | * This struct represents an OR node in the explicit AND/OR search graph or a 32 | * subproblem in the optimal decision tree search problem. The OR node contains 33 | * the lower and upper bounds on the unnormalized log posterior for the 34 | * subproblem, as well as the children of the OR node. The children of the OR 35 | * node are AND nodes which represents valid splits of the subproblem. The OR 36 | * node also contains a pointer to the AND nodes with the best upper/lower 37 | * bounds. These are used to efficiently identify the next tree to expand and 38 | * the best tree found so far. The OR node also contains a list of pointers to 39 | * its parents, which is used to efficiently backpropagate upper/lower bounds 40 | * through the explicit search graph. 41 | */ 42 | struct OrNode { 43 | size_t depth; 44 | double lowerBound; 45 | double upperBound; 46 | bool expanded; 47 | AndNode *childWithBestLB; 48 | AndNode *childWithBestUB; 49 | std::vector children; 50 | std::forward_list parents; 51 | bool isSolved() { 52 | return lowerBound == upperBound; 53 | }; 54 | }; 55 | 56 | /** 57 | * @struct AndNode 58 | * @brief Represents an AND node in the explicit AND/OR search graph or a 59 | * split of its parent subproblem in the optimal decision tree search problem. 60 | * 61 | * This struct represents an AND node in the explicit AND/OR search graph or a 62 | * split of its parent subproblem in the optimal decision tree search problem. 63 | * The AND node contains the feature it splits on, as well as the two resulting 64 | * OR Nodes representing child subproblems (left is feature = 0, right is 65 | * feature = 1). The AND node also contains a pointer to its parent OR node, 66 | * which is used to efficiently backpropagate upper/lower bounds through the 67 | * explicit search graph. 68 | */ 69 | struct AndNode { 70 | size_t feature; 71 | OrNode *leftChild; 72 | OrNode *rightChild; 73 | OrNode *parent; 74 | bool isSolved() { 75 | return leftChild != nullptr && leftChild->isSolved() && rightChild != nullptr && rightChild->isSolved(); 76 | }; 77 | }; 78 | 79 | /** 80 | * @class BestFirstSearchMAPSearch 81 | * @brief Best-first search for MAP tree. 82 | * @implements BaseMAPSearch 83 | * 84 | * This class implements a best-first search for the MAP tree. The algorithm is 85 | * adapted from AO* and utilizes an admissible heuristic based on the "perfect 86 | * split" lower bound. It can be summarized loosely as follows: 87 | * 88 | * 1. Initialize the explicit graph with the full subproblem as the root OR 89 | * Node. 90 | * 2. While the root OR Node is not solved and the expansion limit and time 91 | * limit have not been reached: 92 | * 2a. Find an unexpanded leaf of the tree in the explicity graph with the 93 | * lowest lower bound. 94 | * 2b. Expand this leaf, adding its children to the explicit graph. 95 | * 2c. Update the bounds of this leaf and its descendants. 96 | * 3. Return the best tree found so far. 97 | */ 98 | class BestFirstSearchMAPSearch : BaseMAPSearch { 99 | public: 100 | static constexpr int INF_EXPANSIONS = -1; 101 | static constexpr int INF_TIME_LIMIT = -1; 102 | 103 | BestFirstSearchMAPSearch( 104 | const DataManager& dm, 105 | const TreeLikelihood& likelihood, 106 | const TreePrior& prior, 107 | int numExpansions = INF_EXPANSIONS, 108 | int timeLimit = INF_TIME_LIMIT 109 | ) 110 | : BaseMAPSearch(dm, likelihood, prior) 111 | , cache_(NUM_BLOCKS(dm_.getNumSamples())) 112 | , expansionLimit_(numExpansions) 113 | , timeLimit_(timeLimit) 114 | , subproblem_(dm_) 115 | , rootNode_(buildNode(subproblem_.getLabelCounts(), 0)) 116 | {}; 117 | ~BestFirstSearchMAPSearch() override { 118 | for (OrNode *orNode : orNodes_) delete orNode; 119 | for (AndNode *andNode : andNodes_) delete andNode; 120 | }; 121 | Solution search() override; 122 | 123 | private: 124 | ApproxBitsetCache cache_; 125 | std::forward_list orNodes_ = std::forward_list(); 126 | std::forward_list andNodes_ = std::forward_list(); 127 | int expansionLimit_; 128 | int timeLimit_; 129 | Subproblem subproblem_; 130 | OrNode *rootNode_; 131 | 132 | OrNode *buildNode(const std::array& labelCounts, size_t depth); 133 | OrNode *findExpandableLeaf(); 134 | void expand(OrNode *node); 135 | bool updateLowerBound(OrNode *node); 136 | void backpropagateLowerBound(OrNode *source); 137 | void backpropagateUpperBound(OrNode *source); 138 | DecisionTree *buildDecisionTree(OrNode *node); 139 | }; 140 | 141 | #endif 142 | -------------------------------------------------------------------------------- /maptree/include/solution/decision_tree.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file decision_tree.h 3 | * 4 | * @brief Defines the DecisionTree class. 5 | * 6 | * This file contains the definition of the DecisionTree class, which is used to 7 | * represent a decision tree. The decision tree is represented as a binary tree 8 | * where each internal node represent a binary split on a feature. 9 | */ 10 | 11 | #ifndef DECISION_TREE_H 12 | #define DECISION_TREE_H 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | /** 21 | * @class DecisionTree 22 | * @brief Represents a decision tree. 23 | * 24 | * This class contains the definition of the DecisionTree class, consisting of 25 | * the feature that the root node splits on and pointers to the left and right 26 | * subtrees. If the tree is a leaf, then the feature is set to NO_FEATURE and 27 | * the left and right subtrees are null. 28 | */ 29 | class DecisionTree { 30 | public: 31 | static constexpr size_t NO_FEATURE = std::numeric_limits::max(); 32 | 33 | size_t feature; 34 | DecisionTree *left, *right; 35 | 36 | DecisionTree( 37 | size_t feature = NO_FEATURE, 38 | DecisionTree *left = nullptr, 39 | DecisionTree *right = nullptr 40 | ) 41 | : feature(feature) 42 | , left(left) 43 | , right(right) 44 | { 45 | assert((left == nullptr) == (right == nullptr)); 46 | assert((feature == NO_FEATURE) == isLeaf()); 47 | }; 48 | 49 | ~DecisionTree() { 50 | if (left != nullptr) delete left; 51 | if (right != nullptr) delete right; 52 | }; 53 | 54 | /** 55 | * @brief Checks if the tree is a leaf 56 | */ 57 | bool isLeaf() const; 58 | 59 | /** 60 | * @brief Returns a string representation of the tree. 61 | * @returns A string representation of the tree. 62 | * 63 | * Leaf nodes are represented as an empty string, and internal nodes are 64 | * represented as a string of the form: 65 | * 66 | * "()". 67 | * 68 | * For example, see the following tree: 69 | * 70 | * 1 71 | * / \ 72 | * 5 9 ------> "((5)1(9))" 73 | * / \ / \ 74 | * . . . . 75 | * 76 | */ 77 | std::string toString() const; 78 | 79 | /** 80 | * @brief Returns a string representation of the tree. 81 | * @see DecisionTree::toString 82 | */ 83 | friend std::ostream& operator<<(std::ostream& os, const DecisionTree& tree) { 84 | os << tree.toString(); 85 | return os; 86 | }; 87 | }; 88 | 89 | #endif -------------------------------------------------------------------------------- /maptree/include/solution/solution.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file solution.h 3 | * @brief Defines the Solution struct. 4 | * 5 | * This file contains the definition of the Solution struct, which is used to 6 | * return the results of a search. The Solution struct contains the unnormalized 7 | * log posterior upper and lower bounds and a string representation of the 8 | * output tree. 9 | */ 10 | 11 | #ifndef SOLUTION_H 12 | #define SOLUTION_H 13 | 14 | #include 15 | 16 | /** 17 | * @struct Solution 18 | * @brief Contains the results of a search. 19 | */ 20 | struct Solution { 21 | double lowerBound; 22 | double upperBound; 23 | std::string treeRepresentation; 24 | }; 25 | 26 | #endif -------------------------------------------------------------------------------- /maptree/include/subproblem.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file subproblem.h 3 | * @brief Decision tree search subproblem. 4 | * 5 | * This file contains a class for storing a subproblem for the decision tree 6 | * search problem. A subproblem is defined by a list of splits that have been 7 | * applied to the root node and the set of points remaining which is represented 8 | * by a Bitset object. 9 | */ 10 | 11 | #ifndef SUBPROBLEM_H 12 | #define SUBPROBLEM_H 13 | 14 | #include 15 | #include 16 | #include "data/split.h" 17 | #include "data/bitset.h" 18 | #include "data/data_manager.h" 19 | 20 | /** 21 | * @class Subproblem 22 | * @brief A subproblem for the decision tree search problem. 23 | */ 24 | class Subproblem { 25 | public: 26 | Subproblem( 27 | const DataManager& dm 28 | ) 29 | : dm_(dm) 30 | , path_() 31 | , bitset_(dm.getNumSamples(), dm.getNumFeatures()) 32 | {}; 33 | 34 | /** 35 | * @brief Returns the list of splits that this subproblem has taken. 36 | * @returns The list of this subproblem's splits 37 | */ 38 | const std::vector& getPath() const; 39 | 40 | /** 41 | * @brief Returns the subproblem's bitset. 42 | * @returns The subproblem's bitset. 43 | */ 44 | const Bitset& getBitset() const; 45 | 46 | /** 47 | * @brief Returns the features on which this subproblem can be split. 48 | * @returns A list of the features on which this subproblem can be 49 | * split. 50 | */ 51 | const std::vector& getValidSplits(); 52 | 53 | /** 54 | * @brief Returns the label counts of the subproblem. 55 | * @returns The label counts for the subproblem. 56 | */ 57 | const std::array& getLabelCounts(); 58 | 59 | /** 60 | * @brief Returns the depth of the subproblem. 61 | * @returns The depth of the subproblem. 62 | */ 63 | size_t getDepth() const; 64 | 65 | /** 66 | * @brief Applies the provided feature split to the given subproblem 67 | * choosing the subsubproblem with the provided value. 68 | * @param value 69 | */ 70 | void applySplit( 71 | size_t feature, 72 | bool value 73 | ); 74 | 75 | /** 76 | * @brief Reverts the last split applied to the subproblem. 77 | * @returns void 78 | */ 79 | void revertSplit(); 80 | 81 | /** 82 | * @brief Resets subproblem to initial state. 83 | * 84 | * This method resets the subproblem to its initial state by resetting 85 | * its bitset and clearing its split list. 86 | */ 87 | void reset(); 88 | 89 | private: 90 | const DataManager& dm_; 91 | std::vector path_; 92 | Bitset bitset_; 93 | std::array labelCounts_; 94 | std::vector validSplits; 95 | 96 | bool hasLabelCounts_ = false; 97 | bool hasValidSplits_ = false; 98 | }; 99 | 100 | #endif 101 | -------------------------------------------------------------------------------- /maptree/src/cache/approx_bitset_cache.cpp: -------------------------------------------------------------------------------- 1 | #include "cache/approx_bitset_cache.h" 2 | #include 3 | 4 | constexpr std::array ApproxBitsetCache::BLOCK_MULT_BASE; 5 | constexpr BLOCK ApproxBitsetCache::DEPTH_MULT; 6 | 7 | bool operator==(const ApproxBitsetCacheKey &lhs, const ApproxBitsetCacheKey &rhs) { 8 | return lhs.hashedBitset == rhs.hashedBitset && lhs.depth == rhs.depth; 9 | } 10 | 11 | size_t ApproxBitsetCacheKeyHash::operator()(const ApproxBitsetCacheKey &key) const { 12 | size_t hash = key.depth * ApproxBitsetCache::DEPTH_MULT; 13 | for (size_t i = 0; i < APPROX_BITSET_CACHE_NUM_ULL_HASH_VALUES; i++) { 14 | hash ^= key.hashedBitset[i]; 15 | } 16 | return hash; 17 | } 18 | 19 | void ApproxBitsetCache::put(Subproblem& subproblem, void *value) { 20 | ApproxBitsetCacheKey key = constructKey(subproblem); 21 | cache_.insert({key, value}); 22 | } 23 | 24 | void *ApproxBitsetCache::get(Subproblem& subproblem) { 25 | ApproxBitsetCacheKey key = constructKey(subproblem); 26 | auto entry = cache_.find(key); 27 | if (entry == cache_.end()) return nullptr; 28 | return entry->second; 29 | } 30 | 31 | size_t ApproxBitsetCache::size() const { 32 | return cache_.size(); 33 | } 34 | 35 | ApproxBitsetCacheKey ApproxBitsetCache::constructKey(Subproblem& subproblem) const { 36 | ApproxBitsetCacheKey key; 37 | key.depth = subproblem.getDepth(); 38 | for (size_t i = 0; i < APPROX_BITSET_CACHE_NUM_ULL_HASH_VALUES; i++) { 39 | key.hashedBitset[i] = subproblem.getBitset().sumOfBlocks(blockMults_[i]); 40 | } 41 | return key; 42 | } 43 | 44 | void ApproxBitsetCache::initBlockMults(size_t numBlocks) { 45 | for (size_t p = 0; p < APPROX_BITSET_CACHE_NUM_ULL_HASH_VALUES; p++) { 46 | blockMults_[p].resize(numBlocks); 47 | blockMults_[p][0] = BLOCK_MULT_BASE[p]; 48 | for (size_t i = 1; i < numBlocks; i++) { 49 | blockMults_[p][i] = blockMults_[p][i - 1] * blockMults_[p][0]; 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /maptree/src/data/binary_data_loader.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "data/binary_data_loader.h" 3 | 4 | const std::vector>& BinaryDataLoader::getFeatures() const { 5 | return features_; 6 | } 7 | 8 | const std::vector& BinaryDataLoader::getLabels() const { 9 | return labels_; 10 | } 11 | 12 | void BinaryDataLoader::load() { 13 | std::ifstream file(filename_); 14 | if (!file) { 15 | throw std::runtime_error("Could not open file " + filename_); 16 | } 17 | 18 | std::string line; 19 | size_t lastLineSize = 0; 20 | size_t lineNum = 0; 21 | while (std::getline(file, line)) { 22 | std::vector sample; 23 | for (char c : line) { 24 | if (c == '0' || c == '1') sample.push_back(c == '1'); 25 | } 26 | if (sample.size() < 2) continue; 27 | if (lastLineSize != 0 && sample.size() != lastLineSize) { 28 | throw std::runtime_error("Inconsistent sample size on line " + std::to_string(lineNum) + " of " + filename_); 29 | } 30 | features_.push_back({sample.begin() + 1, sample.end()}); 31 | labels_.push_back(sample[0]); 32 | lastLineSize = sample.size(); 33 | lineNum++; 34 | } 35 | file.close(); 36 | } -------------------------------------------------------------------------------- /maptree/src/data/bitset.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "data/bitset.h" 4 | 5 | size_t Bitset::level() const { 6 | return level_; 7 | } 8 | 9 | int Bitset::count() const { 10 | int count = 0; 11 | size_t idx; 12 | for (size_t i = 0; i < limit_.get(); i++) { 13 | idx = indices_[i]; 14 | count += blocks_[idx].countBits(); 15 | } 16 | return count; 17 | } 18 | 19 | int Bitset::countIntersection(const FixedBitset& other) const { 20 | int count = 0; 21 | size_t idx; 22 | for (size_t i = 0; i < limit_.get(); i++) { 23 | idx = indices_[i]; 24 | count += blocks_[idx].countBitsAtIntersection(other.getBlock(idx)); 25 | } 26 | return count; 27 | } 28 | 29 | bool Bitset::isSubset(const FixedBitset& other) const { 30 | size_t idx; 31 | for (size_t i = 0; i < limit_.get(); i++) { 32 | idx = indices_[i]; 33 | if (!blocks_[idx].isSubset(other.getBlock(idx))) return false; 34 | } 35 | return true; 36 | } 37 | 38 | void Bitset::intersect(const FixedBitset& other) { 39 | assert(level_ + 1 < maxLevel_); 40 | size_t limit = limit_.get(); 41 | size_t idx; 42 | if (limit == 0) return; 43 | for (size_t i = limit; i--;) { 44 | idx = indices_[i]; 45 | blocks_[idx].intersect(other.getBlock(idx)); 46 | if (blocks_[idx].empty()) { 47 | assert(limit > 0); 48 | limit--; 49 | indices_[i] = indices_[limit]; 50 | indices_[limit] = idx; 51 | } 52 | } 53 | limit_.update(limit); 54 | level_++; 55 | } 56 | 57 | void Bitset::reverse() { 58 | size_t idx; 59 | limit_.reverse(); 60 | for (size_t i = 0; i < limit_.get(); i++) { 61 | idx = indices_[i]; 62 | blocks_[idx].reverse(); 63 | } 64 | level_--; 65 | } 66 | 67 | void Bitset::reset() { 68 | limit_.reset(); 69 | for (size_t i = 0; i < numBlocks_; i++) { 70 | blocks_[i].reset(); 71 | } 72 | level_ = 0; 73 | } 74 | 75 | BLOCK Bitset::sumOfBlocks(const std::vector& blockMults) const { 76 | size_t idx; 77 | BLOCK sum = 0; 78 | for (size_t i = 0; i < limit_.get(); i++) { 79 | idx = indices_[i]; 80 | sum += blocks_[idx].get() * blockMults[idx]; 81 | } 82 | return sum; 83 | } 84 | 85 | std::ostream& operator<<(std::ostream& os, const Bitset& bitset) { 86 | os << "[ "; 87 | for (size_t i = 0; i < bitset.limit_.get(); i++) { 88 | os << "(" << bitset.indices_[i] << ": " << bitset.blocks_[bitset.indices_[i]].get() << ") "; 89 | } 90 | os << "]"; 91 | return os; 92 | } -------------------------------------------------------------------------------- /maptree/src/data/data_manager.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "data/data_manager.h" 4 | 5 | size_t DataManager::getNumFeatures() const { 6 | return numFeatures_; 7 | } 8 | 9 | size_t DataManager::getNumSamples() const { 10 | return numSamples_; 11 | } 12 | 13 | const FixedBitset& DataManager::getFeatureMask(size_t feature, bool value) const { 14 | return featureMasks_[feature * 2 + value]; 15 | } 16 | 17 | const FixedBitset& DataManager::getLabelMask(bool value) const { 18 | return labelMasks_[value]; 19 | } 20 | 21 | void DataManager::buildFeatureMasks(const std::vector>& features) { 22 | std::vector featureValues(numSamples_); 23 | for (size_t f = 0; f < numFeatures_; f++) { 24 | for (size_t i = 0; i < numSamples_; i++) { 25 | featureValues[i] = features[i][f]; 26 | } 27 | featureMasks_[f * 2 + 1].setBits(featureValues); 28 | featureValues.flip(); 29 | featureMasks_[f * 2].setBits(featureValues); 30 | } 31 | } 32 | 33 | void DataManager::buildLabelMasks(const std::vector& labels) { 34 | std::vector labelValues(labels); 35 | labelMasks_[1].setBits(labelValues); 36 | labelValues.flip(); 37 | labelMasks_[0].setBits(labelValues); 38 | } -------------------------------------------------------------------------------- /maptree/src/data/fixed_bitset.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "data/fixed_bitset.h" 4 | 5 | void FixedBitset::setBits(const std::vector& bits) { 6 | assert(bits.size() == numSamples_); 7 | for (size_t i = 0; i < blocks_.size(); i++) { 8 | for (size_t j = 0; j < BLOCK_BITS; j++) { 9 | size_t idx = i * BLOCK_BITS + j; 10 | if (idx >= bits.size()) break; 11 | if (bits[idx]) { 12 | blocks_[i] |= 1ULL << j; 13 | } 14 | } 15 | } 16 | } 17 | 18 | BLOCK FixedBitset::getBlock(size_t idx) const { 19 | return blocks_[idx]; 20 | } -------------------------------------------------------------------------------- /maptree/src/data/rnumber.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "data/rnumber.h" 4 | 5 | #ifdef _MSC_VER 6 | # include 7 | # define __builtin_popcountll __popcnt64 8 | #endif 9 | 10 | BLOCK RNumber::get() const { 11 | assert(head_ < values_.size()); 12 | return values_[head_]; 13 | } 14 | 15 | void RNumber::set(BLOCK value) { 16 | assert(head_ < values_.size()); 17 | values_[head_] = value; 18 | } 19 | 20 | void RNumber::update(BLOCK value) { 21 | assert(head_ < values_.size()); 22 | values_[++head_] = value; 23 | } 24 | 25 | void RNumber::intersect(BLOCK other) { 26 | assert(head_ < values_.size()); 27 | values_[head_ + 1] = other & get(); 28 | head_++; 29 | } 30 | 31 | void RNumber::reverse() { 32 | assert(head_ > 0); 33 | head_--; 34 | } 35 | 36 | void RNumber::reset() { 37 | head_ = 0; 38 | } 39 | 40 | int RNumber::countBits() const { 41 | return __builtin_popcountll(get()); 42 | } 43 | 44 | int RNumber::countBitsAtIntersection(BLOCK other) const { 45 | return __builtin_popcountll(get() & other); 46 | } 47 | 48 | bool RNumber::isSubset(BLOCK other) const { 49 | return (get() & other) == get(); 50 | } 51 | 52 | bool RNumber::empty() const { 53 | return get() == 0; 54 | } 55 | 56 | bool operator==(const RNumber& lhs, const RNumber& rhs) { 57 | return lhs.get() == rhs.get(); 58 | } -------------------------------------------------------------------------------- /maptree/src/posterior/tree_likelihood.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "posterior/tree_likelihood.h" 6 | 7 | double TreeLikelihood::logLikelihood(const std::array& labelCounts) const { 8 | return logBeta( 9 | static_cast(labelCounts[0]) + rho_[0], 10 | static_cast(labelCounts[1]) + rho_[1] 11 | ) - logBeta(rho_[0], rho_[1]); 12 | } 13 | 14 | double TreeLikelihood::logLikelihoodPerfectSplit(const std::array& labelCounts) const { 15 | return logBeta(static_cast(labelCounts[0]) + rho_[0], rho_[1]) \ 16 | + logBeta(rho_[0], static_cast(labelCounts[1]) + rho_[1]) \ 17 | - 2 * logBeta(rho_[0], rho_[1]); 18 | } -------------------------------------------------------------------------------- /maptree/src/posterior/tree_prior.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "posterior/tree_prior.h" 4 | 5 | double BCARTTreePrior::logSplitProb( 6 | size_t depth, 7 | size_t numValidSplits, 8 | [[maybe_unused]] size_t numFeatures 9 | ) const { 10 | double logSplitProb = std::log(alpha_) - beta_ * std::log(depth + 1); 11 | return logSplitProb - std::log(numValidSplits); 12 | } 13 | 14 | double BCARTTreePrior::logStopProb( 15 | size_t depth, 16 | size_t numValidSplits, 17 | [[maybe_unused]] size_t numFeatures 18 | ) const { 19 | if (numValidSplits == 0) return 0.0; 20 | double logSplitProb = std::log(alpha_) - beta_ * std::log(depth + 1); 21 | return std::log(1.0 - std::exp(logSplitProb)); 22 | } 23 | 24 | double BCARTDegenTreePrior::logSplitProb( 25 | size_t depth, 26 | [[maybe_unused]] size_t numValidSplits, 27 | size_t numFeatures 28 | ) const { 29 | double logSplitProb = std::log(alpha_) - beta_ * std::log(depth + 1); 30 | return logSplitProb - std::log(numFeatures); 31 | } 32 | 33 | double BCARTDegenTreePrior::logStopProb( 34 | size_t depth, 35 | [[maybe_unused]] size_t numValidSplits, 36 | [[maybe_unused]] size_t numFeatures 37 | ) const { 38 | double logSplitProb = std::log(alpha_) - beta_ * std::log(depth + 1); 39 | return std::log(1.0 - std::exp(logSplitProb)); 40 | } 41 | 42 | double UniformTreePrior::logSplitProb( 43 | [[maybe_unused]] size_t depth, 44 | [[maybe_unused]] size_t numValidSplits, 45 | [[maybe_unused]] size_t numFeatures 46 | ) const { 47 | return 0.0; 48 | } 49 | 50 | double UniformTreePrior::logStopProb( 51 | [[maybe_unused]] size_t depth, 52 | [[maybe_unused]] size_t numValidSplits, 53 | [[maybe_unused]] size_t numFeatures 54 | ) const { 55 | return 0.0; 56 | } -------------------------------------------------------------------------------- /maptree/src/python_bindings/search_bindings.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file search_bindings.cpp 3 | * @brief Set up bindings for search functions to Python with pybind11. 4 | * 5 | * This file contains the bindings for the search functions to Python with 6 | * pybind11. The bindings are exported as a Python module called maptree. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include "search/befs_map_search.h" 17 | #include "solution/solution.h" 18 | #include "data/data_manager.h" 19 | #include "posterior/tree_prior.h" 20 | #include "posterior/tree_likelihood.h" 21 | 22 | namespace py = pybind11; 23 | 24 | /** 25 | * @brief MAP Tree search function 26 | * @param features (num samples) x (num features) 2D boolean vector of features. 27 | * @param labels (num samples) 1D boolean vector of labels. 28 | * @param alpha The alpha parameter of the constructive BCART prior. 29 | * @param beta The beta parameter of the constructive BCART prior. 30 | * @param rho A 2-item array indexing BCART's Beta distribution prior for the 31 | * Bernoulli distributions in each of the leaf nodes. 32 | * @param numExpansions The maximum number of expansions to perform. 33 | * @param timeLimit The time limit in seconds. If -1, no time limit. 34 | * @param degen Whether or not the BCART prior should support degenerate trees. 35 | * Note that it is still guaranteed that a degenerate tree will not be returned. 36 | * @returns A Solution object containing the unnormalized log posterior upper/ 37 | * lower bound and a string representation of the output tree. 38 | * 39 | * The search function uses the best-first search algorithm to find the MAP 40 | * tree. 41 | * 42 | * @see BestFirstSearchMAPSearch 43 | */ 44 | Solution searchBeFS( 45 | std::vector> features, 46 | std::vector labels, 47 | double alpha, 48 | double beta, 49 | std::array rho, 50 | int numExpansions, 51 | int timeLimit, 52 | bool degen 53 | ) 54 | { 55 | DataManager dm(features, labels); 56 | 57 | TreePrior *prior; 58 | prior = degen 59 | ? static_cast(new BCARTDegenTreePrior(alpha, beta)) 60 | : static_cast(new BCARTTreePrior(alpha, beta)); 61 | 62 | TreeLikelihood likelihood(rho); 63 | BestFirstSearchMAPSearch searchObj(dm, likelihood, *prior, numExpansions, timeLimit); 64 | 65 | Solution result = searchObj.search(); 66 | delete prior; 67 | 68 | return result; 69 | } 70 | 71 | //! Here, we define the maptree Python module, binding the search function. 72 | PYBIND11_MODULE(maptree, m) { 73 | m.doc() = "MAP tree search binding"; 74 | 75 | m.def( 76 | "search", 77 | &searchBeFS, 78 | "Best first search", 79 | py::arg("features"), 80 | py::arg("labels"), 81 | py::arg("alpha"), 82 | py::arg("beta"), 83 | py::arg("rho"), 84 | py::arg("numExpansions")=BestFirstSearchMAPSearch::INF_EXPANSIONS, 85 | py::arg("timeLimit")=BestFirstSearchMAPSearch::INF_TIME_LIMIT, 86 | py::arg("degen")=false 87 | ); 88 | 89 | py::class_(m, "Solution") \ 90 | .def_readwrite("lb", &Solution::lowerBound) \ 91 | .def_readwrite("ub", &Solution::upperBound) \ 92 | .def_readwrite("tree", &Solution::treeRepresentation); 93 | } -------------------------------------------------------------------------------- /maptree/src/search/base_map_search.cpp: -------------------------------------------------------------------------------- 1 | #include "search/base_map_search.h" 2 | 3 | constexpr size_t BaseMAPSearch::UNKNOWN_VALID_SPLITS; 4 | 5 | double BaseMAPSearch::getLowerBound( 6 | const std::array& labelCounts, 7 | size_t depth, 8 | size_t numValidSplits 9 | ) const { 10 | double perfectSplitValue = -( 11 | prior_.logSplitProb( 12 | depth, 13 | numValidSplits == UNKNOWN_VALID_SPLITS ? 1 : numValidSplits, 14 | dm_.getNumFeatures()) 15 | + 2 * prior_.logStopProb( 16 | depth + 1, 17 | 0, 18 | dm_.getNumFeatures()) 19 | + likelihood_.logLikelihoodPerfectSplit(labelCounts) 20 | ); 21 | double stopValue = -( 22 | prior_.logStopProb( 23 | depth, 24 | numValidSplits == UNKNOWN_VALID_SPLITS ? 0 : numValidSplits, 25 | dm_.getNumFeatures()) 26 | + likelihood_.logLikelihood(labelCounts) 27 | ); 28 | 29 | return std::min(perfectSplitValue, stopValue); 30 | } 31 | 32 | double BaseMAPSearch::getUpperBound( 33 | const std::array& labelCounts, 34 | size_t depth, 35 | size_t numValidSplits 36 | ) const { 37 | return -( 38 | prior_.logStopProb( 39 | depth, 40 | numValidSplits == UNKNOWN_VALID_SPLITS ? 1 : numValidSplits, 41 | dm_.getNumFeatures()) 42 | + likelihood_.logLikelihood(labelCounts) 43 | ); 44 | } -------------------------------------------------------------------------------- /maptree/src/search/befs_map_search.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "search/befs_map_search.h" 8 | 9 | constexpr int BestFirstSearchMAPSearch::INF_EXPANSIONS; 10 | constexpr int BestFirstSearchMAPSearch::INF_TIME_LIMIT; 11 | 12 | Solution BestFirstSearchMAPSearch::search() { 13 | bool hasExpansionLimit = expansionLimit_ != BestFirstSearchMAPSearch::INF_EXPANSIONS; 14 | bool hasTimeLimit = timeLimit_ != BestFirstSearchMAPSearch::INF_TIME_LIMIT; 15 | 16 | size_t expansionsRemaining = static_cast(expansionLimit_); 17 | long long timeLimit = static_cast(timeLimit_); 18 | std::chrono::time_point startTime = std::chrono::steady_clock::now(); 19 | long long secondsElapsed; 20 | 21 | while (!rootNode_->isSolved()) { 22 | subproblem_.reset(); 23 | OrNode *leaf = findExpandableLeaf(); 24 | expand(leaf); 25 | backpropagateLowerBound(leaf); 26 | backpropagateUpperBound(leaf); 27 | 28 | expansionsRemaining--; 29 | secondsElapsed = std::chrono::duration_cast( 30 | std::chrono::steady_clock::now() - startTime).count(); 31 | if (hasExpansionLimit && expansionsRemaining == 0) break; 32 | if (hasTimeLimit && secondsElapsed >= timeLimit) break; 33 | } 34 | 35 | DecisionTree *dt = buildDecisionTree(rootNode_); 36 | std::string treeRepresentation = dt->toString(); 37 | delete dt; 38 | 39 | return { 40 | rootNode_->lowerBound, 41 | rootNode_->upperBound, 42 | treeRepresentation 43 | }; 44 | } 45 | 46 | OrNode *BestFirstSearchMAPSearch::buildNode(const std::array& labelCounts, size_t depth) { 47 | OrNode *node = new OrNode(); 48 | orNodes_.push_front(node); 49 | 50 | node->depth = depth; 51 | node->children = std::vector(0); 52 | node->parents = std::forward_list(0); 53 | node->childWithBestLB = node->childWithBestUB = nullptr; 54 | node->upperBound = getUpperBound(labelCounts, depth); 55 | node->lowerBound = getLowerBound(labelCounts, depth); 56 | assert(node->lowerBound > 0); 57 | node->expanded = false; 58 | 59 | return node; 60 | } 61 | 62 | OrNode *BestFirstSearchMAPSearch::findExpandableLeaf() { 63 | assert(subproblem_.getDepth() == 0); 64 | assert(!rootNode_->isSolved()); 65 | 66 | OrNode *node = rootNode_; 67 | AndNode *markedChild; 68 | double leftSpread, rightSpread; 69 | bool value; 70 | while (node->expanded) { 71 | markedChild = node->childWithBestLB; 72 | leftSpread = markedChild->leftChild->upperBound - markedChild->leftChild->lowerBound; 73 | rightSpread = markedChild->rightChild->upperBound - markedChild->rightChild->lowerBound; 74 | value = leftSpread < rightSpread; 75 | node = value ? markedChild->rightChild : markedChild->leftChild; 76 | subproblem_.applySplit(markedChild->feature, value); 77 | } 78 | 79 | // we should not end up at a solved node 80 | assert(!node->isSolved()); 81 | 82 | return node; 83 | } 84 | 85 | void BestFirstSearchMAPSearch::expand(OrNode *node) { 86 | assert(!node->expanded); 87 | 88 | node->expanded = true; 89 | 90 | const std::vector& validSplits = subproblem_.getValidSplits(); 91 | if (validSplits.empty()) { 92 | node->upperBound = node->lowerBound = getUpperBound(subproblem_.getLabelCounts(), node->depth, 0); 93 | return; 94 | } else { 95 | node->children.resize(validSplits.size()); 96 | } 97 | 98 | double splitPenalty = -prior_.logSplitProb(node->depth, validSplits.size(), dm_.getNumFeatures()); 99 | std::array outerLabelCounts = subproblem_.getLabelCounts(); 100 | 101 | double splitValue; 102 | AndNode *child; 103 | OrNode *subChild; 104 | std::array subChildLabelCounts; 105 | size_t childIdx = 0; 106 | for (size_t feature : validSplits) { 107 | child = new AndNode(); 108 | andNodes_.push_front(child); 109 | child->feature = feature; 110 | child->parent = node; 111 | for (bool value : {true, false}) { 112 | subproblem_.applySplit(feature, value); 113 | 114 | if (value) { 115 | subChildLabelCounts = subproblem_.getLabelCounts(); 116 | } else { 117 | subChildLabelCounts[0] = outerLabelCounts[0] - subChildLabelCounts[0]; 118 | subChildLabelCounts[1] = outerLabelCounts[1] - subChildLabelCounts[1]; 119 | } 120 | 121 | subChild = static_cast(cache_.get(subproblem_)); 122 | if (subChild == nullptr) { 123 | subChild = buildNode(subChildLabelCounts, node->depth + 1); 124 | cache_.put(subproblem_, subChild); 125 | } 126 | 127 | subChild->parents.push_front(child); 128 | if (value) { 129 | child->rightChild = subChild; 130 | } else { 131 | child->leftChild = subChild; 132 | } 133 | 134 | subproblem_.revertSplit(); 135 | } 136 | 137 | splitValue = child->leftChild->upperBound + child->rightChild->upperBound + splitPenalty; 138 | if (splitValue < node->upperBound) { 139 | node->upperBound = splitValue; 140 | node->childWithBestUB = child; 141 | } 142 | 143 | node->children[childIdx++] = child; 144 | } 145 | } 146 | 147 | bool BestFirstSearchMAPSearch::updateLowerBound(OrNode *node) { 148 | assert(node->expanded); 149 | 150 | double bestLowerBound = node->upperBound; 151 | node->childWithBestLB = nullptr; 152 | double splitPenalty = -prior_.logSplitProb(node->depth, node->children.size(), dm_.getNumFeatures()); 153 | 154 | double splitValueLowerBound; 155 | for (AndNode *child : node->children) { 156 | splitValueLowerBound = child->leftChild->lowerBound + child->rightChild->lowerBound + splitPenalty; 157 | if (splitValueLowerBound < bestLowerBound) { 158 | bestLowerBound = splitValueLowerBound; 159 | node->childWithBestLB = child; 160 | } 161 | } 162 | 163 | //! check perfect split heuristic addmissibility 164 | assert(bestLowerBound >= node->lowerBound); 165 | 166 | bool improvedLowerBound = bestLowerBound > node->lowerBound; 167 | node->lowerBound = bestLowerBound; 168 | 169 | return improvedLowerBound; 170 | } 171 | 172 | void BestFirstSearchMAPSearch::backpropagateLowerBound(OrNode *source) { 173 | std::set visited; 174 | std::queue toVisit; 175 | toVisit.push(source); 176 | visited.insert(source); 177 | OrNode *front; 178 | while (!toVisit.empty()) { 179 | front = toVisit.front(); 180 | toVisit.pop(); 181 | if (!updateLowerBound(front)) continue; 182 | for (AndNode *parent : front->parents) { 183 | if (visited.find(parent->parent) == visited.end() 184 | && !parent->parent->isSolved() 185 | && parent->parent->childWithBestLB == parent 186 | ) { 187 | toVisit.push(parent->parent); 188 | visited.insert(parent->parent); 189 | } 190 | } 191 | } 192 | } 193 | 194 | void BestFirstSearchMAPSearch::backpropagateUpperBound(OrNode *source) { 195 | std::set visited; 196 | std::queue toVisit; 197 | toVisit.push(source); 198 | visited.insert(source); 199 | OrNode *front; 200 | double splitPenalty; 201 | double splitValue; 202 | while (!toVisit.empty()) { 203 | front = toVisit.front(); 204 | toVisit.pop(); 205 | for (AndNode *parent : front->parents) { 206 | splitPenalty = -prior_.logSplitProb(parent->parent->depth, parent->parent->children.size(), dm_.getNumFeatures()); 207 | splitValue = parent->leftChild->upperBound + parent->rightChild->upperBound + splitPenalty; 208 | if (splitValue < parent->parent->upperBound) { 209 | parent->parent->upperBound = splitValue; 210 | parent->parent->childWithBestUB = parent; 211 | if (visited.find(parent->parent) == visited.end()) { 212 | toVisit.push(parent->parent); 213 | visited.insert(parent->parent); 214 | } 215 | } 216 | } 217 | } 218 | } 219 | 220 | DecisionTree *BestFirstSearchMAPSearch::buildDecisionTree(OrNode *node) { 221 | // no possible splits — return leaf 222 | if (node->children.empty() || !node->expanded || node->childWithBestUB == nullptr) { 223 | return new DecisionTree(); 224 | } 225 | 226 | AndNode *markedChild = node->childWithBestUB; 227 | return new DecisionTree( 228 | markedChild->feature, 229 | buildDecisionTree(markedChild->leftChild), 230 | buildDecisionTree(markedChild->rightChild) 231 | ); 232 | } 233 | -------------------------------------------------------------------------------- /maptree/src/solution/decision_tree.cpp: -------------------------------------------------------------------------------- 1 | #include "solution/decision_tree.h" 2 | 3 | constexpr size_t DecisionTree::NO_FEATURE; 4 | 5 | bool DecisionTree::isLeaf() const { 6 | return left == nullptr && right == nullptr; 7 | } 8 | 9 | std::string DecisionTree::toString() const { 10 | return isLeaf() ? "" : ("(" + left->toString() + std::to_string(feature) + right->toString() + ")"); 11 | } 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /maptree/src/subproblem.cpp: -------------------------------------------------------------------------------- 1 | #include "subproblem.h" 2 | 3 | const std::vector& Subproblem::getPath() const { 4 | return path_; 5 | } 6 | 7 | const Bitset& Subproblem::getBitset() const { 8 | return bitset_; 9 | } 10 | 11 | const std::vector& Subproblem::getValidSplits() { 12 | if (hasValidSplits_) return validSplits; 13 | validSplits.clear(); 14 | for (size_t f = 0; f < dm_.getNumFeatures(); f++) { 15 | if (!bitset_.isSubset(dm_.getFeatureMask(f, false)) 16 | && !bitset_.isSubset(dm_.getFeatureMask(f, true))) { 17 | validSplits.push_back(f); 18 | } 19 | } 20 | hasValidSplits_ = true; 21 | return validSplits; 22 | } 23 | 24 | const std::array& Subproblem::getLabelCounts() { 25 | if (hasLabelCounts_) return labelCounts_; 26 | int count = bitset_.count(); 27 | labelCounts_[1] = bitset_.countIntersection(dm_.getLabelMask(true)); 28 | labelCounts_[0] = count - labelCounts_[1]; 29 | hasLabelCounts_ = true; 30 | return labelCounts_; 31 | } 32 | 33 | size_t Subproblem::getDepth() const { 34 | return path_.size(); 35 | } 36 | 37 | void Subproblem::applySplit(size_t feature, bool value) { 38 | path_.push_back({feature, value}); 39 | bitset_.intersect(dm_.getFeatureMask(feature, value)); 40 | hasValidSplits_ = false; 41 | hasLabelCounts_ = false; 42 | } 43 | 44 | void Subproblem::revertSplit() { 45 | path_.pop_back(); 46 | bitset_.reverse(); 47 | hasValidSplits_ = false; 48 | hasLabelCounts_ = false; 49 | } 50 | 51 | void Subproblem::reset() { 52 | path_.clear(); 53 | bitset_.reset(); 54 | hasValidSplits_ = false; 55 | hasLabelCounts_ = false; 56 | } 57 | 58 | -------------------------------------------------------------------------------- /maptree/tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | 3 | set(TESTFILES 4 | main.cpp 5 | bcart/test_bcart_utils.cpp 6 | search/test_search.cpp 7 | data/test_fixed_bitset.cpp 8 | data/test_rnumber.cpp 9 | data/test_bitset.cpp 10 | data/test_data_manager.cpp 11 | ) 12 | 13 | set(TEST_MAIN unit_tests) 14 | add_executable(${TEST_MAIN} ${TESTFILES}) 15 | target_link_libraries(${TEST_MAIN} PRIVATE ${LIBRARY_NAME} doctest) 16 | set_target_properties(${TEST_MAIN} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) 17 | target_set_warnings(${TEST_MAIN} ENABLE ALL AS_ERROR ALL DISABLE Annoying) # Set warnings (if needed). 18 | 19 | set_target_properties(${TEST_MAIN} PROPERTIES 20 | CXX_STANDARD 17 21 | CXX_STANDARD_REQUIRED YES 22 | CXX_EXTENSIONS NO 23 | ) 24 | 25 | add_test( 26 | NAME ${LIBRARY_NAME}.${TEST_MAIN} 27 | COMMAND ${TEST_MAIN} 28 | ) 29 | 30 | include(CodeCoverage) 31 | -------------------------------------------------------------------------------- /maptree/tests/bcart/test_bcart_utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "doctest/doctest.h" 3 | 4 | #include "posterior/tree_likelihood.h" 5 | #include "posterior/tree_prior.h" 6 | 7 | TEST_CASE("prior/likelihood test") 8 | { 9 | std::array labelCounts = {3, 5}; 10 | std::array rho = {1.0, 1.0}; 11 | 12 | BCARTTreePrior prior = BCARTTreePrior(.95, .5); 13 | TreeLikelihood likelihood = TreeLikelihood(rho); 14 | 15 | CHECK(likelihood.logBeta(1, 1) == doctest::Approx(0)); 16 | CHECK(likelihood.logBeta(3, 5) == doctest::Approx(-4.65396)); 17 | CHECK(likelihood.logLikelihood(labelCounts) == doctest::Approx(-6.222576)); 18 | CHECK(prior.logSplitProb(5, 1, 1) == doctest::Approx(-0.94717)); 19 | CHECK(prior.logStopProb(5, 1, 1) == doctest::Approx(-0.490755)); 20 | } 21 | -------------------------------------------------------------------------------- /maptree/tests/main.cpp: -------------------------------------------------------------------------------- 1 | #define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN 2 | #include "doctest/doctest.h" 3 | -------------------------------------------------------------------------------- /maptree/tests/search/test_search.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "doctest/doctest.h" 5 | // #include "search/bnb_map_search.h" 6 | #include "search/befs_map_search.h" 7 | #include "data/binary_data_loader.h" 8 | 9 | using namespace std; 10 | 11 | TEST_CASE("search test on small dataset") 12 | { 13 | 14 | // test data generating tree 15 | // x_2 16 | // / \z 17 | // 0 / \ 1 18 | // (0: 0, 1: 19) x_3 19 | // / \z 20 | // 0 / \ 1 21 | // (0: 0, 1: 10) (0: 11, 1: 0) 22 | 23 | 24 | // alpha_s = 0.95 25 | // beta_s = 0.5 26 | // alpha = (1, 1) 27 | 28 | // log prior = ln((0.95/(sqrt(1) * 4)) * (0.95/(sqrt(2) * 3)) * (1 - 0.95/sqrt(2)) * (1 - 0.95/sqrt(3))^2) = -5.638 29 | // log likelihood = ln [(Beta(1, 20) * Beta(1, 11) * Beta(12, 1)) / Beta(1, 1)^3] = -7.879 30 | // log posterior = log prior + log likelihood = -13.517 31 | 32 | BinaryDataLoader bdl("data/test_data_small.txt"); 33 | double alpha = 0.95; 34 | double beta = 0.5; 35 | array rho = {1, 1}; 36 | 37 | DataManager dm(bdl.getFeatures(), bdl.getLabels()); 38 | TreeLikelihood likelihood(rho); 39 | BCARTTreePrior prior(alpha, beta); 40 | 41 | // BranchAndBoundMAPSearch bnbSearch(bdl.getFeatures(), bdl.getLabels(), alpha, beta, rho); 42 | BestFirstSearchMAPSearch befsSearch(dm, likelihood, prior); 43 | 44 | // Solution bnbResult = bnbSearch.search(); 45 | Solution befsResult = befsSearch.search(); 46 | 47 | // CHECK(bnbResult.upperBound == doctest::Approx(13.517)); 48 | CHECK(befsResult.upperBound == doctest::Approx(13.517)); 49 | } 50 | 51 | TEST_CASE("search test on medium dataset") 52 | { 53 | BinaryDataLoader bdl("data/test_data_medium.txt"); 54 | double alpha = 0.95; 55 | double beta = 0.5; 56 | array rho = {2.5, 2.5}; 57 | 58 | DataManager dm(bdl.getFeatures(), bdl.getLabels()); 59 | TreeLikelihood likelihood(rho); 60 | BCARTTreePrior prior(alpha, beta); 61 | 62 | BestFirstSearchMAPSearch befsSearch(dm, likelihood, prior); 63 | Solution befsResult = befsSearch.search(); 64 | CHECK(befsResult.upperBound == doctest::Approx(66.006945)); 65 | 66 | } 67 | -------------------------------------------------------------------------------- /maptree/tests/test_data/test_data_small.txt: -------------------------------------------------------------------------------- 1 | 1 0 1 0 1 2 | 0 0 1 1 0 3 | 0 1 1 1 0 4 | 0 0 1 1 1 5 | 1 1 0 1 1 6 | 0 0 1 1 0 7 | 1 0 0 1 1 8 | 0 0 1 1 1 9 | 1 0 1 0 1 10 | 1 0 0 0 1 11 | 1 0 1 0 1 12 | 1 0 0 1 0 13 | 0 0 1 1 1 14 | 1 0 1 0 1 15 | 0 0 1 1 0 16 | 1 1 0 0 1 17 | 1 1 0 1 1 18 | 1 0 0 1 1 19 | 1 1 0 1 1 20 | 1 1 0 1 1 21 | 1 0 1 0 1 22 | 1 1 0 0 0 23 | 1 0 1 0 0 24 | 1 0 0 1 0 25 | 1 0 0 1 1 26 | 1 0 0 1 1 27 | 1 0 0 1 1 28 | 0 0 1 1 0 29 | 1 0 0 0 1 30 | 0 0 1 1 1 31 | 1 0 0 0 0 32 | 1 1 0 0 0 33 | 1 0 1 0 0 34 | 1 0 0 1 1 35 | 1 0 0 0 0 36 | 1 1 1 0 1 37 | 1 0 1 0 1 38 | 1 0 1 0 1 39 | 0 1 1 1 0 40 | 0 1 1 1 1 -------------------------------------------------------------------------------- /plot_results.py: -------------------------------------------------------------------------------- 1 | from experiments.experiments.fig1.plotter import run as plot_fig1 2 | from experiments.experiments.fig2.plotter import run as plot_fig2 3 | from experiments.experiments.fig3.plotter import run as plot_fig3 4 | 5 | 6 | if __name__ == '__main__': 7 | plot_fig1() 8 | plot_fig2() 9 | plot_fig3() 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "pybind11~=2.6.1"] 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gosdt==0.1.7 2 | pydl8.5==0.1.8 3 | scikit-learn~=1.0 4 | setuptools~=68.0.0 5 | pybind11~=2.6.1 6 | numpy~=1.25.1 7 | pandas~=2.0.3 8 | matplotlib~=3.7.2 9 | scipy~=1.11.1 10 | requests~=2.29.0 11 | seaborn~=0.12.2 12 | sklearn==0.0.post9 13 | -------------------------------------------------------------------------------- /requirements_m1mac.txt: -------------------------------------------------------------------------------- 1 | pydl8.5==0.1.8 2 | scikit-learn~=1.0 3 | setuptools~=68.0.0 4 | pybind11~=2.6.1 5 | numpy~=1.25.1 6 | pandas~=2.0.3 7 | matplotlib~=3.7.2 8 | scipy~=1.11.1 9 | requests~=2.29.0 10 | seaborn~=0.12.2 11 | sklearn==0.0.post9 12 | -------------------------------------------------------------------------------- /run_experiment.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from experiments.experiments.fig1.runner import run as run_fig1 4 | from experiments.experiments.fig2.runner import run as run_fig2 5 | from experiments.experiments.fig3.runner import run as run_fig3 6 | 7 | from experiments.globals import CP4IM_DATASET_NAMES, SYNTH_NUM_TREES 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = ArgumentParser() 12 | parser.add_argument('--job_index', '-j', default=None, type=int) 13 | args = parser.parse_args() 14 | 15 | jobs = list(range(52)) if args.job_index is None else [args.job_index] 16 | 17 | for job in jobs: 18 | if job < 2 * len(CP4IM_DATASET_NAMES): 19 | run = run_fig1 if (job % 2 == 0) else run_fig2 20 | run(CP4IM_DATASET_NAMES[job // 2]) 21 | else: 22 | job -= 2 * len(CP4IM_DATASET_NAMES) 23 | tree_id = job % SYNTH_NUM_TREES 24 | run_fig3(tree_id) 25 | -------------------------------------------------------------------------------- /script.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=bdt-map 3 | #SBATCH --account=thrun 4 | #SBATCH --partition=thrun 5 | #SBATCH --time=1-0 6 | #SBATCH --mem-per-cpu=16GB 7 | #SBATCH --cpus-per-task=1 8 | #SBATCH --mail-type=ALL 9 | #SBATCH --array=0-51 10 | 11 | python -u run_experiment.py -j ${SLURM_ARRAY_TASK_ID} -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from setuptools import setup 3 | try: 4 | from pybind11.setup_helpers import Pybind11Extension 5 | except ImportError: 6 | from setuptools import Extension as Pybind11Extension 7 | from os.path import join 8 | 9 | __version__ = "0.0.1" 10 | 11 | 12 | class get_pybind_include(object): 13 | """ 14 | Helper class to determine the pybind11 include path. 15 | The purpose of this class is to postpone importing pybind11 16 | until it is actually installed via dataloader's setup_requires arg, 17 | so that the ``get_include()`` method can be invoked. 18 | """ 19 | def __str__(self): 20 | import pybind11 21 | 22 | return pybind11.get_include() 23 | 24 | 25 | ALL_SOURCE_FILES = sorted( 26 | glob(join("maptree", "src", "*", "*.cpp")) 27 | + glob(join("maptree", "src", "*.cpp")) 28 | ) 29 | 30 | ALL_HEADER_FILES = sorted( 31 | glob(join("maptree", "include", "*", "*.h")) 32 | + glob(join("maptree", "include", "*.h")) 33 | ) 34 | 35 | ext_modules = [ 36 | Pybind11Extension( 37 | "maptree", 38 | ALL_SOURCE_FILES, 39 | include_dirs=[get_pybind_include(), join("maptree", "include")], 40 | extra_compile_args=['-O3', '-DNDEBUG'], 41 | ), 42 | ] 43 | 44 | setup( 45 | name="maptree", 46 | version=__version__, 47 | author="Redacted", 48 | maintainer="Redacted", 49 | author_email="Redacted", 50 | ext_modules=ext_modules, 51 | setup_requires=[ 52 | "pybind11>=2.5.0", 53 | "numpy>=1.18", 54 | ], 55 | include_package_data=True, 56 | zip_safe=False, 57 | headers=ALL_HEADER_FILES, 58 | ) -------------------------------------------------------------------------------- /setup_data.py: -------------------------------------------------------------------------------- 1 | import experiments.dataloader.cp4im 2 | import experiments.dataloader.synth 3 | 4 | 5 | def main(): 6 | experiments.dataloader.cp4im.download_and_install() 7 | experiments.dataloader.synth.generate_synthetic_tree_data() 8 | 9 | 10 | if __name__ == '__main__': 11 | main() 12 | --------------------------------------------------------------------------------