├── hiclass ├── _calibration │ ├── __init__.py │ ├── BinaryCalibrator.py │ ├── calibration_utils.py │ ├── PlattScaling.py │ ├── IsotonicRegression.py │ ├── BetaCalibrator.py │ └── Calibrator.py ├── _hiclass_utils.py ├── probability_combiner │ ├── __init__.py │ ├── ProbabilityCombiner.py │ ├── MultiplyCombiner.py │ ├── ArithmeticMeanCombiner.py │ └── GeometricMeanCombiner.py ├── __init__.py ├── Pipeline.py ├── ConstantClassifier.py ├── datasets.py └── FlatClassifier.py ├── .gitattributes ├── tests ├── __init__.py ├── fixtures │ └── small_dag_edgelist.csv ├── test_FlatClassifier.py ├── test_ConstantClassifier.py ├── test_Datasets.py └── test_LocalClassifiers.py ├── .coveragerc ├── benchmarks └── consumer_complaints │ ├── results │ ├── flat │ │ ├── lightgbm │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── prediction_benchmark.txt │ │ │ ├── optimization_results.md │ │ │ └── training_benchmark.txt │ │ ├── random_forest │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── optimization_results.md │ │ │ ├── prediction_benchmark.txt │ │ │ └── training_benchmark.txt │ │ └── logistic_regression │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── optimization_results.md │ │ │ ├── prediction_benchmark.txt │ │ │ └── training_benchmark.txt │ ├── local_classifier_per_node │ │ ├── lightgbm │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── prediction_benchmark.txt │ │ │ ├── optimization_results.md │ │ │ └── training_benchmark.txt │ │ ├── random_forest │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── optimization_results.md │ │ │ ├── prediction_benchmark.txt │ │ │ └── training_benchmark.txt │ │ └── logistic_regression │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── optimization_results.md │ │ │ ├── prediction_benchmark.txt │ │ │ └── training_benchmark.txt │ ├── local_classifier_per_level │ │ ├── lightgbm │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── prediction_benchmark.txt │ │ │ ├── optimization_results.md │ │ │ └── training_benchmark.txt │ │ ├── random_forest │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── optimization_results.md │ │ │ ├── prediction_benchmark.txt │ │ │ └── training_benchmark.txt │ │ └── logistic_regression │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── optimization_results.md │ │ │ ├── prediction_benchmark.txt │ │ │ └── training_benchmark.txt │ ├── local_classifier_per_parent_node │ │ ├── lightgbm │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── prediction_benchmark.txt │ │ │ ├── optimization_results.md │ │ │ └── training_benchmark.txt │ │ ├── random_forest │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── optimization_results.md │ │ │ ├── prediction_benchmark.txt │ │ │ └── training_benchmark.txt │ │ └── logistic_regression │ │ │ ├── metrics.csv │ │ │ ├── optimization_results.yaml │ │ │ ├── optimization_results.md │ │ │ ├── prediction_benchmark.txt │ │ │ └── training_benchmark.txt │ └── statistics │ │ └── statistics.csv │ ├── pytest.ini │ ├── configs │ ├── random_forest.yaml │ ├── logistic_regression.yaml │ ├── lightgbm.yaml │ ├── optuna.yaml │ └── snakemake.yml │ ├── envs │ ├── hiclass.yml │ └── snakemake.yml │ ├── rules │ ├── download.smk │ ├── metrics.smk │ ├── tune_table.smk │ ├── statistics.smk │ ├── predict.smk │ ├── split_data.smk │ ├── train.smk │ └── tune.smk │ ├── setup.cfg │ ├── Snakefile │ ├── tests │ ├── test_predict.py │ ├── test_metrics.py │ ├── test_statistics.py │ ├── test_train.py │ ├── test_split_data.py │ ├── test_tune_table.py │ └── test_data.py │ ├── Makefile │ └── scripts │ ├── predict.py │ ├── metrics.py │ ├── data.py │ ├── train.py │ ├── tune_table.py │ ├── statistics.py │ └── split_data.py ├── docs ├── source │ ├── algorithms │ │ ├── hc_format.png │ │ ├── hc_metrics.png │ │ ├── hiclass-uml.png │ │ ├── hc_background.png │ │ ├── hc_prediction.png │ │ ├── shap_explanation.png │ │ ├── hc_dog_breed_hierarchy.png │ │ ├── index.rst │ │ ├── local_classifier_per_parent_node.rst │ │ ├── local_classifier_per_level.rst │ │ ├── local_classifier_per_node.rst │ │ ├── metrics.rst │ │ ├── training_policies.rst │ │ └── calibration.rst │ ├── introduction │ │ ├── index.rst │ │ ├── learn.rst │ │ └── what.rst │ ├── get_started │ │ ├── index.rst │ │ ├── prerequisites.rst │ │ ├── training_and_predicting.rst │ │ ├── hello_hiclass.rst │ │ ├── pipenv.rst │ │ ├── conda.rst │ │ ├── virtual_environments.rst │ │ ├── install.rst │ │ ├── venv.rst │ │ ├── verify.rst │ │ ├── full_example.rst │ │ ├── hierarchical_data.rst │ │ └── local_classifier.rst │ ├── api │ │ ├── index.rst │ │ ├── classifiers.rst │ │ └── utilities.rst │ ├── index.rst │ └── conf.py ├── requirements.txt ├── examples │ ├── README.rst │ ├── plot_hello_hiclass.py │ ├── plot_empty_levels.py │ ├── plot_model_persistence.py │ ├── plot_pipeline.py │ ├── plot_parallel_training.py │ ├── plot_binary_policies.py │ └── plot_calibration.py ├── Makefile └── make.bat ├── MANIFEST.in ├── .readthedocs.yaml ├── .pre-commit-config.yaml ├── .github ├── workflows │ ├── mirror-gitlab.yml │ ├── test-pr.yml │ └── deploy-pypi.yml └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── Pipfile ├── setup.cfg ├── LICENSE └── CONTRIBUTING.md /hiclass/_calibration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | hiclass/_version.py export-subst 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the hiclass package.""" 2 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | hiclass/_version.py 4 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/lightgbm/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.37099081415338825 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/random_forest/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.6672103122674672 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/logistic_regression/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.7748312934595746 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/lightgbm/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.790120724604213 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/lightgbm/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.6853751315012958 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/random_forest/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.7668433915787853 3 | -------------------------------------------------------------------------------- /tests/fixtures/small_dag_edgelist.csv: -------------------------------------------------------------------------------- 1 | parent, child 2 | 1,2 3 | 1,3 4 | 1,4 5 | 2,5 6 | 2,6 7 | 5,7 8 | 5,8 9 | 4,7 10 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/random_forest/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.7383252508146666 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/lightgbm/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.7530838657532138 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/random_forest/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.74074679906602 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/statistics/statistics.csv: -------------------------------------------------------------------------------- 1 | Snapshot,Training set size,Test set size 2 | 02/11/2022,727495,311784 3 | -------------------------------------------------------------------------------- /docs/source/algorithms/hc_format.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/hiclass/main/docs/source/algorithms/hc_format.png -------------------------------------------------------------------------------- /docs/source/algorithms/hc_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/hiclass/main/docs/source/algorithms/hc_metrics.png -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/logistic_regression/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.7794578939265646 3 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/logistic_regression/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.7762713930156775 3 | -------------------------------------------------------------------------------- /docs/source/algorithms/hiclass-uml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/hiclass/main/docs/source/algorithms/hiclass-uml.png -------------------------------------------------------------------------------- /docs/source/introduction/index.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | 7 | what 8 | learn 9 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/logistic_regression/metrics.csv: -------------------------------------------------------------------------------- 1 | f1_hierarchical 2 | 0.7798507941395325 3 | -------------------------------------------------------------------------------- /docs/source/algorithms/hc_background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/hiclass/main/docs/source/algorithms/hc_background.png -------------------------------------------------------------------------------- /docs/source/algorithms/hc_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/hiclass/main/docs/source/algorithms/hc_prediction.png -------------------------------------------------------------------------------- /docs/source/algorithms/shap_explanation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/hiclass/main/docs/source/algorithms/shap_explanation.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme 3 | readthedocs-sphinx-search 4 | sphinx_code_tabs 5 | sphinx-gallery 6 | matplotlib 7 | ray 8 | numpy 9 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include versioneer.py 2 | include hiclass/_version.py 3 | include README.md 4 | include LICENSE 5 | graft hiclass 6 | recursive-exclude * *.py[co] 7 | -------------------------------------------------------------------------------- /docs/source/algorithms/hc_dog_breed_hierarchy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/hiclass/main/docs/source/algorithms/hc_dog_breed_hierarchy.png -------------------------------------------------------------------------------- /docs/examples/README.rst: -------------------------------------------------------------------------------- 1 | Gallery of Examples 2 | =================== 3 | 4 | These examples illustrate the main features of HiClass. 5 | 6 | .. toctree:: 7 | :hidden: 8 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/random_forest/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | n_estimators: 200 4 | criterion: gini 5 | best_value: 0.6589227417370566 6 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/logistic_regression/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | solver: saga 4 | max_iter: 10000 5 | best_value: 0.7727118399439171 6 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/lightgbm/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | num_leaves: 31 4 | n_estimators: 200 5 | min_child_samples: 40 6 | best_value: 0.40766190832926685 7 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/logistic_regression/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | solver: lbfgs 4 | max_iter: 10000 5 | best_value: 0.7778761366057501 6 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/random_forest/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | n_estimators: 200 4 | criterion: gini 5 | best_value: 0.7303734046282104 6 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/random_forest/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | n_estimators: 200 4 | criterion: gini 5 | best_value: 0.7618615935504711 6 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/logistic_regression/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | solver: liblinear 4 | max_iter: 10000 5 | best_value: 0.7742795483130468 6 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/logistic_regression/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | solver: lbfgs 4 | max_iter: 10000 5 | best_value: 0.7782135959697317 6 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/random_forest/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | n_estimators: 200 4 | criterion: gini 5 | best_value: 0.7327919779517386 6 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/lightgbm/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | num_leaves: 31 4 | n_estimators: 100 5 | min_child_samples: 40 6 | best_value: 0.6696726437982392 7 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/lightgbm/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | num_leaves: 62 4 | n_estimators: 200 5 | min_child_samples: 20 6 | best_value: 0.7876225953442979 7 | -------------------------------------------------------------------------------- /docs/source/get_started/index.rst: -------------------------------------------------------------------------------- 1 | Get Started 2 | ============ 3 | 4 | .. toctree:: 5 | :includehidden: 6 | :maxdepth: 3 7 | 8 | prerequisites 9 | virtual_environments 10 | install 11 | hello_hiclass 12 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/lightgbm/optimization_results.yaml: -------------------------------------------------------------------------------- 1 | name: optuna 2 | best_params: 3 | num_leaves: 31 4 | n_estimators: 100 5 | min_child_samples: 40 6 | best_value: 0.7433824287452148 7 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | pythonpath = . scripts 3 | testpaths=scripts tests 4 | addopts = --flake8 5 | --pydocstyle 6 | --cov=scripts 7 | --cov-fail-under=90 8 | --cov-report html 9 | --disable-warnings 10 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/configs/random_forest.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - optuna 4 | 5 | hydra: 6 | sweeper: 7 | params: 8 | n_estimators: choice(100, 200) 9 | criterion: choice("gini", "entropy", "log_loss") 10 | 11 | n_estimators: 1 12 | criterion: 1 13 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/configs/logistic_regression.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - optuna 4 | 5 | hydra: 6 | sweeper: 7 | params: 8 | solver: choice("newton-cg", "lbfgs", "liblinear", "sag", "saga") 9 | max_iter: choice(10000) 10 | 11 | solver: 1 12 | max_iter: 1 13 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/configs/lightgbm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - optuna 4 | 5 | hydra: 6 | sweeper: 7 | params: 8 | num_leaves: choice(31, 62) 9 | n_estimators: choice(100, 200) 10 | min_child_samples: choice(20, 40) 11 | 12 | num_leaves: 1 13 | n_estimators: 1 14 | min_child_samples: 1 15 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/envs/hiclass.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | dependencies: 4 | - python=3.8 5 | - pip 6 | - pip: 7 | - hydra-core==1.2.0 8 | - hydra-optuna-sweeper==1.2.0 9 | - hydra-colorlog==1.2.0 10 | - scikit-learn==1.1.1 11 | - lightgbm==3.3.2 12 | - hiclass==4.2.2 13 | - pandas==1.4.1 14 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/rules/download.smk: -------------------------------------------------------------------------------- 1 | from snakemake.remote.HTTP import RemoteProvider as HTTPRemoteProvider 2 | 3 | HTTP = HTTPRemoteProvider() 4 | 5 | rule download: 6 | input: 7 | ancient(HTTP.remote(config["data"], keep_local=True)) 8 | output: 9 | "data/complaints.csv.zip" 10 | shell: 11 | """ 12 | mv {input} {output} 13 | """ 14 | -------------------------------------------------------------------------------- /tests/test_FlatClassifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_array_equal 3 | 4 | from hiclass import FlatClassifier 5 | 6 | 7 | def test_fit_predict(): 8 | flat = FlatClassifier() 9 | x = np.array([[1, 2], [3, 4]]) 10 | y = np.array([["a", "b"], ["b", "c"]]) 11 | flat.fit(x, y) 12 | predictions = flat.predict(x) 13 | assert_array_equal(y, predictions) 14 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths=scripts tests 3 | addopts = --flake8 4 | --pydocstyle 5 | --cov=scripts 6 | --cov-fail-under=90 7 | --cov-report html 8 | --disable-warnings 9 | 10 | [flake8] 11 | ignore = E203, E266, E501, W503, F403, F401 12 | max-line-length = 120 13 | exclude = **/__init__.py, docs/source/conf.py 14 | -------------------------------------------------------------------------------- /hiclass/_hiclass_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _normalize_probabilities(proba): 5 | if isinstance(proba, np.ndarray): 6 | return np.nan_to_num(proba / proba.sum(axis=1, keepdims=True)) 7 | return [ 8 | np.nan_to_num( 9 | level_probabilities / level_probabilities.sum(axis=1, keepdims=True) 10 | ) 11 | for level_probabilities in proba 12 | ] 13 | -------------------------------------------------------------------------------- /hiclass/probability_combiner/__init__.py: -------------------------------------------------------------------------------- 1 | """Init the probability combiner module.""" 2 | 3 | from .MultiplyCombiner import MultiplyCombiner 4 | from .ArithmeticMeanCombiner import ArithmeticMeanCombiner 5 | from .GeometricMeanCombiner import GeometricMeanCombiner 6 | 7 | __all__ = [ 8 | "MultiplyCombiner", 9 | "ArithmeticMeanCombiner", 10 | "GeometricMeanCombiner", 11 | ] 12 | 13 | init_strings = [ 14 | "multiply", 15 | "geometric", 16 | "arithmetic", 17 | ] 18 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-20.04" 5 | tools: 6 | python: "3.12" 7 | 8 | # Build from the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | 12 | # Explicitly set the version of Python and its requirements 13 | python: 14 | install: 15 | - requirements: docs/requirements.txt 16 | - method: pip 17 | path: . 18 | extra_requirements: 19 | - docs 20 | - method: setuptools 21 | path: . 22 | -------------------------------------------------------------------------------- /docs/source/get_started/prerequisites.rst: -------------------------------------------------------------------------------- 1 | Installation prerequisites 2 | ========================== 3 | 4 | - HiClass supports GNU/Linux, Windows and macOS. If you encounter any problems on these platforms, please open an issue describing the problem at `https://github.com/scikit-learn-contrib/hiclass/issues `_. 5 | 6 | - In order to use HiClass, we highly recommend that you download and install `Anaconda `_. 7 | -------------------------------------------------------------------------------- /docs/source/get_started/training_and_predicting.rst: -------------------------------------------------------------------------------- 1 | Training and Predicting 2 | ======================= 3 | 4 | HiClass adheres to the same API from the popular machine learning library scikit-learn. Hence, training is as easy as calling the :literal:`fit` method on the training data: 5 | 6 | .. code-block:: python 7 | 8 | classifier.fit(X_train, Y_train) 9 | 10 | Prediction is performed by calling the :literal:`predict` method on the test features: 11 | 12 | .. code-block:: python 13 | 14 | predictions = classifier.predict(X_test) 15 | -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | .. _code: 2 | 3 | API reference 4 | ============= 5 | The documentation lists all available functions for each of the implemented classes. This includes inherited functions. 6 | Therefore, not everything that is listed under a classes documentations is necessarily implemented by said class. 7 | This is done in order to provide a complete list of the callable functions for each of the classes. 8 | 9 | .................................. 10 | 11 | .. toctree:: 12 | :maxdepth: 3 13 | 14 | classifiers 15 | utilities 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 24.2.0 10 | hooks: 11 | - id: black 12 | - repo: https://github.com/pycqa/flake8 13 | rev: 7.1.1 14 | hooks: 15 | - id: flake8 16 | - repo: https://github.com/pycqa/pydocstyle 17 | rev: 6.3.0 18 | hooks: 19 | - id: pydocstyle 20 | files: ^hiclass/ 21 | -------------------------------------------------------------------------------- /docs/source/get_started/hello_hiclass.rst: -------------------------------------------------------------------------------- 1 | A "Hello World" example 2 | ======================= 3 | 4 | It is time to introduce the most basic elements of HiClass. We have split a small example into sections to discuss each of the concepts with code. 5 | 6 | You can copy the example as one chunk of code from the last page of this section. 7 | 8 | .. note:: 9 | 10 | We do not use real data in this example, but illustrate the concepts with a single ``.py`` script. 11 | 12 | .. toctree:: 13 | :hidden: 14 | 15 | hierarchical_data 16 | local_classifier 17 | training_and_predicting 18 | full_example 19 | -------------------------------------------------------------------------------- /docs/source/get_started/pipenv.rst: -------------------------------------------------------------------------------- 1 | pipenv 2 | ====== 3 | 4 | You will need to install :literal:`pipenv` as follows: 5 | 6 | .. code-block:: bash 7 | 8 | pip install pipenv 9 | 10 | Create a directory for the virtual environment and change to that directory: 11 | 12 | .. code-block:: bash 13 | 14 | mkdir hiclass-environment && cd hiclass-environment 15 | 16 | Once all the dependencies are installed, to start a session with the correct virtual environment activated: 17 | 18 | .. code-block:: bash 19 | 20 | pipenv shell 21 | 22 | To exit the shell session: 23 | 24 | .. code-block:: bash 25 | 26 | exit 27 | -------------------------------------------------------------------------------- /docs/source/introduction/learn.rst: -------------------------------------------------------------------------------- 1 | Learn how to use HiClass 2 | ======================== 3 | 4 | In the next few chapters, you will learn how to :ref:`Install HiClass` and set up your own hierarchical machine learning pipelines. 5 | 6 | Once you are set up, we suggest working through our examples, including: 7 | 8 | - A typical :ref:`A "Hello World" example`, for an entry-level description of the main concepts. 9 | - Further examples are displayed in our :ref:`Gallery of Examples`, to give you hands-on experience. 10 | 11 | We also recommend the sections :ref:`Algorithms Overview` and :ref:`API reference` for additional information. 12 | -------------------------------------------------------------------------------- /hiclass/_calibration/BinaryCalibrator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | 4 | 5 | class _BinaryCalibrator(abc.ABC): 6 | @abc.abstractmethod 7 | def fit( 8 | self, y: np.ndarray, scores: np.ndarray, X: np.ndarray = None 9 | ): # pragma: no cover 10 | ... 11 | 12 | @abc.abstractmethod 13 | def predict_proba( 14 | self, scores: np.ndarray, X: np.ndarray = None 15 | ): # pragma: no cover 16 | ... 17 | 18 | def __sklearn_is_fitted__(self): 19 | """Check fitted status and return a Boolean value.""" 20 | return hasattr(self, "_is_fitted") and self._is_fitted 21 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/rules/metrics.smk: -------------------------------------------------------------------------------- 1 | rule metrics: 2 | input: 3 | predictions = "results/{model}/{classifier}/predictions.csv.zip", 4 | ground_truth = "results/split_data/y_test.csv.zip" 5 | output: 6 | metrics = "results/{model}/{classifier}/metrics.csv", 7 | params: 8 | classifier = "{classifier}", 9 | model = "{model}", 10 | conda: 11 | "../envs/hiclass.yml" 12 | shell: 13 | """ 14 | python scripts/metrics.py \ 15 | --predictions {input.predictions} \ 16 | --ground-truth {input.ground_truth} \ 17 | --metrics {output.metrics} 18 | """ 19 | -------------------------------------------------------------------------------- /docs/source/introduction/what.rst: -------------------------------------------------------------------------------- 1 | What is HiClass? 2 | ================ 3 | 4 | HiClass is an open-source Python library for hierarchical classification compatible with scikit-learn. 5 | It mirrors the popular API from scikit-learn to train and predict with the most common design patterns for local hierarchical classification. 6 | Implemented models include the local classifiers per node, per parent node and per level. 7 | HiClass is released under the simplified `BSD 3-Clause license `_. 8 | 9 | For the source code, please take a look at the repository on `Github `_. 10 | -------------------------------------------------------------------------------- /hiclass/__init__.py: -------------------------------------------------------------------------------- 1 | """Init module for the library.""" 2 | 3 | import os 4 | 5 | from ._version import get_versions 6 | from .FlatClassifier import FlatClassifier 7 | from .LocalClassifierPerLevel import LocalClassifierPerLevel 8 | from .LocalClassifierPerNode import LocalClassifierPerNode 9 | from .LocalClassifierPerParentNode import LocalClassifierPerParentNode 10 | from .Pipeline import Pipeline 11 | 12 | __version__ = get_versions()["version"] 13 | del get_versions 14 | 15 | __all__ = [ 16 | "LocalClassifierPerNode", 17 | "LocalClassifierPerParentNode", 18 | "LocalClassifierPerLevel", 19 | "Pipeline", 20 | "FlatClassifier", 21 | "datasets", 22 | ] 23 | -------------------------------------------------------------------------------- /hiclass/_calibration/calibration_utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import LabelBinarizer 2 | from sklearn.base import BaseEstimator 3 | import numpy as np 4 | 5 | 6 | def _one_vs_rest_split(y: np.ndarray, scores: np.ndarray, estimator: BaseEstimator): 7 | # binarize multiclass labels 8 | label_binarizer = LabelBinarizer() 9 | label_binarizer.fit(estimator.classes_) 10 | binary_labels = label_binarizer.transform(y).T 11 | 12 | # split scores into k one vs rest splits 13 | score_splits = [scores[:, i] for i in range(scores.shape[1])] 14 | label_splits = [binary_labels[i] for i in range(len(score_splits))] 15 | 16 | return score_splits, label_splits 17 | -------------------------------------------------------------------------------- /.github/workflows/mirror-gitlab.yml: -------------------------------------------------------------------------------- 1 | name: Mirror to GitLab 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | mirror: 10 | runs-on: ubuntu-latest 11 | steps: # <-- must use actions/checkout before mirroring! 12 | - uses: actions/checkout@v2 13 | with: 14 | fetch-depth: 0 15 | - uses: pixta-dev/repository-mirroring-action@v1 16 | with: 17 | target_repo_url: 18 | git@gitlab.com:dacs-hpi/hiclass.git 19 | ssh_private_key: # <-- use 'secrets' to pass credential information. 20 | ${{ secrets.GITLAB_SSH_PRIVATE_KEY }} 21 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/logistic_regression/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: flat 2 | ## Base classifier: logistic_regression 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'solver': 'lbfgs', 'max_iter': 10000}|[0.772, 0.773, 0.772, 0.774, 0.772]|0.773|0.001| 6 | |{'solver': 'liblinear', 'max_iter': 10000}|[0.763, 0.764, 0.763, 0.764, 0.763]|0.763|0.001| 7 | |{'solver': 'sag', 'max_iter': 10000}|[0.772, 0.773, 0.772, 0.774, 0.772]|0.773|0.001| 8 | |{'solver': 'newton-cg', 'max_iter': 10000}|[0.772, 0.773, 0.772, 0.774, 0.772]|0.773|0.001| 9 | |{'solver': 'saga', 'max_iter': 10000}|[0.772, 0.773, 0.772, 0.774, 0.772]|0.773|0.001| 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/configs/optuna.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - override hydra/sweeper: optuna 3 | - override hydra/sweeper/sampler: grid 4 | - override hydra/job_logging: colorlog 5 | - override hydra/hydra_logging: colorlog 6 | 7 | hydra: 8 | sweeper: 9 | n_jobs: 1 10 | direction: maximize 11 | params: 12 | classifier: choice(None) 13 | model: choice(None) 14 | n_jobs: choice(1) 15 | x_train: choice(None) 16 | y_train: choice(None) 17 | output_dir: choice(None) 18 | mem_gb: choice(None) 19 | n_splits: choice(None) 20 | 21 | classifier: 1 22 | model: 1 23 | n_jobs: 1 24 | x_train: 1 25 | y_train: 1 26 | output_dir: 1 27 | mem_gb: 1 28 | n_splits: 1 29 | -------------------------------------------------------------------------------- /hiclass/Pipeline.py: -------------------------------------------------------------------------------- 1 | """Custom Pipeline class that supports the `calibrate` method.""" 2 | 3 | from sklearn.pipeline import Pipeline as skPipeline 4 | 5 | 6 | class Pipeline(skPipeline): 7 | """Custom Pipeline class that supports the `calibrate` method.""" 8 | 9 | def __init__(self, steps, **kwargs): 10 | """Create Pipeline object.""" 11 | super().__init__(steps, **kwargs) 12 | 13 | def calibrate(self, X, y, **params): 14 | """Transform the data and apply `calibrate` with the final estimator.""" 15 | Xt = X 16 | for _, name, transform in self._iter(with_final=False): 17 | Xt = transform.transform(Xt) 18 | return self.steps[-1][1].calibrate(Xt, y) 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.python.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | networkx = "*" 8 | numpy = "*" 9 | scikit-learn = "1.4.2" 10 | scipy = "1.11.4" 11 | 12 | [dev-packages] 13 | pytest = "7.1.2" 14 | flake8 = "4.0.1" 15 | pytest-flake8 = "1.1.1" 16 | pydocstyle = "6.1.1" 17 | pytest-pydocstyle = "2.3.0" 18 | pytest-cov = "3.0.0" 19 | twine = "*" 20 | sphinx = "5.0.0" 21 | sphinx-rtd-theme = "1.0.0" 22 | readthedocs-sphinx-search = "0.1.2" 23 | sphinx_code_tabs = "0.5.3" 24 | sphinx-gallery = "0.10.1" 25 | matplotlib = "3.9.2" 26 | pandas = "1.4.2" 27 | black = {version = "24.3.0", extras = ["colorama"]} 28 | pre-commit = "2.20.0" 29 | pyfakefs = "*" 30 | 31 | [extras] 32 | ray = "*" 33 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/logistic_regression/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_node 2 | ## Base classifier: logistic_regression 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'solver': 'lbfgs', 'max_iter': 10000}|[0.774, 0.774, 0.774, 0.775, 0.774]|0.774|0.001| 6 | |{'solver': 'liblinear', 'max_iter': 10000}|[0.774, 0.774, 0.774, 0.775, 0.774]|0.774|0.001| 7 | |{'solver': 'newton-cg', 'max_iter': 10000}|[0.774, 0.774, 0.774, 0.775, 0.774]|0.774|0.001| 8 | |{'solver': 'sag', 'max_iter': 10000}|[0.774, 0.774, 0.774, 0.775, 0.774]|0.774|0.001| 9 | |{'solver': 'saga', 'max_iter': 10000}|[0.774, 0.774, 0.774, 0.775, 0.774]|0.774|0.001| 10 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/logistic_regression/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_level 2 | ## Base classifier: logistic_regression 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'solver': 'liblinear', 'max_iter': 10000}|[0.773, 0.774, 0.774, 0.775, 0.774]|0.774|0.001| 6 | |{'solver': 'saga', 'max_iter': 10000}|[0.777, 0.778, 0.778, 0.779, 0.778]|0.778|0.001| 7 | |{'solver': 'lbfgs', 'max_iter': 10000}|[0.777, 0.778, 0.778, 0.779, 0.778]|0.778|0.001| 8 | |{'solver': 'sag', 'max_iter': 10000}|[0.777, 0.778, 0.778, 0.779, 0.778]|0.778|0.001| 9 | |{'solver': 'newton-cg', 'max_iter': 10000}|[0.777, 0.778, 0.778, 0.779, 0.778]|0.778|0.001| 10 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/logistic_regression/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_parent_node 2 | ## Base classifier: logistic_regression 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'solver': 'saga', 'max_iter': 10000}|[0.777, 0.779, 0.778, 0.779, 0.778]|0.778|0.001| 6 | |{'solver': 'liblinear', 'max_iter': 10000}|[0.774, 0.774, 0.774, 0.775, 0.774]|0.774|0.001| 7 | |{'solver': 'lbfgs', 'max_iter': 10000}|[0.777, 0.779, 0.778, 0.779, 0.778]|0.778|0.001| 8 | |{'solver': 'sag', 'max_iter': 10000}|[0.777, 0.779, 0.778, 0.779, 0.778]|0.778|0.001| 9 | |{'solver': 'newton-cg', 'max_iter': 10000}|[0.777, 0.779, 0.778, 0.779, 0.778]|0.778|0.001| 10 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/Snakefile: -------------------------------------------------------------------------------- 1 | configfile: "configs/snakemake.yml" 2 | 3 | workdir: config["workdir"] 4 | 5 | include: "rules/download.smk" 6 | include: "rules/split_data.smk" 7 | include: "rules/tune.smk" 8 | include: "rules/train.smk" 9 | include: "rules/predict.smk" 10 | include: "rules/metrics.smk" 11 | include: "rules/statistics.smk" 12 | include: "rules/tune_table.smk" 13 | 14 | rule all: 15 | input: 16 | metrics = [f"results/{model}/{classifier}/metrics.csv" for classifier in config["classifiers"] for model in config["models"]], 17 | tables = [f"results/{model}/{classifier}/optimization_results.md" for classifier in config["classifiers"] for model in config["models"]], 18 | statistics = "results/statistics/statistics.csv" 19 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/configs/snakemake.yml: -------------------------------------------------------------------------------- 1 | workdir: /hpi/fs00/home/fabio.malchermiranda/scratch/hiclass/benchmarks/consumer_complaints 2 | 3 | threads: 12 4 | 5 | data: 'https://files.consumerfinance.gov/ccdb/complaints.csv.zip' 6 | 7 | # Random state used for splitting the data in training and test 8 | random_state: 42 9 | 10 | # Number of folds for cross-validation 11 | n_splits: 5 12 | 13 | # Memory allocated for tune and train rules 14 | mem_gb: 450 15 | 16 | # Uncomment the next line to use a subset of the data 17 | #nrows: 2000 18 | 19 | classifiers: 20 | - logistic_regression 21 | - random_forest 22 | - lightgbm 23 | 24 | models: 25 | - flat 26 | - local_classifier_per_node 27 | - local_classifier_per_parent_node 28 | - local_classifier_per_level 29 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths=hiclass tests 3 | addopts = --disable-warnings 4 | --color=yes 5 | --ignore=hiclass/_version.py, 6 | 7 | [flake8] 8 | ignore = E203, E266, E501, W503, F403, F401 9 | max-line-length = 120 10 | exclude = **/__init__.py, docs/source/conf.py 11 | 12 | ;per-file-ignores = 13 | 14 | ;file.py: error 15 | 16 | [requires] 17 | python_version = ">=3.9,<3.13" 18 | 19 | # See the docstring in versioneer.py for instructions. Note that you must 20 | # re-run 'versioneer.py setup' after changing this section, and commit the 21 | # resulting files. 22 | 23 | [versioneer] 24 | VCS = git 25 | style = pep440 26 | versionfile_source = hiclass/_version.py 27 | versionfile_build = hiclass/_version.py 28 | tag_prefix = 29 | parentdir_prefix = hiclass- 30 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/rules/tune_table.smk: -------------------------------------------------------------------------------- 1 | rule tune_table: 2 | input: 3 | best_parameters = "results/{model}/{classifier}/optimization_results.yaml", 4 | output: 5 | table = "results/{model}/{classifier}/optimization_results.md", 6 | params: 7 | model = "{model}", 8 | classifier = "{classifier}", 9 | folder = "results/{model}/{classifier}", 10 | conda: 11 | "../envs/hiclass.yml" 12 | threads: 13 | config["threads"] 14 | resources: 15 | mem_gb = config["mem_gb"] 16 | shell: 17 | """ 18 | python scripts/tune_table.py \ 19 | --folder {params.folder} \ 20 | --model {params.model} \ 21 | --classifier {params.classifier} \ 22 | --output {output.table} 23 | """ 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/rules/statistics.smk: -------------------------------------------------------------------------------- 1 | rule statistics: 2 | input: 3 | data = "data/complaints.csv.zip", 4 | x_train = "results/split_data/x_train.csv.zip", 5 | y_train = "results/split_data/y_train.csv.zip", 6 | x_test = "results/split_data/x_test.csv.zip", 7 | y_test ="results/split_data/y_test.csv.zip", 8 | output: 9 | statistics = "results/statistics/statistics.csv", 10 | conda: 11 | "../envs/hiclass.yml" 12 | shell: 13 | """ 14 | python scripts/statistics.py \ 15 | --data {input.data} \ 16 | --x-train {input.x_train} \ 17 | --y-train {input.y_train} \ 18 | --x-test {input.x_test} \ 19 | --y-test {input.y_test} \ 20 | --statistics {output.statistics} 21 | """ 22 | -------------------------------------------------------------------------------- /hiclass/_calibration/PlattScaling.py: -------------------------------------------------------------------------------- 1 | from hiclass._calibration.BinaryCalibrator import _BinaryCalibrator 2 | from sklearn.calibration import _SigmoidCalibration 3 | from sklearn.utils.validation import check_is_fitted 4 | import numpy as np 5 | 6 | 7 | class _PlattScaling(_BinaryCalibrator): 8 | name = "PlattScaling" 9 | 10 | def __init__(self) -> None: 11 | self._is_fitted = False 12 | self.platt_scaling = _SigmoidCalibration() 13 | 14 | def fit(self, y: np.ndarray, scores: np.ndarray, X: np.ndarray = None): 15 | self.platt_scaling.fit(scores, y) 16 | self._is_fitted = True 17 | return self 18 | 19 | def predict_proba(self, scores: np.ndarray, X: np.ndarray = None): 20 | check_is_fitted(self) 21 | return self.platt_scaling.predict(scores) 22 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/random_forest/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: flat 2 | ## Base classifier: random_forest 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'n_estimators': 200, 'criterion': 'entropy'}|[0.617, 0.616, 0.618, 0.618, 0.614]|0.617|0.001| 6 | |{'n_estimators': 100, 'criterion': 'gini'}|[0.656, 0.655, 0.656, 0.656, 0.658]|0.656|0.001| 7 | |{'n_estimators': 100, 'criterion': 'entropy'}|[0.613, 0.613, 0.615, 0.616, 0.615]|0.614|0.001| 8 | |{'n_estimators': 100, 'criterion': 'log_loss'}|[0.615, 0.617, 0.619, 0.618, 0.617]|0.617|0.001| 9 | |{'n_estimators': 200, 'criterion': 'gini'}|[0.659, 0.658, 0.66, 0.66, 0.657]|0.659|0.001| 10 | |{'n_estimators': 200, 'criterion': 'log_loss'}|[0.618, 0.616, 0.619, 0.618, 0.616]|0.617|0.001| 11 | -------------------------------------------------------------------------------- /docs/source/algorithms/index.rst: -------------------------------------------------------------------------------- 1 | .. _algorithms: 2 | 3 | Algorithms Overview 4 | =================== 5 | 6 | HiClass provides implementations for the most popular machine learning models for hierarchical classification, including the Local Classifier Per Node, the Local Classifier Per Parent Node and the Local Classifier Per Level. Additionally, the library includes metrics to evaluate model performance on hierarchical data. In this section we present in details the different approaches for hierarchical classification as well as the hierarchical metrics. 7 | 8 | .................................. 9 | 10 | .. toctree:: 11 | :includehidden: 12 | :maxdepth: 3 13 | 14 | local_classifier_per_node 15 | local_classifier_per_parent_node 16 | local_classifier_per_level 17 | metrics 18 | calibration 19 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/tests/test_predict.py: -------------------------------------------------------------------------------- 1 | from scripts.predict import parse_args 2 | 3 | 4 | def test_parser(): 5 | parser = parse_args( 6 | [ 7 | "--trained-model", 8 | "model.sav", 9 | "--x-test", 10 | "x_test.csv.zip", 11 | "--predictions", 12 | "predictions.csv.zip", 13 | "--classifier", 14 | "hist_gradient", 15 | ] 16 | ) 17 | assert parser.trained_model is not None 18 | assert "model.sav" == parser.trained_model 19 | assert parser.x_test is not None 20 | assert "x_test.csv.zip" == parser.x_test 21 | assert parser.predictions is not None 22 | assert "predictions.csv.zip" == parser.predictions 23 | assert parser.classifier is not None 24 | assert "hist_gradient" == parser.classifier 25 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/random_forest/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_level 2 | ## Base classifier: random_forest 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'n_estimators': 100, 'criterion': 'log_loss'}|[0.713, 0.713, 0.715, 0.712, 0.715]|0.714|0.001| 6 | |{'n_estimators': 100, 'criterion': 'entropy'}|[0.714, 0.712, 0.716, 0.715, 0.714]|0.714|0.001| 7 | |{'n_estimators': 200, 'criterion': 'log_loss'}|[0.718, 0.718, 0.72, 0.72, 0.718]|0.719|0.001| 8 | |{'n_estimators': 100, 'criterion': 'gini'}|[0.724, 0.726, 0.726, 0.725, 0.724]|0.725|0.001| 9 | |{'n_estimators': 200, 'criterion': 'gini'}|[0.729, 0.73, 0.732, 0.73, 0.73]|0.730|0.001| 10 | |{'n_estimators': 200, 'criterion': 'entropy'}|[0.719, 0.72, 0.719, 0.72, 0.719]|0.720|0.001| 11 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/random_forest/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_node 2 | ## Base classifier: random_forest 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'n_estimators': 100, 'criterion': 'entropy'}|[0.75, 0.75, 0.749, 0.75, 0.749]|0.750|0.001| 6 | |{'n_estimators': 100, 'criterion': 'gini'}|[0.759, 0.758, 0.759, 0.757, 0.757]|0.758|0.001| 7 | |{'n_estimators': 200, 'criterion': 'log_loss'}|[0.754, 0.753, 0.753, 0.752, 0.753]|0.753|0.001| 8 | |{'n_estimators': 100, 'criterion': 'log_loss'}|[0.75, 0.749, 0.75, 0.75, 0.749]|0.750|0.000| 9 | |{'n_estimators': 200, 'criterion': 'entropy'}|[0.753, 0.753, 0.753, 0.753, 0.753]|0.753|0.000| 10 | |{'n_estimators': 200, 'criterion': 'gini'}|[0.762, 0.762, 0.763, 0.761, 0.761]|0.762|0.001| 11 | -------------------------------------------------------------------------------- /docs/source/get_started/conda.rst: -------------------------------------------------------------------------------- 1 | conda 2 | ===== 3 | 4 | Install :literal:`conda` on your computer, following the `official guide `_. 5 | 6 | Create a new virtual environment called :literal:`hiclass-environment` using :literal:`conda`: 7 | 8 | .. code-block:: bash 9 | 10 | conda create --name hiclass-environment python=3.8 --yes 11 | 12 | 13 | This will create an isolated Python 3.8 environment. To activate it: 14 | 15 | .. code-block:: bash 16 | 17 | conda activate hiclass-environment 18 | 19 | 20 | To exit :literal:`hiclass-environment`: 21 | 22 | .. code-block:: bash 23 | 24 | conda deactivate 25 | 26 | 27 | .. note:: 28 | 29 | The :literal:`conda` virtual environment is not dependent on your current working directory and can be activated from any folder. 30 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/rules/predict.smk: -------------------------------------------------------------------------------- 1 | rule predict: 2 | input: 3 | trained_model = "results/{model}/{classifier}/trained_model.sav", 4 | x_test = "results/split_data/x_test.csv.zip" 5 | output: 6 | predictions = "results/{model}/{classifier}/predictions.csv.zip", 7 | benchmark = "results/{model}/{classifier}/prediction_benchmark.txt" 8 | params: 9 | model = "{model}" 10 | conda: 11 | "../envs/hiclass.yml" 12 | threads: 13 | config["threads"] 14 | shell: 15 | """ 16 | /usr/bin/time -v \ 17 | -o {output.benchmark} \ 18 | python scripts/predict.py \ 19 | --trained-model {input.trained_model} \ 20 | --x-test {input.x_test} \ 21 | --predictions {output.predictions} \ 22 | --classifier {params.model} 23 | """ 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/random_forest/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_parent_node 2 | ## Base classifier: random_forest 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'n_estimators': 100, 'criterion': 'gini'}|[0.727, 0.729, 0.729, 0.728, 0.729]|0.728|0.001| 6 | |{'n_estimators': 200, 'criterion': 'entropy'}|[0.72, 0.721, 0.723, 0.721, 0.722]|0.721|0.001| 7 | |{'n_estimators': 100, 'criterion': 'log_loss'}|[0.717, 0.719, 0.717, 0.717, 0.717]|0.717|0.001| 8 | |{'n_estimators': 200, 'criterion': 'gini'}|[0.733, 0.732, 0.733, 0.733, 0.732]|0.733|0.001| 9 | |{'n_estimators': 200, 'criterion': 'log_loss'}|[0.72, 0.722, 0.72, 0.721, 0.721]|0.721|0.001| 10 | |{'n_estimators': 100, 'criterion': 'entropy'}|[0.715, 0.716, 0.716, 0.717, 0.717]|0.716|0.001| 11 | -------------------------------------------------------------------------------- /docs/source/get_started/virtual_environments.rst: -------------------------------------------------------------------------------- 1 | Virtual environments 2 | ==================== 3 | 4 | The main purpose of Python virtual environments is to create an isolated environment for a Python project to have its own dependencies, regardless of other projects. We recommend that you create a new virtual environment for HiClass. 5 | 6 | .. seealso:: 7 | 8 | Read more about Python Virtual Environments here: `https://realpython.com/python-virtual-environments-a-primer/ `_. 9 | 10 | Depending on your preferred installation method, you can create virtual environments for HiClass as follows: 11 | 12 | - With :ref:`conda`, a package and environment manager program bundled with Anaconda. 13 | 14 | - Without Anaconda, using :ref:`venv` or :ref:`pipenv`. 15 | 16 | .. toctree:: 17 | :hidden: 18 | 19 | conda 20 | venv 21 | pipenv 22 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/examples/plot_hello_hiclass.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ===================== 4 | Hello HiClass 5 | ===================== 6 | 7 | A minimalist example showing how to use HiClass to train and predict. 8 | """ 9 | from sklearn.ensemble import RandomForestClassifier 10 | 11 | from hiclass import LocalClassifierPerNode 12 | 13 | # Define data 14 | X_train = [[1], [2], [3], [4]] 15 | X_test = [[4], [3], [2], [1]] 16 | Y_train = [ 17 | ["Animal", "Mammal", "Sheep"], 18 | ["Animal", "Mammal", "Cow"], 19 | ["Animal", "Reptile", "Snake"], 20 | ["Animal", "Reptile", "Lizard"], 21 | ] 22 | 23 | # Use random forest classifiers for every node 24 | rf = RandomForestClassifier() 25 | classifier = LocalClassifierPerNode(local_classifier=rf) 26 | 27 | # Train local classifier per node 28 | classifier.fit(X_train, Y_train) 29 | 30 | # Predict 31 | predictions = classifier.predict(X_test) 32 | print(predictions) 33 | -------------------------------------------------------------------------------- /docs/source/get_started/install.rst: -------------------------------------------------------------------------------- 1 | Install HiClass 2 | =============== 3 | 4 | To install HiClass from the Python Package Index (PyPI) simply run: 5 | 6 | .. code-block:: bash 7 | 8 | pip install hiclass 9 | 10 | Additionally, it is also possible to install optional packages along. To install optional packages run: 11 | 12 | .. code-block:: bash 13 | 14 | pip install hiclass"[]" 15 | 16 | :literal:`` can have one of the following options: 17 | 18 | - ray: Installs the ray package, which is required for parallel processing support. 19 | 20 | It is also possible to install HiClass using :literal:`conda`, as follows: 21 | 22 | .. code-block:: bash 23 | 24 | conda install -c conda-forge hiclass --yes 25 | 26 | .. note:: 27 | 28 | We recommend using :literal:`pip` at this point to eliminate any potential dependency issues. 29 | 30 | .. toctree:: 31 | :hidden: 32 | 33 | verify 34 | -------------------------------------------------------------------------------- /docs/source/algorithms/local_classifier_per_parent_node.rst: -------------------------------------------------------------------------------- 1 | .. _local-classifier-per-parent-node-overview: 2 | 3 | Local Classifier Per Parent Node 4 | ================================ 5 | 6 | The local classifier per parent node approach consists of training a multi-class classifier for each parent node existing in the hierarchy, as shown in the image below. 7 | 8 | .. figure:: local_classifier_per_parent_node.svg 9 | :align: center 10 | 11 | Visual representation of the local classifier per parent node approach. 12 | 13 | While training is executed in parallel, prediction is performed in a top-down style in order to avoid inconsistencies. For example, let's suppose that the classifier located at the root node decides that a test example belongs to class "Reptile", then the next level can only be predicted by the classifier located at node "Reptile", which in turn will decide if the test example belongs to class "Snake" or "Lizard". 14 | -------------------------------------------------------------------------------- /hiclass/_calibration/IsotonicRegression.py: -------------------------------------------------------------------------------- 1 | from hiclass._calibration.BinaryCalibrator import _BinaryCalibrator 2 | from sklearn.isotonic import IsotonicRegression as SkLearnIR 3 | from sklearn.utils.validation import check_is_fitted 4 | import numpy as np 5 | 6 | 7 | class _IsotonicRegression(_BinaryCalibrator): 8 | name = "IsotonicRegression" 9 | 10 | def __init__(self, params={}) -> None: 11 | self._is_fitted = False 12 | if "out_of_bounds" not in params: 13 | params["out_of_bounds"] = "clip" 14 | self.isotonic_regression = SkLearnIR(**params) 15 | 16 | def fit(self, y: np.ndarray, scores: np.ndarray, X: np.ndarray = None): 17 | self.isotonic_regression.fit(scores, y) 18 | self._is_fitted = True 19 | return self 20 | 21 | def predict_proba(self, scores: np.ndarray, X: np.ndarray = None): 22 | check_is_fitted(self) 23 | return self.isotonic_regression.predict(scores) 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/rules/split_data.smk: -------------------------------------------------------------------------------- 1 | rule split_data: 2 | input: 3 | "data/complaints.csv.zip" 4 | output: 5 | x_train = "results/split_data/x_train.csv.zip", 6 | x_test = "results/split_data/x_test.csv.zip", 7 | y_train = "results/split_data/y_train.csv.zip", 8 | y_test = "results/split_data/y_test.csv.zip", 9 | benchmark = "results/split_data/benchmark.txt" 10 | params: 11 | nrows = config["nrows"] if "nrows" in config else None, 12 | random_state = config["random_state"] 13 | conda: 14 | "../envs/hiclass.yml" 15 | shell: 16 | """ 17 | /usr/bin/time -v \ 18 | -o {output.benchmark} \ 19 | python scripts/split_data.py \ 20 | --data {input} \ 21 | --x-train {output.x_train} \ 22 | --x-test {output.x_test} \ 23 | --y-train {output.y_train} \ 24 | --y-test {output.y_test} \ 25 | --random-state {params.random_state} \ 26 | --nrows {params.nrows} 27 | """ 28 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/lightgbm/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/flat/lightgbm/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/flat/lightgbm/predictions.csv.zip --classifier flat" 2 | User time (seconds): 556.90 3 | System time (seconds): 2.27 4 | Percent of CPU this job got: 664% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 1:24.14 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 1475340 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 878715 14 | Voluntary context switches: 480 15 | Involuntary context switches: 247096 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 8 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/random_forest/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/flat/random_forest/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/flat/random_forest/predictions.csv.zip --classifier flat" 2 | User time (seconds): 235.15 3 | System time (seconds): 77.72 4 | Percent of CPU this job got: 259% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 2:00.72 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 168561396 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 49000524 14 | Voluntary context switches: 5374 15 | Involuntary context switches: 224793 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/logistic_regression/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/flat/logistic_regression/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/flat/logistic_regression/predictions.csv.zip --classifier flat" 2 | User time (seconds): 34.03 3 | System time (seconds): 2.68 4 | Percent of CPU this job got: 104% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 0:35.30 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 1545352 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 859901 14 | Voluntary context switches: 252 15 | Involuntary context switches: 190661 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | 3 | import numpy as np 4 | 5 | from scripts.metrics import parse_args, compute_f1 6 | 7 | 8 | def test_parser(): 9 | parser = parse_args( 10 | [ 11 | "--predictions", 12 | "predictions.tsv", 13 | "--ground-truth", 14 | "ground_truth.tsv", 15 | "--metrics", 16 | "metrics.tsv", 17 | ] 18 | ) 19 | assert parser.predictions is not None 20 | assert "predictions.tsv" == parser.predictions 21 | assert parser.ground_truth is not None 22 | assert "ground_truth.tsv" == parser.ground_truth 23 | assert parser.metrics is not None 24 | assert "metrics.tsv" == parser.metrics 25 | 26 | 27 | def test_compute_f1(): 28 | ground_truth = "f1_hierarchical\n" 29 | ground_truth += "1.0\n" 30 | output = StringIO() 31 | y = np.array( 32 | [["Reports", "Credit"], ["Debt", "Mortgage"], ["Loan", "Student loan"]] 33 | ) 34 | compute_f1(y, y, output) 35 | output.seek(0) 36 | assert ground_truth == output.read() 37 | -------------------------------------------------------------------------------- /.github/workflows/test-pr.yml: -------------------------------------------------------------------------------- 1 | name: Test PR 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | 7 | lint: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - uses: psf/black@stable 12 | 13 | test: 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.9", "3.10", "3.11", "3.12"] 19 | os: [ubuntu-latest, macOS-latest, windows-latest] 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install -e ".[dev]" 30 | - name: Test with pytest 31 | run: | 32 | flake8 33 | pydocstyle hiclass tests 34 | pytest -v --cov=hiclass --cov-fail-under=90 --cov-report html 35 | coverage xml 36 | - name: Upload Coverage to Codecov 37 | uses: codecov/codecov-action@v2 38 | -------------------------------------------------------------------------------- /docs/source/algorithms/local_classifier_per_level.rst: -------------------------------------------------------------------------------- 1 | .. _local-classifier-per-level-overview: 2 | 3 | Local Classifier Per Level 4 | ========================== 5 | 6 | The local classifier per level approach consists of training a multi-class classifier for each level of the class taxonomy. An example is displayed in the figure below. 7 | 8 | .. figure:: local_classifier_per_level.svg 9 | :align: center 10 | 11 | Visual representation of the local classifier per level approach. 12 | 13 | Similar to the other hierarchical classifiers, the local classifier per level can also be trained in parallel, and prediction is performed in a top-down style to avoid inconsistencies. For example, supposing that for a given test example the classifier at the first level returns the probabilities 0.91 and 0.7 for classes "Reptile" and "Mammal", respectively, then the one with the highest probability is considered as the correct prediction, which in this case is class "Reptile". For the second level, only the probabilities for classes "Snake" and "Lizard" are considered and the one with the highest probability is the final prediction. 14 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/lightgbm/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: flat 2 | ## Base classifier: lightgbm 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'num_leaves': 31, 'n_estimators': 100, 'min_child_samples': 20}|[0.469, 0.423, 0.195, 0.409, 0.29]|0.357|0.100| 6 | |{'num_leaves': 62, 'n_estimators': 100, 'min_child_samples': 40}|[0.225, 0.47, 0.313, 0.373, 0.299]|0.336|0.082| 7 | |{'num_leaves': 31, 'n_estimators': 100, 'min_child_samples': 40}|[0.248, 0.37, 0.459, 0.438, 0.476]|0.398|0.083| 8 | |{'num_leaves': 62, 'n_estimators': 200, 'min_child_samples': 40}|[0.453, 0.177, 0.455, 0.357, 0.135]|0.316|0.135| 9 | |{'num_leaves': 62, 'n_estimators': 200, 'min_child_samples': 20}|[0.403, 0.471, 0.158, 0.385, 0.376]|0.358|0.106| 10 | |{'num_leaves': 31, 'n_estimators': 200, 'min_child_samples': 40}|[0.338, 0.397, 0.46, 0.401, 0.444]|0.408|0.043| 11 | |{'num_leaves': 31, 'n_estimators': 200, 'min_child_samples': 20}|[0.296, 0.48, 0.2, 0.203, 0.244]|0.285|0.104| 12 | |{'num_leaves': 62, 'n_estimators': 100, 'min_child_samples': 20}|[0.463, 0.368, 0.477, 0.076, 0.452]|0.367|0.151| 13 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/rules/train.smk: -------------------------------------------------------------------------------- 1 | rule train: 2 | input: 3 | x_train = "results/split_data/x_train.csv.zip", 4 | y_train = "results/split_data/y_train.csv.zip", 5 | best_parameters = "results/{model}/{classifier}/optimization_results.yaml", 6 | output: 7 | trained_model = "results/{model}/{classifier}/trained_model.sav", 8 | benchmark = "results/{model}/{classifier}/training_benchmark.txt" 9 | params: 10 | classifier = "{classifier}", 11 | model = "{model}", 12 | conda: 13 | "../envs/hiclass.yml" 14 | threads: 15 | config["threads"] 16 | resources: 17 | mem_gb = config["mem_gb"] 18 | shell: 19 | """ 20 | /usr/bin/time -v \ 21 | -o {output.benchmark} \ 22 | python scripts/train.py \ 23 | --n-jobs {threads} \ 24 | --x-train {input.x_train} \ 25 | --y-train {input.y_train} \ 26 | --trained-model {output.trained_model} \ 27 | --classifier {params.classifier} \ 28 | --model {params.model} \ 29 | --best-parameters {input.best_parameters} 30 | """ 31 | -------------------------------------------------------------------------------- /tests/test_ConstantClassifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from numpy.testing import assert_array_equal 4 | 5 | from hiclass.ConstantClassifier import ConstantClassifier 6 | 7 | 8 | def test_fit_1(): 9 | X = [1, 2, 3] 10 | y = ["a", "a", "a"] 11 | classifier = ConstantClassifier() 12 | classifier.fit(X, y) 13 | assert classifier.classes_ == "a" 14 | 15 | 16 | def test_fit_2(): 17 | X = [1, 2, 3] 18 | y = ["a", "b", "c"] 19 | classifier = ConstantClassifier() 20 | with pytest.raises(ValueError): 21 | classifier.fit(X, y) 22 | 23 | 24 | def test_predict_proba(): 25 | X = np.array([1, 2, 3]) 26 | classifier = ConstantClassifier() 27 | predict_proba = classifier.predict_proba(X) 28 | ground_truth = np.array([[1], [1], [1]]) 29 | assert_array_equal(ground_truth, predict_proba) 30 | 31 | 32 | def test_predict(): 33 | X = np.array([1, 2, 3]) 34 | classifier = ConstantClassifier() 35 | classifier.classes_ = "a" 36 | predictions = classifier.predict(X) 37 | ground_truth = np.array([["a"], ["a"], ["a"]]) 38 | assert_array_equal(ground_truth, predictions) 39 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/lightgbm/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_node/lightgbm/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_node/lightgbm/predictions.csv.zip --classifier local_classifier_per_node" 2 | User time (seconds): 901.92 3 | System time (seconds): 5.05 4 | Percent of CPU this job got: 783% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 1:55.71 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 3598540 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 2088661 14 | Voluntary context switches: 5456 15 | Involuntary context switches: 150437 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 8 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/lightgbm/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/flat/lightgbm/trained_model.sav --classifier lightgbm --model flat --best-parameters results/flat/lightgbm/optimization_results.yaml" 2 | User time (seconds): 47707.20 3 | System time (seconds): 43.99 4 | Percent of CPU this job got: 1114% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 1:11:24 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 4764032 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 20940825 14 | Voluntary context switches: 718248 15 | Involuntary context switches: 4907613 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/lightgbm/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_level/lightgbm/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_level/lightgbm/predictions.csv.zip --classifier local_classifier_per_level" 2 | User time (seconds): 684.15 3 | System time (seconds): 5.29 4 | Percent of CPU this job got: 616% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 1:51.76 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 8796884 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 3050846 14 | Voluntary context switches: 3319 15 | Involuntary context switches: 290984 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/random_forest/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_level/random_forest/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_level/random_forest/predictions.csv.zip --classifier local_classifier_per_level" 2 | User time (seconds): 191.85 3 | System time (seconds): 90.04 4 | Percent of CPU this job got: 99% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 4:44.13 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 199640884 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 72801163 14 | Voluntary context switches: 1712 15 | Involuntary context switches: 225086 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/random_forest/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_node/random_forest/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_node/random_forest/predictions.csv.zip --classifier local_classifier_per_node" 2 | User time (seconds): 1124.76 3 | System time (seconds): 26.68 4 | Percent of CPU this job got: 99% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 19:16.97 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 55965648 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 21032658 14 | Voluntary context switches: 3305 15 | Involuntary context switches: 308560 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/lightgbm/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_node 2 | ## Base classifier: lightgbm 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'num_leaves': 62, 'n_estimators': 100, 'min_child_samples': 40}|[0.775, 0.778, 0.777, 0.777, 0.774]|0.776|0.001| 6 | |{'num_leaves': 62, 'n_estimators': 200, 'min_child_samples': 20}|[0.789, 0.789, 0.789, 0.784, 0.788]|0.788|0.002| 7 | |{'num_leaves': 31, 'n_estimators': 200, 'min_child_samples': 40}|[0.779, 0.781, 0.78, 0.78, 0.777]|0.779|0.001| 8 | |{'num_leaves': 31, 'n_estimators': 100, 'min_child_samples': 20}|[0.769, 0.771, 0.757, 0.77, 0.768]|0.767|0.005| 9 | |{'num_leaves': 31, 'n_estimators': 200, 'min_child_samples': 20}|[0.782, 0.783, 0.769, 0.781, 0.78]|0.779|0.005| 10 | |{'num_leaves': 31, 'n_estimators': 100, 'min_child_samples': 40}|[0.77, 0.772, 0.768, 0.771, 0.77]|0.770|0.001| 11 | |{'num_leaves': 62, 'n_estimators': 200, 'min_child_samples': 40}|[0.784, 0.786, 0.784, 0.784, 0.783]|0.784|0.001| 12 | |{'num_leaves': 62, 'n_estimators': 100, 'min_child_samples': 20}|[0.776, 0.778, 0.774, 0.776, 0.775]|0.776|0.001| 13 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/random_forest/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/flat/random_forest/trained_model.sav --classifier random_forest --model flat --best-parameters results/flat/random_forest/optimization_results.yaml" 2 | User time (seconds): 55970.21 3 | System time (seconds): 188.34 4 | Percent of CPU this job got: 1076% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 1:26:55 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 170290436 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 98700661 14 | Voluntary context switches: 18058 15 | Involuntary context switches: 5556045 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/lightgbm/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_level 2 | ## Base classifier: lightgbm 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'num_leaves': 62, 'n_estimators': 100, 'min_child_samples': 20}|[0.676, 0.666, 0.66, 0.657, 0.661]|0.664|0.007| 6 | |{'num_leaves': 31, 'n_estimators': 200, 'min_child_samples': 40}|[0.661, 0.668, 0.643, 0.664, 0.658]|0.659|0.009| 7 | |{'num_leaves': 62, 'n_estimators': 100, 'min_child_samples': 40}|[0.66, 0.669, 0.676, 0.665, 0.664]|0.667|0.006| 8 | |{'num_leaves': 31, 'n_estimators': 200, 'min_child_samples': 20}|[0.649, 0.651, 0.66, 0.658, 0.655]|0.655|0.004| 9 | |{'num_leaves': 62, 'n_estimators': 200, 'min_child_samples': 40}|[0.651, 0.66, 0.668, 0.657, 0.654]|0.658|0.006| 10 | |{'num_leaves': 31, 'n_estimators': 100, 'min_child_samples': 40}|[0.675, 0.681, 0.652, 0.669, 0.672]|0.670|0.010| 11 | |{'num_leaves': 62, 'n_estimators': 200, 'min_child_samples': 20}|[0.668, 0.654, 0.654, 0.651, 0.651]|0.656|0.007| 12 | |{'num_leaves': 31, 'n_estimators': 100, 'min_child_samples': 20}|[0.645, 0.646, 0.655, 0.664, 0.657]|0.653|0.007| 13 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/logistic_regression/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_node/logistic_regression/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_node/logistic_regression/predictions.csv.zip --classifier local_classifier_per_node" 2 | User time (seconds): 34.02 3 | System time (seconds): 2.83 4 | Percent of CPU this job got: 104% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 0:35.35 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 1820992 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 840372 14 | Voluntary context switches: 258 15 | Involuntary context switches: 92586 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/lightgbm/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_parent_node/lightgbm/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_parent_node/lightgbm/predictions.csv.zip --classifier local_classifier_per_parent_node" 2 | User time (seconds): 279.34 3 | System time (seconds): 3.34 4 | Percent of CPU this job got: 493% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 0:57.30 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 1723380 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 1652406 14 | Voluntary context switches: 1081 15 | Involuntary context switches: 237075 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 8 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/logistic_regression/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_level/logistic_regression/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_level/logistic_regression/predictions.csv.zip --classifier local_classifier_per_level" 2 | User time (seconds): 56.31 3 | System time (seconds): 5.31 4 | Percent of CPU this job got: 101% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 1:00.75 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 8860800 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 3136787 14 | Voluntary context switches: 278 15 | Involuntary context switches: 216171 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /docs/source/get_started/venv.rst: -------------------------------------------------------------------------------- 1 | venv 2 | ==== 3 | 4 | If you are using Python 3, you should already have the :literal:`venv` module installed with the standard library. Create a directory for HiClass within your virtual environment: 5 | 6 | .. code-block:: bash 7 | 8 | mkdir hiclass-environment && cd hiclass-environment 9 | 10 | 11 | This will create a folder called :literal:`hiclass-environment` in your current working directory. Then you should create a new virtual environment in this directory by running: 12 | 13 | .. tabs:: 14 | 15 | .. code-tab:: bash 16 | :caption: GNU/Linux or macOS 17 | 18 | python -m venv env/hiclass-environment 19 | 20 | .. code-tab:: bash 21 | :caption: Windows 22 | 23 | python -m venv env\hiclass-environment 24 | 25 | Activate this virtual environment: 26 | 27 | .. tabs:: 28 | 29 | .. code-tab:: bash 30 | :caption: GNU/Linux or macOS 31 | 32 | source env/hiclass-environment/bin/activate 33 | 34 | .. code-tab:: bash 35 | :caption: Windows 36 | 37 | .\env\hiclass-environment\Scripts\activate 38 | 39 | To exit the environment: 40 | 41 | .. code-block:: bash 42 | 43 | deactivate 44 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/lightgbm/optimization_results.md: -------------------------------------------------------------------------------- 1 | # Model: local_classifier_per_parent_node 2 | ## Base classifier: lightgbm 3 | |Parameters|Scores|Average|Standard deviation| 4 | |----------|------|-------|------------------| 5 | |{'num_leaves': 62, 'n_estimators': 200, 'min_child_samples': 40}|[0.729, 0.736, 0.738, 0.729, 0.727]|0.732|0.005| 6 | |{'num_leaves': 62, 'n_estimators': 100, 'min_child_samples': 20}|[0.744, 0.738, 0.734, 0.732, 0.734]|0.736|0.004| 7 | |{'num_leaves': 31, 'n_estimators': 200, 'min_child_samples': 40}|[0.732, 0.733, 0.719, 0.741, 0.73]|0.731|0.007| 8 | |{'num_leaves': 62, 'n_estimators': 100, 'min_child_samples': 40}|[0.741, 0.745, 0.748, 0.738, 0.738]|0.742|0.004| 9 | |{'num_leaves': 31, 'n_estimators': 100, 'min_child_samples': 20}|[0.72, 0.717, 0.726, 0.739, 0.724]|0.725|0.008| 10 | |{'num_leaves': 31, 'n_estimators': 200, 'min_child_samples': 20}|[0.725, 0.723, 0.732, 0.733, 0.722]|0.727|0.005| 11 | |{'num_leaves': 31, 'n_estimators': 100, 'min_child_samples': 40}|[0.748, 0.748, 0.73, 0.746, 0.746]|0.743|0.007| 12 | |{'num_leaves': 62, 'n_estimators': 200, 'min_child_samples': 20}|[0.735, 0.724, 0.726, 0.723, 0.721]|0.726|0.005| 13 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/flat/logistic_regression/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/flat/logistic_regression/trained_model.sav --classifier logistic_regression --model flat --best-parameters results/flat/logistic_regression/optimization_results.yaml" 2 | User time (seconds): 3407.11 3 | System time (seconds): 27.12 4 | Percent of CPU this job got: 99% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 57:16.20 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 3303628 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 20909985 14 | Voluntary context switches: 4108 15 | Involuntary context switches: 432244 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 8 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/random_forest/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_parent_node/random_forest/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_parent_node/random_forest/predictions.csv.zip --classifier local_classifier_per_parent_node" 2 | User time (seconds): 153.62 3 | System time (seconds): 23.07 4 | Percent of CPU this job got: 100% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 2:56.25 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 49191956 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 17647998 14 | Voluntary context switches: 4211 15 | Involuntary context switches: 132350 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/logistic_regression/prediction_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/predict.py --trained-model results/local_classifier_per_parent_node/logistic_regression/trained_model.sav --x-test results/split_data/x_test.csv.zip --predictions results/local_classifier_per_parent_node/logistic_regression/predictions.csv.zip --classifier local_classifier_per_parent_node" 2 | User time (seconds): 34.70 3 | System time (seconds): 3.97 4 | Percent of CPU this job got: 103% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 0:37.22 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 1566736 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 1429523 14 | Voluntary context switches: 273 15 | Involuntary context switches: 136455 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/lightgbm/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_level/lightgbm/trained_model.sav --classifier lightgbm --model local_classifier_per_level --best-parameters results/local_classifier_per_level/lightgbm/optimization_results.yaml" 2 | User time (seconds): 7024.20 3 | System time (seconds): 53.93 4 | Percent of CPU this job got: 123% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 1:35:26 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 9481704 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 100554869 14 | Voluntary context switches: 5535 15 | Involuntary context switches: 977451 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/lightgbm/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_node/lightgbm/trained_model.sav --classifier lightgbm --model local_classifier_per_node --best-parameters results/local_classifier_per_node/lightgbm/optimization_results.yaml" 2 | User time (seconds): 32925.40 3 | System time (seconds): 158.01 4 | Percent of CPU this job got: 985% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 55:55.98 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 33474944 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 159692014 14 | Voluntary context switches: 53730 15 | Involuntary context switches: 3540181 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /docs/source/get_started/verify.rst: -------------------------------------------------------------------------------- 1 | Verify a successful installation 2 | ================================ 3 | 4 | To check that HiClass is installed, start the Python interpreter by running ``python`` on the terminal, then try to import HiClass: 5 | 6 | .. code-block:: python 7 | 8 | Python 3.8.13 (default, Mar 28 2022, 11:38:47) 9 | [GCC 7.5.0] :: Anaconda, Inc. on linux 10 | Type "help", "copyright", "credits" or "license" for more information. 11 | >>> import hiclass 12 | >>> 13 | 14 | If everything goes smoothly, it means your installation works. However, you should see an error if HiClass was not installed successfully. For example: 15 | 16 | .. code-block:: python 17 | 18 | Python 3.9.7 (default, Sep 16 2021, 13:09:58) 19 | [GCC 7.5.0] :: Anaconda, Inc. on linux 20 | Type "help", "copyright", "credits" or "license" for more information. 21 | >>> import hiclass 22 | Traceback (most recent call last): 23 | File "", line 1, in 24 | ModuleNotFoundError: No module named 'hiclass' 25 | 26 | If you have any problems with your installation, please open an issue describing it at `https://github.com/scikit-learn-contrib/hiclass/issues `_. 27 | -------------------------------------------------------------------------------- /docs/examples/plot_empty_levels.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ========================== 4 | Different Number of Levels 5 | ========================== 6 | 7 | HiClass supports different number of levels in the hierarchy. 8 | For this example, we will train a local classifier per node 9 | with a hierarchy similar to the following image: 10 | 11 | .. figure:: ../algorithms/local_classifier_per_node.svg 12 | :align: center 13 | """ 14 | import numpy as np 15 | from sklearn.linear_model import LogisticRegression 16 | 17 | from hiclass import LocalClassifierPerNode 18 | 19 | # Define data 20 | X_train = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] 21 | X_test = [[9, 10], [7, 8], [5, 6], [3, 4], [1, 2]] 22 | Y_train = np.array( 23 | [ 24 | ["Bird"], 25 | ["Reptile", "Snake"], 26 | ["Reptile", "Lizard"], 27 | ["Mammal", "Cat"], 28 | ["Mammal", "Wolf", "Dog"], 29 | ], 30 | dtype=object, 31 | ) 32 | 33 | # Use random forest classifiers for every node 34 | rf = LogisticRegression() 35 | classifier = LocalClassifierPerNode(local_classifier=rf) 36 | 37 | # Train local classifier per node 38 | classifier.fit(X_train, Y_train) 39 | 40 | # Predict 41 | predictions = classifier.predict(X_test) 42 | print(predictions) 43 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/random_forest/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_node/random_forest/trained_model.sav --classifier random_forest --model local_classifier_per_node --best-parameters results/local_classifier_per_node/random_forest/optimization_results.yaml" 2 | User time (seconds): 175580.64 3 | System time (seconds): 158.39 4 | Percent of CPU this job got: 713% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 6:50:17 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 57871116 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 57422298 14 | Voluntary context switches: 71165 15 | Involuntary context switches: 17531687 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/rules/tune.smk: -------------------------------------------------------------------------------- 1 | rule tune: 2 | input: 3 | x_train = "results/split_data/x_train.csv.zip", 4 | y_train = "results/split_data/y_train.csv.zip" 5 | output: 6 | best_parameters = "results/{model}/{classifier}/optimization_results.yaml", 7 | params: 8 | classifier = "{classifier}", 9 | model = "{model}", 10 | output_dir = "results/{model}/{classifier}", 11 | study_name = "{model}_{classifier}", 12 | n_splits = config["n_splits"], 13 | conda: 14 | "../envs/hiclass.yml" 15 | threads: 16 | config["threads"] 17 | resources: 18 | mem_gb = config["mem_gb"] 19 | shell: 20 | """ 21 | python scripts/tune.py \ 22 | --config-name {params.classifier} \ 23 | --multirun \ 24 | 'classifier={params.classifier}' \ 25 | 'model={params.model}' \ 26 | 'n_jobs={threads}' \ 27 | 'x_train={input.x_train}' \ 28 | 'y_train={input.y_train}' \ 29 | 'output_dir={params.output_dir}' \ 30 | 'mem_gb={resources.mem_gb}' \ 31 | 'n_splits={params.n_splits}' \ 32 | hydra.sweep.dir={params.output_dir} \ 33 | hydra.sweeper.study_name={params.study_name} \ 34 | """ 35 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/random_forest/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_level/random_forest/trained_model.sav --classifier random_forest --model local_classifier_per_level --best-parameters results/local_classifier_per_level/random_forest/optimization_results.yaml" 2 | User time (seconds): 60475.69 3 | System time (seconds): 829.12 4 | Percent of CPU this job got: 174% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 9:45:18 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 200689644 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 147569336 14 | Voluntary context switches: 22850 15 | Involuntary context switches: 6300609 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/lightgbm/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_parent_node/lightgbm/trained_model.sav --classifier lightgbm --model local_classifier_per_parent_node --best-parameters results/local_classifier_per_parent_node/lightgbm/optimization_results.yaml" 2 | User time (seconds): 3489.46 3 | System time (seconds): 27.48 4 | Percent of CPU this job got: 209% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 28:00.93 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 9434040 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 27807258 14 | Voluntary context switches: 22109 15 | Involuntary context switches: 564583 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_node/logistic_regression/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_node/logistic_regression/trained_model.sav --classifier logistic_regression --model local_classifier_per_node --best-parameters results/local_classifier_per_node/logistic_regression/optimization_results.yaml" 2 | User time (seconds): 1126.79 3 | System time (seconds): 679.72 4 | Percent of CPU this job got: 736% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 4:05.23 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 22026536 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 18755931 14 | Voluntary context switches: 80018 15 | Involuntary context switches: 372099166 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_level/logistic_regression/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_level/logistic_regression/trained_model.sav --classifier logistic_regression --model local_classifier_per_level --best-parameters results/local_classifier_per_level/logistic_regression/optimization_results.yaml" 2 | User time (seconds): 11364.30 3 | System time (seconds): 6577.55 4 | Percent of CPU this job got: 206% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 2:24:44 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 12006424 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 400972427 14 | Voluntary context switches: 69089 15 | Involuntary context switches: 306163777 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/random_forest/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_parent_node/random_forest/trained_model.sav --classifier random_forest --model local_classifier_per_parent_node --best-parameters results/local_classifier_per_parent_node/random_forest/optimization_results.yaml" 2 | User time (seconds): 38410.18 3 | System time (seconds): 236.38 4 | Percent of CPU this job got: 141% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 7:34:47 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 50648180 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 36768287 14 | Voluntary context switches: 52700 15 | Involuntary context switches: 4129465 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 8 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /docs/source/algorithms/local_classifier_per_node.rst: -------------------------------------------------------------------------------- 1 | .. _local-classifier-per-node-overview: 2 | 3 | Local Classifier Per Node 4 | ========================= 5 | 6 | One of the most popular approaches in the literature, the local classifier per node consists of training one binary classifier for each node of the class taxonomy, except for the root node. A visual representation of the local classifier per node is shown in the image below. 7 | 8 | .. figure:: local_classifier_per_node.svg 9 | :align: center 10 | 11 | Visual representation of the local classifier per node approach. 12 | 13 | .. toctree:: 14 | :hidden: 15 | 16 | training_policies 17 | 18 | Each binary classifier is trained in parallel using the library `Ray `_. In order to avoid inconsistencies, prediction is performed in a top-down manner. For example, given a hypothetical test example, the local classifier per node firstly queries the binary classifiers at nodes "Reptile" and "Mammal". Let's suppose that in this case the probability of the test example belonging to class "Reptile" is 0.8, while the probability of belonging to class "Mammal" is 0.5, then class "Reptile" is picked. At the next level, only the classifiers at nodes "Snake" and "Lizard" are queried, and again the one with the highest probability is selected. 19 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/results/local_classifier_per_parent_node/logistic_regression/training_benchmark.txt: -------------------------------------------------------------------------------- 1 | Command being timed: "python scripts/train.py --n-jobs 12 --x-train results/split_data/x_train.csv.zip --y-train results/split_data/y_train.csv.zip --trained-model results/local_classifier_per_parent_node/logistic_regression/trained_model.sav --classifier logistic_regression --model local_classifier_per_parent_node --best-parameters results/local_classifier_per_parent_node/logistic_regression/optimization_results.yaml" 2 | User time (seconds): 3842.86 3 | System time (seconds): 6103.31 4 | Percent of CPU this job got: 516% 5 | Elapsed (wall clock) time (h:mm:ss or m:ss): 32:05.04 6 | Average shared text size (kbytes): 0 7 | Average unshared data size (kbytes): 0 8 | Average stack size (kbytes): 0 9 | Average total size (kbytes): 0 10 | Maximum resident set size (kbytes): 9744672 11 | Average resident set size (kbytes): 0 12 | Major (requiring I/O) page faults: 0 13 | Minor (reclaiming a frame) page faults: 49920910 14 | Voluntary context switches: 493049 15 | Involuntary context switches: 1928027088 16 | Swaps: 0 17 | File system inputs: 0 18 | File system outputs: 0 19 | Socket messages sent: 0 20 | Socket messages received: 0 21 | Signals delivered: 0 22 | Page size (bytes): 4096 23 | Exit status: 0 24 | -------------------------------------------------------------------------------- /docs/examples/plot_model_persistence.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ===================== 4 | Model Persistence 5 | ===================== 6 | 7 | HiClass is fully compatible with Pickle. 8 | Pickle can be used to easily store machine learning models on disk. 9 | In this example, we demonstrate how to use pickle to store and load trained classifiers. 10 | """ 11 | import pickle 12 | 13 | from sklearn.linear_model import LogisticRegression 14 | 15 | from hiclass import LocalClassifierPerLevel 16 | 17 | # Define data 18 | X_train = [[1, 2], [3, 4], [5, 6], [7, 8]] 19 | X_test = [[7, 8], [5, 6], [3, 4], [1, 2]] 20 | Y_train = [ 21 | ["Animal", "Mammal", "Sheep"], 22 | ["Animal", "Mammal", "Cow"], 23 | ["Animal", "Reptile", "Snake"], 24 | ["Animal", "Reptile", "Lizard"], 25 | ] 26 | 27 | # Use Logistic Regression classifiers for every level in the hierarchy 28 | lr = LogisticRegression() 29 | classifier = LocalClassifierPerLevel(local_classifier=lr) 30 | 31 | # Train local classifier per level 32 | classifier.fit(X_train, Y_train) 33 | 34 | # Save the model to disk 35 | filename = "trained_model.sav" 36 | pickle.dump(classifier, open(filename, "wb")) 37 | 38 | # Some time in the future... 39 | 40 | # Load the model from disk 41 | loaded_model = pickle.load(open(filename, "rb")) 42 | 43 | # Predict 44 | predictions = loaded_model.predict(X_test) 45 | print(predictions) 46 | -------------------------------------------------------------------------------- /.github/workflows/deploy-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to PyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | 10 | bump: 11 | runs-on: ubuntu-latest 12 | permissions: 13 | contents: write 14 | steps: 15 | - uses: actions/checkout@v3 16 | with: 17 | fetch-depth: '0' 18 | - name: Bump version and push tag 19 | uses: anothrNick/github-tag-action@1.64.0 20 | env: 21 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 22 | WITH_V: true 23 | DEFAULT_BUMP: patch 24 | 25 | build-n-publish: 26 | needs: bump 27 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI 28 | runs-on: ubuntu-latest 29 | steps: 30 | - uses: actions/checkout@v1 31 | - name: Set up Python 3.12 32 | uses: actions/setup-python@v2 33 | with: 34 | python-version: 3.12 35 | - name: Install pypa/build 36 | run: | 37 | git fetch --tags 38 | git branch --create-reflog main origin/main 39 | python -m pip install build --user . 40 | - name: Build a binary wheel and a source tarball 41 | run: | 42 | python -m build --sdist --wheel --outdir dist/ . 43 | - name: Publish distribution 📦 to PyPI 44 | uses: pypa/gh-action-pypi-publish@release/v1 45 | with: 46 | user: __token__ 47 | password: ${{ secrets.PYPI_API_TOKEN }} 48 | verbose: true 49 | -------------------------------------------------------------------------------- /docs/examples/plot_pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ===================== 4 | Building Pipelines 5 | ===================== 6 | 7 | HiClass can be adopted in scikit-learn pipelines, and fully supports sparse matrices as input. 8 | This example desmonstrates the use of both of these features. 9 | """ 10 | from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer 11 | from sklearn.linear_model import LogisticRegression 12 | from sklearn.pipeline import Pipeline 13 | 14 | from hiclass import LocalClassifierPerParentNode 15 | 16 | # Define data 17 | X_train = [ 18 | "Struggling to repay loan", 19 | "Unable to get annual report", 20 | ] 21 | X_test = [ 22 | "Unable to get annual report", 23 | "Struggling to repay loan", 24 | ] 25 | Y_train = [["Loan", "Student loan"], ["Credit reporting", "Reports"]] 26 | 27 | # We will use logistic regression classifiers for every parent node 28 | lr = LogisticRegression() 29 | 30 | # Let's build a pipeline using CountVectorizer and TfidfTransformer 31 | # to extract features as sparse matrices 32 | pipeline = Pipeline( 33 | [ 34 | ("count", CountVectorizer()), 35 | ("tfidf", TfidfTransformer()), 36 | ("lcppn", LocalClassifierPerParentNode(local_classifier=lr)), 37 | ] 38 | ) 39 | 40 | # Now, let's train a local classifier per parent node 41 | pipeline.fit(X_train, Y_train) 42 | 43 | # Finally, let's predict using the pipeline 44 | predictions = pipeline.predict(X_test) 45 | print(predictions) 46 | -------------------------------------------------------------------------------- /docs/source/api/classifiers.rst: -------------------------------------------------------------------------------- 1 | .. _classifiers: 2 | 3 | Hierarchical Classifiers 4 | ======================== 5 | Shared classes 6 | -------------- 7 | 8 | HierarchicalClassifier 9 | ^^^^^^^^^^^^^^^^^^^^^^ 10 | .. autoclass:: HierarchicalClassifier.HierarchicalClassifier 11 | :members: 12 | :special-members: __init__ 13 | 14 | 15 | .................................. 16 | 17 | ConstantClassifier 18 | ^^^^^^^^^^^^^^^^^^ 19 | .. autoclass:: ConstantClassifier.ConstantClassifier 20 | :members: 21 | :special-members: __init__ 22 | 23 | .................................. 24 | 25 | LocalClassifierPerLevel 26 | ----------------------- 27 | .. autoclass:: LocalClassifierPerLevel.LocalClassifierPerLevel 28 | :members: 29 | :show-inheritance: 30 | :inherited-members: 31 | :special-members: __init__ 32 | 33 | .................................. 34 | 35 | LocalClassifierPerNode 36 | ---------------------- 37 | .. autoclass:: LocalClassifierPerNode.LocalClassifierPerNode 38 | :members: 39 | :show-inheritance: 40 | :inherited-members: 41 | :special-members: __init__ 42 | 43 | .................................. 44 | 45 | LocalClassifierPerParentNode 46 | ---------------------------- 47 | .. autoclass:: LocalClassifierPerParentNode.LocalClassifierPerParentNode 48 | :members: 49 | :show-inheritance: 50 | :inherited-members: 51 | :special-members: __init__ 52 | 53 | Flat Classifier 54 | =============== 55 | 56 | FlatClassifier 57 | ^^^^^^^^^^^^^^^^^^^^^^ 58 | .. autoclass:: FlatClassifier.FlatClassifier 59 | :members: 60 | :special-members: __init__ 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, dacs-hpi 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /docs/source/get_started/full_example.rst: -------------------------------------------------------------------------------- 1 | Hello HiClass! 2 | ============== 3 | 4 | It is now time to stitch the code together. Here is the full example: 5 | 6 | .. code-block:: python 7 | 8 | """Contents of hello_hiclass.py""" 9 | from hiclass import LocalClassifierPerNode 10 | from sklearn.ensemble import RandomForestClassifier 11 | 12 | # Define data 13 | X_train = [[1], [2], [3], [4]] 14 | X_test = [[4], [3], [2], [1]] 15 | Y_train = [ 16 | ['Animal', 'Mammal', 'Sheep'], 17 | ['Animal', 'Mammal', 'Cow'], 18 | ['Animal', 'Reptile', 'Snake'], 19 | ['Animal', 'Reptile', 'Lizard'], 20 | ] 21 | 22 | # Use random forest classifiers for every node 23 | rf = RandomForestClassifier() 24 | classifier = LocalClassifierPerNode(local_classifier=rf) 25 | 26 | # Train local classifier per node 27 | classifier.fit(X_train, Y_train) 28 | 29 | # Predict 30 | predictions = classifier.predict(X_test) 31 | print(predictions) 32 | 33 | Save the code above in a file called :literal:`hello_hiclass.py`, then open a terminal and run the following command: 34 | 35 | .. code-block:: bash 36 | 37 | python hello_hiclass.py 38 | 39 | The array below should be printed on the terminal: 40 | 41 | .. code-block:: python 42 | 43 | [['Animal' 'Reptile' 'Lizard'] 44 | ['Animal' 'Reptile' 'Snake'] 45 | ['Animal' 'Mammal' 'Cow'] 46 | ['Animal' 'Mammal' 'Sheep']] 47 | 48 | There is more to HiClass than what is shown in this "Hello World" example, such as training with missing leaf nodes, storing trained models and computation of hierarchical metrics. These concepts are covered in the :ref:`Gallery of Examples`. 49 | -------------------------------------------------------------------------------- /hiclass/probability_combiner/ProbabilityCombiner.py: -------------------------------------------------------------------------------- 1 | """Abstract class defining the structure of a probability combiner.""" 2 | 3 | import abc 4 | import numpy as np 5 | from typing import List 6 | from collections import defaultdict 7 | from networkx.exception import NetworkXError 8 | from hiclass import HierarchicalClassifier 9 | from hiclass._hiclass_utils import _normalize_probabilities 10 | 11 | 12 | class ProbabilityCombiner(abc.ABC): 13 | """Abstract class defining the structure of a probability combiner.""" 14 | 15 | def __init__( 16 | self, classifier: HierarchicalClassifier, normalize: bool = True 17 | ) -> None: 18 | """Initialize probability combiner object.""" 19 | self.classifier = classifier 20 | self.normalize = normalize 21 | 22 | @abc.abstractmethod 23 | def combine(self, proba: List[np.ndarray]) -> List[np.ndarray]: 24 | """Combine probabilities over multiple levels.""" 25 | ... 26 | 27 | def _normalize(self, proba: List[np.ndarray]): 28 | return _normalize_probabilities(proba) 29 | 30 | def _find_predecessors(self, level: int): 31 | predecessors = defaultdict(list) 32 | for node in self.classifier.global_classes_[level]: 33 | if self.classifier.hierarchy_.has_node(node): 34 | predecessor = list(self.classifier.hierarchy_.predecessors(node))[0] 35 | predecessor_name = str(predecessor).split(self.classifier.separator_)[ 36 | -1 37 | ] 38 | node_name = str(node).split(self.classifier.separator_)[-1] 39 | predecessors[node_name].append(predecessor_name) 40 | return predecessors 41 | -------------------------------------------------------------------------------- /docs/source/algorithms/metrics.rst: -------------------------------------------------------------------------------- 1 | .. _metrics-overview: 2 | 3 | Metrics 4 | ==================== 5 | 6 | Classification Metrics 7 | -------------- 8 | 9 | According to [1]_, the use of flat classification metrics might not be adequate to give enough insight of which algorithm is better at classifying hierarchical data. Hence, in HiClass we implemented the metrics of hierarchical precision (hP), hierarchical recall (hR) and hierarchical F-score (hF), which are extensions of the renowned metrics of precision, recall and F-score, but tailored to the hierarchical classification scenario. These hierarchical counterparts were initially proposed by [2]_, and are defined as follows: 10 | 11 | :math:`\displaystyle{hP = \frac{\sum_i|\alpha_i\cap\beta_i|}{\sum_i|\alpha_i|}}`, :math:`\displaystyle{hR = \frac{\sum_i|\alpha_i\cap\beta_i|}{\sum_i|\beta_i|}}`, :math:`\displaystyle{hF = \frac{2 \times hP \times hR}{hP + hR}}` 12 | 13 | where :math:`\alpha_i` is the set consisting of the most specific classes predicted for test example :math:`i` and all their ancestor classes, while :math:`\beta_i` is the set containing the true most specific classes of test example :math:`i` and all their ancestors, with summations computed over all test examples. 14 | 15 | Calibration Metrics 16 | -------------- 17 | 18 | 19 | .. [1] Silla, C. N., & Freitas, A. A. (2011). A survey of hierarchical classification across different application domains. Data Mining and Knowledge Discovery, 22(1), 31-72. 20 | 21 | .. [2] Kiritchenko, S., Matwin, S., Nock, R., & Famili, A. F. (2006, June). Learning and evaluation in the presence of class hierarchies: Application to text categorization. In Conference of the Canadian Society for Computational Studies of Intelligence (pp. 395-406). Springer, Berlin, Heidelberg. 22 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/envs/snakemake.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - bioconda 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=4.5 8 | - amply=0.1.4 9 | - appdirs=1.4.4 10 | - attrs=20.3.0 11 | - brotlipy=0.7.0 12 | - ca-certificates=2020.12.5 13 | - certifi=2020.12.5 14 | - cffi=1.14.5 15 | - chardet=4.0.0 16 | - coincbc=2.10.5 17 | - configargparse=1.3 18 | - cryptography=3.4.4 19 | - datrie=0.8.2 20 | - docutils=0.16 21 | - gitdb=4.0.5 22 | - gitpython=3.1.13 23 | - idna=2.10 24 | - importlib-metadata=3.4.0 25 | - importlib_metadata=3.4.0 26 | - ipython_genutils=0.2.0 27 | - jsonschema=3.2.0 28 | - jupyter_core=4.7.1 29 | - ld_impl_linux-64=2.35.1 30 | - libblas=3.9.0 31 | - libcblas=3.9.0 32 | - libffi=3.3 33 | - libgcc-ng=9.3.0 34 | - libgfortran-ng=9.3.0 35 | - libgfortran5=9.3.0 36 | - libgomp=9.3.0 37 | - liblapack=3.9.0 38 | - libopenblas=0.3.12 39 | - libstdcxx-ng=9.3.0 40 | - nbformat=5.1.2 41 | - ncurses=6.2 42 | - openssl=1.1.1j 43 | - pandas=1.2.2 44 | - pip=21.0.1 45 | - psutil=5.8.0 46 | - pulp=2.3.1 47 | - pycparser=2.20 48 | - pyopenssl=20.0.1 49 | - pyparsing=2.4.7 50 | - pyrsistent=0.17.3 51 | - pysocks=1.7.1 52 | - python=3.9.2 53 | - python_abi=3.9 54 | - pyyaml=5.4.1 55 | - ratelimiter=1.2.0 56 | - readline=8.0 57 | - requests=2.25.1 58 | - setuptools=49.6.0 59 | - six=1.15.0 60 | - smmap=3.0.5 61 | - snakemake-minimal=5.32.2 62 | - sqlite=3.34.0 63 | - tk=8.6.10 64 | - toposort=1.6 65 | - traitlets=5.0.5 66 | - tzdata=2021a 67 | - urllib3=1.26.3 68 | - wheel=0.36.2 69 | - wrapt=1.12.1 70 | - xz=5.2.5 71 | - yaml=0.2.5 72 | - zipp=3.4.0 73 | - zlib=1.2.11 74 | - biopython=1.78 75 | -------------------------------------------------------------------------------- /hiclass/probability_combiner/MultiplyCombiner.py: -------------------------------------------------------------------------------- 1 | """Defines the MultiplyCombiner.""" 2 | 3 | import numpy as np 4 | from hiclass.probability_combiner.ProbabilityCombiner import ProbabilityCombiner 5 | from typing import List 6 | 7 | 8 | class MultiplyCombiner(ProbabilityCombiner): 9 | """Combine probabilities of multiple levels by multiplication.""" 10 | 11 | def combine(self, proba: List[np.ndarray]): 12 | """Combine probabilities of each level with probabilities of previous levels. 13 | 14 | Multiply node probabilities with the probabilities of its predecessors. 15 | """ 16 | res = [proba[0]] 17 | for level in range(1, self.classifier.max_levels_): 18 | level_probs = np.zeros_like(proba[level]) 19 | # find all predecessors of a node 20 | predecessors = self._find_predecessors(level) 21 | 22 | for node in predecessors.keys(): 23 | index = self.classifier.class_to_index_mapping_[level][node] 24 | # find indices of all predecessors 25 | predecessor_indices = [ 26 | self.classifier.class_to_index_mapping_[level - 1][predecessor] 27 | for predecessor in predecessors[node] 28 | ] 29 | # combine probabilities of all predecessors 30 | predecessors_combined_prob = np.sum( 31 | [res[level - 1][:, pre_index] for pre_index in predecessor_indices], 32 | axis=0, 33 | ) 34 | level_probs[:, index] = ( 35 | predecessors_combined_prob * proba[level][:, index] 36 | ) 37 | 38 | res.append(level_probs) 39 | return self._normalize(res) if self.normalize else res 40 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help 2 | help: ## display this help 3 | @grep -E '^[a-z.A-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | .PHONY: mamba 6 | mamba: ## install mamba 7 | conda install -n base -c conda-forge mamba 8 | 9 | .PHONY: snakemake 10 | snakemake: ## install snakemake 11 | -mamba remove -n snakemake --all -y 12 | -mamba create -c conda-forge -c bioconda -n snakemake snakemake -y 13 | 14 | .PHONY: local 15 | local: ## Run pipeline without slurm 16 | -sed -i "s?workdir.*?workdir: `pwd`?" configs/snakemake.yml 17 | -sed -i "s?threads.*?threads: 12?" configs/snakemake.yml 18 | -sed -i "s?#nrows.*?nrows: 2000?" configs/snakemake.yml 19 | -snakemake --printshellcmds --reason --use-conda --cores 12 --conda-frontend mamba 20 | -sed -i "s?nrows.*?#nrows: 2000?" configs/snakemake.yml 21 | 22 | .PHONY: delab 23 | delab: ## Run pipeline on delab with slurm 24 | -sed -i "s?workdir.*?workdir: `pwd`?" configs/snakemake.yml 25 | -sed -i "s?threads.*?threads: 12?" configs/snakemake.yml 26 | -sed -i "s?nrows.*?#nrows: 2000?" configs/snakemake.yml 27 | -srun -A renard --mem=30G --cpus-per-task=12 --time=30-00:00:00 -p cauldron snakemake --restart-times 5 --keep-going --printshellcmds --reason --use-conda --cores 1 --resources mem_gb=2700 --jobs 6 --cluster "sbatch -A renard -p magic --mem=450G --cpus-per-task=12 --time=5-00:00:00 -C 'ARCH:X86'" 28 | -rm -f slurm-* 29 | 30 | .PHONY: clean 31 | clean: ## delete temporary files from Snakemake 32 | -rm -rf .snakemake 33 | 34 | .PHONY: delete-results 35 | delete-results: ## delete results from pipeline 36 | -rm -rf results 37 | 38 | .PHONY: git 39 | git: ## Update local repository 40 | -git reset --hard 41 | -git fetch 42 | -git pull 43 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. hiclass documentation master file, created by 2 | sphinx-quickstart on Tue Jul 20 12:37:27 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to hiclass' documentation! 7 | =================================== 8 | 9 | .. image:: https://github.com/scikit-learn-contrib/hiclass/actions/workflows/deploy-pypi.yml/badge.svg?event=push 10 | :target: https://github.com/scikit-learn-contrib/hiclass/actions/workflows/deploy-pypi.yml 11 | :alt: Deploy PyPI 12 | 13 | .. image:: https://readthedocs.org/projects/hiclass/badge/?version=latest 14 | :target: https://hiclass.readthedocs.io/en/latest/?badge=latest 15 | :alt: Documentation Status 16 | 17 | .. image:: https://codecov.io/gh/scikit-learn-contrib/hiclass/branch/main/graph/badge.svg?token=PR8VLBMMNR 18 | :target: https://codecov.io/gh/scikit-learn-contrib/hiclass 19 | :alt: codecov 20 | 21 | .. image:: https://static.pepy.tech/personalized-badge/hiclass?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=pypi 22 | :target: https://pypi.org/project/hiclass/ 23 | :alt: Downloads pypi 24 | 25 | .. image:: https://img.shields.io/conda/dn/conda-forge/hiclass?label=conda 26 | :target: https://anaconda.org/conda-forge/hiclass 27 | :alt: Downloads Conda 28 | 29 | .. image:: https://img.shields.io/badge/License-BSD_3--Clause-blue.svg 30 | :target: https://opensource.org/licenses/BSD-3-Clause 31 | :alt: License 32 | 33 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg 34 | :target: https://github.com/psf/black 35 | 36 | .. toctree:: 37 | :includehidden: 38 | :maxdepth: 3 39 | 40 | introduction/index 41 | get_started/index 42 | auto_examples/index 43 | algorithms/index 44 | api/index 45 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## General guidelines 2 | 3 | To contribute, fork the repository and send a pull request. 4 | 5 | When submitting code, please make every effort to follow existing conventions and style in order to keep the code as readable as possible. 6 | 7 | Where appropriate, please provide unit tests or integration tests. Unit tests should be pytest based tests and be added to /tests. 8 | 9 | Please make sure all tests pass before submitting a pull request. It is also good if you squash your commits and add the tags #major or #minor to the pull request title if need be, otherwise your pull request will be considered a patch bump. Please check [https://semver.org/](https://semver.org/) for more information about versioning. 10 | 11 | ## Testing the code locally 12 | 13 | To test the code locally you need to install the dependencies for the library in the current environment. Additionally, you need to install the dependencies for testing. All of those dependencies can be installed with: 14 | 15 | ``` 16 | pip install -e ".[dev]" 17 | ``` 18 | 19 | To run the tests simply execute: 20 | 21 | ``` 22 | pytest -v --cov=hiclass --cov-fail-under=90 --cov-report html 23 | ``` 24 | 25 | Lastly, you can set up the git hooks scripts to fix formatting errors locally during commits: 26 | 27 | ``` 28 | pre-commit install 29 | ``` 30 | 31 | If black is not executed locally and there are formatting errors, the CI/CD pipeline will fail. 32 | 33 | ## Building the documentation locally 34 | 35 | To build the documentation locally, you need to install another set of dependencies that are specific for the documentation. It is easier to create a separate conda environment and run the following command: 36 | 37 | ``` 38 | pip install -r docs/requirements.txt 39 | ``` 40 | 41 | To build the documentation you need to change to run the following commands: 42 | 43 | ``` 44 | cd docs 45 | make html 46 | ``` 47 | -------------------------------------------------------------------------------- /hiclass/_calibration/BetaCalibrator.py: -------------------------------------------------------------------------------- 1 | from hiclass._calibration.BinaryCalibrator import _BinaryCalibrator 2 | from sklearn.utils.validation import check_is_fitted 3 | import numpy as np 4 | from sklearn.linear_model import LogisticRegression 5 | 6 | 7 | class _BetaCalibrator(_BinaryCalibrator): 8 | name = "BetaCalibrator" 9 | 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self.skip_calibration = False 13 | 14 | def fit(self, y: np.ndarray, scores: np.ndarray, X: np.ndarray = None): 15 | unique_labels = len(np.unique(y)) 16 | if unique_labels < 2: 17 | self.skip_calibration = True 18 | self._is_fitted = True 19 | return self 20 | 21 | scores_1 = np.log(scores) 22 | # replace negative infinity with limit for log(n), n -> -inf 23 | replace_negative_inf = np.log(1e-300) 24 | scores_1 = np.nan_to_num(scores_1, neginf=replace_negative_inf) 25 | 26 | scores_2 = -np.log(1 - scores) 27 | # replace positive infinity with limit for log(n), n -> inf 28 | replace_positive_inf = np.log(1e300) 29 | scores_2 = np.nan_to_num(scores_2, posinf=replace_positive_inf) 30 | 31 | feature_matrix = np.column_stack((scores_1, scores_2)) 32 | 33 | lr = LogisticRegression() 34 | lr.fit(feature_matrix, y) 35 | self.a, self.b = lr.coef_.flatten() 36 | self.c = lr.intercept_[0] 37 | 38 | self._is_fitted = True 39 | return self 40 | 41 | def predict_proba(self, scores: np.ndarray, X: np.ndarray = None): 42 | check_is_fitted(self) 43 | if self.skip_calibration: 44 | return scores 45 | return 1 / ( 46 | 1 47 | + 1 48 | / ( 49 | np.exp(self.c) 50 | * (np.power(scores, self.a) / np.power((1 - scores), self.b)) 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/scripts/predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to predict with flat approach.""" 3 | import argparse 4 | import pickle 5 | import sys 6 | from argparse import Namespace 7 | 8 | import pandas as pd 9 | 10 | from data import load_dataframe, save_dataframe, unflatten_labels 11 | 12 | 13 | def parse_args(args: list) -> Namespace: 14 | """ 15 | Parse a list of arguments. 16 | 17 | Parameters 18 | ---------- 19 | args : list 20 | Arguments to parse. 21 | 22 | Returns 23 | ------- 24 | _ : Namespace 25 | Parsed arguments. 26 | """ 27 | parser = argparse.ArgumentParser(description="Predict") 28 | parser.add_argument( 29 | "--trained-model", 30 | type=str, 31 | required=True, 32 | help="Path to trained model", 33 | ) 34 | parser.add_argument( 35 | "--x-test", 36 | type=str, 37 | required=True, 38 | help="Input CSV file with test features", 39 | ) 40 | parser.add_argument( 41 | "--predictions", 42 | type=str, 43 | required=True, 44 | help="Output CSV file to write predictions", 45 | ) 46 | parser.add_argument( 47 | "--classifier", 48 | type=str, 49 | required=True, 50 | help="Algorithm used for predicting", 51 | ) 52 | return parser.parse_args(args) 53 | 54 | 55 | def main(): # pragma: no cover 56 | """Predict with flat approach.""" 57 | args = parse_args(sys.argv[1:]) 58 | classifier = pickle.load(open(args.trained_model, "rb")) 59 | x_test = load_dataframe(args.x_test).squeeze() 60 | predictions = classifier.predict(x_test) 61 | if args.classifier == "flat": 62 | predictions = unflatten_labels(predictions) 63 | else: 64 | predictions = pd.DataFrame(predictions) 65 | save_dataframe(predictions, args.predictions) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() # pragma: no cover 70 | -------------------------------------------------------------------------------- /docs/source/get_started/hierarchical_data.rst: -------------------------------------------------------------------------------- 1 | Hierarchical Data 2 | ================= 3 | 4 | Many datasets have labels in hierarchical structures, which means that they can be computationally represented as directed acyclic graphs or trees. Two notorious examples of hierarchical data are music genre and phylogeny, which are displayed in the figures below. 5 | 6 | .. figure:: music_genre.svg 7 | :align: center 8 | 9 | Music class hierarchy adapted from [1]_. 10 | 11 | .. figure:: phylogeny.svg 12 | :align: center 13 | 14 | Animal hierarchy adapted from [2]_. 15 | 16 | HiClass makes it simple to train local hierarchical classifiers. All it needs are hierarchical labels defined in a :math:`m \times n` matrix, where each row is a training example and each column is a level in the hierarchy. This matrix can be represented with Python lists, NumPy arrays or Pandas DataFrames. Training features need to be numerical, hence feature extraction might be necessary depending on the data. 17 | 18 | For this example we will define a short phylogeny tree, with the following numerical features and hierarchical labels: 19 | 20 | .. code-block:: python 21 | 22 | X_train = [[1], [2], [3], [4]] 23 | X_test = [[4], [3], [2], [1]] 24 | Y_train = [ 25 | ['Animal', 'Mammal', 'Sheep'], 26 | ['Animal', 'Mammal', 'Cow'], 27 | ['Animal', 'Reptile', 'Snake'], 28 | ['Animal', 'Reptile', 'Lizard'], 29 | ] 30 | 31 | Note that the order of the training and test features were reversed to make sure that the model actually works. 32 | 33 | .. [1] Burred, J. J., & Lerch, A. (2003, September). A hierarchical approach to automatic musical genre classification. In Proceedings of the 6th international conference on digital audio effects (pp. 8-11). 34 | 35 | .. [2] Barutcuoglu, Z., & DeCoro, C. (2006, June). Hierarchical shape classification using Bayesian aggregation. In IEEE International Conference on Shape Modeling and Applications 2006 (SMI'06) (pp. 44-44). IEEE. 36 | -------------------------------------------------------------------------------- /docs/examples/plot_parallel_training.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ===================== 4 | Parallel Training 5 | ===================== 6 | 7 | Larger datasets require more time for training. 8 | While by default the models in HiClass are trained using a single core, 9 | it is possible to train each local classifier in parallel by leveraging the library Ray [1]_. 10 | If Ray is not installed, the parallelism defaults to Joblib. 11 | In this example, we demonstrate how to train a hierarchical classifier in parallel by 12 | setting the parameter :literal:`n_jobs` to use all the cores available. Training 13 | is performed on a mock dataset from Kaggle [2]_. 14 | 15 | .. [1] https://www.ray.io/ 16 | .. [2] https://www.kaggle.com/datasets/kashnitsky/hierarchical-text-classification 17 | """ 18 | import sys 19 | from os import cpu_count 20 | from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer 21 | from sklearn.linear_model import LogisticRegression 22 | from sklearn.pipeline import Pipeline 23 | 24 | from hiclass import LocalClassifierPerParentNode 25 | from hiclass.datasets import load_hierarchical_text_classification 26 | 27 | # Load train and test splits 28 | X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification() 29 | 30 | # We will use logistic regression classifiers for every parent node 31 | lr = LogisticRegression(max_iter=1000) 32 | 33 | pipeline = Pipeline( 34 | [ 35 | ("count", CountVectorizer()), 36 | ("tfidf", TfidfTransformer()), 37 | ( 38 | "lcppn", 39 | LocalClassifierPerParentNode(local_classifier=lr, n_jobs=cpu_count()), 40 | ), 41 | ] 42 | ) 43 | 44 | # Fixes bug AttributeError: '_LoggingTee' object has no attribute 'fileno' 45 | # This only happens when building the documentation 46 | # Hence, you don't actually need it for your code to work 47 | sys.stdout.fileno = lambda: False 48 | 49 | # Now, let's train the local classifier per parent node 50 | pipeline.fit(X_train, Y_train) 51 | -------------------------------------------------------------------------------- /hiclass/probability_combiner/ArithmeticMeanCombiner.py: -------------------------------------------------------------------------------- 1 | """Defines the ArithmeticMeanCombiner.""" 2 | 3 | import numpy as np 4 | from hiclass.probability_combiner.ProbabilityCombiner import ProbabilityCombiner 5 | from typing import List 6 | 7 | 8 | class ArithmeticMeanCombiner(ProbabilityCombiner): 9 | """Combine probabilities of multiple levels by taking their arithmetic mean.""" 10 | 11 | def combine(self, proba: List[np.ndarray]): 12 | """Combine probabilities of each level with probabilities of previous levels. 13 | 14 | Calculate the arithmetic mean of node probabilities and the probabilities of its predecessors. 15 | """ 16 | res = [proba[0]] 17 | sums = [proba[0]] 18 | for level in range(1, self.classifier.max_levels_): 19 | level_probs = np.zeros_like(proba[level]) 20 | level_sum = np.zeros_like(proba[level]) 21 | # find all predecessors of a node 22 | predecessors = self._find_predecessors(level) 23 | 24 | for node in predecessors.keys(): 25 | index = self.classifier.class_to_index_mapping_[level][node] 26 | # find indices of all predecessors 27 | predecessor_indices = [ 28 | self.classifier.class_to_index_mapping_[level - 1][predecessor] 29 | for predecessor in predecessors[node] 30 | ] 31 | # combine probabilities of all predecessors 32 | predecessors_combined_prob = np.sum( 33 | [ 34 | sums[level - 1][:, pre_index] 35 | for pre_index in predecessor_indices 36 | ], 37 | axis=0, 38 | ) 39 | level_sum[:, index] += ( 40 | proba[level][:, index] + predecessors_combined_prob 41 | ) 42 | level_probs[:, index] = level_sum[:, index] / (level + 1) 43 | 44 | res.append(level_probs) 45 | sums.append(level_sum) 46 | return self._normalize(res) if self.normalize else res 47 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/scripts/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to compute metrics.""" 3 | import argparse 4 | import sys 5 | from argparse import Namespace 6 | from typing import TextIO 7 | 8 | import numpy as np 9 | from hiclass.metrics import f1 10 | 11 | from data import load_dataframe 12 | 13 | 14 | def parse_args(args: list) -> Namespace: 15 | """ 16 | Parse a list of arguments. 17 | 18 | Parameters 19 | ---------- 20 | args : list 21 | Arguments to parse. 22 | 23 | Returns 24 | ------- 25 | _ : Namespace 26 | Parsed arguments. 27 | """ 28 | parser = argparse.ArgumentParser(description="Compute metrics") 29 | parser.add_argument( 30 | "--predictions", 31 | type=str, 32 | required=True, 33 | help="Input TSV file with predictions", 34 | ) 35 | parser.add_argument( 36 | "--ground-truth", 37 | type=str, 38 | required=True, 39 | help="Input TSV file with ground truth", 40 | ) 41 | parser.add_argument( 42 | "--metrics", 43 | type=str, 44 | required=True, 45 | help="Output TSV file with computed metrics", 46 | ) 47 | return parser.parse_args(args) 48 | 49 | 50 | def compute_f1(y_true: np.ndarray, y_pred: np.ndarray, output: TextIO) -> None: 51 | """ 52 | Compute hierarchical f1 score. 53 | 54 | Parameters 55 | ---------- 56 | y_true : np.ndarray 57 | Expected output. 58 | y_pred : np.ndarray 59 | Predicted output. 60 | output : TextIO 61 | File where output will be written. 62 | """ 63 | output.write("f1_hierarchical\n") 64 | f1_hierarchical = f1(y_true, y_pred) 65 | output.write(f"{f1_hierarchical}\n") 66 | 67 | 68 | def main(): # pragma: no cover 69 | """Compute traditional and hierarchical metrics.""" 70 | args = parse_args(sys.argv[1:]) 71 | predictions = load_dataframe(args.predictions) 72 | ground_truth = load_dataframe(args.ground_truth) 73 | with open(args.metrics, "w") as output: 74 | compute_f1(ground_truth, predictions, output) 75 | 76 | 77 | if __name__ == "__main__": 78 | main() # pragma: no cover 79 | -------------------------------------------------------------------------------- /hiclass/probability_combiner/GeometricMeanCombiner.py: -------------------------------------------------------------------------------- 1 | """Defines the GeometricMeanCombiner.""" 2 | 3 | import numpy as np 4 | from hiclass.probability_combiner.ProbabilityCombiner import ProbabilityCombiner 5 | from typing import List 6 | 7 | 8 | class GeometricMeanCombiner(ProbabilityCombiner): 9 | """Combine probabilities of multiple levels by taking their geometric mean.""" 10 | 11 | def combine(self, proba: List[np.ndarray]): 12 | """Combine probabilities of each level with probabilities of previous levels. 13 | 14 | Calculate the geometric mean of node probabilities and the probabilities of its predecessors. 15 | """ 16 | res = [proba[0]] 17 | log_sum = [np.log(proba[0])] 18 | for level in range(1, self.classifier.max_levels_): 19 | level_probs = np.zeros_like(proba[level]) 20 | level_log_sum = np.zeros_like(proba[level]) 21 | # find all predecessors of a node 22 | predecessors = self._find_predecessors(level) 23 | 24 | for node in predecessors.keys(): 25 | index = self.classifier.class_to_index_mapping_[level][node] 26 | # find indices of all predecessors 27 | predecessor_indices = [ 28 | self.classifier.class_to_index_mapping_[level - 1][predecessor] 29 | for predecessor in predecessors[node] 30 | ] 31 | # combine probabilities of all predecessors 32 | predecessors_combined_log_prob = np.log( 33 | np.sum( 34 | [ 35 | np.exp(log_sum[level - 1][:, pre_index]) 36 | for pre_index in predecessor_indices 37 | ], 38 | axis=0, 39 | ) 40 | ) 41 | 42 | level_log_sum[:, index] += ( 43 | np.log(proba[level][:, index]) + predecessors_combined_log_prob 44 | ) 45 | level_probs[:, index] = np.exp(level_log_sum[:, index] / (level + 1)) 46 | 47 | log_sum.append(level_log_sum) 48 | res.append(level_probs) 49 | return self._normalize(res) if self.normalize else res 50 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/scripts/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to share common functions for data manipulation.""" 3 | import csv 4 | from typing import TextIO 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | def load_dataframe(path: TextIO) -> pd.DataFrame: 11 | """ 12 | Load a dataframe from a CSV file. 13 | 14 | Parameters 15 | ---------- 16 | path : TextIO 17 | Path to CSV file. 18 | 19 | Returns 20 | ------- 21 | df : pd.DataFrame 22 | Loaded dataframe. 23 | """ 24 | return pd.read_csv(path, compression="infer", header=0, sep=",", low_memory=False) 25 | 26 | 27 | def flatten_labels(y, separator: str = ":sep:"): 28 | """ 29 | Flatten hierarchical labels into a single column. 30 | 31 | Parameters 32 | ---------- 33 | y : pd.DataFrame 34 | hierarchical labels. 35 | separator : str, default=":sep:" 36 | Separator used to differentiate between columns. 37 | 38 | Returns 39 | ------- 40 | y : pd.Series 41 | Joined labels. 42 | """ 43 | y = y[y.columns].apply(lambda x: separator.join(x.dropna().astype(str)), axis=1) 44 | return y 45 | 46 | 47 | def unflatten_labels(y: np.ndarray, separator: str = ":sep:") -> pd.DataFrame: 48 | """ 49 | Separate flat labels back into hierarchical labels. 50 | 51 | Parameters 52 | ---------- 53 | y : np.ndarray 54 | Flat labels. 55 | separator : str, default=":sep:" 56 | Separator used to differentiate between columns. 57 | 58 | Returns 59 | ------- 60 | y : pd.DataFrame 61 | Hierarchical labels. 62 | """ 63 | y = pd.Series(y) 64 | y = y.str.split( 65 | separator, 66 | expand=True, 67 | ) 68 | return y 69 | 70 | 71 | def save_dataframe(dataframe: pd.DataFrame, file_path: TextIO) -> None: 72 | """ 73 | Save dataframe to CSV file. 74 | 75 | Parameters 76 | ---------- 77 | dataframe : pd.DataFrame 78 | Dataframe to save. 79 | file_path : TextIO 80 | Path to save dataframe. 81 | """ 82 | dataframe.to_csv( 83 | file_path, 84 | index=False, 85 | header=True, 86 | sep=",", 87 | compression="infer", 88 | quoting=csv.QUOTE_ALL, 89 | ) 90 | -------------------------------------------------------------------------------- /docs/source/api/utilities.rst: -------------------------------------------------------------------------------- 1 | Data Utilities 2 | ============== 3 | 4 | Binary Policies 5 | --------------- 6 | 7 | ExclusivePolicy 8 | ^^^^^^^^^^^^^^^ 9 | 10 | .. autoclass:: BinaryPolicy.ExclusivePolicy 11 | :members: 12 | :show-inheritance: 13 | :inherited-members: 14 | :special-members: __init__ 15 | 16 | .................................. 17 | 18 | LessExclusivePolicy 19 | ^^^^^^^^^^^^^^^^^^^ 20 | 21 | .. autoclass:: BinaryPolicy.LessExclusivePolicy 22 | :members: 23 | :show-inheritance: 24 | :inherited-members: 25 | :special-members: __init__ 26 | 27 | .................................. 28 | 29 | InclusivePolicy 30 | ^^^^^^^^^^^^^^^ 31 | .. autoclass:: BinaryPolicy.InclusivePolicy 32 | :members: 33 | :show-inheritance: 34 | :inherited-members: 35 | :special-members: __init__ 36 | 37 | .................................. 38 | 39 | LessInclusivePolicy 40 | ^^^^^^^^^^^^^^^^^^^ 41 | .. autoclass:: BinaryPolicy.LessInclusivePolicy 42 | :members: 43 | :show-inheritance: 44 | :inherited-members: 45 | :special-members: __init__ 46 | 47 | .................................. 48 | 49 | SiblingsPolicy 50 | ^^^^^^^^^^^^^^ 51 | .. autoclass:: BinaryPolicy.SiblingsPolicy 52 | :members: 53 | :show-inheritance: 54 | :inherited-members: 55 | :special-members: __init__ 56 | 57 | .................................. 58 | 59 | ExclusiveSiblingsPolicy 60 | ^^^^^^^^^^^^^^^^^^^^^^^ 61 | .. autoclass:: BinaryPolicy.ExclusiveSiblingsPolicy 62 | :members: 63 | :show-inheritance: 64 | :inherited-members: 65 | :special-members: __init__ 66 | 67 | .................................. 68 | 69 | 70 | Hierarchical Metrics 71 | -------------------- 72 | 73 | Precision 74 | ^^^^^^^^^ 75 | 76 | .. autofunction:: metrics.precision 77 | 78 | .................................. 79 | 80 | Recall 81 | ^^^^^^ 82 | 83 | .. autofunction:: metrics.recall 84 | 85 | .................................. 86 | 87 | F-score 88 | ^^^^^^^ 89 | 90 | .. autofunction:: metrics.f1 91 | 92 | .................................. 93 | 94 | 95 | Datasets 96 | ---------- 97 | 98 | Hierarchical text classification dataset 99 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 100 | 101 | .. autofunction:: datasets.load_hierarchical_text_classification 102 | 103 | .................................. 104 | -------------------------------------------------------------------------------- /hiclass/ConstantClassifier.py: -------------------------------------------------------------------------------- 1 | """Constant classifier if there is only one class in the training set.""" 2 | 3 | import numpy as np 4 | 5 | 6 | class ConstantClassifier: 7 | """A classifier that always returns the only label seen during fit.""" 8 | 9 | def fit(self, X, y, sample_weight=None): 10 | """ 11 | Fit a constant classifier. 12 | 13 | Parameters 14 | ---------- 15 | X : {array-like, sparse matrix} of shape (n_samples, n_features) 16 | The training input samples. Internally, its dtype will be converted 17 | to ``dtype=np.float32``. If a sparse matrix is provided, it will be 18 | converted into a sparse ``csc_matrix``. 19 | y : array-like of shape (n_samples, n_levels) 20 | The target values, i.e., hierarchical class labels for classification. 21 | sample_weight : array-like of shape (n_samples,), default=None 22 | Array of weights that are assigned to individual samples. 23 | If not provided, then each sample is given unit weight. 24 | 25 | Returns 26 | ------- 27 | self : object 28 | Fitted estimator. 29 | """ 30 | self.classes_ = np.unique(y) 31 | if len(self.classes_) != 1: 32 | raise ValueError( 33 | f"Labels should have only one class to fit, but instead found {len(self.classes_)}" 34 | ) 35 | return self 36 | 37 | def predict_proba(self, X: np.ndarray) -> np.ndarray: 38 | """ 39 | Predict classes for the given data. 40 | 41 | Parameters 42 | ---------- 43 | X : np.ndarray of shape(n_samples, ...) 44 | Data that should be predicted. Only the number of samples matters. 45 | 46 | Returns 47 | ------- 48 | output : np.ndarray 49 | 1 for the previously seen class. 50 | """ 51 | return np.vstack([1] * X.shape[0]) 52 | 53 | def predict(self, X: np.ndarray) -> np.ndarray: 54 | """ 55 | Predict classes for the given data. 56 | 57 | Parameters 58 | ---------- 59 | X : np.ndarray of shape(n_samples, ...) 60 | Data that should be predicted. Only the number of samples matters. 61 | 62 | Returns 63 | ------- 64 | output : np.ndarray 65 | 1 for the previously seen class. 66 | """ 67 | return np.vstack([self.classes_] * X.shape[0]) 68 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/tests/test_statistics.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import pandas as pd 4 | import pytest 5 | from pandas._testing import assert_frame_equal 6 | from pyfakefs.fake_filesystem_unittest import Patcher 7 | from scripts.statistics import parse_args 8 | 9 | from scripts.statistics import get_file_modification, create_dataframe, save_statistics 10 | 11 | 12 | def test_parser(): 13 | parser = parse_args( 14 | [ 15 | "--data", 16 | "complaints.csv.zip", 17 | "--x-train", 18 | "x_train.csv.zip", 19 | "--y-train", 20 | "y_train.csv.zip", 21 | "--x-test", 22 | "x_test.csv.zip", 23 | "--y-test", 24 | "y_test.csv.zip", 25 | "--statistics", 26 | "statistics.csv", 27 | ] 28 | ) 29 | assert parser.data is not None 30 | assert "complaints.csv.zip" == parser.data 31 | assert parser.x_train is not None 32 | assert "x_train.csv.zip" == parser.x_train 33 | assert parser.y_train is not None 34 | assert "y_train.csv.zip" == parser.y_train 35 | assert parser.x_test is not None 36 | assert "x_test.csv.zip" == parser.x_test 37 | assert parser.y_test is not None 38 | assert "y_test.csv.zip" == parser.y_test 39 | assert parser.statistics is not None 40 | assert "statistics.csv" == parser.statistics 41 | 42 | 43 | def test_get_file_modification(): 44 | with Patcher() as patcher: 45 | patcher.fs.create_file("complaints.csv.zip") 46 | assert date.today().strftime("%d/%m/%Y") == get_file_modification( 47 | "complaints.csv.zip" 48 | ) 49 | 50 | 51 | def test_create_dataframe(): 52 | statistics = create_dataframe("2021-01-01", 70, 30) 53 | assert (1, 3) == statistics.shape 54 | assert "2021-01-01" == statistics["Snapshot"].values[0] 55 | assert 70 == statistics["Training set size"].values[0] 56 | assert 30 == statistics["Test set size"].values[0] 57 | 58 | 59 | @pytest.fixture 60 | def statistics(): 61 | statistics = pd.DataFrame( 62 | { 63 | "Snapshot": ["02/11/2020"], 64 | "Training set size": [70], 65 | "Test set size": [30], 66 | } 67 | ) 68 | return statistics 69 | 70 | 71 | def test_save_statistics(statistics): 72 | with Patcher() as patcher: 73 | save_statistics(statistics, "statistics.csv") 74 | assert patcher.fs.exists("statistics.csv") 75 | result = pd.read_csv("statistics.csv") 76 | assert_frame_equal(statistics, result) 77 | -------------------------------------------------------------------------------- /docs/examples/plot_binary_policies.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | =========================== 4 | Binary Training Policies 5 | =========================== 6 | 7 | The siblings policy is used by default on the local classifier per node, but the remaining ones can be selected with the parameter :literal:`binary_policy`, for example: 8 | 9 | .. tabs:: 10 | 11 | .. code-tab:: python 12 | :caption: Exclusive 13 | 14 | rf = RandomForestClassifier() 15 | classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="exclusive") 16 | 17 | .. code-tab:: python 18 | :caption: Less exclusive 19 | 20 | rf = RandomForestClassifier() 21 | classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="less_exclusive") 22 | 23 | .. code-tab:: python 24 | :caption: Less inclusive 25 | 26 | rf = RandomForestClassifier() 27 | classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="less_inclusive") 28 | 29 | .. code-tab:: python 30 | :caption: Inclusive 31 | 32 | rf = RandomForestClassifier() 33 | classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="inclusive") 34 | 35 | .. code-tab:: python 36 | :caption: Siblings 37 | 38 | rf = RandomForestClassifier() 39 | classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="siblings") 40 | 41 | .. code-tab:: python 42 | :caption: Exclusive siblings 43 | 44 | rf = RandomForestClassifier() 45 | classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="exclusive_siblings") 46 | 47 | In the code below, the inclusive policy is selected. 48 | However, the code can be easily updated by replacing lines 20-21 with the examples shown in the tabs above. 49 | 50 | .. seealso:: 51 | 52 | Mathematical definition on the different policies is given at :ref:`Training Policies`. 53 | """ 54 | from sklearn.ensemble import RandomForestClassifier 55 | 56 | from hiclass import LocalClassifierPerNode 57 | 58 | # Define data 59 | X_train = [[1], [2], [3], [4]] 60 | X_test = [[4], [3], [2], [1]] 61 | Y_train = [ 62 | ["Animal", "Mammal", "Sheep"], 63 | ["Animal", "Mammal", "Cow"], 64 | ["Animal", "Reptile", "Snake"], 65 | ["Animal", "Reptile", "Lizard"], 66 | ] 67 | 68 | # Use random forest classifiers for every node 69 | # And inclusive policy to select training examples for binary classifiers. 70 | rf = RandomForestClassifier() 71 | classifier = LocalClassifierPerNode(local_classifier=rf, binary_policy="inclusive") 72 | 73 | # Train local classifier per node 74 | classifier.fit(X_train, Y_train) 75 | 76 | # Predict 77 | predictions = classifier.predict(X_test) 78 | print(predictions) 79 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/tests/test_train.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | 3 | import pytest 4 | from omegaconf import DictConfig 5 | from pyfakefs.fake_filesystem_unittest import Patcher 6 | from scripts.train import parse_args, load_parameters 7 | 8 | 9 | def test_parser(): 10 | parser = parse_args( 11 | [ 12 | "--n-jobs", 13 | "8", 14 | "--x-train", 15 | "x_train.csv.zip", 16 | "--y-train", 17 | "y_train.csv.zip", 18 | "--trained-model", 19 | "model.sav", 20 | "--classifier", 21 | "lightgbm", 22 | "--model", 23 | "flat", 24 | "--best-parameters", 25 | "best_parameters.yml", 26 | ] 27 | ) 28 | assert parser.n_jobs is not None 29 | assert 8 == parser.n_jobs 30 | assert parser.x_train is not None 31 | assert "x_train.csv.zip" == parser.x_train 32 | assert parser.y_train is not None 33 | assert "y_train.csv.zip" == parser.y_train 34 | assert parser.trained_model is not None 35 | assert "model.sav" == parser.trained_model 36 | assert parser.classifier is not None 37 | assert "lightgbm" == parser.classifier 38 | assert parser.model is not None 39 | assert "flat" == parser.model 40 | assert parser.best_parameters is not None 41 | assert "best_parameters.yml" == parser.best_parameters 42 | 43 | 44 | @pytest.fixture 45 | def tuned_parameters(): 46 | cfg = StringIO() 47 | cfg.write("name: optuna\n") 48 | cfg.write("best_params:\n") 49 | cfg.write(" C: 0.001\n") 50 | cfg.write(" class_weight: balanced\n") 51 | cfg.write(" dual: false\n") 52 | cfg.write(" fit_intercept: false\n") 53 | cfg.write(" intercept_scaling: 3\n") 54 | cfg.write(" max_iter: 100\n") 55 | cfg.write(" multi_class: auto\n") 56 | cfg.write(" penalty: l2\n") 57 | cfg.write(" solver: liblinear\n") 58 | cfg.write(" tol: 1.0e-06\n") 59 | cfg.write("best_value: 0.9387345438252911\n") 60 | cfg.seek(0) 61 | return cfg 62 | 63 | 64 | def test_load_parameters(tuned_parameters): 65 | expected = DictConfig( 66 | { 67 | "C": 0.001, 68 | "class_weight": "balanced", 69 | "dual": False, 70 | "fit_intercept": False, 71 | "intercept_scaling": 3, 72 | "max_iter": 100, 73 | "multi_class": "auto", 74 | "penalty": "l2", 75 | "solver": "liblinear", 76 | "tol": 1.0e-06, 77 | } 78 | ) 79 | with Patcher() as patcher: 80 | patcher.fs.create_file("best_parameters.yml", contents=tuned_parameters.read()) 81 | parameters = load_parameters("best_parameters.yml") 82 | assert expected == parameters 83 | -------------------------------------------------------------------------------- /tests/test_Datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import numpy as np 5 | import pytest 6 | from numpy.testing import assert_array_equal 7 | 8 | import hiclass.datasets 9 | from hiclass.datasets import load_hierarchical_text_classification 10 | 11 | 12 | def test_load_hierarchical_text_classification_shape(): 13 | X_train, X_test, y_train, y_test = load_hierarchical_text_classification( 14 | test_size=0.2, random_state=42 15 | ) 16 | assert X_train.shape[0] == y_train.shape[0] 17 | assert X_test.shape[0] == y_test.shape[0] 18 | 19 | 20 | def test_load_hierarchical_text_classification_random_state(): 21 | X_train_1, X_test_1, y_train_1, y_test_1 = load_hierarchical_text_classification( 22 | test_size=0.2, random_state=42 23 | ) 24 | X_train_2, X_test_2, y_train_2, y_test_2 = load_hierarchical_text_classification( 25 | test_size=0.2, random_state=42 26 | ) 27 | assert_array_equal(X_train_1, X_train_2) 28 | assert_array_equal(X_test_1, X_test_2) 29 | assert_array_equal(y_train_1, y_train_2) 30 | assert_array_equal(y_test_1, y_test_2) 31 | 32 | 33 | def test_load_hierarchical_text_classification_file_exists(): 34 | dataset_name = "hierarchical_text_classification.csv" 35 | cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) 36 | 37 | if os.path.exists(cached_file_path): 38 | os.remove(cached_file_path) 39 | 40 | if not os.path.exists(cached_file_path): 41 | load_hierarchical_text_classification() 42 | assert os.path.exists(cached_file_path) 43 | 44 | 45 | def test_download_dataset(): 46 | dataset_name = "hierarchical_text_classification.csv" 47 | url = hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL 48 | cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) 49 | 50 | if os.path.exists(cached_file_path): 51 | os.remove(cached_file_path) 52 | 53 | if not os.path.exists(cached_file_path): 54 | hiclass.datasets._download_file(url, cached_file_path) 55 | assert os.path.exists(cached_file_path) 56 | 57 | 58 | def test_download_error_load_hierarchical_text(): 59 | dataset_name = "hierarchical_text_classification.csv" 60 | backup_url = hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL 61 | hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = "" 62 | cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) 63 | 64 | if os.path.exists(cached_file_path): 65 | os.remove(cached_file_path) 66 | 67 | if not os.path.exists(cached_file_path): 68 | with pytest.raises(RuntimeError): 69 | load_hierarchical_text_classification() 70 | 71 | hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = backup_url 72 | 73 | 74 | def test_url_links(): 75 | assert hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL != "" 76 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # noqa 2 | # Configuration file for the Sphinx documentation builder. 3 | # 4 | # This file only contains a selection of the most common options. For a full 5 | # list see the documentation: 6 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 7 | 8 | # -- Path setup -------------------------------------------------------------- 9 | 10 | # If extensions (or modules to document with autodoc) are in another directory, 11 | # add these directories to sys.path here. If the directory is relative to the 12 | # documentation root, use os.path.abspath to make it absolute, like shown here. 13 | # 14 | import os 15 | import sys 16 | 17 | sys.path.insert(0, os.path.abspath("./../..")) 18 | sys.path.insert(0, os.path.abspath("./../../hiclass")) 19 | print(sys.path) 20 | 21 | import sphinx_code_tabs 22 | 23 | # -- Project information ----------------------------------------------------- 24 | 25 | project = "hiclass" 26 | copyright = "2024, Fabio Malcher Miranda, Niklas Köhnecke" 27 | author = "Fabio Malcher Miranda, Niklas Köhnecke" 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | "sphinx.ext.autodoc", 37 | "sphinx.ext.napoleon", 38 | "sphinx.ext.autosectionlabel", 39 | "sphinx_code_tabs", 40 | "sphinx_gallery.gen_gallery", 41 | ] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ["_templates"] 45 | 46 | # List of patterns, relative to source directory, that match files and 47 | # directories to ignore when looking for source files. 48 | # This pattern also affects html_static_path and html_extra_path. 49 | exclude_patterns = [] 50 | 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | 54 | # The theme to use for HTML and HTML Help pages. See the documentation for 55 | # a list of builtin themes. 56 | use_rtd_scheme = False 57 | try: 58 | import sphinx_rtd_theme 59 | 60 | extensions.extend(["sphinx_rtd_theme"]) 61 | use_rtd_scheme = True 62 | except ImportError: 63 | print("sphinx_rtd_theme was not installed, using alabaster as fallback!") 64 | 65 | html_theme = "sphinx_rtd_theme" if use_rtd_scheme else "alabaster" 66 | 67 | 68 | # Add any paths that contain custom static files (such as style sheets) here, 69 | # relative to this directory. They are copied after the builtin static files, 70 | # so a file named "default.css" will overwrite the builtin "default.css". 71 | html_static_path = [] 72 | 73 | # options 74 | 75 | html_theme_options = {} 76 | 77 | if not use_rtd_scheme: 78 | html_theme_options["sidebar_width"] = "230px" 79 | 80 | sphinx_gallery_conf = { 81 | "examples_dirs": "../examples", 82 | "gallery_dirs": "auto_examples", 83 | } 84 | -------------------------------------------------------------------------------- /docs/source/get_started/local_classifier.rst: -------------------------------------------------------------------------------- 1 | Local Hierarchical Classifier 2 | ============================= 3 | 4 | A :literal:`local hierarchical classifier` is a supervised machine learning model, where the output of the classification algorithm is defined over a pre-established hierarchical class taxonomy. In HiClass, there are 3 main approaches for local hierarchical classification, i.e., the most common design patterns for local hierarchical classification identified in the literature [1]_, which are the :ref:`local-classifier-per-node-overview`, :ref:`local-classifier-per-parent-node-overview` and :ref:`local-classifier-per-level-overview`. 5 | 6 | In this example, we will be using the :literal:`LocalClassifierPerNode` along with the :literal:`RandomForestClassifier` from scikit-learn, but you can click on the other tabs to see how the code changes for the :literal:`LocalClassifierPerParentNode` and :literal:`LocalClassifierPerLevel`: 7 | 8 | .. tabs:: 9 | 10 | .. code-tab:: python 11 | :caption: LocalClassifierPerNode 12 | 13 | from hiclass import LocalClassifierPerNode 14 | from sklearn.ensemble import RandomForestClassifier 15 | 16 | .. code-tab:: python 17 | :caption: LocalClassifierPerParentNode 18 | 19 | from hiclass import LocalClassifierPerParentNode 20 | from sklearn.ensemble import RandomForestClassifier 21 | 22 | .. code-tab:: python 23 | :caption: LocalClassifierPerLevel 24 | 25 | from hiclass import LocalClassifierPerLevel 26 | from sklearn.ensemble import RandomForestClassifier 27 | 28 | 29 | We will be using a :literal:`RandomForestClassifier` for each node in the :literal:`LocalClassifierPerNode`, except for the root node. This :literal:`LocalClassifierPerNode` model will have the same structure pre-defined in the hierarchical data used to train the model. This is how we create both objects: 30 | 31 | .. tabs:: 32 | 33 | .. code-tab:: python 34 | :caption: LocalClassifierPerNode 35 | 36 | rf = RandomForestClassifier() 37 | classifier = LocalClassifierPerNode(local_classifier=rf) 38 | 39 | .. code-tab:: python 40 | :caption: LocalClassifierPerParentNode 41 | 42 | rf = RandomForestClassifier() 43 | classifier = LocalClassifierPerParentNode(local_classifier=rf) 44 | 45 | .. code-tab:: python 46 | :caption: LocalClassifierPerLevel 47 | 48 | rf = RandomForestClassifier() 49 | classifier = LocalClassifierPerLevel(local_classifier=rf) 50 | 51 | .. note:: 52 | 53 | The :literal:`LocalClassifierPerParentNode` has a :literal:`RandomForestClassifier` for each parent node existing in the hierarchy, while the :literal:`LocalClassifierPerLevel` contains a :literal:`RandomForestClassifier` for each level in the training labels. More information on the nuances of the hierarchical classifiers is available at the section :ref:`Algorithms Overview`. 54 | 55 | .. [1] Silla, C. N., & Freitas, A. A. (2011). A survey of hierarchical classification across different application domains. Data Mining and Knowledge Discovery, 22(1), 31-72. 56 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/tests/test_split_data.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | from io import BytesIO 3 | 4 | import pandas as pd 5 | from pandas.testing import assert_frame_equal 6 | from pandas.testing import assert_series_equal 7 | 8 | from scripts.split_data import ( 9 | parse_args, 10 | get_nrows, 11 | load_data, 12 | split_data, 13 | ) 14 | 15 | 16 | def test_parser(): 17 | parser = parse_args( 18 | [ 19 | "--data", 20 | "complaints.csv.zip", 21 | "--x-train", 22 | "x_train.csv.zip", 23 | "--x-test", 24 | "x_test.csv.zip", 25 | "--y-train", 26 | "y_train.csv.zip", 27 | "--y-test", 28 | "y_test.csv.zip", 29 | "--random-state", 30 | "0", 31 | "--nrows", 32 | "1000", 33 | ] 34 | ) 35 | assert parser.data is not None 36 | assert "complaints.csv.zip" == parser.data 37 | assert parser.x_train is not None 38 | assert "x_train.csv.zip" == parser.x_train 39 | assert parser.x_test is not None 40 | assert "x_test.csv.zip" == parser.x_test 41 | assert parser.y_train is not None 42 | assert "y_train.csv.zip" == parser.y_train 43 | assert parser.y_test is not None 44 | assert "y_test.csv.zip" == parser.y_test 45 | assert parser.random_state is not None 46 | assert 0 == parser.random_state 47 | assert parser.nrows is not None 48 | assert "1000" == parser.nrows 49 | 50 | 51 | def test_get_nrows(): 52 | assert 1000 == get_nrows("1000") 53 | assert 1 == get_nrows("1") 54 | assert get_nrows("None") is None 55 | assert get_nrows("asdf") is None 56 | 57 | 58 | def test_load_data(): 59 | data = BytesIO() 60 | content = "Consumer complaint narrative,Product,Sub-product\n" 61 | content += ",Student loan,Private student loan\n" 62 | content += "Incorrect information on your report,,Private student loan\n" 63 | content += "Incorrect information on your report,Student loan,\n" 64 | content += ( 65 | "Incorrect information on your report,Student loan,Private student loan\n" 66 | ) 67 | with zipfile.ZipFile(data, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: 68 | zf.writestr("complaints.csv", content) 69 | x, y = load_data(data) 70 | ground_truth = pd.DataFrame( 71 | { 72 | "Consumer complaint narrative": ["Incorrect information on your report"], 73 | "Product": ["Student loan"], 74 | "Sub-product": ["Private student loan"], 75 | } 76 | ) 77 | assert_series_equal(ground_truth["Consumer complaint narrative"], x) 78 | assert_frame_equal(ground_truth[["Product", "Sub-product"]], y) 79 | 80 | 81 | def test_split_data(): 82 | x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 83 | y = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] 84 | random_state = 42 85 | x_train, x_test, y_train, y_test = split_data(x, y, random_state) 86 | assert [1, 8, 3, 10, 5, 4, 7] == x_train 87 | assert [9, 2, 6] == x_test 88 | assert ["a", "h", "c", "j", "e", "d", "g"] == y_train 89 | assert ["i", "b", "f"] == y_test 90 | -------------------------------------------------------------------------------- /hiclass/datasets.py: -------------------------------------------------------------------------------- 1 | """Datasets util for downloading and maintaining sample datasets.""" 2 | 3 | import csv 4 | import logging 5 | import os 6 | import tempfile 7 | 8 | import numpy as np 9 | import requests 10 | from sklearn.model_selection import train_test_split 11 | 12 | # Configure logging 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | # Use temp directory to store cached datasets 17 | CACHE_DIR = tempfile.gettempdir() 18 | 19 | # Ensure cache directory exists 20 | os.makedirs(CACHE_DIR, exist_ok=True) 21 | 22 | # Dataset urls 23 | HIERARCHICAL_TEXT_CLASSIFICATION_URL = ( 24 | "https://zenodo.org/record/6657410/files/train_40k.csv?download=1" 25 | ) 26 | 27 | 28 | def _download_file(url, destination): 29 | """Download file from given URL to specified destination.""" 30 | try: 31 | response = requests.get(url) 32 | # Raise HTTPError if response code is not OK 33 | response.raise_for_status() 34 | with open(destination, "wb") as f: 35 | f.write(response.content) 36 | except requests.RequestException as e: 37 | raise RuntimeError(f"Failed to download file from {url}: {str(e)}") 38 | 39 | 40 | def load_hierarchical_text_classification(test_size=0.3, random_state=42): 41 | """ 42 | Load hierarchical text classification dataset. 43 | 44 | Parameters 45 | ---------- 46 | test_size : float, default=0.3 47 | The proportion of the dataset to include in the test split. 48 | random_state : int or None, default=42 49 | Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls. 50 | 51 | Returns 52 | ------- 53 | list 54 | List containing train-test split of inputs. 55 | 56 | Raises 57 | ------ 58 | RuntimeError 59 | If failed to access or process the dataset. 60 | Examples 61 | -------- 62 | >>> from hiclass.datasets import load_hierarchical_text_classification 63 | >>> X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification() 64 | >>> X_train[:3] 65 | 38015 Nature's Way Selenium 66 | 2281 Music In Motion Developmental Mobile W Remote 67 | 36629 Twinings Ceylon Orange Pekoe Tea, Tea Bags, 20... 68 | Name: Title, dtype: object 69 | >>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape 70 | (28000,) (12000,) (28000, 3) (12000, 3) 71 | """ 72 | dataset_name = "hierarchical_text_classification.csv" 73 | cached_file_path = os.path.join(CACHE_DIR, dataset_name) 74 | 75 | # Check if the file exists in the cache 76 | if not os.path.exists(cached_file_path): 77 | try: 78 | logger.info("Downloading hierarchical text classification dataset..") 79 | _download_file(HIERARCHICAL_TEXT_CLASSIFICATION_URL, cached_file_path) 80 | except Exception as e: 81 | raise RuntimeError(f"Failed to access or download dataset: {str(e)}") 82 | 83 | data = [row for row in csv.reader(open(cached_file_path))] 84 | data.pop(0) 85 | data = np.array(data, dtype=object) 86 | X = data[:, 1] 87 | y = data[:, 7:] 88 | 89 | # Return tuple (X_train, X_test, y_train, y_test) 90 | return train_test_split(X, y, test_size=test_size, random_state=random_state) 91 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/tests/test_tune_table.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from io import StringIO 3 | 4 | import pytest 5 | from omegaconf import OmegaConf 6 | from pyfakefs.fake_filesystem_unittest import Patcher 7 | from scripts.tune_table import ( 8 | parse_args, 9 | compute, 10 | create_table, 11 | ) 12 | 13 | from scripts.tune import save_trial 14 | 15 | 16 | def test_parser(): 17 | parser = parse_args( 18 | [ 19 | "--folder", 20 | "folder", 21 | "--model", 22 | "flat", 23 | "--classifier", 24 | "lightgbm", 25 | "--output", 26 | "output.md", 27 | ] 28 | ) 29 | assert parser.folder is not None 30 | assert "folder" == parser.folder 31 | assert parser.model is not None 32 | assert "flat" == parser.model 33 | assert parser.classifier is not None 34 | assert "lightgbm" == parser.classifier 35 | assert parser.output is not None 36 | assert "output.md" == parser.output 37 | 38 | 39 | @pytest.fixture 40 | def lightgbm_config(): 41 | cfg = OmegaConf.create( 42 | { 43 | "model": "flat", 44 | "classifier": "lightgbm", 45 | "n_jobs": 12, 46 | "x_train": "x_train.csv", 47 | "y_train": "y_train.csv", 48 | "output_dir": ".", 49 | "mem_gb": 1, 50 | "n_splits": 2, 51 | "num_leaves": 100, 52 | "n_estimators": 200, 53 | "min_child_samples": 6, 54 | } 55 | ) 56 | return cfg 57 | 58 | 59 | def test_compute(lightgbm_config): 60 | expected_hyperparameters = { 61 | "num_leaves": 100, 62 | "n_estimators": 200, 63 | "min_child_samples": 6, 64 | } 65 | with Patcher(): 66 | save_trial(lightgbm_config, [1, 2, 3]) 67 | hyperparameters, scores, avg, std = compute(".") 68 | assert [expected_hyperparameters] == hyperparameters 69 | assert [[1, 2, 3]] == scores 70 | assert [2] == avg 71 | assert [0.816496580927726] == std 72 | 73 | 74 | @pytest.fixture 75 | def args(): 76 | return Namespace( 77 | output="output.md", 78 | model="flat", 79 | classifier="lightgbm", 80 | folder=".", 81 | ) 82 | 83 | 84 | @pytest.fixture 85 | def expected_content(): 86 | content = StringIO() 87 | content.write("# Model: flat\n") 88 | content.write("## Base classifier: lightgbm\n") 89 | content.write("|Parameters|Scores|Average|Standard deviation|\n") 90 | content.write("|----------|------|-------|------------------|\n") 91 | content.write( 92 | "|{'num_leaves': 100, 'n_estimators': 200, 'min_child_samples': 6}|[1, 2, 3]|2.000|0.816|\n" 93 | ) 94 | return content.getvalue() 95 | 96 | 97 | def test_create_table(lightgbm_config, args, expected_content): 98 | with Patcher() as patcher: 99 | save_trial(lightgbm_config, [1, 2, 3]) 100 | create_table(args) 101 | assert patcher.fs.exists("output.md") 102 | with open("output.md", "r") as f: 103 | content = f.read() 104 | print(expected_content) 105 | print(content) 106 | assert expected_content == content 107 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to train with flat or hierarchical approaches.""" 3 | import argparse 4 | import pickle 5 | import sys 6 | from argparse import Namespace 7 | 8 | from joblib import parallel_backend 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | from data import load_dataframe, flatten_labels 12 | from tune import configure_pipeline 13 | 14 | 15 | def parse_args(args: list) -> Namespace: 16 | """ 17 | Parse a list of arguments. 18 | 19 | Parameters 20 | ---------- 21 | args : list 22 | Arguments to parse. 23 | 24 | Returns 25 | ------- 26 | _ : Namespace 27 | Parsed arguments. 28 | """ 29 | parser = argparse.ArgumentParser(description="Train classifier") 30 | parser.add_argument( 31 | "--n-jobs", 32 | type=int, 33 | required=True, 34 | help="Number of jobs to run training in parallel", 35 | ) 36 | parser.add_argument( 37 | "--x-train", 38 | type=str, 39 | required=True, 40 | help="Input CSV file with training features", 41 | ) 42 | parser.add_argument( 43 | "--y-train", 44 | type=str, 45 | required=True, 46 | help="Input CSV file with training labels", 47 | ) 48 | parser.add_argument( 49 | "--trained-model", 50 | type=str, 51 | required=True, 52 | help="Path to store trained model", 53 | ) 54 | parser.add_argument( 55 | "--classifier", 56 | type=str, 57 | required=True, 58 | help="Algorithm used for fitting, e.g., logistic_regression or random_forest", 59 | ) 60 | parser.add_argument( 61 | "--model", 62 | type=str, 63 | required=True, 64 | help="Model used for training, e.g., flat, lcpl, lcpn or lcppn", 65 | ) 66 | parser.add_argument( 67 | "--best-parameters", 68 | type=str, 69 | required=True, 70 | help="Path to optuna's tuned parameters", 71 | ) 72 | return parser.parse_args(args) 73 | 74 | 75 | def load_parameters(yml: str) -> DictConfig: 76 | """ 77 | Load parameters from a YAML file. 78 | 79 | Parameters 80 | ---------- 81 | yml : str 82 | Path to YAML file containing tuned parameters. 83 | 84 | Returns 85 | ------- 86 | cfg : DictConfig 87 | Dictionary containing all configuration information. 88 | """ 89 | cfg = OmegaConf.load(yml) 90 | return cfg["best_params"] 91 | 92 | 93 | def train() -> None: # pragma: no cover 94 | """Train with flat or hierarchical approaches.""" 95 | args = parse_args(sys.argv[1:]) 96 | x_train = load_dataframe(args.x_train).squeeze() 97 | y_train = load_dataframe(args.y_train) 98 | if args.model == "flat": 99 | y_train = flatten_labels(y_train) 100 | best_params = load_parameters(args.best_parameters) 101 | best_params.model = args.model 102 | best_params.classifier = args.classifier 103 | best_params.n_jobs = args.n_jobs 104 | pipeline = configure_pipeline(best_params) 105 | with parallel_backend("threading", n_jobs=args.n_jobs): 106 | pipeline.fit(x_train, y_train) 107 | pickle.dump(pipeline, open(args.trained_model, "wb")) 108 | 109 | 110 | if __name__ == "__main__": 111 | train() # pragma: no cover 112 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/tests/test_data.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from pandas._testing import assert_series_equal 6 | from pandas.testing import assert_frame_equal 7 | 8 | from scripts.data import ( 9 | load_dataframe, 10 | save_dataframe, 11 | flatten_labels, 12 | unflatten_labels, 13 | ) 14 | 15 | 16 | def test_load_dataframe(): 17 | data = StringIO() 18 | data.write("a,b\n") 19 | data.write("1,2\n") 20 | data.write("3,4\n") 21 | data.seek(0) 22 | ground_truth = pd.DataFrame({"a": [1, 3], "b": [2, 4]}) 23 | metadata = load_dataframe(data) 24 | assert_frame_equal(ground_truth, metadata) 25 | 26 | 27 | def test_flatten_labels_1(): 28 | y = pd.DataFrame( 29 | { 30 | "Product": ["Debt collection", "Checking or savings account"], 31 | "Sub-product": ["I do not know", "Checking account"], 32 | } 33 | ) 34 | flat_y = flatten_labels(y) 35 | ground_truth = pd.Series( 36 | [ 37 | "Debt collection:sep:I do not know", 38 | "Checking or savings account:sep:Checking account", 39 | ] 40 | ) 41 | assert_series_equal(ground_truth, flat_y) 42 | 43 | 44 | def test_flatten_labels_2(): 45 | y = pd.DataFrame( 46 | { 47 | "Product": ["Debt collection", "Checking or savings account"], 48 | "Sub-product": ["I do not know", "Checking account"], 49 | } 50 | ) 51 | separator = "," 52 | flat_y = flatten_labels(y, separator) 53 | ground_truth = pd.Series( 54 | [ 55 | "Debt collection,I do not know", 56 | "Checking or savings account,Checking account", 57 | ] 58 | ) 59 | assert_series_equal(ground_truth, flat_y) 60 | 61 | 62 | def test_unflatten_labels_1(): 63 | y = np.array( 64 | [ 65 | "Debt collection:sep:I do not know", 66 | "Checking or savings account:sep:Checking account", 67 | ] 68 | ) 69 | y = unflatten_labels(y) 70 | ground_truth = pd.DataFrame( 71 | [ 72 | ["Debt collection", "I do not know"], 73 | ["Checking or savings account", "Checking account"], 74 | ] 75 | ) 76 | assert_frame_equal(ground_truth, y) 77 | 78 | 79 | def test_separate_2(): 80 | y = np.array( 81 | [ 82 | "Debt collection/I do not know", 83 | "Checking or savings account/Checking account", 84 | ] 85 | ) 86 | separator = "/" 87 | y = unflatten_labels(y, separator) 88 | ground_truth = pd.DataFrame( 89 | [ 90 | ["Debt collection", "I do not know"], 91 | ["Checking or savings account", "Checking account"], 92 | ] 93 | ) 94 | assert_frame_equal(ground_truth, y) 95 | 96 | 97 | def test_save_dataframe(): 98 | ground_truth = '"Narrative","Product","Sub-product"\n' 99 | ground_truth += '"Incorrect information","Student loan","Private student loan"\n' 100 | data = pd.DataFrame( 101 | { 102 | "Narrative": ["Incorrect information"], 103 | "Product": ["Student loan"], 104 | "Sub-product": ["Private student loan"], 105 | } 106 | ) 107 | output = StringIO() 108 | save_dataframe(data, output) 109 | output.seek(0) 110 | assert ground_truth == output.read() 111 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/scripts/tune_table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to create table with tuning results for flat and hierarchical approaches.""" 3 | import argparse 4 | import glob 5 | import pickle 6 | import sys 7 | from argparse import Namespace 8 | from typing import Tuple, List 9 | 10 | import numpy as np 11 | 12 | 13 | def parse_args(args: list) -> Namespace: 14 | """ 15 | Parse a list of arguments. 16 | 17 | Parameters 18 | ---------- 19 | args : list 20 | Arguments to parse. 21 | 22 | Returns 23 | ------- 24 | _ : Namespace 25 | Parsed arguments. 26 | """ 27 | parser = argparse.ArgumentParser( 28 | description="Create table with hyper-parameter tuning results" 29 | ) 30 | parser.add_argument( 31 | "--folder", 32 | type=str, 33 | required=True, 34 | help="Folder where the tuning results are stored", 35 | ) 36 | parser.add_argument( 37 | "--model", 38 | type=str, 39 | required=True, 40 | help="Model used for tuning", 41 | ) 42 | parser.add_argument( 43 | "--classifier", 44 | type=str, 45 | required=True, 46 | help="Classifier used for tuning", 47 | ) 48 | parser.add_argument( 49 | "--output", 50 | type=str, 51 | required=True, 52 | help="Output to write the table in markdown format (.md)", 53 | ) 54 | return parser.parse_args(args) 55 | 56 | 57 | def compute( 58 | folder: str, 59 | ) -> Tuple[List[dict], List[list], List[np.ndarray], List[np.ndarray]]: 60 | """ 61 | Compute average and standard deviation of the tuning results. 62 | 63 | Parameters 64 | ---------- 65 | folder : str 66 | Folder where the tuning results are stored. 67 | 68 | Returns 69 | ------- 70 | hyperparameters : List[dict] 71 | Hyperparameters tested for tuning. 72 | scores : List[list] 73 | Scores for each hyperparameter combination tested. 74 | avg : List[np.ndarray] 75 | Averages of k-fold cross-validation. 76 | std : List[np.ndarray] 77 | Standard deviations of k-fold cross-validation. 78 | """ 79 | results = glob.glob(f"{folder}/*.sav") 80 | if "{}/trained_model.sav".format(folder) in results: 81 | results.remove(f"{folder}/trained_model.sav") 82 | hyperparameters = [] 83 | scores = [] 84 | avg = [] 85 | std = [] 86 | for result in results: 87 | parameters, s = pickle.load(open(result, "rb")) 88 | hyperparameters.append(parameters) 89 | scores.append([round(i, 3) for i in s]) 90 | avg.append(np.mean(s)) 91 | std.append(np.std(s)) 92 | return hyperparameters, scores, avg, std 93 | 94 | 95 | def create_table(args): 96 | """Create table with tuning results for flat and hierarchical approaches.""" 97 | with open(args.output, "w") as fout: 98 | fout.write(f"# Model: {args.model}\n") 99 | fout.write(f"## Base classifier: {args.classifier}\n") 100 | fout.write("|Parameters|Scores|Average|Standard deviation|\n") 101 | fout.write("|----------|------|-------|------------------|\n") 102 | hyperparameters, scores, avg, std = compute(args.folder) 103 | for hp, sc, av, st in zip(hyperparameters, scores, avg, std): 104 | fout.write(f"|{hp}|{sc}|{av:.3f}|{st:.3f}|\n") 105 | 106 | 107 | if __name__ == "__main__": # pragma: no cover 108 | args = parse_args(sys.argv[1:]) 109 | create_table(args) 110 | -------------------------------------------------------------------------------- /hiclass/FlatClassifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Flat classifier approach, used for comparison purposes. 3 | 4 | Implementation by @lpfgarcia 5 | """ 6 | 7 | import numpy as np 8 | from sklearn.base import BaseEstimator 9 | from sklearn.linear_model import LogisticRegression 10 | from sklearn.utils.validation import check_is_fitted 11 | 12 | 13 | class FlatClassifier(BaseEstimator): 14 | """ 15 | A flat classifier utility that accepts as input a hierarchy and flattens it internally. 16 | 17 | Examples 18 | -------- 19 | >>> from hiclass import FlatClassifier 20 | >>> y = [['1', '1.1'], ['2', '2.1']] 21 | >>> X = [[1, 2], [3, 4]] 22 | >>> flat = FlatClassifier() 23 | >>> flat.fit(X, y) 24 | >>> flat.predict(X) 25 | array([['1', '1.1'], 26 | ['2', '2.1']]) 27 | """ 28 | 29 | def __init__( 30 | self, 31 | local_classifier: BaseEstimator = LogisticRegression(), 32 | ): 33 | """ 34 | Initialize a flat classifier. 35 | 36 | Parameters 37 | ---------- 38 | local_classifier : BaseEstimator, default=LogisticRegression 39 | The scikit-learn model used for the flat classification. Needs to have fit, predict and clone methods. 40 | """ 41 | self.local_classifier = local_classifier 42 | 43 | def fit(self, X, y, sample_weight=None): 44 | """ 45 | Fit a flat classifier. 46 | 47 | Parameters 48 | ---------- 49 | X : {array-like, sparse matrix} of shape (n_samples, n_features) 50 | The training input samples. Internally, its dtype will be converted 51 | to ``dtype=np.float32``. If a sparse matrix is provided, it will be 52 | converted into a sparse ``csc_matrix``. 53 | y : array-like of shape (n_samples, n_levels) 54 | The target values, i.e., hierarchical class labels for classification. 55 | sample_weight : array-like of shape (n_samples,), default=None 56 | Array of weights that are assigned to individual samples. 57 | If not provided, then each sample is given unit weight. 58 | 59 | Returns 60 | ------- 61 | self : object 62 | Fitted estimator. 63 | """ 64 | # Convert from hierarchical labels to flat labels 65 | self.separator_ = "::HiClass::Separator::" 66 | y = [self.separator_.join(i) for i in y] 67 | 68 | # Fit flat classifier 69 | self.local_classifier.fit(X, y, sample_weight=sample_weight) 70 | 71 | # Return the classifier 72 | return self 73 | 74 | def predict(self, X): 75 | """ 76 | Predict classes for the given data. 77 | 78 | Hierarchical labels are returned. 79 | 80 | Parameters 81 | ---------- 82 | X : {array-like, sparse matrix} of shape (n_samples, n_features) 83 | The input samples. Internally, its dtype will be converted 84 | to ``dtype=np.float32``. If a sparse matrix is provided, it will be 85 | converted into a sparse ``csr_matrix``. 86 | Returns 87 | ------- 88 | y : ndarray of shape (n_samples,) or (n_samples, n_outputs) 89 | The predicted classes. 90 | """ 91 | # Check if fit has been called 92 | check_is_fitted(self) 93 | 94 | # Predict and remove separator 95 | predictions = [ 96 | i.split(self.separator_) for i in self.local_classifier.predict(X) 97 | ] 98 | 99 | return np.array(predictions) 100 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/scripts/statistics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to store basic statistical information.""" 3 | import argparse 4 | import os 5 | import sys 6 | from argparse import Namespace 7 | from datetime import datetime 8 | 9 | import pandas as pd 10 | 11 | from data import load_dataframe 12 | 13 | 14 | def parse_args(args: list) -> Namespace: 15 | """ 16 | Parse a list of arguments. 17 | 18 | Parameters 19 | ---------- 20 | args : list 21 | Arguments to parse. 22 | 23 | Returns 24 | ------- 25 | _ : Namespace 26 | Parsed arguments. 27 | """ 28 | parser = argparse.ArgumentParser(description="Train classifier") 29 | parser.add_argument( 30 | "--data", 31 | type=str, 32 | required=True, 33 | help="Full dataset for timestamp extraction", 34 | ) 35 | parser.add_argument( 36 | "--x-train", 37 | type=str, 38 | required=True, 39 | help="Input CSV file with training features", 40 | ) 41 | parser.add_argument( 42 | "--y-train", 43 | type=str, 44 | required=True, 45 | help="Input CSV file with training labels", 46 | ) 47 | parser.add_argument( 48 | "--x-test", 49 | type=str, 50 | required=True, 51 | help="Input CSV file with testing features", 52 | ) 53 | parser.add_argument( 54 | "--y-test", 55 | type=str, 56 | required=True, 57 | help="Input CSV file with testing labels", 58 | ) 59 | parser.add_argument( 60 | "--statistics", 61 | type=str, 62 | required=True, 63 | help="Path to store statistics in CSV format", 64 | ) 65 | return parser.parse_args(args) 66 | 67 | 68 | def get_file_modification(file_path: str) -> str: 69 | """ 70 | Get the modification date of a file. 71 | 72 | Parameters 73 | ---------- 74 | file_path : str 75 | Path to file. 76 | 77 | Returns 78 | ------- 79 | _ : str 80 | Modification date. 81 | """ 82 | date = datetime.fromtimestamp(os.path.getmtime(file_path)).strftime("%d/%m/%Y") 83 | return date 84 | 85 | 86 | def create_dataframe(snapshot: str, x_train: int, x_test: int) -> pd.DataFrame: 87 | """ 88 | Create dataframe with statistics. 89 | 90 | Parameters 91 | ---------- 92 | snapshot : str 93 | Snapshot date. 94 | x_train : int 95 | Number of training examples. 96 | x_test : int 97 | Number of testing examples. 98 | 99 | Returns 100 | ------- 101 | statistics : pd.DataFrame 102 | Basic statistics. 103 | """ 104 | statistics = pd.DataFrame( 105 | { 106 | "Snapshot": [snapshot], 107 | "Training set size": [x_train], 108 | "Test set size": [x_test], 109 | } 110 | ) 111 | return statistics 112 | 113 | 114 | def save_statistics(stats: pd.DataFrame, output_path: str) -> None: 115 | """ 116 | Save statistics to CSV file. 117 | 118 | Parameters 119 | ---------- 120 | stats : pd.DataFrame 121 | Basic statistics. 122 | output_path : str 123 | Path to store statistics. 124 | """ 125 | stats.to_csv(output_path, index=False) 126 | 127 | 128 | if __name__ == "__main__": # pragma: no cover 129 | args = parse_args(sys.argv[1:]) 130 | snapshot = get_file_modification(args.data) 131 | x_train = load_dataframe(args.x_train).squeeze().shape[0] 132 | y_train = load_dataframe(args.y_train).shape[0] 133 | x_test = load_dataframe(args.x_test).squeeze().shape[0] 134 | y_test = load_dataframe(args.y_test).shape[0] 135 | assert x_train == y_train 136 | assert x_test == y_test 137 | statistics = create_dataframe(snapshot, x_train, x_test) 138 | save_statistics(statistics, args.statistics) 139 | -------------------------------------------------------------------------------- /tests/test_LocalClassifiers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import pytest 6 | from numpy.testing import assert_array_equal 7 | from pyfakefs.fake_filesystem_unittest import Patcher 8 | from sklearn.linear_model import LogisticRegression 9 | from sklearn.neighbors import KNeighborsClassifier 10 | from sklearn.utils.validation import check_is_fitted 11 | 12 | from hiclass import ( 13 | LocalClassifierPerLevel, 14 | LocalClassifierPerNode, 15 | LocalClassifierPerParentNode, 16 | ) 17 | from hiclass.ConstantClassifier import ConstantClassifier 18 | 19 | classifiers = [ 20 | LocalClassifierPerLevel, 21 | LocalClassifierPerParentNode, 22 | LocalClassifierPerNode, 23 | ] 24 | 25 | 26 | @pytest.mark.parametrize("classifier", classifiers) 27 | def test_fit_1_class(classifier): 28 | clf = classifier(local_classifier=LogisticRegression(), n_jobs=2) 29 | y = np.array([["1", "2"]]) 30 | X = np.array([[1, 2]]) 31 | ground_truth = np.array([["1", "2"]]) 32 | clf.fit(X, y) 33 | prediction = clf.predict(X) 34 | assert_array_equal(ground_truth, prediction) 35 | 36 | 37 | @pytest.fixture 38 | def empty_levels(): 39 | X = [ 40 | [1], 41 | [2], 42 | [3], 43 | ] 44 | y = np.array( 45 | [ 46 | ["1"], 47 | ["2", "2.1"], 48 | ["3", "3.1", "3.1.2"], 49 | ], 50 | dtype=object, 51 | ) 52 | return X, y 53 | 54 | 55 | @pytest.mark.parametrize("classifier", classifiers) 56 | def test_empty_levels(empty_levels, classifier): 57 | clf = classifier() 58 | X, y = empty_levels 59 | clf.fit(X, y) 60 | predictions = clf.predict(X) 61 | ground_truth = [ 62 | ["1", "", ""], 63 | ["2", "2.1", ""], 64 | ["3", "3.1", "3.1.2"], 65 | ] 66 | assert list(clf.hierarchy_.nodes) == [ 67 | "1", 68 | "2", 69 | "2" + clf.separator_ + "2.1", 70 | "3", 71 | "3" + clf.separator_ + "3.1", 72 | "3" + clf.separator_ + "3.1" + clf.separator_ + "3.1.2", 73 | clf.root_, 74 | ] 75 | assert_array_equal(ground_truth, predictions) 76 | 77 | 78 | @pytest.mark.parametrize("classifier", classifiers) 79 | def test_knn(classifier): 80 | knn = KNeighborsClassifier( 81 | n_neighbors=2, 82 | ) 83 | clf = classifier( 84 | local_classifier=knn, 85 | ) 86 | y = np.array([["a", "b"], ["a", "c"]]) 87 | X = np.array([[1, 2], [3, 4]]) 88 | clf.fit(X, y) 89 | check_is_fitted(clf) 90 | # predictions = lcpn.predict(X) 91 | # assert_array_equal(y, predictions) 92 | 93 | 94 | @pytest.mark.parametrize("classifier", classifiers) 95 | def test_fit_multiple_dim_input(classifier): 96 | clf = classifier() 97 | X = np.random.rand(1, 275, 3) 98 | y = np.array([["a", "b", "c"]]) 99 | clf.fit(X, y) 100 | check_is_fitted(clf) 101 | 102 | 103 | @pytest.mark.parametrize("classifier", classifiers) 104 | def test_predict_multiple_dim_input(classifier): 105 | clf = classifier() 106 | X = np.random.rand(1, 275, 3) 107 | y = np.array([["a", "b", "c"]]) 108 | clf.fit(X, y) 109 | predictions = clf.predict(X) 110 | assert predictions is not None 111 | 112 | 113 | @pytest.mark.parametrize("classifier", classifiers) 114 | def test_tmp_dir(classifier): 115 | clf = classifier(tmp_dir=".") 116 | with Patcher() as patcher: 117 | x = np.array([[1, 2], [3, 4]]) 118 | y = np.array([["a", "b"], ["c", "d"]]) 119 | clf.fit(x, y) 120 | if isinstance(clf, LocalClassifierPerLevel): 121 | filename = "cfcd208495d565ef66e7dff9f98764da.sav" 122 | expected_name = 0 123 | else: 124 | filename = "0cc175b9c0f1b6a831c399e269772661.sav" 125 | expected_name = "a" 126 | assert patcher.fs.exists(filename) 127 | (name, classifier) = pickle.load(open(filename, "rb")) 128 | assert expected_name == name 129 | check_is_fitted(classifier) 130 | clf.fit(x, y) 131 | -------------------------------------------------------------------------------- /docs/source/algorithms/training_policies.rst: -------------------------------------------------------------------------------- 1 | Training Policies 2 | ================= 3 | 4 | There are multiple ways to define the set of positive and negative examples for training the binary classifiers. In HiClass we implemented 6 policies described at [1]_, which were based on previous work from [2]_ and [3]_. In the table below the notation used to define the sets of positive and negative examples is presented, as described by [1]_. 5 | 6 | ============================= =============================================================== 7 | **Symbol** **Meaning** 8 | ----------------------------- --------------------------------------------------------------- 9 | :math:`Tr` The set of all training examples 10 | :math:`Tr^+(c_i)` The set of positive training examples of :math:`c_i` 11 | :math:`Tr^-(c_i)` The set of negative training examples of :math:`c_i` 12 | :math:`\uparrow (c_i)` The parent category of :math:`c_i` 13 | :math:`\downarrow (c_i)` The set of children categories of :math:`c_i` 14 | :math:`\Uparrow (c_i)` The set of ancestor categories of :math:`c_i` 15 | :math:`\Downarrow (c_i)` The set of descendant categories of :math:`c_i` 16 | :math:`\leftrightarrow (c_i)` The set of sibling categories of :math:`c_i` 17 | :math:`*(c_i)` Denotes examples whose most specific known class is :math:`c_i` 18 | ============================= =============================================================== 19 | 20 | Based on this notation, we can define the different policies and their sets of positive and negative examples as follows: 21 | 22 | ====================== ================================================ ============================================================= 23 | **Policy** **Positive examples** **Negative examples** 24 | ---------------------- ------------------------------------------------ ------------------------------------------------------------- 25 | **Exclusive** :math:`Tr^+(c_i) = *(c_i)` :math:`Tr^-(c_i) = Tr \setminus *(c_i)` 26 | **Less exclusive** :math:`Tr^+(c_i) = *(c_i)` :math:`Tr^-(c_i) = Tr \setminus *(c_i) \cup \Downarrow (c_i)` 27 | **Less inclusive** :math:`Tr^+(c_i) = *(c_i) \cup \Downarrow (c_i)` :math:`Tr^-(c_i) = Tr \setminus *(c_i) \cup \Downarrow (c_i)` 28 | **Inclusive** :math:`Tr^+(c_i) = *(c_i) \cup \Downarrow (c_i)` :math:`Tr^-(c_i) = Tr \setminus *(c_i) \cup \Downarrow (c_i) \cup \Uparrow (c_i)` 29 | **Siblings** :math:`Tr^+(c_i) = *(c_i) \cup \Downarrow (c_i)` :math:`Tr^-(c_i) = \leftrightarrow (c_i) \cup \Downarrow (\leftrightarrow (c_i))` 30 | **Exclusive siblings** :math:`Tr^+(c_i) = *(c_i)` :math:`Tr^-(c_i) = \leftrightarrow (c_i)` 31 | ====================== ================================================ ============================================================= 32 | 33 | Using as example the class "Wolf" from the hierarchy represented in the image below, we have the following sets of positive and negative examples for each policy: 34 | 35 | .. figure:: local_classifier_per_node.svg 36 | :align: center 37 | 38 | Visual representation of the local classifier per node approach. 39 | 40 | ====================== ====================== =============================================== 41 | **Policy** :math:`Tr^+(c_{Wolf})` :math:`Tr^-(c_{Wolf})` 42 | ---------------------- ---------------------- ----------------------------------------------- 43 | **Exclusive** Wolf Reptile, Snake, Lizard, Mammal, Cat, Dog 44 | **Less exclusive** Wolf Reptile, Snake, Lizard, Mammal, Cat 45 | **Less inclusive** Wolf, Dog Reptile, Snake, Lizard, Mammal, Cat 46 | **Inclusive** Wolf, Dog Reptile, Snake, Lizard, Cat 47 | **Siblings** Wolf, Dog Cat 48 | **Exclusive siblings** Wolf Cat 49 | ====================== ====================== =============================================== 50 | 51 | .. seealso:: 52 | 53 | In terms of code, we explain how to select those different policies here: :ref:`Binary Training Policies`. 54 | 55 | .. [1] Silla, C. N., & Freitas, A. A. (2011). A survey of hierarchical classification across different application domains. Data Mining and Knowledge Discovery, 22(1), 31-72. 56 | 57 | .. [2] Eisner, R., Poulin, B., Szafron, D., Lu, P., & Greiner, R. (2005, November). Improving protein function prediction using the hierarchical structure of the gene ontology. In 2005 IEEE symposium on computational intelligence in bioinformatics and computational biology (pp. 1-10). IEEE. 58 | 59 | .. [3] Fagni, T., & Sebastiani, F. (2007, October). On the selection of negative examples for hierarchical text categorization. In Proceedings of the 3rd language technology conference (pp. 24-28). 60 | -------------------------------------------------------------------------------- /benchmarks/consumer_complaints/scripts/split_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to split train and test data.""" 3 | import argparse 4 | import sys 5 | from argparse import Namespace 6 | from io import BytesIO 7 | from typing import Union 8 | 9 | import pandas as pd 10 | from sklearn.model_selection import train_test_split 11 | 12 | from data import save_dataframe 13 | 14 | 15 | def parse_args(args: list) -> Namespace: 16 | """ 17 | Parse a list of arguments. 18 | 19 | Parameters 20 | ---------- 21 | args : list 22 | Arguments to parse. 23 | 24 | Returns 25 | ------- 26 | _ : Namespace 27 | Parsed arguments. 28 | """ 29 | parser = argparse.ArgumentParser(description="Split data into train and test sets.") 30 | parser.add_argument( 31 | "--data", 32 | type=str, 33 | required=True, 34 | help="Input CSV file containing consumer complaints", 35 | ) 36 | parser.add_argument( 37 | "--x-train", 38 | type=str, 39 | required=True, 40 | help="Output CSV file to write training features", 41 | ) 42 | parser.add_argument( 43 | "--x-test", 44 | type=str, 45 | required=True, 46 | help="Output CSV file to write test features", 47 | ) 48 | parser.add_argument( 49 | "--y-train", 50 | type=str, 51 | required=True, 52 | help="Output CSV file to write training labels", 53 | ) 54 | parser.add_argument( 55 | "--y-test", 56 | type=str, 57 | required=True, 58 | help="Output CSV file to write test labels", 59 | ) 60 | parser.add_argument( 61 | "--random-state", 62 | type=int, 63 | required=True, 64 | help="Random state to enable reproducibility", 65 | ) 66 | parser.add_argument( 67 | "--nrows", 68 | type=str, 69 | required=True, 70 | help="Number of rows to read from CSV file", 71 | ) 72 | return parser.parse_args(args) 73 | 74 | 75 | def get_nrows(nrows: str): 76 | """ 77 | Convert a nrows string either to integer or None. 78 | 79 | Parameters 80 | ---------- 81 | nrows : str 82 | String with number of rows or 'None'. 83 | 84 | Returns 85 | ------- 86 | nrows : Union[int, None] 87 | Number of rows as int or None if conversion fails. 88 | """ 89 | try: 90 | return int(nrows) 91 | except ValueError: 92 | return None 93 | 94 | 95 | def load_data(file_path: Union[str, BytesIO], nrows: int = None) -> tuple: 96 | """ 97 | Load data for training and test. 98 | 99 | Parameters 100 | ---------- 101 | file_path : Union[str, BytesIO] 102 | Path for zipped CSV file with consumer complaints. 103 | nrows : int, default=None 104 | Number of rows to read from CSV file. 105 | 106 | Returns 107 | ------- 108 | x, y : tuple 109 | Consumer complaint narrative and hierarchical labels. 110 | """ 111 | data = pd.read_csv( 112 | file_path, 113 | compression="zip", 114 | header=0, 115 | sep=",", 116 | low_memory=False, 117 | usecols=["Consumer complaint narrative", "Product", "Sub-product"], 118 | nrows=nrows, 119 | ) 120 | # Remove rows with NaN in any column 121 | data.dropna( 122 | subset=["Consumer complaint narrative", "Product", "Sub-product"], inplace=True 123 | ) 124 | # Rebuild index 125 | data.reset_index(drop=True, inplace=True) 126 | x = data["Consumer complaint narrative"] 127 | y = data[["Product", "Sub-product"]] 128 | # Alternative y can be built with columns "Issue" and "Sub-issue" 129 | return x, y 130 | 131 | 132 | def split_data(x: pd.Series, y: pd.DataFrame, random_state: int) -> tuple: 133 | """ 134 | Split data in train and test subsets. 135 | 136 | Parameters 137 | ---------- 138 | x : pd.Series 139 | Consumer complaint narrative. 140 | y : pd.DataFrame 141 | hierarchical labels. 142 | random_state : int 143 | Random state to enable reproducibility. 144 | 145 | Returns 146 | ------- 147 | x_train, x_test, y_train, y_test : tuple 148 | Train and test split. 149 | """ 150 | x_train, x_test, y_train, y_test = train_test_split( 151 | x, y, test_size=0.3, random_state=random_state 152 | ) 153 | return x_train, x_test, y_train, y_test 154 | 155 | 156 | def main(): # pragma: no cover 157 | """Split train and test data.""" 158 | args = parse_args(sys.argv[1:]) 159 | x, y = load_data(args.data, get_nrows(args.nrows)) 160 | x_train, x_test, y_train, y_test = split_data(x, y, args.random_state) 161 | save_dataframe(x_train, args.x_train) 162 | save_dataframe(x_test, args.x_test) 163 | save_dataframe(y_train, args.y_train) 164 | save_dataframe(y_test, args.y_test) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() # pragma: no cover 169 | -------------------------------------------------------------------------------- /docs/source/algorithms/calibration.rst: -------------------------------------------------------------------------------- 1 | .. _calibration-overview: 2 | 3 | =========================== 4 | Classifier Calibration 5 | =========================== 6 | HiClass provides support for probability calibration using various post-hoc calibration methods. 7 | 8 | ++++++++++++++++++++++++++ 9 | Motivation 10 | ++++++++++++++++++++++++++ 11 | While many machine learning models can output uncertainty scores, these scores are known to be often poorly calibrated [1]_ [2]_. Model calibration aims to improve the quality of probabilistic forecasts by learning a transformation of the scores, using a separate dataset. 12 | 13 | ++++++++++++++++++++++++++ 14 | Methods 15 | ++++++++++++++++++++++++++ 16 | 17 | HiClass supports the following calibration methods: 18 | 19 | * Isotonic Regression [3]_ 20 | 21 | * Platt Scaling [4]_ 22 | 23 | * Beta Calibration [5]_ 24 | 25 | * Inductive Venn-Abers Calibration [6]_ 26 | 27 | * Cross Venn-Abers Calibration [6]_ 28 | 29 | ++++++++++++++++++++++++++ 30 | Probability Aggregation 31 | ++++++++++++++++++++++++++ 32 | 33 | Combining probabilities over multiple levels is another method to improve probabilistic forecasts. The following methods are supported: 34 | 35 | Conditional Probability Aggregation (Multiply Aggregation) 36 | -------------- 37 | Given a node hierarchy with :math:`n` levels, the probability of a node :math:`A_i`, where :math:`i` denotes the level, is calculated as: 38 | 39 | :math:`\displaystyle{\mathbb{P}(A_1 \cap A_2 \cap \ldots \cap A_i) = \mathbb{P}(A_1) \cdot \mathbb{P}(A_2 \mid A_1) \cdot \mathbb{P}(A_3 \mid A_1 \cap A_2) \cdot \ldots}` 40 | :math:`\displaystyle{\cdot \mathbb{P}(A_i \mid A_1 \cap A_2 \cap \ldots \cap A_{i-1})}` 41 | 42 | Arithmetic Mean Aggregation 43 | -------------- 44 | :math:`\displaystyle{\mathbb{P}(A_i) = \frac{1}{i} \sum_{j=1}^{i} \mathbb{P}(A_{j})}` 45 | 46 | Geometric Mean Aggregation 47 | -------------- 48 | :math:`\displaystyle{\mathbb{P}(A_i) = \exp{\left(\frac{1}{i} \sum_{j=1}^{i} \ln \mathbb{P}(A_{j})\right)}}` 49 | 50 | ++++++++++++++++++++++++++ 51 | Code sample 52 | ++++++++++++++++++++++++++ 53 | 54 | .. code-block:: python 55 | 56 | from sklearn.ensemble import RandomForestClassifier 57 | 58 | from hiclass import LocalClassifierPerNode 59 | 60 | # Define data 61 | X_train = [[1], [2], [3], [4]] 62 | X_test = [[4], [3], [2], [1]] 63 | X_cal = [[5], [6], [7], [8]] 64 | Y_train = [ 65 | ["Animal", "Mammal", "Sheep"], 66 | ["Animal", "Mammal", "Cow"], 67 | ["Animal", "Reptile", "Snake"], 68 | ["Animal", "Reptile", "Lizard"], 69 | ] 70 | 71 | Y_cal = [ 72 | ["Animal", "Mammal", "Cow"], 73 | ["Animal", "Mammal", "Sheep"], 74 | ["Animal", "Reptile", "Lizard"], 75 | ["Animal", "Reptile", "Snake"], 76 | ] 77 | 78 | # Use random forest classifiers for every node 79 | rf = RandomForestClassifier() 80 | 81 | # Use local classifier per node with isotonic regression as calibration method 82 | classifier = LocalClassifierPerNode( 83 | local_classifier=rf, calibration_method="isotonic", probability_combiner="multiply" 84 | ) 85 | 86 | # Train local classifier per node 87 | classifier.fit(X_train, Y_train) 88 | 89 | # Calibrate local classifier per node 90 | classifier.calibrate(X_cal, Y_cal) 91 | 92 | # Predict probabilities 93 | probabilities = classifier.predict_proba(X_test) 94 | 95 | # Print probabilities and labels for the last level 96 | print(classifier.classes_[2]) 97 | print(probabilities) 98 | 99 | .. [1] Niculescu-Mizil, Alexandru; Caruana, Rich (2005): Predicting good probabilities with supervised learning. In: Saso Dzeroski (Hg.): Proceedings of the 22nd international conference on Machine learning - ICML '05. the 22nd international conference. Bonn, Germany, 07.08.2005 - 11.08.2005. New York, New York, USA: ACM Press, S. 625-632. 100 | 101 | .. [2] Chuan Guo; Geoff Pleiss; Yu Sun; Kilian Q. Weinberger (2017): On Calibration of Modern Neural Networks. In: Doina Precup und Yee Whye Teh (Hg.): Proceedings of the 34th International Conference on Machine Learning, Bd. 70: PMLR (Proceedings of Machine Learning Research), S. 1321-1330. 102 | 103 | .. [3] Zadrozny, Bianca; Elkan, Charles (2002): Transforming classifier scores into accurate multiclass probability estimates. In: Proceedings of the Eighth ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. New York, NY, USA: Association for Computing Machinery (KDD ’02), S. 694-699. 104 | 105 | .. [4] Platt, John (2000): Probabilistic Outputs for Support Vector Machines and Comparisons to Regularized Likelihood Methods. In: Adv. Large Margin Classif. 10. 106 | 107 | .. [5] Kull, Meelis; Filho, Telmo Silva; Flach, Peter (2017): Beta calibration: a well-founded and easily implemented improvement on logistic calibration for binary classifiers. In: Aarti Singh und Jerry Zhu (Hg.): Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, Bd. 54: PMLR (Proceedings of Machine Learning Research), S. 623-631. 108 | 109 | .. [6] Vovk, Vladimir; Petej, Ivan; Fedorova, Valentina (2015): Large-scale probabilistic predictors with and without guarantees of validity. In: C. Cortes, N. Lawrence, D. Lee, M. Sugiyama und R. Garnett (Hg.): Advances in Neural Information Processing Systems, Bd. 28: Curran Associates, Inc. 110 | 111 | -------------------------------------------------------------------------------- /docs/examples/plot_calibration.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ===================== 4 | Calibrating a Classifier 5 | ===================== 6 | 7 | A minimalist example showing how to calibrate a HiClass LCN model. The calibration method can be selected with the :literal:`calibration_method` parameter, for example: 8 | 9 | .. tabs:: 10 | 11 | .. code-tab:: python 12 | :caption: Isotonic Regression 13 | 14 | rf = RandomForestClassifier() 15 | classifier = LocalClassifierPerNode( 16 | local_classifier=rf, 17 | calibration_method='isotonic' 18 | ) 19 | 20 | .. code-tab:: python 21 | :caption: Platt scaling 22 | 23 | rf = RandomForestClassifier() 24 | classifier = LocalClassifierPerNode( 25 | local_classifier=rf, 26 | calibration_method='platt' 27 | ) 28 | 29 | .. code-tab:: python 30 | :caption: Beta scaling 31 | 32 | rf = RandomForestClassifier() 33 | classifier = LocalClassifierPerNode( 34 | local_classifier=rf, 35 | calibration_method='beta' 36 | ) 37 | 38 | .. code-tab:: python 39 | :caption: IVAP 40 | 41 | rf = RandomForestClassifier() 42 | classifier = LocalClassifierPerNode( 43 | local_classifier=rf, 44 | calibration_method='ivap' 45 | ) 46 | 47 | .. code-tab:: python 48 | :caption: CVAP 49 | 50 | rf = RandomForestClassifier() 51 | classifier = LocalClassifierPerNode( 52 | local_classifier=rf, 53 | calibration_method='cvap' 54 | ) 55 | 56 | Furthermore, probabilites of multiple levels can be aggregated by defining a probability combiner: 57 | 58 | .. tabs:: 59 | 60 | .. code-tab:: python 61 | :caption: Multiply (Default) 62 | 63 | rf = RandomForestClassifier() 64 | classifier = LocalClassifierPerNode( 65 | local_classifier=rf, 66 | calibration_method='isotonic', 67 | probability_combiner='multiply' 68 | ) 69 | 70 | .. code-tab:: python 71 | :caption: Geometric Mean 72 | 73 | rf = RandomForestClassifier() 74 | classifier = LocalClassifierPerNode( 75 | local_classifier=rf, 76 | calibration_method='isotonic', 77 | probability_combiner='geometric' 78 | ) 79 | 80 | .. code-tab:: python 81 | :caption: Arithmetic Mean 82 | 83 | rf = RandomForestClassifier() 84 | classifier = LocalClassifierPerNode( 85 | local_classifier=rf, 86 | calibration_method='isotonic', 87 | probability_combiner='arithmetic' 88 | ) 89 | 90 | .. code-tab:: python 91 | :caption: No Aggregation 92 | 93 | rf = RandomForestClassifier() 94 | classifier = LocalClassifierPerNode( 95 | local_classifier=rf, 96 | calibration_method='isotonic', 97 | probability_combiner=None 98 | ) 99 | 100 | 101 | A hierarchical classifier can be calibrated by calling calibrate on the model or by using a Pipeline: 102 | 103 | .. tabs:: 104 | 105 | .. code-tab:: python 106 | :caption: Default 107 | 108 | rf = RandomForestClassifier() 109 | classifier = LocalClassifierPerNode( 110 | local_classifier=rf, 111 | calibration_method='isotonic' 112 | ) 113 | 114 | classifier.fit(X_train, Y_train) 115 | classifier.calibrate(X_cal, Y_cal) 116 | classifier.predict_proba(X_test) 117 | 118 | .. code-tab:: python 119 | :caption: Pipeline 120 | 121 | from hiclass import Pipeline 122 | 123 | rf = RandomForestClassifier() 124 | classifier = LocalClassifierPerNode( 125 | local_classifier=rf, 126 | calibration_method='isotonic' 127 | ) 128 | 129 | pipeline = Pipeline([ 130 | ('classifier', classifier), 131 | ]) 132 | 133 | pipeline.fit(X_train, Y_train) 134 | pipeline.calibrate(X_cal, Y_cal) 135 | pipeline.predict_proba(X_test) 136 | 137 | In the code below, isotonic regression is used to calibrate the model. 138 | 139 | """ 140 | from sklearn.ensemble import RandomForestClassifier 141 | 142 | from hiclass import LocalClassifierPerNode 143 | 144 | # Define data 145 | X_train = [[1], [2], [3], [4]] 146 | X_test = [[4], [3], [2], [1]] 147 | X_cal = [[5], [6], [7], [8]] 148 | Y_train = [ 149 | ["Animal", "Mammal", "Sheep"], 150 | ["Animal", "Mammal", "Cow"], 151 | ["Animal", "Reptile", "Snake"], 152 | ["Animal", "Reptile", "Lizard"], 153 | ] 154 | 155 | Y_cal = [ 156 | ["Animal", "Mammal", "Cow"], 157 | ["Animal", "Mammal", "Sheep"], 158 | ["Animal", "Reptile", "Lizard"], 159 | ["Animal", "Reptile", "Snake"], 160 | ] 161 | 162 | # Use random forest classifiers for every node 163 | rf = RandomForestClassifier() 164 | 165 | # Use local classifier per node with isotonic regression as calibration method 166 | classifier = LocalClassifierPerNode( 167 | local_classifier=rf, calibration_method="isotonic", probability_combiner="multiply" 168 | ) 169 | 170 | # Train local classifier per node 171 | classifier.fit(X_train, Y_train) 172 | 173 | # Calibrate local classifier per node 174 | classifier.calibrate(X_cal, Y_cal) 175 | 176 | # Predict probabilities 177 | probabilities = classifier.predict_proba(X_test) 178 | 179 | # Print probabilities and labels for the last level 180 | print(classifier.classes_[2]) 181 | print(probabilities) 182 | -------------------------------------------------------------------------------- /hiclass/_calibration/Calibrator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.base import BaseEstimator 3 | from sklearn.preprocessing import LabelBinarizer 4 | from sklearn.preprocessing import LabelEncoder 5 | from hiclass._calibration.VennAbersCalibrator import ( 6 | _InductiveVennAbersCalibrator, 7 | _CrossVennAbersCalibrator, 8 | ) 9 | from hiclass._calibration.IsotonicRegression import _IsotonicRegression 10 | from hiclass._calibration.PlattScaling import _PlattScaling 11 | from hiclass._calibration.BetaCalibrator import _BetaCalibrator 12 | from hiclass._calibration.calibration_utils import _one_vs_rest_split 13 | from hiclass._hiclass_utils import _normalize_probabilities 14 | 15 | 16 | class _Calibrator(BaseEstimator): 17 | available_methods = ["ivap", "cvap", "sigmoid", "isotonic", "beta"] 18 | _multiclass_methods = ["cvap"] 19 | 20 | def __init__( 21 | self, estimator: BaseEstimator, method: str = "ivap", **method_params 22 | ) -> None: 23 | assert callable(getattr(estimator, "predict_proba", None)) 24 | self.estimator = estimator 25 | self.method_params = method_params 26 | # self.classes_ = self.estimator.classes_ 27 | self.multiclass = False 28 | self.multiclass_support = method in self._multiclass_methods 29 | if method not in self.available_methods: 30 | raise ValueError(f"{method} is not a valid calibration method.") 31 | self.method = method 32 | 33 | def fit(self, X: np.ndarray, y: np.ndarray): 34 | """ 35 | Fit a calibrator. 36 | 37 | Parameters 38 | ---------- 39 | X : {array-like, sparse matrix} of shape (n_samples, n_features) 40 | The calibration input samples. Internally, its dtype will be converted 41 | to ``dtype=np.float32``. If a sparse matrix is provided, it will be 42 | converted into a sparse ``csr_matrix``. 43 | y : array-like of shape (n_samples, n_levels) 44 | The target values, i.e., hierarchical class labels for classification. 45 | 46 | Returns 47 | ------- 48 | self : object 49 | Calibrated estimator. 50 | """ 51 | self.classes_ = self.estimator.classes_ 52 | calibration_scores = self.estimator.predict_proba(X) 53 | 54 | if calibration_scores.shape[1] > 2: 55 | self.multiclass = True 56 | 57 | self.calibrators = [] 58 | 59 | if self.multiclass: 60 | if self.multiclass_support: 61 | # only cvap 62 | self.label_encoder = LabelEncoder() 63 | encoded_y = self.label_encoder.fit_transform(y) 64 | calibrator = self._create_calibrator(self.method, self.method_params) 65 | calibrator.fit(encoded_y, calibration_scores, X) 66 | self.calibrators.append(calibrator) 67 | 68 | else: 69 | # do one vs rest calibration 70 | score_splits, label_splits = _one_vs_rest_split( 71 | y, calibration_scores, self.estimator 72 | ) 73 | for i in range(len(score_splits)): 74 | # create a calibrator for each split 75 | calibrator = self._create_calibrator( 76 | self.method, self.method_params 77 | ) 78 | calibrator.fit(label_splits[i], score_splits[i], X) 79 | self.calibrators.append(calibrator) 80 | 81 | else: 82 | self.label_encoder = LabelEncoder() 83 | encoded_y = self.label_encoder.fit_transform(y) 84 | calibrator = self._create_calibrator(self.method, self.method_params) 85 | calibrator.fit(encoded_y, calibration_scores[:, 1], X) 86 | self.calibrators.append(calibrator) 87 | self._is_fitted = True 88 | return self 89 | 90 | def predict_proba(self, X: np.ndarray): 91 | test_scores = self.estimator.predict_proba(X) 92 | 93 | if self.multiclass: 94 | if self.multiclass_support: 95 | # only cvap 96 | return self.calibrators[0].predict_proba(test_scores) 97 | 98 | else: 99 | # one vs rest calibration 100 | score_splits = [test_scores[:, i] for i in range(test_scores.shape[1])] 101 | 102 | probabilities = np.zeros((X.shape[0], len(self.estimator.classes_))) 103 | for idx, split in enumerate(score_splits): 104 | probabilities[:, idx] = self.calibrators[idx].predict_proba(split) 105 | 106 | probabilities = _normalize_probabilities(probabilities) 107 | 108 | else: 109 | probabilities = np.zeros((X.shape[0], 2)) 110 | probabilities[:, 1] = self.calibrators[0].predict_proba(test_scores[:, 1]) 111 | probabilities[:, 0] = 1.0 - probabilities[:, 1] 112 | 113 | return probabilities 114 | 115 | def _create_calibrator(self, name: str, params): 116 | if name == "ivap": 117 | return _InductiveVennAbersCalibrator(**params) 118 | elif name == "cvap": 119 | return _CrossVennAbersCalibrator(self.estimator, **params) 120 | elif name == "sigmoid" or name == "platt": 121 | return _PlattScaling() 122 | elif name == "isotonic": 123 | return _IsotonicRegression(params) 124 | elif name == "beta": 125 | return _BetaCalibrator() 126 | 127 | def __sklearn_is_fitted__(self): 128 | """Check fitted status and return a Boolean value.""" 129 | return hasattr(self, "_is_fitted") and self._is_fitted 130 | --------------------------------------------------------------------------------