├── imbens ├── datasets │ ├── tests │ │ ├── __init__.py │ │ ├── test_imbalance.py │ │ └── test_zenodo.py │ └── __init__.py ├── ensemble │ ├── tests │ │ ├── __init__.py │ │ └── test_base_ensemble.py │ ├── _compatible │ │ ├── tests │ │ │ ├── __init__.py │ │ │ └── test_adabost_compatible.py │ │ └── __init__.py │ ├── _over_sampling │ │ ├── tests │ │ │ ├── __init__.py │ │ │ └── test_over_boost.py │ │ └── __init__.py │ ├── _reweighting │ │ ├── tests │ │ │ └── __init__.py │ │ └── __init__.py │ ├── _under_sampling │ │ ├── tests │ │ │ ├── __init__.py │ │ │ └── test_rus_boost.py │ │ └── __init__.py │ └── __init__.py ├── metrics │ ├── tests │ │ ├── __init__.py │ │ └── test_score_objects.py │ └── __init__.py ├── sampler │ ├── tests │ │ └── __init__.py │ ├── _over_sampling │ │ ├── tests │ │ │ └── __init__.py │ │ ├── _smote │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ ├── test_borderline_smote.py │ │ │ │ └── test_svm_smote.py │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── base.py │ ├── _under_sampling │ │ ├── _prototype_generation │ │ │ ├── tests │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── _prototype_selection │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ ├── test_tomek_links.py │ │ │ │ ├── test_neighbourhood_cleaning_rule.py │ │ │ │ ├── test_instance_hardness_threshold.py │ │ │ │ ├── test_one_sided_selection.py │ │ │ │ └── test_condensed_nearest_neighbour.py │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── base.py │ └── __init__.py ├── utils │ ├── tests │ │ ├── __init__.py │ │ ├── test_deprecation.py │ │ ├── test_testing.py │ │ ├── test_plot.py │ │ ├── test_docstring.py │ │ └── test_show_versions.py │ ├── __init__.py │ ├── deprecation.py │ ├── _show_versions.py │ ├── _plot.py │ └── testing.py ├── visualizer │ ├── tests │ │ └── __init__.py │ └── __init__.py ├── base.py ├── exceptions.py ├── _version.py └── __init__.py ├── docs ├── source │ ├── sphinxext │ │ ├── MANIFEST.in │ │ ├── README.txt │ │ └── github_link.py │ ├── _static │ │ ├── css │ │ │ └── my_theme.css │ │ ├── thumbnail.png │ │ └── training_log_thumbnail.png │ ├── api │ │ ├── datasets │ │ │ ├── api.rst │ │ │ ├── _autosummary │ │ │ │ ├── imbens.datasets.make_imbalance.rst │ │ │ │ ├── imbens.datasets.fetch_openml_datasets.rst │ │ │ │ ├── imbens.datasets.fetch_zenodo_datasets.rst │ │ │ │ └── imbens.datasets.generate_imbalance_data.rst │ │ │ └── datasets.rst │ │ ├── pipeline │ │ │ ├── api.rst │ │ │ ├── _autosummary │ │ │ │ ├── imbens.pipeline.make_pipeline.rst │ │ │ │ └── imbens.pipeline.Pipeline.rst │ │ │ └── pipeline.rst │ │ ├── visualizer │ │ │ ├── api.rst │ │ │ ├── visualizer.rst │ │ │ └── _autosummary │ │ │ │ └── imbens.visualizer.ImbalancedEnsembleVisualizer.rst │ │ ├── metrics │ │ │ ├── api.rst │ │ │ ├── pairwise.rst │ │ │ ├── _autosummary │ │ │ │ ├── imbens.metrics.sensitivity_score.rst │ │ │ │ ├── imbens.metrics.specificity_score.rst │ │ │ │ ├── imbens.metrics.geometric_mean_score.rst │ │ │ │ ├── imbens.metrics.make_index_balanced_accuracy.rst │ │ │ │ ├── imbens.metrics.sensitivity_specificity_support.rst │ │ │ │ ├── imbens.metrics.classification_report_imbalanced.rst │ │ │ │ ├── imbens.metrics.macro_averaged_mean_absolute_error.rst │ │ │ │ └── imbens.metrics.ValueDifferenceMetric.rst │ │ │ └── classification.rst │ │ ├── sampler │ │ │ ├── api.rst │ │ │ ├── over-samplers.rst │ │ │ ├── _autosummary │ │ │ │ ├── imbens.sampler.SMOTE.rst │ │ │ │ ├── imbens.sampler.ADASYN.rst │ │ │ │ ├── imbens.sampler.AllKNN.rst │ │ │ │ ├── imbens.sampler.NearMiss.rst │ │ │ │ ├── imbens.sampler.SVMSMOTE.rst │ │ │ │ ├── imbens.sampler.KMeansSMOTE.rst │ │ │ │ ├── imbens.sampler.BorderlineSMOTE.rst │ │ │ │ ├── imbens.sampler.ClusterCentroids.rst │ │ │ │ ├── imbens.sampler.TomekLinks.rst │ │ │ │ ├── imbens.sampler.OneSidedSelection.rst │ │ │ │ ├── imbens.sampler.RandomOverSampler.rst │ │ │ │ ├── imbens.sampler.RandomUnderSampler.rst │ │ │ │ ├── imbens.sampler.SelfPacedUnderSampler.rst │ │ │ │ ├── imbens.sampler.EditedNearestNeighbours.rst │ │ │ │ ├── imbens.sampler.CondensedNearestNeighbour.rst │ │ │ │ ├── imbens.sampler.InstanceHardnessThreshold.rst │ │ │ │ ├── imbens.sampler.NeighbourhoodCleaningRule.rst │ │ │ │ ├── imbens.sampler.BalanceCascadeUnderSampler.rst │ │ │ │ └── imbens.sampler.RepeatedEditedNearestNeighbours.rst │ │ │ └── under-samplers.rst │ │ ├── utils │ │ │ ├── api.rst │ │ │ ├── _autosummary │ │ │ │ ├── imbens.utils.evaluate_print.rst │ │ │ │ ├── imbens.utils.check_target_type.rst │ │ │ │ ├── imbens.utils.check_eval_metrics.rst │ │ │ │ ├── imbens.utils.check_eval_datasets.rst │ │ │ │ ├── imbens.utils.check_neighbors_object.rst │ │ │ │ ├── imbens.utils.check_sampling_strategy.rst │ │ │ │ ├── imbens.utils.check_balancing_schedule.rst │ │ │ │ └── imbens.utils.check_target_label_and_n_target_samples.rst │ │ │ ├── evaluation.rst │ │ │ ├── validation_sampler.rst │ │ │ └── validation_ensemble.rst │ │ └── ensemble │ │ │ ├── api.rst │ │ │ ├── compatible.rst │ │ │ ├── reweighting.rst │ │ │ ├── over-sampling.rst │ │ │ ├── under-sampling.rst │ │ │ └── _autosummary │ │ │ ├── imbens.ensemble.BalanceCascadeClassifier.rst │ │ │ ├── imbens.ensemble.OverBaggingClassifier.rst │ │ │ ├── imbens.ensemble.EasyEnsembleClassifier.rst │ │ │ ├── imbens.ensemble.SMOTEBaggingClassifier.rst │ │ │ ├── imbens.ensemble.SelfPacedEnsembleClassifier.rst │ │ │ ├── imbens.ensemble.UnderBaggingClassifier.rst │ │ │ ├── imbens.ensemble.CompatibleBaggingClassifier.rst │ │ │ ├── imbens.ensemble.BalancedRandomForestClassifier.rst │ │ │ ├── imbens.ensemble.AdaCostClassifier.rst │ │ │ ├── imbens.ensemble.RUSBoostClassifier.rst │ │ │ ├── imbens.ensemble.AdaUBoostClassifier.rst │ │ │ ├── imbens.ensemble.AsymBoostClassifier.rst │ │ │ ├── imbens.ensemble.OverBoostClassifier.rst │ │ │ ├── imbens.ensemble.SMOTEBoostClassifier.rst │ │ │ ├── imbens.ensemble.KmeansSMOTEBoostClassifier.rst │ │ │ └── imbens.ensemble.CompatibleAdaBoostClassifier.rst │ ├── _templates │ │ ├── function.rst │ │ ├── numpydoc_docstring.rst │ │ ├── ensemble_class.rst │ │ └── class.rst │ ├── doc_requirements.txt │ ├── install.rst │ ├── get_start.rst │ └── sg_execution_times.rst ├── requirements.txt ├── Makefile └── make.bat ├── MANIFEST.in ├── examples ├── basic │ ├── README.txt │ ├── plot_basic_visualize.py │ ├── plot_training_log.py │ └── plot_basic_example.py ├── datasets │ ├── README.txt │ ├── plot_generate_imbalance.py │ ├── plot_make_imbalance.py │ └── plot_make_imbalance_digits.py ├── visualizer │ └── README.txt ├── evaluation │ ├── README.txt │ ├── plot_classification_report.py │ └── plot_metrics.py ├── pipeline │ ├── README.txt │ └── plot_pipeline_classification.py ├── classification │ ├── README.txt │ └── plot_probability.py └── README.txt ├── requirements.txt ├── readthedocs.yml ├── .circleci └── config.yml ├── LICENSE ├── .all-contributorsrc ├── .gitignore └── setup.py /imbens/datasets/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/ensemble/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/metrics/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/sampler/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/utils/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/visualizer/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/ensemble/_compatible/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/ensemble/_over_sampling/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/ensemble/_reweighting/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/sampler/_over_sampling/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/ensemble/_under_sampling/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/sampler/_over_sampling/_smote/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_generation/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_selection/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/sphinxext/MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include tests *.py 2 | include *.txt 3 | -------------------------------------------------------------------------------- /docs/source/_static/css/my_theme.css: -------------------------------------------------------------------------------- 1 | .wy-nav-content { 2 | max-width: 1000px !important; 3 | } -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include docs * 2 | recursive-include examples * 3 | include LICENSE 4 | include README.md -------------------------------------------------------------------------------- /docs/source/_static/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiningLiu1998/imbalanced-ensemble/HEAD/docs/source/_static/thumbnail.png -------------------------------------------------------------------------------- /examples/basic/README.txt: -------------------------------------------------------------------------------- 1 | .. _basic_examples: 2 | 3 | Basic usage examples 4 | -------------------- 5 | 6 | Quick start with :mod:`imbens`. 7 | -------------------------------------------------------------------------------- /docs/source/_static/training_log_thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhiningLiu1998/imbalanced-ensemble/HEAD/docs/source/_static/training_log_thumbnail.png -------------------------------------------------------------------------------- /examples/datasets/README.txt: -------------------------------------------------------------------------------- 1 | .. _dataset_examples: 2 | 3 | Dataset examples 4 | ---------------- 5 | 6 | Examples concerning the :mod:`imbens.datasets` module. 7 | -------------------------------------------------------------------------------- /examples/visualizer/README.txt: -------------------------------------------------------------------------------- 1 | .. _visualizer_examples: 2 | 3 | Visualizer examples 4 | ------------------- 5 | 6 | Examples concerning the :mod:`imbens.visualizer`. module. 7 | -------------------------------------------------------------------------------- /examples/evaluation/README.txt: -------------------------------------------------------------------------------- 1 | .. _evaluation_examples: 2 | 3 | Evaluation examples 4 | ------------------- 5 | 6 | Examples illustrating how classification using imbalanced dataset can be done. 7 | -------------------------------------------------------------------------------- /examples/pipeline/README.txt: -------------------------------------------------------------------------------- 1 | .. _pipeline_examples: 2 | 3 | Pipeline examples 4 | ----------------- 5 | 6 | Example of how to use the a pipeline to include under-sampling with `scikit-learn` estimators. -------------------------------------------------------------------------------- /examples/classification/README.txt: -------------------------------------------------------------------------------- 1 | .. _classification_examples: 2 | 3 | Classification examples 4 | ----------------------- 5 | 6 | Examples about using classification algorithms in :mod:`imbens.ensemble` module. 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.0 2 | scipy>=1.9.1 3 | pandas>=2.1.1 4 | joblib>=0.11 5 | scikit-learn==1.6.0 6 | matplotlib>=3.3.2 7 | seaborn>=0.13.2 8 | tqdm>=4.50.2 9 | openml>=0.14.0 10 | platformdirs>=3.0.0 -------------------------------------------------------------------------------- /docs/source/api/datasets/api.rst: -------------------------------------------------------------------------------- 1 | ``imbens.datasets`` 2 | ********************************* 3 | 4 | This is the full API documentation of the `imbens.datasets` module. 5 | 6 | .. toctree:: 7 | :maxdepth: 3 8 | 9 | datasets 10 | -------------------------------------------------------------------------------- /docs/source/api/pipeline/api.rst: -------------------------------------------------------------------------------- 1 | ``imbens.pipeline`` 2 | ********************************* 3 | 4 | This is the full API documentation of the `imbens.pipeline` module. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | pipeline 10 | -------------------------------------------------------------------------------- /docs/source/api/visualizer/api.rst: -------------------------------------------------------------------------------- 1 | ``imbens.visualizer`` 2 | *********************************** 3 | 4 | This is the full API documentation of the `imbens.visualizer` module. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | visualizer 10 | -------------------------------------------------------------------------------- /docs/source/api/metrics/api.rst: -------------------------------------------------------------------------------- 1 | ``imbens.metrics`` 2 | ********************************* 3 | 4 | This is the full API documentation of the `imbens.metrics` module. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | classification 10 | pairwise 11 | -------------------------------------------------------------------------------- /docs/source/api/sampler/api.rst: -------------------------------------------------------------------------------- 1 | ``imbens.sampler`` 2 | ********************************* 3 | 4 | This is the full API documentation of the `imbens.sampler` module. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | under-samplers 10 | over-samplers 11 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _examples-index: 2 | 3 | **General-purpose and introductory examples for the imbalanced-ensemble toolbox.** 4 | 5 | **The examples gallery is still under construction.** Please refer to APIs for more detailed guidelines of how to use 6 | ``imbens``. -------------------------------------------------------------------------------- /docs/source/api/utils/api.rst: -------------------------------------------------------------------------------- 1 | ``imbens.utils`` 2 | ********************************* 3 | 4 | This is the full API documentation of the `imbens.utils` module. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | evaluation 10 | validation_ensemble 11 | validation_sampler -------------------------------------------------------------------------------- /imbens/visualizer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.visualizer` module include a visualizer 3 | for visualizing ensemble estimators. 4 | """ 5 | 6 | from .visualizer import ImbalancedEnsembleVisualizer 7 | 8 | __all__ = [ 9 | "ImbalancedEnsembleVisualizer", 10 | ] 11 | -------------------------------------------------------------------------------- /docs/source/api/ensemble/api.rst: -------------------------------------------------------------------------------- 1 | ``imbens.ensemble`` 2 | ********************************* 3 | 4 | This is the full API documentation of the `imbens.ensemble` module. 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | under-sampling 10 | over-sampling 11 | reweighting 12 | compatible -------------------------------------------------------------------------------- /imbens/sampler/_over_sampling/_smote/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SMOTE 2 | 3 | from .cluster import KMeansSMOTE 4 | 5 | from .filter import BorderlineSMOTE 6 | from .filter import SVMSMOTE 7 | 8 | __all__ = [ 9 | "SMOTE", 10 | "KMeansSMOTE", 11 | "BorderlineSMOTE", 12 | "SVMSMOTE", 13 | ] -------------------------------------------------------------------------------- /docs/source/_templates/function.rst: -------------------------------------------------------------------------------- 1 | {{objname}} 2 | {{ underline }}==================== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autofunction:: {{ objname }} 7 | 8 | .. include:: ../../../back_references/{{module}}.{{objname}}.examples 9 | 10 | .. raw:: html 11 | 12 |
13 | -------------------------------------------------------------------------------- /docs/source/_templates/numpydoc_docstring.rst: -------------------------------------------------------------------------------- 1 | {{index}} 2 | {{summary}} 3 | {{extended_summary}} 4 | {{parameters}} 5 | {{returns}} 6 | {{yields}} 7 | {{other_parameters}} 8 | {{attributes}} 9 | {{raises}} 10 | {{warns}} 11 | {{warnings}} 12 | {{see_also}} 13 | {{notes}} 14 | {{references}} 15 | {{examples}} 16 | {{methods}} -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_generation/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.sampler._under_sampling.prototype_generation` 3 | submodule contains methods that generate new samples in order to balance 4 | the dataset. 5 | """ 6 | 7 | from ._cluster_centroids import ClusterCentroids 8 | 9 | __all__ = ["ClusterCentroids"] 10 | -------------------------------------------------------------------------------- /docs/source/doc_requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.0 2 | scipy>=1.9.1 3 | pandas>=2.1.1 4 | joblib>=0.11 5 | scikit-learn>=1.5.0 6 | matplotlib>=3.3.2 7 | seaborn>=0.13.2 8 | tqdm>=4.50.2 9 | openml>=0.14.0 10 | platformdirs 11 | sphinx 12 | sphinx-gallery 13 | sphinx_rtd_theme 14 | pydata-sphinx-theme 15 | numpydoc 16 | sphinxcontrib-bibtex 17 | torch 18 | pytest -------------------------------------------------------------------------------- /docs/source/api/utils/_autosummary/imbens.utils.evaluate_print.rst: -------------------------------------------------------------------------------- 1 | evaluate_print 2 | =============================================== 3 | 4 | .. currentmodule:: imbens.utils 5 | 6 | .. autofunction:: evaluate_print 7 | 8 | .. include:: ../../../back_references/imbens.utils.evaluate_print.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/metrics/pairwise.rst: -------------------------------------------------------------------------------- 1 | .. _metrics_api: 2 | 3 | Pairwise Metrics 4 | ================================ 5 | 6 | .. automodule:: imbens.metrics 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.metrics 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | ValueDifferenceMetric -------------------------------------------------------------------------------- /docs/source/api/pipeline/_autosummary/imbens.pipeline.make_pipeline.rst: -------------------------------------------------------------------------------- 1 | make_pipeline 2 | ================================================= 3 | 4 | .. currentmodule:: imbens.pipeline 5 | 6 | .. autofunction:: make_pipeline 7 | 8 | .. include:: ../../../back_references/imbens.pipeline.make_pipeline.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/datasets/_autosummary/imbens.datasets.make_imbalance.rst: -------------------------------------------------------------------------------- 1 | make_imbalance 2 | ================================================== 3 | 4 | .. currentmodule:: imbens.datasets 5 | 6 | .. autofunction:: make_imbalance 7 | 8 | .. include:: ../../../back_references/imbens.datasets.make_imbalance.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/utils/_autosummary/imbens.utils.check_target_type.rst: -------------------------------------------------------------------------------- 1 | check_target_type 2 | ================================================== 3 | 4 | .. currentmodule:: imbens.utils 5 | 6 | .. autofunction:: check_target_type 7 | 8 | .. include:: ../../../back_references/imbens.utils.check_target_type.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/utils/evaluation.rst: -------------------------------------------------------------------------------- 1 | .. _evaluation_api: 2 | 3 | Utilities for evaluation 4 | =================================== 5 | 6 | .. automodule:: imbens.utils 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.utils 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: function.rst 15 | 16 | evaluate_print -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.0 2 | scipy>=0.19.1 3 | pandas>=1.1.3 4 | joblib>=0.11 5 | scikit-learn>=0.24 6 | imbalanced-learn>=0.7.0 7 | matplotlib>=3.3.2 8 | seaborn>=0.11.0 9 | tqdm>=4.50.2 10 | openml>=0.14.0 11 | platformdirs 12 | sphinx 13 | sphinx-gallery 14 | numpydoc 15 | pydata-sphinx-theme 16 | sphinxcontrib-bibtex 17 | torch 18 | torchvision 19 | pytest 20 | pytest-cov -------------------------------------------------------------------------------- /docs/source/api/utils/_autosummary/imbens.utils.check_eval_metrics.rst: -------------------------------------------------------------------------------- 1 | check_eval_metrics 2 | =================================================== 3 | 4 | .. currentmodule:: imbens.utils 5 | 6 | .. autofunction:: check_eval_metrics 7 | 8 | .. include:: ../../../back_references/imbens.utils.check_eval_metrics.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | #conda: 2 | # file: docs/environment.yml 3 | 4 | version: 2 5 | 6 | # Set the OS, Python version and other tools you might need 7 | build: 8 | os: ubuntu-22.04 9 | tools: 10 | python: "3.11" 11 | 12 | sphinx: 13 | configuration: docs/source/conf.py 14 | 15 | python: 16 | install: 17 | - requirements: docs/source/doc_requirements.txt 18 | -------------------------------------------------------------------------------- /docs/source/api/metrics/_autosummary/imbens.metrics.sensitivity_score.rst: -------------------------------------------------------------------------------- 1 | sensitivity_score 2 | ==================================================== 3 | 4 | .. currentmodule:: imbens.metrics 5 | 6 | .. autofunction:: sensitivity_score 7 | 8 | .. include:: ../../../back_references/imbens.metrics.sensitivity_score.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/metrics/_autosummary/imbens.metrics.specificity_score.rst: -------------------------------------------------------------------------------- 1 | specificity_score 2 | ==================================================== 3 | 4 | .. currentmodule:: imbens.metrics 5 | 6 | .. autofunction:: specificity_score 7 | 8 | .. include:: ../../../back_references/imbens.metrics.specificity_score.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/utils/_autosummary/imbens.utils.check_eval_datasets.rst: -------------------------------------------------------------------------------- 1 | check_eval_datasets 2 | ==================================================== 3 | 4 | .. currentmodule:: imbens.utils 5 | 6 | .. autofunction:: check_eval_datasets 7 | 8 | .. include:: ../../../back_references/imbens.utils.check_eval_datasets.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/visualizer/visualizer.rst: -------------------------------------------------------------------------------- 1 | .. _visualizer_api: 2 | 3 | Visualizer 4 | ================================ 5 | 6 | .. automodule:: imbens.visualizer 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.visualizer 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | ImbalancedEnsembleVisualizer -------------------------------------------------------------------------------- /docs/source/api/metrics/_autosummary/imbens.metrics.geometric_mean_score.rst: -------------------------------------------------------------------------------- 1 | geometric_mean_score 2 | ======================================================= 3 | 4 | .. currentmodule:: imbens.metrics 5 | 6 | .. autofunction:: geometric_mean_score 7 | 8 | .. include:: ../../../back_references/imbens.metrics.geometric_mean_score.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/utils/_autosummary/imbens.utils.check_neighbors_object.rst: -------------------------------------------------------------------------------- 1 | check_neighbors_object 2 | ======================================================= 3 | 4 | .. currentmodule:: imbens.utils 5 | 6 | .. autofunction:: check_neighbors_object 7 | 8 | .. include:: ../../../back_references/imbens.utils.check_neighbors_object.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/utils/_autosummary/imbens.utils.check_sampling_strategy.rst: -------------------------------------------------------------------------------- 1 | check_sampling_strategy 2 | ======================================================== 3 | 4 | .. currentmodule:: imbens.utils 5 | 6 | .. autofunction:: check_sampling_strategy 7 | 8 | .. include:: ../../../back_references/imbens.utils.check_sampling_strategy.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/datasets/_autosummary/imbens.datasets.fetch_openml_datasets.rst: -------------------------------------------------------------------------------- 1 | fetch_openml_datasets 2 | ========================================================= 3 | 4 | .. currentmodule:: imbens.datasets 5 | 6 | .. autofunction:: fetch_openml_datasets 7 | 8 | .. include:: ../../../back_references/imbens.datasets.fetch_openml_datasets.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/datasets/_autosummary/imbens.datasets.fetch_zenodo_datasets.rst: -------------------------------------------------------------------------------- 1 | fetch_zenodo_datasets 2 | ========================================================= 3 | 4 | .. currentmodule:: imbens.datasets 5 | 6 | .. autofunction:: fetch_zenodo_datasets 7 | 8 | .. include:: ../../../back_references/imbens.datasets.fetch_zenodo_datasets.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/utils/_autosummary/imbens.utils.check_balancing_schedule.rst: -------------------------------------------------------------------------------- 1 | check_balancing_schedule 2 | ========================================================= 3 | 4 | .. currentmodule:: imbens.utils 5 | 6 | .. autofunction:: check_balancing_schedule 7 | 8 | .. include:: ../../../back_references/imbens.utils.check_balancing_schedule.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/datasets/_autosummary/imbens.datasets.generate_imbalance_data.rst: -------------------------------------------------------------------------------- 1 | generate_imbalance_data 2 | =========================================================== 3 | 4 | .. currentmodule:: imbens.datasets 5 | 6 | .. autofunction:: generate_imbalance_data 7 | 8 | .. include:: ../../../back_references/imbens.datasets.generate_imbalance_data.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/compatible.rst: -------------------------------------------------------------------------------- 1 | .. _compatible_api: 2 | 3 | Compatible ensembles 4 | ====================== 5 | 6 | .. automodule:: imbens.ensemble 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.ensemble 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | CompatibleAdaBoostClassifier 17 | CompatibleBaggingClassifier 18 | -------------------------------------------------------------------------------- /imbens/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Includes all possible values of an imbens classifier's 3 | properties. 4 | """ 5 | 6 | # Authors: Zhining Liu 7 | # License: MIT 8 | 9 | 10 | ENSEMBLE_TYPES = ('boosting', 'bagging', 'random-forest', 'general') 11 | 12 | TRAINING_TYPES = ('iterative', 'parallel') 13 | 14 | SOLUTION_TYPES = ('resampling', 'reweighting') 15 | 16 | SAMPLING_TYPES = ('under-sampling', 'over-sampling') 17 | -------------------------------------------------------------------------------- /docs/source/api/ensemble/reweighting.rst: -------------------------------------------------------------------------------- 1 | .. _reweighting_api: 2 | 3 | Reweighting-based ensembles 4 | ================================ 5 | 6 | .. automodule:: imbens.ensemble 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.ensemble 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | AdaCostClassifier 17 | AdaUBoostClassifier 18 | AsymBoostClassifier -------------------------------------------------------------------------------- /docs/source/api/metrics/_autosummary/imbens.metrics.make_index_balanced_accuracy.rst: -------------------------------------------------------------------------------- 1 | make_index_balanced_accuracy 2 | =============================================================== 3 | 4 | .. currentmodule:: imbens.metrics 5 | 6 | .. autofunction:: make_index_balanced_accuracy 7 | 8 | .. include:: ../../../back_references/imbens.metrics.make_index_balanced_accuracy.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/datasets/datasets.rst: -------------------------------------------------------------------------------- 1 | .. _datasets_api: 2 | 3 | Datasets 4 | ================================ 5 | 6 | .. automodule:: imbens.datasets 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.datasets 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: function.rst 15 | 16 | make_imbalance 17 | generate_imbalance_data 18 | fetch_zenodo_datasets 19 | fetch_openml_datasets -------------------------------------------------------------------------------- /docs/source/api/metrics/_autosummary/imbens.metrics.sensitivity_specificity_support.rst: -------------------------------------------------------------------------------- 1 | sensitivity_specificity_support 2 | ================================================================== 3 | 4 | .. currentmodule:: imbens.metrics 5 | 6 | .. autofunction:: sensitivity_specificity_support 7 | 8 | .. include:: ../../../back_references/imbens.metrics.sensitivity_specificity_support.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /imbens/ensemble/_compatible/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.ensemble.compatible` submodule contains 3 | a set of `sklearn.ensemble` learning methods that were re-implemented in 4 | `imbens` style. 5 | """ 6 | 7 | from .adaboost_compatible import CompatibleAdaBoostClassifier 8 | from .bagging_compatible import CompatibleBaggingClassifier 9 | 10 | __all__ = [ 11 | "CompatibleAdaBoostClassifier", 12 | "CompatibleBaggingClassifier", 13 | ] 14 | -------------------------------------------------------------------------------- /imbens/ensemble/_reweighting/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.ensemble._reweighting` submodule contains 3 | a set of reweighting-based ensemble imbalanced learning methods. 4 | """ 5 | 6 | from .adacost import AdaCostClassifier 7 | from .adauboost import AdaUBoostClassifier 8 | from .asymmetric_boost import AsymBoostClassifier 9 | 10 | __all__ = [ 11 | "AdaCostClassifier", 12 | "AdaUBoostClassifier", 13 | "AsymBoostClassifier", 14 | ] 15 | -------------------------------------------------------------------------------- /docs/source/api/metrics/_autosummary/imbens.metrics.classification_report_imbalanced.rst: -------------------------------------------------------------------------------- 1 | classification_report_imbalanced 2 | =================================================================== 3 | 4 | .. currentmodule:: imbens.metrics 5 | 6 | .. autofunction:: classification_report_imbalanced 7 | 8 | .. include:: ../../../back_references/imbens.metrics.classification_report_imbalanced.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/utils/validation_sampler.rst: -------------------------------------------------------------------------------- 1 | .. _utils_sampler_api: 2 | 3 | Validation checks used in samplers 4 | ================================== 5 | 6 | .. automodule:: imbens.utils 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.utils 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: function.rst 15 | 16 | check_neighbors_object 17 | check_sampling_strategy 18 | check_target_type -------------------------------------------------------------------------------- /docs/source/api/metrics/_autosummary/imbens.metrics.macro_averaged_mean_absolute_error.rst: -------------------------------------------------------------------------------- 1 | macro_averaged_mean_absolute_error 2 | ===================================================================== 3 | 4 | .. currentmodule:: imbens.metrics 5 | 6 | .. autofunction:: macro_averaged_mean_absolute_error 7 | 8 | .. include:: ../../../back_references/imbens.metrics.macro_averaged_mean_absolute_error.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /docs/source/api/pipeline/pipeline.rst: -------------------------------------------------------------------------------- 1 | .. _pipeline_api: 2 | 3 | Pipeline 4 | ================================ 5 | 6 | .. automodule:: imbens.pipeline 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.pipeline 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | Pipeline 17 | 18 | .. autosummary:: 19 | :toctree: _autosummary 20 | :template: function.rst 21 | 22 | make_pipeline -------------------------------------------------------------------------------- /docs/source/api/sampler/over-samplers.rst: -------------------------------------------------------------------------------- 1 | .. _over_sampling_sampler_api: 2 | 3 | Over-sampling Samplers 4 | ================================ 5 | 6 | .. automodule:: imbens.sampler 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.sampler 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | ADASYN 17 | RandomOverSampler 18 | KMeansSMOTE 19 | SMOTE 20 | BorderlineSMOTE 21 | SVMSMOTE -------------------------------------------------------------------------------- /imbens/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.datasets` provides methods to generate 3 | imbalanced data. 4 | """ 5 | 6 | from ._imbalance import make_imbalance 7 | from ._imbalance import generate_imbalance_data 8 | 9 | from ._zenodo import fetch_zenodo_datasets 10 | from ._openml import fetch_openml_datasets 11 | 12 | __all__ = [ 13 | "make_imbalance", 14 | "generate_imbalance_data", 15 | "fetch_zenodo_datasets", 16 | "fetch_openml_datasets", 17 | ] 18 | -------------------------------------------------------------------------------- /docs/source/api/utils/_autosummary/imbens.utils.check_target_label_and_n_target_samples.rst: -------------------------------------------------------------------------------- 1 | check_target_label_and_n_target_samples 2 | ======================================================================== 3 | 4 | .. currentmodule:: imbens.utils 5 | 6 | .. autofunction:: check_target_label_and_n_target_samples 7 | 8 | .. include:: ../../../back_references/imbens.utils.check_target_label_and_n_target_samples.examples 9 | 10 | .. raw:: html 11 | 12 |
-------------------------------------------------------------------------------- /imbens/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.exceptions` module includes all custom warnings and error 3 | classes and functions used across imbalanced-learn. 4 | """ 5 | # Adapted from imbalanced-learn 6 | 7 | # Authors: Guillaume Lemaitre 8 | # License: MIT 9 | 10 | 11 | def raise_isinstance_error(variable_name, possible_type, variable): 12 | raise ValueError( 13 | f"{variable_name} has to be one of {possible_type}. " 14 | f"Got {type(variable)} instead." 15 | ) 16 | -------------------------------------------------------------------------------- /docs/source/api/utils/validation_ensemble.rst: -------------------------------------------------------------------------------- 1 | .. _utils_ensemble_api: 2 | 3 | Validation checks used in ensembles 4 | =================================== 5 | 6 | .. automodule:: imbens.utils 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.utils 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: function.rst 15 | 16 | check_eval_datasets 17 | check_eval_metrics 18 | check_target_label_and_n_target_samples 19 | check_balancing_schedule -------------------------------------------------------------------------------- /docs/source/api/ensemble/over-sampling.rst: -------------------------------------------------------------------------------- 1 | .. _over_sampling_api: 2 | 3 | Over-sampling-based ensembles 4 | ================================ 5 | 6 | .. automodule:: imbens.ensemble 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.ensemble 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | OverBoostClassifier 17 | SMOTEBoostClassifier 18 | KmeansSMOTEBoostClassifier 19 | OverBaggingClassifier 20 | SMOTEBaggingClassifier -------------------------------------------------------------------------------- /docs/source/_templates/ensemble_class.rst: -------------------------------------------------------------------------------- 1 | {{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | 10 | {% if methods %} 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | {% for item in methods %} 15 | {% if '__init__' not in item %} 16 | ~{{ name }}.{{ item }} 17 | {% endif %} 18 | {%- endfor %} 19 | {% endif %} 20 | {% endblock %} 21 | 22 | .. raw:: html 23 | 24 |
25 | -------------------------------------------------------------------------------- /docs/source/api/ensemble/under-sampling.rst: -------------------------------------------------------------------------------- 1 | .. _under_sampling_api: 2 | 3 | Under-sampling-based ensembles 4 | ================================ 5 | 6 | .. automodule:: imbens.ensemble 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.ensemble 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | SelfPacedEnsembleClassifier 17 | BalanceCascadeClassifier 18 | BalancedRandomForestClassifier 19 | EasyEnsembleClassifier 20 | RUSBoostClassifier 21 | UnderBaggingClassifier -------------------------------------------------------------------------------- /docs/source/api/metrics/classification.rst: -------------------------------------------------------------------------------- 1 | .. _metrics_api: 2 | 3 | Classification Metrics 4 | ================================ 5 | 6 | .. automodule:: imbens.metrics 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.metrics 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: function.rst 15 | 16 | sensitivity_specificity_support 17 | sensitivity_score 18 | specificity_score 19 | geometric_mean_score 20 | make_index_balanced_accuracy 21 | classification_report_imbalanced 22 | macro_averaged_mean_absolute_error -------------------------------------------------------------------------------- /docs/source/_templates/class.rst: -------------------------------------------------------------------------------- 1 | {{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | 10 | {% if methods %} 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | {% for item in methods %} 15 | {% if '__init__' not in item %} 16 | ~{{ name }}.{{ item }} 17 | {% endif %} 18 | {%- endfor %} 19 | {% endif %} 20 | {% endblock %} 21 | 22 | .. include:: ../../../back_references/{{module}}.{{objname}}.examples 23 | 24 | .. raw:: html 25 | 26 |
27 | -------------------------------------------------------------------------------- /imbens/ensemble/_over_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.ensemble._over_sampling` submodule contains 3 | a set of over-sampling-based ensemble imbalanced learning methods. 4 | """ 5 | 6 | from .over_boost import OverBoostClassifier 7 | from .smote_boost import SMOTEBoostClassifier 8 | from .kmeans_smote_boost import KmeansSMOTEBoostClassifier 9 | from .smote_bagging import SMOTEBaggingClassifier 10 | from .over_bagging import OverBaggingClassifier 11 | 12 | __all__ = [ 13 | "OverBoostClassifier", 14 | "SMOTEBoostClassifier", 15 | "KmeansSMOTEBoostClassifier", 16 | "OverBaggingClassifier", 17 | "SMOTEBaggingClassifier", 18 | ] 19 | -------------------------------------------------------------------------------- /imbens/utils/tests/test_deprecation.py: -------------------------------------------------------------------------------- 1 | """Test for the deprecation helper""" 2 | 3 | # Authors: Guillaume Lemaitre 4 | # License: MIT 5 | 6 | import pytest 7 | 8 | from imbens.utils.deprecation import deprecate_parameter 9 | 10 | 11 | class Sampler: 12 | def __init__(self): 13 | self.a = "something" 14 | self.b = "something" 15 | 16 | 17 | def test_deprecate_parameter(): 18 | with pytest.warns(DeprecationWarning, match="is deprecated from"): 19 | deprecate_parameter(Sampler(), "0.2", "a") 20 | with pytest.warns(DeprecationWarning, match="Use 'b' instead."): 21 | deprecate_parameter(Sampler(), "0.2", "a", "b") 22 | -------------------------------------------------------------------------------- /imbens/sampler/_over_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.sampler._over_sampling` submodule provides a 3 | set of methods to perform over-sampling. 4 | """ 5 | 6 | from ._adasyn import ADASYN 7 | from ._random_over_sampler import RandomOverSampler 8 | from ._smote import SMOTE 9 | from ._smote import BorderlineSMOTE 10 | from ._smote import KMeansSMOTE 11 | from ._smote import SVMSMOTE 12 | # from ._smote import SMOTENC 13 | # from ._smote import SMOTEN 14 | 15 | __all__ = [ 16 | "ADASYN", 17 | "RandomOverSampler", 18 | "KMeansSMOTE", 19 | "SMOTE", 20 | "BorderlineSMOTE", 21 | "SVMSMOTE", 22 | # "SMOTENC", 23 | # "SMOTEN", 24 | ] 25 | -------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.SMOTE.rst: -------------------------------------------------------------------------------- 1 | SMOTE 2 | ================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: SMOTE 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~SMOTE.fit 18 | 19 | 20 | ~SMOTE.fit_resample 21 | 22 | 23 | ~SMOTE.get_metadata_routing 24 | 25 | 26 | ~SMOTE.get_params 27 | 28 | 29 | ~SMOTE.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.SMOTE.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.ADASYN.rst: -------------------------------------------------------------------------------- 1 | ADASYN 2 | =================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: ADASYN 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~ADASYN.fit 18 | 19 | 20 | ~ADASYN.fit_resample 21 | 22 | 23 | ~ADASYN.get_metadata_routing 24 | 25 | 26 | ~ADASYN.get_params 27 | 28 | 29 | ~ADASYN.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.ADASYN.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.AllKNN.rst: -------------------------------------------------------------------------------- 1 | AllKNN 2 | =================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: AllKNN 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~AllKNN.fit 18 | 19 | 20 | ~AllKNN.fit_resample 21 | 22 | 23 | ~AllKNN.get_metadata_routing 24 | 25 | 26 | ~AllKNN.get_params 27 | 28 | 29 | ~AllKNN.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.AllKNN.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/under-samplers.rst: -------------------------------------------------------------------------------- 1 | .. _under_sampling_sampler_api: 2 | 3 | Under-sampling Samplers 4 | ================================ 5 | 6 | .. automodule:: imbens.sampler 7 | :no-members: 8 | :no-inherited-members: 9 | 10 | .. currentmodule:: imbens.sampler 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | :template: class.rst 15 | 16 | ClusterCentroids 17 | RandomUnderSampler 18 | InstanceHardnessThreshold 19 | NearMiss 20 | TomekLinks 21 | EditedNearestNeighbours 22 | RepeatedEditedNearestNeighbours 23 | AllKNN 24 | OneSidedSelection 25 | CondensedNearestNeighbour 26 | NeighbourhoodCleaningRule 27 | BalanceCascadeUnderSampler 28 | SelfPacedUnderSampler -------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.NearMiss.rst: -------------------------------------------------------------------------------- 1 | NearMiss 2 | ===================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: NearMiss 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~NearMiss.fit 18 | 19 | 20 | ~NearMiss.fit_resample 21 | 22 | 23 | ~NearMiss.get_metadata_routing 24 | 25 | 26 | ~NearMiss.get_params 27 | 28 | 29 | ~NearMiss.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.NearMiss.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.SVMSMOTE.rst: -------------------------------------------------------------------------------- 1 | SVMSMOTE 2 | ===================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: SVMSMOTE 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~SVMSMOTE.fit 18 | 19 | 20 | ~SVMSMOTE.fit_resample 21 | 22 | 23 | ~SVMSMOTE.get_metadata_routing 24 | 25 | 26 | ~SVMSMOTE.get_params 27 | 28 | 29 | ~SVMSMOTE.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.SVMSMOTE.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.KMeansSMOTE.rst: -------------------------------------------------------------------------------- 1 | KMeansSMOTE 2 | ======================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: KMeansSMOTE 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~KMeansSMOTE.fit 18 | 19 | 20 | ~KMeansSMOTE.fit_resample 21 | 22 | 23 | ~KMeansSMOTE.get_metadata_routing 24 | 25 | 26 | ~KMeansSMOTE.get_params 27 | 28 | 29 | ~KMeansSMOTE.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.KMeansSMOTE.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /imbens/ensemble/_under_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.ensemble._under_sampling` submodule contains 3 | a set of under-sampling-based ensemble imbalanced learning methods. 4 | """ 5 | 6 | from .self_paced_ensemble import SelfPacedEnsembleClassifier 7 | from .balance_cascade import BalanceCascadeClassifier 8 | from .balanced_random_forest import BalancedRandomForestClassifier 9 | from .easy_ensemble import EasyEnsembleClassifier 10 | from .rus_boost import RUSBoostClassifier 11 | from .under_bagging import UnderBaggingClassifier 12 | 13 | __all__ = [ 14 | "SelfPacedEnsembleClassifier", 15 | "BalanceCascadeClassifier", 16 | "BalancedRandomForestClassifier", 17 | "EasyEnsembleClassifier", 18 | "RUSBoostClassifier", 19 | "UnderBaggingClassifier", 20 | ] 21 | -------------------------------------------------------------------------------- /imbens/_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``imbalanced-ensemble`` is a set of python-based ensemble learning methods for 3 | dealing with class-imbalanced classification problems in machine learning. 4 | """ 5 | 6 | # Based on NiLearn, imblearn package 7 | # License: simplified BSD, MIT 8 | 9 | # PEP0440 compatible formatted version, see: 10 | # https://www.python.org/dev/peps/pep-0440/ 11 | # 12 | # Generic release markers: 13 | # X.Y 14 | # X.Y.Z # For bugfix releases 15 | # 16 | # Admissible pre-release markers: 17 | # X.YaN # Alpha release 18 | # X.YbN # Beta release 19 | # X.YrcN # Release Candidate 20 | # X.Y # Final release 21 | # 22 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. 23 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev' 24 | # 25 | 26 | __version__ = "0.2.3" 27 | -------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.BorderlineSMOTE.rst: -------------------------------------------------------------------------------- 1 | BorderlineSMOTE 2 | ============================================ 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: BorderlineSMOTE 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~BorderlineSMOTE.fit 18 | 19 | 20 | ~BorderlineSMOTE.fit_resample 21 | 22 | 23 | ~BorderlineSMOTE.get_metadata_routing 24 | 25 | 26 | ~BorderlineSMOTE.get_params 27 | 28 | 29 | ~BorderlineSMOTE.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.BorderlineSMOTE.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/visualizer/_autosummary/imbens.visualizer.ImbalancedEnsembleVisualizer.rst: -------------------------------------------------------------------------------- 1 | ImbalancedEnsembleVisualizer 2 | ============================================================ 3 | 4 | .. currentmodule:: imbens.visualizer 5 | 6 | .. autoclass:: ImbalancedEnsembleVisualizer 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~ImbalancedEnsembleVisualizer.confusion_matrix_heatmap 18 | 19 | 20 | ~ImbalancedEnsembleVisualizer.fit 21 | 22 | 23 | ~ImbalancedEnsembleVisualizer.performance_lineplot 24 | 25 | 26 | 27 | 28 | .. include:: ../../../back_references/imbens.visualizer.ImbalancedEnsembleVisualizer.examples 29 | 30 | .. raw:: html 31 | 32 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.ClusterCentroids.rst: -------------------------------------------------------------------------------- 1 | ClusterCentroids 2 | ============================================= 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: ClusterCentroids 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~ClusterCentroids.fit 18 | 19 | 20 | ~ClusterCentroids.fit_resample 21 | 22 | 23 | ~ClusterCentroids.get_metadata_routing 24 | 25 | 26 | ~ClusterCentroids.get_params 27 | 28 | 29 | ~ClusterCentroids.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.ClusterCentroids.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.TomekLinks.rst: -------------------------------------------------------------------------------- 1 | TomekLinks 2 | ======================================= 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: TomekLinks 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~TomekLinks.fit 18 | 19 | 20 | ~TomekLinks.fit_resample 21 | 22 | 23 | ~TomekLinks.get_metadata_routing 24 | 25 | 26 | ~TomekLinks.get_params 27 | 28 | 29 | ~TomekLinks.is_tomek 30 | 31 | 32 | ~TomekLinks.set_params 33 | 34 | 35 | 36 | 37 | .. include:: ../../../back_references/imbens.sampler.TomekLinks.examples 38 | 39 | .. raw:: html 40 | 41 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.OneSidedSelection.rst: -------------------------------------------------------------------------------- 1 | OneSidedSelection 2 | ============================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: OneSidedSelection 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~OneSidedSelection.fit 18 | 19 | 20 | ~OneSidedSelection.fit_resample 21 | 22 | 23 | ~OneSidedSelection.get_metadata_routing 24 | 25 | 26 | ~OneSidedSelection.get_params 27 | 28 | 29 | ~OneSidedSelection.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.OneSidedSelection.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.RandomOverSampler.rst: -------------------------------------------------------------------------------- 1 | RandomOverSampler 2 | ============================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: RandomOverSampler 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~RandomOverSampler.fit 18 | 19 | 20 | ~RandomOverSampler.fit_resample 21 | 22 | 23 | ~RandomOverSampler.get_metadata_routing 24 | 25 | 26 | ~RandomOverSampler.get_params 27 | 28 | 29 | ~RandomOverSampler.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.RandomOverSampler.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.RandomUnderSampler.rst: -------------------------------------------------------------------------------- 1 | RandomUnderSampler 2 | =============================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: RandomUnderSampler 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~RandomUnderSampler.fit 18 | 19 | 20 | ~RandomUnderSampler.fit_resample 21 | 22 | 23 | ~RandomUnderSampler.get_metadata_routing 24 | 25 | 26 | ~RandomUnderSampler.get_params 27 | 28 | 29 | ~RandomUnderSampler.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.RandomUnderSampler.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/metrics/_autosummary/imbens.metrics.ValueDifferenceMetric.rst: -------------------------------------------------------------------------------- 1 | ValueDifferenceMetric 2 | ================================================== 3 | 4 | .. currentmodule:: imbens.metrics 5 | 6 | .. autoclass:: ValueDifferenceMetric 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~ValueDifferenceMetric.fit 18 | 19 | 20 | ~ValueDifferenceMetric.get_metadata_routing 21 | 22 | 23 | ~ValueDifferenceMetric.get_params 24 | 25 | 26 | ~ValueDifferenceMetric.pairwise 27 | 28 | 29 | ~ValueDifferenceMetric.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.metrics.ValueDifferenceMetric.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.SelfPacedUnderSampler.rst: -------------------------------------------------------------------------------- 1 | SelfPacedUnderSampler 2 | ================================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: SelfPacedUnderSampler 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~SelfPacedUnderSampler.fit 18 | 19 | 20 | ~SelfPacedUnderSampler.fit_resample 21 | 22 | 23 | ~SelfPacedUnderSampler.get_metadata_routing 24 | 25 | 26 | ~SelfPacedUnderSampler.get_params 27 | 28 | 29 | ~SelfPacedUnderSampler.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.SelfPacedUnderSampler.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.EditedNearestNeighbours.rst: -------------------------------------------------------------------------------- 1 | EditedNearestNeighbours 2 | ==================================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: EditedNearestNeighbours 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~EditedNearestNeighbours.fit 18 | 19 | 20 | ~EditedNearestNeighbours.fit_resample 21 | 22 | 23 | ~EditedNearestNeighbours.get_metadata_routing 24 | 25 | 26 | ~EditedNearestNeighbours.get_params 27 | 28 | 29 | ~EditedNearestNeighbours.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.EditedNearestNeighbours.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.CondensedNearestNeighbour.rst: -------------------------------------------------------------------------------- 1 | CondensedNearestNeighbour 2 | ====================================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: CondensedNearestNeighbour 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~CondensedNearestNeighbour.fit 18 | 19 | 20 | ~CondensedNearestNeighbour.fit_resample 21 | 22 | 23 | ~CondensedNearestNeighbour.get_metadata_routing 24 | 25 | 26 | ~CondensedNearestNeighbour.get_params 27 | 28 | 29 | ~CondensedNearestNeighbour.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.CondensedNearestNeighbour.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.InstanceHardnessThreshold.rst: -------------------------------------------------------------------------------- 1 | InstanceHardnessThreshold 2 | ====================================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: InstanceHardnessThreshold 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~InstanceHardnessThreshold.fit 18 | 19 | 20 | ~InstanceHardnessThreshold.fit_resample 21 | 22 | 23 | ~InstanceHardnessThreshold.get_metadata_routing 24 | 25 | 26 | ~InstanceHardnessThreshold.get_params 27 | 28 | 29 | ~InstanceHardnessThreshold.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.InstanceHardnessThreshold.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.NeighbourhoodCleaningRule.rst: -------------------------------------------------------------------------------- 1 | NeighbourhoodCleaningRule 2 | ====================================================== 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: NeighbourhoodCleaningRule 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~NeighbourhoodCleaningRule.fit 18 | 19 | 20 | ~NeighbourhoodCleaningRule.fit_resample 21 | 22 | 23 | ~NeighbourhoodCleaningRule.get_metadata_routing 24 | 25 | 26 | ~NeighbourhoodCleaningRule.get_params 27 | 28 | 29 | ~NeighbourhoodCleaningRule.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.NeighbourhoodCleaningRule.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.BalanceCascadeUnderSampler.rst: -------------------------------------------------------------------------------- 1 | BalanceCascadeUnderSampler 2 | ======================================================= 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: BalanceCascadeUnderSampler 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~BalanceCascadeUnderSampler.fit 18 | 19 | 20 | ~BalanceCascadeUnderSampler.fit_resample 21 | 22 | 23 | ~BalanceCascadeUnderSampler.get_metadata_routing 24 | 25 | 26 | ~BalanceCascadeUnderSampler.get_params 27 | 28 | 29 | ~BalanceCascadeUnderSampler.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.BalanceCascadeUnderSampler.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /imbens/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.utils` module includes various utilities. 3 | """ 4 | 5 | from ._docstring import Substitution 6 | 7 | from ._evaluate import evaluate_print 8 | 9 | from ._validation import check_neighbors_object 10 | from ._validation import check_target_type 11 | from ._validation import check_sampling_strategy 12 | 13 | from ._validation_data import check_eval_datasets 14 | 15 | from ._validation_param import check_eval_metrics 16 | from ._validation_param import check_target_label_and_n_target_samples 17 | from ._validation_param import check_balancing_schedule 18 | 19 | __all__ = [ 20 | "evaluate_print", 21 | "check_neighbors_object", 22 | "check_sampling_strategy", 23 | "check_target_type", 24 | "check_eval_datasets", 25 | "check_eval_metrics", 26 | "check_target_label_and_n_target_samples", 27 | "check_balancing_schedule", 28 | "Substitution", 29 | ] 30 | -------------------------------------------------------------------------------- /imbens/utils/tests/test_testing.py: -------------------------------------------------------------------------------- 1 | """Test for the testing module""" 2 | # Authors: Guillaume Lemaitre 3 | # Christos Aridas 4 | # License: MIT 5 | 6 | import pytest 7 | 8 | from imbens.sampler.base import SamplerMixin 9 | from imbens.utils.testing import all_estimators 10 | 11 | 12 | def test_all_estimators(): 13 | # check if the filtering is working with a list or a single string 14 | type_filter = "sampler" 15 | all_estimators(type_filter=type_filter) 16 | type_filter = ["sampler"] 17 | estimators = all_estimators(type_filter=type_filter) 18 | for estimator in estimators: 19 | # check that all estimators are sampler 20 | assert issubclass(estimator[1], SamplerMixin) 21 | 22 | # check that an error is raised when the type is unknown 23 | type_filter = "rnd" 24 | with pytest.raises(ValueError, match="Parameter type_filter must be 'sampler'"): 25 | all_estimators(type_filter=type_filter) 26 | -------------------------------------------------------------------------------- /docs/source/api/sampler/_autosummary/imbens.sampler.RepeatedEditedNearestNeighbours.rst: -------------------------------------------------------------------------------- 1 | RepeatedEditedNearestNeighbours 2 | ============================================================ 3 | 4 | .. currentmodule:: imbens.sampler 5 | 6 | .. autoclass:: RepeatedEditedNearestNeighbours 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~RepeatedEditedNearestNeighbours.fit 18 | 19 | 20 | ~RepeatedEditedNearestNeighbours.fit_resample 21 | 22 | 23 | ~RepeatedEditedNearestNeighbours.get_metadata_routing 24 | 25 | 26 | ~RepeatedEditedNearestNeighbours.get_params 27 | 28 | 29 | ~RepeatedEditedNearestNeighbours.set_params 30 | 31 | 32 | 33 | 34 | .. include:: ../../../back_references/imbens.sampler.RepeatedEditedNearestNeighbours.examples 35 | 36 | .. raw:: html 37 | 38 |
-------------------------------------------------------------------------------- /imbens/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.metrics` module includes score functions, performance 3 | metrics and pairwise metrics and distance computations. 4 | """ 5 | 6 | from ._classification import sensitivity_specificity_support 7 | from ._classification import sensitivity_score 8 | from ._classification import specificity_score 9 | from ._classification import geometric_mean_score 10 | from ._classification import make_index_balanced_accuracy 11 | from ._classification import classification_report_imbalanced 12 | from ._classification import macro_averaged_mean_absolute_error 13 | 14 | from .pairwise import ValueDifferenceMetric 15 | 16 | __all__ = [ 17 | "sensitivity_specificity_support", 18 | "sensitivity_score", 19 | "specificity_score", 20 | "geometric_mean_score", 21 | "make_index_balanced_accuracy", 22 | "classification_report_imbalanced", 23 | "macro_averaged_mean_absolute_error", 24 | "ValueDifferenceMetric", 25 | ] 26 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | codecov: codecov/codecov@3.2.4 5 | 6 | jobs: 7 | test39: &test-template 8 | docker: 9 | - image: cimg/python:3.9 10 | # parallelism: 10 11 | steps: 12 | - checkout 13 | - run: 14 | name: Install dependencies 15 | command: | 16 | pip install -r requirements.txt 17 | pip install pytest pytest-cov 18 | pip install pytest-circleci-parallelized 19 | - run: 20 | name: Run tests with coverage 21 | command: | 22 | pytest --cov=imbens --cov-report=xml 23 | - codecov/upload: 24 | file: coverage.xml 25 | 26 | test310: 27 | <<: *test-template 28 | docker: 29 | - image: cimg/python:3.10 30 | 31 | test311: 32 | <<: *test-template 33 | docker: 34 | - image: cimg/python:3.11 35 | 36 | workflows: 37 | version: 2 38 | build-test-deploy: 39 | jobs: 40 | - test39 41 | - test310 42 | - test311 -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Install imbalanced-ensemble 2 | *************************** 3 | 4 | Prerequisites 5 | ============= 6 | 7 | The following packages are requirements: 8 | 9 | * ``numpy`` 10 | * ``scipy`` 11 | * ``pandas`` 12 | * ``joblib`` 13 | * ``sklearn`` 14 | * ``matplotlib`` 15 | * ``seaborn`` 16 | 17 | Installation 18 | ============ 19 | 20 | Install from PyPI 21 | ^^^^^^^^^^^^^^^^^ 22 | 23 | You can install imbalanced-ensemble from 24 | `PyPI `__ by running: 25 | 26 | .. code-block:: bash 27 | 28 | > pip install imbalanced-ensemble 29 | 30 | Please make sure the latest version is installed to avoid potential problems: 31 | 32 | .. code-block:: bash 33 | 34 | > pip install --upgrade imbalanced-ensemble 35 | 36 | Clone from GitHub 37 | ^^^^^^^^^^^^^^^^^ 38 | 39 | Or you can install imbalanced-ensemble locally: 40 | 41 | .. code-block:: bash 42 | 43 | > git clone https://github.com/ZhiningLiu1998/imbalanced-ensemble.git 44 | > cd imbalanced-ensemble 45 | > pip install . -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zhining Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/datasets/plot_generate_imbalance.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | Generate an imbalanced dataset 4 | =============================== 5 | 6 | An illustration of using the 7 | :func:`~imbens.datasets.generate_imbalance_data` 8 | function to create an imbalanced dataset. 9 | """ 10 | 11 | # Authors: Zhining Liu 12 | # License: MIT 13 | 14 | # %% 15 | print(__doc__) 16 | 17 | from imbens.datasets import generate_imbalance_data 18 | from imbens.utils._plot import plot_2Dprojection_and_cardinality 19 | from collections import Counter 20 | 21 | # %% [markdown] 22 | # Generate the dataset 23 | # -------------------- 24 | # 25 | 26 | # %% 27 | X_train, X_test, y_train, y_test = generate_imbalance_data( 28 | n_samples=1000, 29 | weights=[0.7, 0.2, 0.1], 30 | test_size=0.5, 31 | kwargs={'n_informative': 3}, 32 | ) 33 | 34 | print("Train class distribution: ", Counter(y_train)) 35 | print("Test class distribution: ", Counter(y_test)) 36 | 37 | # %% [markdown] 38 | # Plot the generated (training) data 39 | # ---------------------------------- 40 | # 41 | 42 | plot_2Dprojection_and_cardinality(X_train, y_train) 43 | -------------------------------------------------------------------------------- /imbens/utils/tests/test_plot.py: -------------------------------------------------------------------------------- 1 | """Test utilities for plot.""" 2 | 3 | # Authors: Zhining Liu 4 | # License: MIT 5 | 6 | 7 | import numpy as np 8 | 9 | from imbens.utils._plot import * 10 | 11 | X = np.array( 12 | [ 13 | [2.45166, 1.86760], 14 | [1.34450, -1.30331], 15 | [1.02989, 2.89408], 16 | [-1.94577, -1.75057], 17 | [1.21726, 1.90146], 18 | [2.00194, 1.25316], 19 | [2.31968, 2.33574], 20 | [1.14769, 1.41303], 21 | [1.32018, 2.17595], 22 | [-1.74686, -1.66665], 23 | [-2.17373, -1.91466], 24 | [2.41436, 1.83542], 25 | [1.97295, 2.55534], 26 | [-2.12126, -2.43786], 27 | [1.20494, 3.20696], 28 | [-2.30158, -2.39903], 29 | [1.76006, 1.94323], 30 | [2.35825, 1.77962], 31 | [-2.06578, -2.07671], 32 | [0.00245, -0.99528], 33 | ] 34 | ) 35 | y = np.array([2, 0, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 1, 2, 2, 1, 0]) 36 | 37 | 38 | def test_plot_scatter(): 39 | plot_scatter(X, y) 40 | 41 | 42 | def test_plot_class_distribution(): 43 | plot_class_distribution(y) 44 | 45 | 46 | def test_plot_2Dprojection_and_cardinality(): 47 | plot_2Dprojection_and_cardinality(X, y) 48 | -------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.BalanceCascadeClassifier.rst: -------------------------------------------------------------------------------- 1 | BalanceCascadeClassifier 2 | ====================================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: BalanceCascadeClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~BalanceCascadeClassifier.decision_function 18 | 19 | 20 | ~BalanceCascadeClassifier.fit 21 | 22 | 23 | ~BalanceCascadeClassifier.get_metadata_routing 24 | 25 | 26 | ~BalanceCascadeClassifier.get_params 27 | 28 | 29 | ~BalanceCascadeClassifier.predict 30 | 31 | 32 | ~BalanceCascadeClassifier.predict_proba 33 | 34 | 35 | ~BalanceCascadeClassifier.score 36 | 37 | 38 | ~BalanceCascadeClassifier.set_fit_request 39 | 40 | 41 | ~BalanceCascadeClassifier.set_params 42 | 43 | 44 | ~BalanceCascadeClassifier.set_score_request 45 | 46 | 47 | 48 | 49 | .. include:: ../../../back_references/imbens.ensemble.BalanceCascadeClassifier.examples 50 | 51 | .. raw:: html 52 | 53 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.OverBaggingClassifier.rst: -------------------------------------------------------------------------------- 1 | OverBaggingClassifier 2 | =================================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: OverBaggingClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~OverBaggingClassifier.decision_function 18 | 19 | 20 | ~OverBaggingClassifier.fit 21 | 22 | 23 | ~OverBaggingClassifier.get_metadata_routing 24 | 25 | 26 | ~OverBaggingClassifier.get_params 27 | 28 | 29 | ~OverBaggingClassifier.predict 30 | 31 | 32 | ~OverBaggingClassifier.predict_log_proba 33 | 34 | 35 | ~OverBaggingClassifier.predict_proba 36 | 37 | 38 | ~OverBaggingClassifier.score 39 | 40 | 41 | ~OverBaggingClassifier.set_fit_request 42 | 43 | 44 | ~OverBaggingClassifier.set_params 45 | 46 | 47 | ~OverBaggingClassifier.set_score_request 48 | 49 | 50 | 51 | 52 | .. include:: ../../../back_references/imbens.ensemble.OverBaggingClassifier.examples 53 | 54 | .. raw:: html 55 | 56 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.EasyEnsembleClassifier.rst: -------------------------------------------------------------------------------- 1 | EasyEnsembleClassifier 2 | ==================================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: EasyEnsembleClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~EasyEnsembleClassifier.decision_function 18 | 19 | 20 | ~EasyEnsembleClassifier.fit 21 | 22 | 23 | ~EasyEnsembleClassifier.get_metadata_routing 24 | 25 | 26 | ~EasyEnsembleClassifier.get_params 27 | 28 | 29 | ~EasyEnsembleClassifier.predict 30 | 31 | 32 | ~EasyEnsembleClassifier.predict_log_proba 33 | 34 | 35 | ~EasyEnsembleClassifier.predict_proba 36 | 37 | 38 | ~EasyEnsembleClassifier.score 39 | 40 | 41 | ~EasyEnsembleClassifier.set_fit_request 42 | 43 | 44 | ~EasyEnsembleClassifier.set_params 45 | 46 | 47 | ~EasyEnsembleClassifier.set_score_request 48 | 49 | 50 | 51 | 52 | .. include:: ../../../back_references/imbens.ensemble.EasyEnsembleClassifier.examples 53 | 54 | .. raw:: html 55 | 56 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.SMOTEBaggingClassifier.rst: -------------------------------------------------------------------------------- 1 | SMOTEBaggingClassifier 2 | ==================================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: SMOTEBaggingClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~SMOTEBaggingClassifier.decision_function 18 | 19 | 20 | ~SMOTEBaggingClassifier.fit 21 | 22 | 23 | ~SMOTEBaggingClassifier.get_metadata_routing 24 | 25 | 26 | ~SMOTEBaggingClassifier.get_params 27 | 28 | 29 | ~SMOTEBaggingClassifier.predict 30 | 31 | 32 | ~SMOTEBaggingClassifier.predict_log_proba 33 | 34 | 35 | ~SMOTEBaggingClassifier.predict_proba 36 | 37 | 38 | ~SMOTEBaggingClassifier.score 39 | 40 | 41 | ~SMOTEBaggingClassifier.set_fit_request 42 | 43 | 44 | ~SMOTEBaggingClassifier.set_params 45 | 46 | 47 | ~SMOTEBaggingClassifier.set_score_request 48 | 49 | 50 | 51 | 52 | .. include:: ../../../back_references/imbens.ensemble.SMOTEBaggingClassifier.examples 53 | 54 | .. raw:: html 55 | 56 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.SelfPacedEnsembleClassifier.rst: -------------------------------------------------------------------------------- 1 | SelfPacedEnsembleClassifier 2 | ========================================================= 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: SelfPacedEnsembleClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~SelfPacedEnsembleClassifier.decision_function 18 | 19 | 20 | ~SelfPacedEnsembleClassifier.fit 21 | 22 | 23 | ~SelfPacedEnsembleClassifier.get_metadata_routing 24 | 25 | 26 | ~SelfPacedEnsembleClassifier.get_params 27 | 28 | 29 | ~SelfPacedEnsembleClassifier.predict 30 | 31 | 32 | ~SelfPacedEnsembleClassifier.predict_proba 33 | 34 | 35 | ~SelfPacedEnsembleClassifier.score 36 | 37 | 38 | ~SelfPacedEnsembleClassifier.set_fit_request 39 | 40 | 41 | ~SelfPacedEnsembleClassifier.set_params 42 | 43 | 44 | ~SelfPacedEnsembleClassifier.set_score_request 45 | 46 | 47 | 48 | 49 | .. include:: ../../../back_references/imbens.ensemble.SelfPacedEnsembleClassifier.examples 50 | 51 | .. raw:: html 52 | 53 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.UnderBaggingClassifier.rst: -------------------------------------------------------------------------------- 1 | UnderBaggingClassifier 2 | ==================================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: UnderBaggingClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~UnderBaggingClassifier.decision_function 18 | 19 | 20 | ~UnderBaggingClassifier.fit 21 | 22 | 23 | ~UnderBaggingClassifier.get_metadata_routing 24 | 25 | 26 | ~UnderBaggingClassifier.get_params 27 | 28 | 29 | ~UnderBaggingClassifier.predict 30 | 31 | 32 | ~UnderBaggingClassifier.predict_log_proba 33 | 34 | 35 | ~UnderBaggingClassifier.predict_proba 36 | 37 | 38 | ~UnderBaggingClassifier.score 39 | 40 | 41 | ~UnderBaggingClassifier.set_fit_request 42 | 43 | 44 | ~UnderBaggingClassifier.set_params 45 | 46 | 47 | ~UnderBaggingClassifier.set_score_request 48 | 49 | 50 | 51 | 52 | .. include:: ../../../back_references/imbens.ensemble.UnderBaggingClassifier.examples 53 | 54 | .. raw:: html 55 | 56 |
-------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.sampler._under_sampling` submodule contains 3 | methods to under-sample a dataset. 4 | """ 5 | 6 | from ._prototype_generation import ClusterCentroids 7 | 8 | from ._prototype_selection import RandomUnderSampler 9 | from ._prototype_selection import TomekLinks 10 | from ._prototype_selection import NearMiss 11 | from ._prototype_selection import CondensedNearestNeighbour 12 | from ._prototype_selection import OneSidedSelection 13 | from ._prototype_selection import NeighbourhoodCleaningRule 14 | from ._prototype_selection import EditedNearestNeighbours 15 | from ._prototype_selection import RepeatedEditedNearestNeighbours 16 | from ._prototype_selection import AllKNN 17 | from ._prototype_selection import InstanceHardnessThreshold 18 | from ._prototype_selection import BalanceCascadeUnderSampler 19 | from ._prototype_selection import SelfPacedUnderSampler 20 | 21 | __all__ = [ 22 | "ClusterCentroids", 23 | "RandomUnderSampler", 24 | "InstanceHardnessThreshold", 25 | "NearMiss", 26 | "TomekLinks", 27 | "EditedNearestNeighbours", 28 | "RepeatedEditedNearestNeighbours", 29 | "AllKNN", 30 | "OneSidedSelection", 31 | "CondensedNearestNeighbour", 32 | "NeighbourhoodCleaningRule", 33 | "BalanceCascadeUnderSampler", 34 | "SelfPacedUnderSampler", 35 | ] 36 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_selection/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.sampler._under_sampling.prototype_selection` 3 | submodule contains methods that select samples in order to balance the dataset. 4 | """ 5 | 6 | from ._random_under_sampler import RandomUnderSampler 7 | from ._tomek_links import TomekLinks 8 | from ._nearmiss import NearMiss 9 | from ._condensed_nearest_neighbour import CondensedNearestNeighbour 10 | from ._one_sided_selection import OneSidedSelection 11 | from ._neighbourhood_cleaning_rule import NeighbourhoodCleaningRule 12 | from ._edited_nearest_neighbours import EditedNearestNeighbours 13 | from ._edited_nearest_neighbours import RepeatedEditedNearestNeighbours 14 | from ._edited_nearest_neighbours import AllKNN 15 | from ._instance_hardness_threshold import InstanceHardnessThreshold 16 | from ._balance_cascade_under_sampler import BalanceCascadeUnderSampler 17 | from ._self_paced_under_sampler import SelfPacedUnderSampler 18 | 19 | __all__ = [ 20 | "RandomUnderSampler", 21 | "InstanceHardnessThreshold", 22 | "NearMiss", 23 | "TomekLinks", 24 | "EditedNearestNeighbours", 25 | "RepeatedEditedNearestNeighbours", 26 | "AllKNN", 27 | "OneSidedSelection", 28 | "CondensedNearestNeighbour", 29 | "NeighbourhoodCleaningRule", 30 | "BalanceCascadeUnderSampler", 31 | "SelfPacedUnderSampler", 32 | ] -------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.CompatibleBaggingClassifier.rst: -------------------------------------------------------------------------------- 1 | CompatibleBaggingClassifier 2 | ========================================================= 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: CompatibleBaggingClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~CompatibleBaggingClassifier.decision_function 18 | 19 | 20 | ~CompatibleBaggingClassifier.fit 21 | 22 | 23 | ~CompatibleBaggingClassifier.get_metadata_routing 24 | 25 | 26 | ~CompatibleBaggingClassifier.get_params 27 | 28 | 29 | ~CompatibleBaggingClassifier.predict 30 | 31 | 32 | ~CompatibleBaggingClassifier.predict_log_proba 33 | 34 | 35 | ~CompatibleBaggingClassifier.predict_proba 36 | 37 | 38 | ~CompatibleBaggingClassifier.score 39 | 40 | 41 | ~CompatibleBaggingClassifier.set_fit_request 42 | 43 | 44 | ~CompatibleBaggingClassifier.set_params 45 | 46 | 47 | ~CompatibleBaggingClassifier.set_score_request 48 | 49 | 50 | 51 | 52 | .. include:: ../../../back_references/imbens.ensemble.CompatibleBaggingClassifier.examples 53 | 54 | .. raw:: html 55 | 56 |
-------------------------------------------------------------------------------- /imbens/utils/tests/test_docstring.py: -------------------------------------------------------------------------------- 1 | """Test utilities for docstring.""" 2 | 3 | # Authors: Guillaume Lemaitre 4 | # License: MIT 5 | 6 | import pytest 7 | 8 | from imbens.utils import Substitution 9 | from imbens.utils._docstring import _n_jobs_docstring, _random_state_docstring 10 | 11 | func_docstring = """A function. 12 | 13 | Parameters 14 | ---------- 15 | xxx 16 | 17 | yyy 18 | """ 19 | 20 | 21 | def func(param_1, param_2): 22 | """A function. 23 | 24 | Parameters 25 | ---------- 26 | {param_1} 27 | 28 | {param_2} 29 | """ 30 | return param_1, param_2 31 | 32 | 33 | cls_docstring = """A class. 34 | 35 | Parameters 36 | ---------- 37 | xxx 38 | 39 | yyy 40 | """ 41 | 42 | 43 | class cls: 44 | """A class. 45 | 46 | Parameters 47 | ---------- 48 | {param_1} 49 | 50 | {param_2} 51 | """ 52 | 53 | def __init__(self, param_1, param_2): 54 | self.param_1 = param_1 55 | self.param_2 = param_2 56 | 57 | 58 | @pytest.mark.parametrize( 59 | "obj, obj_docstring", [(func, func_docstring), (cls, cls_docstring)] 60 | ) 61 | def test_docstring_inject(obj, obj_docstring): 62 | obj_injected_docstring = Substitution(param_1="xxx", param_2="yyy")(obj) 63 | assert obj_injected_docstring.__doc__ == obj_docstring 64 | 65 | 66 | def test_docstring_template(): 67 | assert "random_state" in _random_state_docstring 68 | assert "n_jobs" in _n_jobs_docstring 69 | -------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.BalancedRandomForestClassifier.rst: -------------------------------------------------------------------------------- 1 | BalancedRandomForestClassifier 2 | ============================================================ 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: BalancedRandomForestClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~BalancedRandomForestClassifier.apply 18 | 19 | 20 | ~BalancedRandomForestClassifier.decision_path 21 | 22 | 23 | ~BalancedRandomForestClassifier.fit 24 | 25 | 26 | ~BalancedRandomForestClassifier.get_metadata_routing 27 | 28 | 29 | ~BalancedRandomForestClassifier.get_params 30 | 31 | 32 | ~BalancedRandomForestClassifier.predict 33 | 34 | 35 | ~BalancedRandomForestClassifier.predict_log_proba 36 | 37 | 38 | ~BalancedRandomForestClassifier.predict_proba 39 | 40 | 41 | ~BalancedRandomForestClassifier.score 42 | 43 | 44 | ~BalancedRandomForestClassifier.set_fit_request 45 | 46 | 47 | ~BalancedRandomForestClassifier.set_params 48 | 49 | 50 | ~BalancedRandomForestClassifier.set_score_request 51 | 52 | 53 | 54 | 55 | .. include:: ../../../back_references/imbens.ensemble.BalancedRandomForestClassifier.examples 56 | 57 | .. raw:: html 58 | 59 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.AdaCostClassifier.rst: -------------------------------------------------------------------------------- 1 | AdaCostClassifier 2 | =============================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: AdaCostClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~AdaCostClassifier.decision_function 18 | 19 | 20 | ~AdaCostClassifier.fit 21 | 22 | 23 | ~AdaCostClassifier.get_metadata_routing 24 | 25 | 26 | ~AdaCostClassifier.get_params 27 | 28 | 29 | ~AdaCostClassifier.predict 30 | 31 | 32 | ~AdaCostClassifier.predict_log_proba 33 | 34 | 35 | ~AdaCostClassifier.predict_proba 36 | 37 | 38 | ~AdaCostClassifier.score 39 | 40 | 41 | ~AdaCostClassifier.set_fit_request 42 | 43 | 44 | ~AdaCostClassifier.set_params 45 | 46 | 47 | ~AdaCostClassifier.set_score_request 48 | 49 | 50 | ~AdaCostClassifier.staged_decision_function 51 | 52 | 53 | ~AdaCostClassifier.staged_predict 54 | 55 | 56 | ~AdaCostClassifier.staged_predict_proba 57 | 58 | 59 | ~AdaCostClassifier.staged_score 60 | 61 | 62 | 63 | 64 | .. include:: ../../../back_references/imbens.ensemble.AdaCostClassifier.examples 65 | 66 | .. raw:: html 67 | 68 |
-------------------------------------------------------------------------------- /imbens/__init__.py: -------------------------------------------------------------------------------- 1 | """Toolbox for ensemble learning on class-imbalanced dataset. 2 | 3 | ``imbalanced-ensemble`` is a set of python-based ensemble learning methods for 4 | dealing with class-imbalanced classification problems in machine learning. 5 | 6 | Subpackages 7 | ----------- 8 | ensemble 9 | Module which provides ensemble imbalanced learning methods. 10 | sampler 11 | Module which provides samplers for resampling class-imbalanced data. 12 | visualizer 13 | Module which provides a visualizer for convenient visualization of 14 | ensemble learning process and results. 15 | metrics 16 | Module which provides metrics to quantified the classification performance 17 | with imbalanced dataset. 18 | utils 19 | Module including various utilities. 20 | exceptions 21 | Module including custom warnings and error classes used across 22 | imbalanced-learn. 23 | pipeline 24 | Module which allowing to create pipeline with scikit-learn estimators. 25 | """ 26 | 27 | from . import ensemble 28 | from . import sampler 29 | from . import visualizer 30 | from . import metrics 31 | from . import utils 32 | from . import exceptions 33 | from . import pipeline 34 | from . import datasets 35 | 36 | from .sampler.base import FunctionSampler 37 | 38 | from ._version import __version__ 39 | 40 | __all__ = [ 41 | "ensemble", 42 | "sampler", 43 | "visualizer", 44 | "metrics", 45 | "utils", 46 | "exceptions", 47 | "pipeline", 48 | "datasets", 49 | "FunctionSampler", 50 | "__version__", 51 | ] 52 | -------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.RUSBoostClassifier.rst: -------------------------------------------------------------------------------- 1 | RUSBoostClassifier 2 | ================================================ 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: RUSBoostClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~RUSBoostClassifier.decision_function 18 | 19 | 20 | ~RUSBoostClassifier.fit 21 | 22 | 23 | ~RUSBoostClassifier.get_metadata_routing 24 | 25 | 26 | ~RUSBoostClassifier.get_params 27 | 28 | 29 | ~RUSBoostClassifier.predict 30 | 31 | 32 | ~RUSBoostClassifier.predict_log_proba 33 | 34 | 35 | ~RUSBoostClassifier.predict_proba 36 | 37 | 38 | ~RUSBoostClassifier.score 39 | 40 | 41 | ~RUSBoostClassifier.set_fit_request 42 | 43 | 44 | ~RUSBoostClassifier.set_params 45 | 46 | 47 | ~RUSBoostClassifier.set_score_request 48 | 49 | 50 | ~RUSBoostClassifier.staged_decision_function 51 | 52 | 53 | ~RUSBoostClassifier.staged_predict 54 | 55 | 56 | ~RUSBoostClassifier.staged_predict_proba 57 | 58 | 59 | ~RUSBoostClassifier.staged_score 60 | 61 | 62 | 63 | 64 | .. include:: ../../../back_references/imbens.ensemble.RUSBoostClassifier.examples 65 | 66 | .. raw:: html 67 | 68 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.AdaUBoostClassifier.rst: -------------------------------------------------------------------------------- 1 | AdaUBoostClassifier 2 | ================================================= 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: AdaUBoostClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~AdaUBoostClassifier.decision_function 18 | 19 | 20 | ~AdaUBoostClassifier.fit 21 | 22 | 23 | ~AdaUBoostClassifier.get_metadata_routing 24 | 25 | 26 | ~AdaUBoostClassifier.get_params 27 | 28 | 29 | ~AdaUBoostClassifier.predict 30 | 31 | 32 | ~AdaUBoostClassifier.predict_log_proba 33 | 34 | 35 | ~AdaUBoostClassifier.predict_proba 36 | 37 | 38 | ~AdaUBoostClassifier.score 39 | 40 | 41 | ~AdaUBoostClassifier.set_fit_request 42 | 43 | 44 | ~AdaUBoostClassifier.set_params 45 | 46 | 47 | ~AdaUBoostClassifier.set_score_request 48 | 49 | 50 | ~AdaUBoostClassifier.staged_decision_function 51 | 52 | 53 | ~AdaUBoostClassifier.staged_predict 54 | 55 | 56 | ~AdaUBoostClassifier.staged_predict_proba 57 | 58 | 59 | ~AdaUBoostClassifier.staged_score 60 | 61 | 62 | 63 | 64 | .. include:: ../../../back_references/imbens.ensemble.AdaUBoostClassifier.examples 65 | 66 | .. raw:: html 67 | 68 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.AsymBoostClassifier.rst: -------------------------------------------------------------------------------- 1 | AsymBoostClassifier 2 | ================================================= 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: AsymBoostClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~AsymBoostClassifier.decision_function 18 | 19 | 20 | ~AsymBoostClassifier.fit 21 | 22 | 23 | ~AsymBoostClassifier.get_metadata_routing 24 | 25 | 26 | ~AsymBoostClassifier.get_params 27 | 28 | 29 | ~AsymBoostClassifier.predict 30 | 31 | 32 | ~AsymBoostClassifier.predict_log_proba 33 | 34 | 35 | ~AsymBoostClassifier.predict_proba 36 | 37 | 38 | ~AsymBoostClassifier.score 39 | 40 | 41 | ~AsymBoostClassifier.set_fit_request 42 | 43 | 44 | ~AsymBoostClassifier.set_params 45 | 46 | 47 | ~AsymBoostClassifier.set_score_request 48 | 49 | 50 | ~AsymBoostClassifier.staged_decision_function 51 | 52 | 53 | ~AsymBoostClassifier.staged_predict 54 | 55 | 56 | ~AsymBoostClassifier.staged_predict_proba 57 | 58 | 59 | ~AsymBoostClassifier.staged_score 60 | 61 | 62 | 63 | 64 | .. include:: ../../../back_references/imbens.ensemble.AsymBoostClassifier.examples 65 | 66 | .. raw:: html 67 | 68 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.OverBoostClassifier.rst: -------------------------------------------------------------------------------- 1 | OverBoostClassifier 2 | ================================================= 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: OverBoostClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~OverBoostClassifier.decision_function 18 | 19 | 20 | ~OverBoostClassifier.fit 21 | 22 | 23 | ~OverBoostClassifier.get_metadata_routing 24 | 25 | 26 | ~OverBoostClassifier.get_params 27 | 28 | 29 | ~OverBoostClassifier.predict 30 | 31 | 32 | ~OverBoostClassifier.predict_log_proba 33 | 34 | 35 | ~OverBoostClassifier.predict_proba 36 | 37 | 38 | ~OverBoostClassifier.score 39 | 40 | 41 | ~OverBoostClassifier.set_fit_request 42 | 43 | 44 | ~OverBoostClassifier.set_params 45 | 46 | 47 | ~OverBoostClassifier.set_score_request 48 | 49 | 50 | ~OverBoostClassifier.staged_decision_function 51 | 52 | 53 | ~OverBoostClassifier.staged_predict 54 | 55 | 56 | ~OverBoostClassifier.staged_predict_proba 57 | 58 | 59 | ~OverBoostClassifier.staged_score 60 | 61 | 62 | 63 | 64 | .. include:: ../../../back_references/imbens.ensemble.OverBoostClassifier.examples 65 | 66 | .. raw:: html 67 | 68 |
-------------------------------------------------------------------------------- /docs/source/api/pipeline/_autosummary/imbens.pipeline.Pipeline.rst: -------------------------------------------------------------------------------- 1 | Pipeline 2 | ====================================== 3 | 4 | .. currentmodule:: imbens.pipeline 5 | 6 | .. autoclass:: Pipeline 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~Pipeline.decision_function 18 | 19 | 20 | ~Pipeline.fit 21 | 22 | 23 | ~Pipeline.fit_predict 24 | 25 | 26 | ~Pipeline.fit_resample 27 | 28 | 29 | ~Pipeline.fit_transform 30 | 31 | 32 | ~Pipeline.get_feature_names_out 33 | 34 | 35 | ~Pipeline.get_metadata_routing 36 | 37 | 38 | ~Pipeline.get_params 39 | 40 | 41 | ~Pipeline.inverse_transform 42 | 43 | 44 | ~Pipeline.predict 45 | 46 | 47 | ~Pipeline.predict_log_proba 48 | 49 | 50 | ~Pipeline.predict_proba 51 | 52 | 53 | ~Pipeline.score 54 | 55 | 56 | ~Pipeline.score_samples 57 | 58 | 59 | ~Pipeline.set_fit_request 60 | 61 | 62 | ~Pipeline.set_output 63 | 64 | 65 | ~Pipeline.set_params 66 | 67 | 68 | ~Pipeline.set_score_request 69 | 70 | 71 | ~Pipeline.transform 72 | 73 | 74 | 75 | 76 | .. include:: ../../../back_references/imbens.pipeline.Pipeline.examples 77 | 78 | .. raw:: html 79 | 80 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.SMOTEBoostClassifier.rst: -------------------------------------------------------------------------------- 1 | SMOTEBoostClassifier 2 | ================================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: SMOTEBoostClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~SMOTEBoostClassifier.decision_function 18 | 19 | 20 | ~SMOTEBoostClassifier.fit 21 | 22 | 23 | ~SMOTEBoostClassifier.get_metadata_routing 24 | 25 | 26 | ~SMOTEBoostClassifier.get_params 27 | 28 | 29 | ~SMOTEBoostClassifier.predict 30 | 31 | 32 | ~SMOTEBoostClassifier.predict_log_proba 33 | 34 | 35 | ~SMOTEBoostClassifier.predict_proba 36 | 37 | 38 | ~SMOTEBoostClassifier.score 39 | 40 | 41 | ~SMOTEBoostClassifier.set_fit_request 42 | 43 | 44 | ~SMOTEBoostClassifier.set_params 45 | 46 | 47 | ~SMOTEBoostClassifier.set_score_request 48 | 49 | 50 | ~SMOTEBoostClassifier.staged_decision_function 51 | 52 | 53 | ~SMOTEBoostClassifier.staged_predict 54 | 55 | 56 | ~SMOTEBoostClassifier.staged_predict_proba 57 | 58 | 59 | ~SMOTEBoostClassifier.staged_score 60 | 61 | 62 | 63 | 64 | .. include:: ../../../back_references/imbens.ensemble.SMOTEBoostClassifier.examples 65 | 66 | .. raw:: html 67 | 68 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.KmeansSMOTEBoostClassifier.rst: -------------------------------------------------------------------------------- 1 | KmeansSMOTEBoostClassifier 2 | ======================================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: KmeansSMOTEBoostClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~KmeansSMOTEBoostClassifier.decision_function 18 | 19 | 20 | ~KmeansSMOTEBoostClassifier.fit 21 | 22 | 23 | ~KmeansSMOTEBoostClassifier.get_metadata_routing 24 | 25 | 26 | ~KmeansSMOTEBoostClassifier.get_params 27 | 28 | 29 | ~KmeansSMOTEBoostClassifier.predict 30 | 31 | 32 | ~KmeansSMOTEBoostClassifier.predict_log_proba 33 | 34 | 35 | ~KmeansSMOTEBoostClassifier.predict_proba 36 | 37 | 38 | ~KmeansSMOTEBoostClassifier.score 39 | 40 | 41 | ~KmeansSMOTEBoostClassifier.set_fit_request 42 | 43 | 44 | ~KmeansSMOTEBoostClassifier.set_params 45 | 46 | 47 | ~KmeansSMOTEBoostClassifier.set_score_request 48 | 49 | 50 | ~KmeansSMOTEBoostClassifier.staged_decision_function 51 | 52 | 53 | ~KmeansSMOTEBoostClassifier.staged_predict 54 | 55 | 56 | ~KmeansSMOTEBoostClassifier.staged_predict_proba 57 | 58 | 59 | ~KmeansSMOTEBoostClassifier.staged_score 60 | 61 | 62 | 63 | 64 | .. include:: ../../../back_references/imbens.ensemble.KmeansSMOTEBoostClassifier.examples 65 | 66 | .. raw:: html 67 | 68 |
-------------------------------------------------------------------------------- /docs/source/api/ensemble/_autosummary/imbens.ensemble.CompatibleAdaBoostClassifier.rst: -------------------------------------------------------------------------------- 1 | CompatibleAdaBoostClassifier 2 | ========================================================== 3 | 4 | .. currentmodule:: imbens.ensemble 5 | 6 | .. autoclass:: CompatibleAdaBoostClassifier 7 | 8 | 9 | 10 | 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | 15 | 16 | 17 | ~CompatibleAdaBoostClassifier.decision_function 18 | 19 | 20 | ~CompatibleAdaBoostClassifier.fit 21 | 22 | 23 | ~CompatibleAdaBoostClassifier.get_metadata_routing 24 | 25 | 26 | ~CompatibleAdaBoostClassifier.get_params 27 | 28 | 29 | ~CompatibleAdaBoostClassifier.predict 30 | 31 | 32 | ~CompatibleAdaBoostClassifier.predict_log_proba 33 | 34 | 35 | ~CompatibleAdaBoostClassifier.predict_proba 36 | 37 | 38 | ~CompatibleAdaBoostClassifier.score 39 | 40 | 41 | ~CompatibleAdaBoostClassifier.set_fit_request 42 | 43 | 44 | ~CompatibleAdaBoostClassifier.set_params 45 | 46 | 47 | ~CompatibleAdaBoostClassifier.set_score_request 48 | 49 | 50 | ~CompatibleAdaBoostClassifier.staged_decision_function 51 | 52 | 53 | ~CompatibleAdaBoostClassifier.staged_predict 54 | 55 | 56 | ~CompatibleAdaBoostClassifier.staged_predict_proba 57 | 58 | 59 | ~CompatibleAdaBoostClassifier.staged_score 60 | 61 | 62 | 63 | 64 | .. include:: ../../../back_references/imbens.ensemble.CompatibleAdaBoostClassifier.examples 65 | 66 | .. raw:: html 67 | 68 |
-------------------------------------------------------------------------------- /imbens/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.sampler` submodule provides a 3 | set of methods to perform resampling. 4 | """ 5 | 6 | from . import _under_sampling 7 | from . import _over_sampling 8 | 9 | from ._under_sampling import ClusterCentroids 10 | from ._under_sampling import RandomUnderSampler 11 | from ._under_sampling import TomekLinks 12 | from ._under_sampling import NearMiss 13 | from ._under_sampling import CondensedNearestNeighbour 14 | from ._under_sampling import OneSidedSelection 15 | from ._under_sampling import NeighbourhoodCleaningRule 16 | from ._under_sampling import EditedNearestNeighbours 17 | from ._under_sampling import RepeatedEditedNearestNeighbours 18 | from ._under_sampling import AllKNN 19 | from ._under_sampling import InstanceHardnessThreshold 20 | from ._under_sampling import BalanceCascadeUnderSampler 21 | from ._under_sampling import SelfPacedUnderSampler 22 | 23 | from ._over_sampling import ADASYN 24 | from ._over_sampling import RandomOverSampler 25 | from ._over_sampling import SMOTE 26 | from ._over_sampling import BorderlineSMOTE 27 | from ._over_sampling import KMeansSMOTE 28 | from ._over_sampling import SVMSMOTE 29 | 30 | 31 | __all__ = [ 32 | "_under_sampling", 33 | "_over_sampling", 34 | 35 | "ClusterCentroids", 36 | "RandomUnderSampler", 37 | "InstanceHardnessThreshold", 38 | "NearMiss", 39 | "TomekLinks", 40 | "EditedNearestNeighbours", 41 | "RepeatedEditedNearestNeighbours", 42 | "AllKNN", 43 | "OneSidedSelection", 44 | "CondensedNearestNeighbour", 45 | "NeighbourhoodCleaningRule", 46 | "BalanceCascadeUnderSampler", 47 | "SelfPacedUnderSampler", 48 | 49 | "ADASYN", 50 | "RandomOverSampler", 51 | "KMeansSMOTE", 52 | "SMOTE", 53 | "BorderlineSMOTE", 54 | "SVMSMOTE", 55 | ] 56 | -------------------------------------------------------------------------------- /imbens/ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`imbens.ensemble` module contains a set of 3 | ensemble imbalanced learning methods. 4 | """ 5 | 6 | from . import _under_sampling 7 | from . import _over_sampling 8 | from . import _reweighting 9 | from . import _compatible 10 | 11 | from ._under_sampling import SelfPacedEnsembleClassifier 12 | from ._under_sampling import BalanceCascadeClassifier 13 | from ._under_sampling import BalancedRandomForestClassifier 14 | from ._under_sampling import EasyEnsembleClassifier 15 | from ._under_sampling import RUSBoostClassifier 16 | from ._under_sampling import UnderBaggingClassifier 17 | 18 | from ._over_sampling import OverBoostClassifier 19 | from ._over_sampling import SMOTEBoostClassifier 20 | from ._over_sampling import KmeansSMOTEBoostClassifier 21 | from ._over_sampling import SMOTEBaggingClassifier 22 | from ._over_sampling import OverBaggingClassifier 23 | 24 | from ._reweighting import AdaCostClassifier 25 | from ._reweighting import AdaUBoostClassifier 26 | from ._reweighting import AsymBoostClassifier 27 | 28 | from ._compatible import CompatibleAdaBoostClassifier 29 | from ._compatible import CompatibleBaggingClassifier 30 | 31 | __all__ = [ 32 | "_under_sampling", 33 | "_over_sampling", 34 | "_reweighting", 35 | "_compatible", 36 | 37 | "SelfPacedEnsembleClassifier", 38 | "BalanceCascadeClassifier", 39 | "BalancedRandomForestClassifier", 40 | "EasyEnsembleClassifier", 41 | "RUSBoostClassifier", 42 | "UnderBaggingClassifier", 43 | 44 | "OverBoostClassifier", 45 | "SMOTEBoostClassifier", 46 | "KmeansSMOTEBoostClassifier", 47 | "OverBaggingClassifier", 48 | "SMOTEBaggingClassifier", 49 | 50 | "AdaCostClassifier", 51 | "AdaUBoostClassifier", 52 | "AsymBoostClassifier", 53 | 54 | "CompatibleAdaBoostClassifier", 55 | "CompatibleBaggingClassifier", 56 | ] -------------------------------------------------------------------------------- /examples/evaluation/plot_classification_report.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================= 3 | Evaluate classification by compiling a report 4 | ============================================= 5 | 6 | Specific metrics have been developed to evaluate classifier which has been 7 | trained using imbalanced data. "mod:`imbens` provides a classification report 8 | (:func:`imbens.metrics.classification_report_imbalanced`) 9 | similar to :mod:`sklearn`, with additional metrics specific to imbalanced 10 | learning problem. 11 | """ 12 | 13 | # Adapted from imbalanced-learn 14 | # Authors: Guillaume Lemaitre 15 | # License: MIT 16 | 17 | from sklearn import datasets 18 | from sklearn.svm import LinearSVC 19 | from sklearn.model_selection import train_test_split 20 | 21 | from imbens.sampler import SMOTE 22 | from imbens import pipeline as pl 23 | from imbens.metrics import classification_report_imbalanced 24 | 25 | print(__doc__) 26 | 27 | RANDOM_STATE = 42 28 | 29 | # sphinx_gallery_thumbnail_path = '../../docs/source/_static/thumbnail.png' 30 | 31 | # Generate a dataset 32 | X, y = datasets.make_classification( 33 | n_classes=2, 34 | class_sep=2, 35 | weights=[0.1, 0.9], 36 | n_informative=10, 37 | n_redundant=1, 38 | flip_y=0, 39 | n_features=20, 40 | n_clusters_per_class=4, 41 | n_samples=5000, 42 | random_state=RANDOM_STATE, 43 | ) 44 | 45 | pipeline = pl.make_pipeline( 46 | SMOTE(random_state=RANDOM_STATE), LinearSVC(random_state=RANDOM_STATE) 47 | ) 48 | 49 | # Split the data 50 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=RANDOM_STATE) 51 | 52 | # Train the classifier with balancing 53 | pipeline.fit(X_train, y_train) 54 | 55 | # Test the classifier and get the prediction 56 | y_pred_bal = pipeline.predict(X_test) 57 | 58 | # Show the classification report 59 | print(classification_report_imbalanced(y_test, y_pred_bal)) 60 | -------------------------------------------------------------------------------- /imbens/utils/deprecation.py: -------------------------------------------------------------------------------- 1 | """Utilities for deprecation 2 | """ 3 | # Adapted from imbalanced-learn 4 | 5 | # Authors: Guillaume Lemaitre 6 | # License: MIT 7 | 8 | import warnings 9 | 10 | 11 | def deprecate_parameter(sampler, version_deprecation, param_deprecated, new_param=None): 12 | """Helper to deprecate a parameter by another one. 13 | 14 | Parameters 15 | ---------- 16 | sampler : sampler object, 17 | The object which will be inspected. 18 | 19 | version_deprecation : str, 20 | The version from which the parameter will be deprecated. The format 21 | should be ``'x.y'`` 22 | 23 | param_deprecated : str, 24 | The parameter being deprecated. 25 | 26 | new_param : str, 27 | The parameter used instead of the deprecated parameter. By default, no 28 | parameter is expected. 29 | 30 | Returns 31 | ------- 32 | None 33 | 34 | """ 35 | x, y = version_deprecation.split(".") 36 | version_removed = x + "." + str(int(y) + 2) 37 | if new_param is None: 38 | if getattr(sampler, param_deprecated) is not None: 39 | warnings.warn( 40 | f"'{param_deprecated}' is deprecated from {version_deprecation} and " 41 | f" will be removed in {version_removed} for the estimator " 42 | f"{sampler.__class__}.", 43 | category=DeprecationWarning, 44 | ) 45 | else: 46 | if getattr(sampler, param_deprecated) is not None: 47 | warnings.warn( 48 | f"'{param_deprecated}' is deprecated from {version_deprecation} and " 49 | f"will be removed in {version_removed} for the estimator " 50 | f"{sampler.__class__}. Use '{new_param}' instead.", 51 | category=DeprecationWarning, 52 | ) 53 | setattr(sampler, new_param, getattr(sampler, param_deprecated)) 54 | -------------------------------------------------------------------------------- /docs/source/sphinxext/README.txt: -------------------------------------------------------------------------------- 1 | ===================================== 2 | numpydoc -- Numpy's Sphinx extensions 3 | ===================================== 4 | 5 | Numpy's documentation uses several custom extensions to Sphinx. These 6 | are shipped in this ``numpydoc`` package, in case you want to make use 7 | of them in third-party projects. 8 | 9 | The following extensions are available: 10 | 11 | - ``numpydoc``: support for the Numpy docstring format in Sphinx, and add 12 | the code description directives ``np-function``, ``np-cfunction``, etc. 13 | that support the Numpy docstring syntax. 14 | 15 | - ``numpydoc.traitsdoc``: For gathering documentation about Traits attributes. 16 | 17 | - ``numpydoc.plot_directives``: Adaptation of Matplotlib's ``plot::`` 18 | directive. Note that this implementation may still undergo severe 19 | changes or eventually be deprecated. 20 | 21 | - ``numpydoc.only_directives``: (DEPRECATED) 22 | 23 | - ``numpydoc.autosummary``: (DEPRECATED) An ``autosummary::`` directive. 24 | Available in Sphinx 0.6.2 and (to-be) 1.0 as ``sphinx.ext.autosummary``, 25 | and it the Sphinx 1.0 version is recommended over that included in 26 | Numpydoc. 27 | 28 | 29 | numpydoc 30 | ======== 31 | 32 | Numpydoc inserts a hook into Sphinx's autodoc that converts docstrings 33 | following the Numpy/Scipy format to a form palatable to Sphinx. 34 | 35 | Options 36 | ------- 37 | 38 | The following options can be set in conf.py: 39 | 40 | - numpydoc_use_plots: bool 41 | 42 | Whether to produce ``plot::`` directives for Examples sections that 43 | contain ``import matplotlib``. 44 | 45 | - numpydoc_show_class_members: bool 46 | 47 | Whether to show all members of a class in the Methods and Attributes 48 | sections automatically. 49 | 50 | - numpydoc_edit_link: bool (DEPRECATED -- edit your HTML template instead) 51 | 52 | Whether to insert an edit link after docstrings. 53 | -------------------------------------------------------------------------------- /docs/source/get_start.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | *************** 3 | 4 | Background 5 | ==================================== 6 | 7 | Class-imbalance (also known as the long-tail problem) is the fact that the 8 | classes are not represented equally in a classification problem, which is 9 | quite common in practice. For instance, fraud detection, prediction of 10 | rare adverse drug reactions and prediction gene families. Failure to account 11 | for the class imbalance often causes inaccurate and decreased predictive 12 | performance of many classification algorithms. 13 | 14 | Imbalanced learning (IL) aims 15 | to tackle the class imbalance problem to learn an unbiased model from 16 | imbalanced data. This is usually achieved by changing the training data 17 | distribution by resampling or reweighting. However, naive resampling or 18 | reweighting may introduce bias/variance to the training data, especially 19 | when the data has class-overlapping or contains noise. 20 | 21 | Ensemble imbalanced learning (EIL) is known to effectively improve typical 22 | IL solutions by combining the outputs of multiple classifiers, thereby 23 | reducing the variance introduce by resampling/reweighting. 24 | 25 | About ``imbens`` 26 | ==================================== 27 | 28 | ``imbens`` aims to provide users with easy-to-use EIL methods 29 | and related utilities, so that everyone can quickly deploy EIL algorithms 30 | to their tasks. The EIL methods implemented in this package have 31 | unified APIs and are compatible with other popular Python machine-learning 32 | packages such as `scikit-learn `__ 33 | and `imbalanced-learn `__. 34 | 35 | ``imbens`` is an early version software and is under development. 36 | Any kinds of contributions are welcome! 37 | 38 | > Note: *many resampling algorithms and utilities are adapted from* 39 | `imbalanced-learn `__, *which is an amazing 40 | project!* -------------------------------------------------------------------------------- /imbens/utils/tests/test_show_versions.py: -------------------------------------------------------------------------------- 1 | """Test for the show_versions helper. Based on the sklearn tests.""" 2 | # Author: Alexander L. Hayes 3 | # License: MIT 4 | 5 | # %% 6 | 7 | from imbens.utils._show_versions import _get_deps_info, show_versions 8 | 9 | 10 | def test_get_deps_info(): 11 | _deps_info = _get_deps_info() 12 | assert "pip" in _deps_info 13 | assert "setuptools" in _deps_info 14 | assert "imblearn" in _deps_info 15 | assert "imbens" in _deps_info 16 | assert "sklearn" in _deps_info 17 | assert "numpy" in _deps_info 18 | assert "scipy" in _deps_info 19 | assert "Cython" in _deps_info 20 | assert "pandas" in _deps_info 21 | assert "joblib" in _deps_info 22 | 23 | 24 | def test_show_versions_default(capsys): 25 | show_versions() 26 | out, err = capsys.readouterr() 27 | assert "python" in out 28 | assert "executable" in out 29 | assert "machine" in out 30 | assert "pip" in out 31 | assert "setuptools" in out 32 | assert "imblearn" in out 33 | assert "imbens" in out 34 | assert "sklearn" in out 35 | assert "numpy" in out 36 | assert "scipy" in out 37 | assert "Cython" in out 38 | assert "pandas" in out 39 | assert "joblib" in out 40 | 41 | 42 | def test_show_versions_github(capsys): 43 | show_versions(github=True) 44 | out, err = capsys.readouterr() 45 | assert "
System, Dependency Information" in out 46 | assert "**System Information**" in out 47 | assert "* python" in out 48 | assert "* executable" in out 49 | assert "* machine" in out 50 | assert "**Python Dependencies**" in out 51 | assert "* pip" in out 52 | assert "* setuptools" in out 53 | assert "* imblearn" in out 54 | assert "* imbens" in out 55 | assert "* sklearn" in out 56 | assert "* numpy" in out 57 | assert "* scipy" in out 58 | assert "* Cython" in out 59 | assert "* pandas" in out 60 | assert "* joblib" in out 61 | assert "
" in out 62 | 63 | 64 | # %% 65 | -------------------------------------------------------------------------------- /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "badgeTemplate": "-orange.svg\">", 8 | "contributors": [ 9 | { 10 | "login": "ZhiningLiu1998", 11 | "name": "Zhining Liu", 12 | "avatar_url": "https://avatars.githubusercontent.com/u/26108487?v=4", 13 | "profile": "http://zhiningliu.com", 14 | "contributions": [ 15 | "code", 16 | "ideas", 17 | "maintenance", 18 | "bug", 19 | "doc" 20 | ] 21 | }, 22 | { 23 | "login": "leaphan", 24 | "name": "leaphan", 25 | "avatar_url": "https://avatars.githubusercontent.com/u/35593707?v=4", 26 | "profile": "https://github.com/leaphan", 27 | "contributions": [ 28 | "bug" 29 | ] 30 | }, 31 | { 32 | "login": "hannanhtang", 33 | "name": "hannanhtang", 34 | "avatar_url": "https://avatars.githubusercontent.com/u/23587399?v=4", 35 | "profile": "https://github.com/hannanhtang", 36 | "contributions": [ 37 | "bug" 38 | ] 39 | }, 40 | { 41 | "login": "huajuanren", 42 | "name": "H.J.Ren", 43 | "avatar_url": "https://avatars.githubusercontent.com/u/37321841?v=4", 44 | "profile": "https://github.com/huajuanren", 45 | "contributions": [ 46 | "bug" 47 | ] 48 | }, 49 | { 50 | "login": "MarcSkovMadsen", 51 | "name": "Marc Skov Madsen", 52 | "avatar_url": "https://avatars.githubusercontent.com/u/42288570?v=4", 53 | "profile": "http://datamodelsanalytics.com", 54 | "contributions": [ 55 | "bug" 56 | ] 57 | } 58 | ], 59 | "contributorsPerLine": 7, 60 | "projectName": "imbalanced-ensemble", 61 | "projectOwner": "ZhiningLiu1998", 62 | "repoType": "github", 63 | "repoHost": "https://github.com", 64 | "skipCi": true, 65 | "commitConvention": "angular" 66 | } 67 | -------------------------------------------------------------------------------- /imbens/sampler/_over_sampling/_smote/tests/test_borderline_smote.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from sklearn.neighbors import NearestNeighbors 5 | from sklearn.utils._testing import assert_allclose 6 | from sklearn.utils._testing import assert_array_equal 7 | 8 | from imbens.sampler._over_sampling import BorderlineSMOTE 9 | 10 | 11 | @pytest.fixture 12 | def data(): 13 | X = np.array( 14 | [ 15 | [0.11622591, -0.0317206], 16 | [0.77481731, 0.60935141], 17 | [1.25192108, -0.22367336], 18 | [0.53366841, -0.30312976], 19 | [1.52091956, -0.49283504], 20 | [-0.28162401, -2.10400981], 21 | [0.83680821, 1.72827342], 22 | [0.3084254, 0.33299982], 23 | [0.70472253, -0.73309052], 24 | [0.28893132, -0.38761769], 25 | [1.15514042, 0.0129463], 26 | [0.88407872, 0.35454207], 27 | [1.31301027, -0.92648734], 28 | [-1.11515198, -0.93689695], 29 | [-0.18410027, -0.45194484], 30 | [0.9281014, 0.53085498], 31 | [-0.14374509, 0.27370049], 32 | [-0.41635887, -0.38299653], 33 | [0.08711622, 0.93259929], 34 | [1.70580611, -0.11219234], 35 | ] 36 | ) 37 | y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) 38 | return X, y 39 | 40 | 41 | def test_borderline_smote_wrong_kind(data): 42 | bsmote = BorderlineSMOTE(kind="rand") 43 | with pytest.raises(ValueError, match='The possible "kind" of algorithm'): 44 | bsmote.fit_resample(*data) 45 | 46 | 47 | @pytest.mark.parametrize("kind", ["borderline-1", "borderline-2"]) 48 | def test_borderline_smote(kind, data): 49 | bsmote = BorderlineSMOTE(kind=kind, random_state=42) 50 | bsmote_nn = BorderlineSMOTE( 51 | kind=kind, 52 | random_state=42, 53 | k_neighbors=NearestNeighbors(n_neighbors=6), 54 | m_neighbors=NearestNeighbors(n_neighbors=11), 55 | ) 56 | 57 | X_res_1, y_res_1 = bsmote.fit_resample(*data) 58 | X_res_2, y_res_2 = bsmote_nn.fit_resample(*data) 59 | 60 | assert_allclose(X_res_1, X_res_2) 61 | assert_array_equal(y_res_1, y_res_2) 62 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_selection/tests/test_tomek_links.py: -------------------------------------------------------------------------------- 1 | """Test the module Tomek's links.""" 2 | # Authors: Guillaume Lemaitre 3 | # Christos Aridas 4 | # License: MIT 5 | 6 | import numpy as np 7 | from sklearn.utils._testing import assert_array_equal 8 | 9 | from imbens.sampler._under_sampling import TomekLinks 10 | 11 | X = np.array( 12 | [ 13 | [0.31230513, 0.1216318], 14 | [0.68481731, 0.51935141], 15 | [1.34192108, -0.13367336], 16 | [0.62366841, -0.21312976], 17 | [1.61091956, -0.40283504], 18 | [-0.37162401, -2.19400981], 19 | [0.74680821, 1.63827342], 20 | [0.2184254, 0.24299982], 21 | [0.61472253, -0.82309052], 22 | [0.19893132, -0.47761769], 23 | [1.06514042, -0.0770537], 24 | [0.97407872, 0.44454207], 25 | [1.40301027, -0.83648734], 26 | [-1.20515198, -1.02689695], 27 | [-0.27410027, -0.54194484], 28 | [0.8381014, 0.44085498], 29 | [-0.23374509, 0.18370049], 30 | [-0.32635887, -0.29299653], 31 | [-0.00288378, 0.84259929], 32 | [1.79580611, -0.02219234], 33 | ] 34 | ) 35 | Y = np.array([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) 36 | 37 | 38 | def test_tl_init(): 39 | tl = TomekLinks() 40 | assert tl.n_jobs is None 41 | 42 | 43 | def test_tl_fit_resample(): 44 | tl = TomekLinks() 45 | X_resampled, y_resampled = tl.fit_resample(X, Y) 46 | 47 | X_gt = np.array( 48 | [ 49 | [0.31230513, 0.1216318], 50 | [0.68481731, 0.51935141], 51 | [1.34192108, -0.13367336], 52 | [0.62366841, -0.21312976], 53 | [1.61091956, -0.40283504], 54 | [-0.37162401, -2.19400981], 55 | [0.74680821, 1.63827342], 56 | [0.2184254, 0.24299982], 57 | [0.61472253, -0.82309052], 58 | [0.19893132, -0.47761769], 59 | [0.97407872, 0.44454207], 60 | [1.40301027, -0.83648734], 61 | [-1.20515198, -1.02689695], 62 | [-0.23374509, 0.18370049], 63 | [-0.32635887, -0.29299653], 64 | [-0.00288378, 0.84259929], 65 | [1.79580611, -0.02219234], 66 | ] 67 | ) 68 | y_gt = np.array([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0]) 69 | assert_array_equal(X_resampled, X_gt) 70 | assert_array_equal(y_resampled, y_gt) 71 | -------------------------------------------------------------------------------- /imbens/utils/_show_versions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility method which prints system info to help with debugging, 3 | and filing issues on GitHub. 4 | Adapted from :func:`sklearn.show_versions`, 5 | which was adapted from :func:`pandas.show_versions` 6 | """ 7 | # Adapted from imbalanced-learn 8 | 9 | # Author: Alexander L. Hayes 10 | # License: MIT 11 | 12 | from .. import __version__ 13 | 14 | 15 | def _get_deps_info(): 16 | """Overview of the installed version of main dependencies 17 | Returns 18 | ------- 19 | deps_info: dict 20 | version information on relevant Python libraries 21 | """ 22 | deps = [ 23 | "pip", 24 | "setuptools", 25 | "imblearn", 26 | "imbens", 27 | "sklearn", 28 | "numpy", 29 | "scipy", 30 | "Cython", 31 | "pandas", 32 | "joblib", 33 | ] 34 | 35 | deps_info = { 36 | "imbalanced-ensemble": __version__, 37 | } 38 | 39 | from importlib.metadata import PackageNotFoundError, version 40 | 41 | for modname in deps: 42 | try: 43 | deps_info[modname] = version(modname) 44 | except PackageNotFoundError: 45 | deps_info[modname] = None 46 | return deps_info 47 | 48 | 49 | def show_versions(github=False): 50 | """Print debugging information. 51 | 52 | .. versionadded:: 0.5 53 | 54 | Parameters 55 | ---------- 56 | github : bool, 57 | If true, wrap system info with GitHub markup. 58 | """ 59 | 60 | from sklearn.utils._show_versions import _get_sys_info 61 | 62 | _sys_info = _get_sys_info() 63 | _deps_info = _get_deps_info() 64 | _github_markup = ( 65 | "
" 66 | "System, Dependency Information\n\n" 67 | "**System Information**\n\n" 68 | "{0}\n" 69 | "**Python Dependencies**\n\n" 70 | "{1}\n" 71 | "
" 72 | ) 73 | 74 | if github: 75 | _sys_markup = "" 76 | _deps_markup = "" 77 | 78 | for k, stat in _sys_info.items(): 79 | _sys_markup += f"* {k:<10}: `{stat}`\n" 80 | for k, stat in _deps_info.items(): 81 | _deps_markup += f"* {k:<10}: `{stat}`\n" 82 | 83 | print(_github_markup.format(_sys_markup, _deps_markup)) 84 | 85 | else: 86 | print("\nSystem:") 87 | for k, stat in _sys_info.items(): 88 | print(f"{k:>11}: {stat}") 89 | 90 | print("\nPython dependencies:") 91 | for k, stat in _deps_info.items(): 92 | print(f"{k:>11}: {stat}") 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | docs/source/auto_examples/ 74 | docs/source/back_references/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # vscode 134 | .vscode/ -------------------------------------------------------------------------------- /imbens/sampler/_over_sampling/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for the over-sampling method. 3 | """ 4 | # Adapted from imbalanced-learn 5 | 6 | # Authors: Guillaume Lemaitre 7 | # Christos Aridas 8 | # Zhining Liu 9 | # License: MIT 10 | 11 | LOCAL_DEBUG = False 12 | 13 | if not LOCAL_DEBUG: 14 | from ..base import BaseSampler 15 | else: # pragma: no cover 16 | import sys # For local test 17 | sys.path.append("../..") 18 | from sampler.base import BaseSampler 19 | 20 | 21 | class BaseOverSampler(BaseSampler): 22 | """Base class for over-sampling algorithms. 23 | 24 | Warning: This class should not be used directly. Use the derive classes 25 | instead. 26 | """ 27 | 28 | _sampling_type = "over-sampling" 29 | 30 | _sampling_strategy_docstring = """sampling_strategy : float, str, dict or callable, default='auto' 31 | Sampling information to resample the data set. 32 | 33 | - When ``float``, it corresponds to the desired ratio of the number of 34 | samples in the minority class over the number of samples in the 35 | majority class after resampling. Therefore, the ratio is expressed as 36 | :math:`\\alpha_{os} = N_{rm} / N_{M}` where :math:`N_{rm}` is the 37 | number of samples in the minority class after resampling and 38 | :math:`N_{M}` is the number of samples in the majority class. 39 | 40 | .. warning:: 41 | ``float`` is only available for **binary** classification. An 42 | error is raised for multi-class classification. 43 | 44 | - When ``str``, specify the class targeted by the resampling. The 45 | number of samples in the different classes will be equalized. 46 | Possible choices are: 47 | 48 | ``'minority'``: resample only the minority class; 49 | 50 | ``'not minority'``: resample all classes but the minority class; 51 | 52 | ``'not majority'``: resample all classes but the majority class; 53 | 54 | ``'all'``: resample all classes; 55 | 56 | ``'auto'``: equivalent to ``'not majority'``. 57 | 58 | - When ``dict``, the keys correspond to the targeted classes. The 59 | values correspond to the desired number of samples for each targeted 60 | class. 61 | 62 | - When callable, function taking ``y`` and returns a ``dict``. The keys 63 | correspond to the targeted classes. The values correspond to the 64 | desired number of samples for each class. 65 | """.strip() 66 | -------------------------------------------------------------------------------- /imbens/metrics/tests/test_score_objects.py: -------------------------------------------------------------------------------- 1 | """Test for score""" 2 | # Authors: Guillaume Lemaitre 3 | # Christos Aridas 4 | # License: MIT 5 | 6 | import pytest 7 | from sklearn.datasets import make_blobs 8 | from sklearn.metrics import make_scorer 9 | from sklearn.model_selection import GridSearchCV, train_test_split 10 | from sklearn.svm import LinearSVC 11 | 12 | from imbens.metrics import ( 13 | geometric_mean_score, 14 | make_index_balanced_accuracy, 15 | sensitivity_score, 16 | specificity_score, 17 | ) 18 | 19 | R_TOL = 1e-2 20 | 21 | 22 | @pytest.fixture 23 | def data(): 24 | X, y = make_blobs(random_state=0, centers=2) 25 | return train_test_split(X, y, random_state=0) 26 | 27 | 28 | @pytest.mark.filterwarnings("ignore:Liblinear failed to converge") 29 | @pytest.mark.parametrize( 30 | "score, expected_score", 31 | [ 32 | (sensitivity_score, 0.92), 33 | (specificity_score, 0.92), 34 | (geometric_mean_score, 0.92), 35 | (make_index_balanced_accuracy()(geometric_mean_score), 0.85), 36 | ], 37 | ) 38 | @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) 39 | def test_scorer_common_average(data, score, expected_score, average): 40 | X_train, X_test, y_train, _ = data 41 | 42 | scorer = make_scorer(score, pos_label=None, average=average) 43 | grid = GridSearchCV( 44 | LinearSVC(dual="auto", random_state=0), 45 | param_grid={"C": [1, 10]}, 46 | scoring=scorer, 47 | cv=3, 48 | ) 49 | grid.fit(X_train, y_train).predict(X_test) 50 | 51 | assert grid.best_score_ == pytest.approx(expected_score, rel=R_TOL) 52 | 53 | 54 | @pytest.mark.filterwarnings("ignore:Liblinear failed to converge") 55 | @pytest.mark.parametrize( 56 | "score, average, expected_score", 57 | [ 58 | (sensitivity_score, "binary", 0.92), 59 | (specificity_score, "binary", 0.95), 60 | (geometric_mean_score, "multiclass", 0.92), 61 | ( 62 | make_index_balanced_accuracy()(geometric_mean_score), 63 | "multiclass", 64 | 0.84, 65 | ), 66 | ], 67 | ) 68 | def test_scorer_default_average(data, score, average, expected_score): 69 | X_train, X_test, y_train, _ = data 70 | 71 | scorer = make_scorer(score, pos_label=1, average=average) 72 | grid = GridSearchCV( 73 | LinearSVC(dual="auto", random_state=0), 74 | param_grid={"C": [1, 10]}, 75 | scoring=scorer, 76 | cv=3, 77 | ) 78 | grid.fit(X_train, y_train).predict(X_test) 79 | 80 | assert grid.best_score_ == pytest.approx(expected_score, rel=R_TOL) 81 | -------------------------------------------------------------------------------- /examples/pipeline/plot_pipeline_classification.py: -------------------------------------------------------------------------------- 1 | """ 2 | ==================================== 3 | Usage of pipeline embedding samplers 4 | ==================================== 5 | 6 | An example of the :class:`~imbens.pipeline.Pipeline` object (or 7 | :func:`~imbens.pipeline.make_pipeline` helper function) working with 8 | transformers (:class:`~sklearn.decomposition.PCA`, 9 | :class:`~sklearn.neighbors.KNeighborsClassifier` from *scikit-learn*) and resamplers 10 | (:class:`~imbens.sampler.EditedNearestNeighbours`, 11 | :class:`~imbens.sampler.SMOTE`). 12 | """ 13 | 14 | # Adapted from imbalanced-learn 15 | # Authors: Christos Aridas 16 | # Guillaume Lemaitre 17 | # License: MIT 18 | 19 | # %% 20 | print(__doc__) 21 | 22 | # sphinx_gallery_thumbnail_path = '../../docs/source/_static/thumbnail.png' 23 | 24 | # %% [markdown] 25 | # Let's first create an imbalanced dataset and split in to two sets. 26 | 27 | # %% 28 | from sklearn.datasets import make_classification 29 | from sklearn.model_selection import train_test_split 30 | 31 | X, y = make_classification( 32 | n_classes=2, 33 | class_sep=1.25, 34 | weights=[0.3, 0.7], 35 | n_informative=3, 36 | n_redundant=1, 37 | flip_y=0, 38 | n_features=5, 39 | n_clusters_per_class=1, 40 | n_samples=5000, 41 | random_state=10, 42 | ) 43 | 44 | X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42) 45 | 46 | # %% [markdown] 47 | # Now, we will create each individual steps 48 | # that we would like later to combine 49 | 50 | # %% 51 | from sklearn.decomposition import PCA 52 | from sklearn.neighbors import KNeighborsClassifier 53 | from imbens.sampler import EditedNearestNeighbours 54 | from imbens.sampler import SMOTE 55 | 56 | pca = PCA(n_components=2) 57 | enn = EditedNearestNeighbours() 58 | smote = SMOTE(random_state=0) 59 | knn = KNeighborsClassifier(n_neighbors=1) 60 | 61 | # %% [markdown] 62 | # Now, we can finally create a pipeline to specify in which order the different 63 | # transformers and samplers should be executed before to provide the data to 64 | # the final classifier. 65 | 66 | # %% 67 | from imbens.pipeline import make_pipeline 68 | 69 | model = make_pipeline(pca, enn, smote, knn) 70 | 71 | # %% [markdown] 72 | # We can now use the pipeline created as a normal classifier where resampling 73 | # will happen when calling `fit` and disabled when calling `decision_function`, 74 | # `predict_proba`, or `predict`. 75 | 76 | # %% 77 | from sklearn.metrics import classification_report 78 | 79 | model.fit(X_train, y_train) 80 | y_pred = model.predict(X_test) 81 | print(classification_report(y_test, y_pred)) 82 | -------------------------------------------------------------------------------- /imbens/sampler/_over_sampling/_smote/tests/test_svm_smote.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from sklearn.neighbors import NearestNeighbors 5 | from sklearn.svm import SVC 6 | 7 | from sklearn.utils._testing import assert_allclose 8 | from sklearn.utils._testing import assert_array_equal 9 | 10 | from imbens.sampler._over_sampling import SVMSMOTE 11 | 12 | 13 | @pytest.fixture 14 | def data(): 15 | X = np.array( 16 | [ 17 | [0.11622591, -0.0317206], 18 | [0.77481731, 0.60935141], 19 | [1.25192108, -0.22367336], 20 | [0.53366841, -0.30312976], 21 | [1.52091956, -0.49283504], 22 | [-0.28162401, -2.10400981], 23 | [0.83680821, 1.72827342], 24 | [0.3084254, 0.33299982], 25 | [0.70472253, -0.73309052], 26 | [0.28893132, -0.38761769], 27 | [1.15514042, 0.0129463], 28 | [0.88407872, 0.35454207], 29 | [1.31301027, -0.92648734], 30 | [-1.11515198, -0.93689695], 31 | [-0.18410027, -0.45194484], 32 | [0.9281014, 0.53085498], 33 | [-0.14374509, 0.27370049], 34 | [-0.41635887, -0.38299653], 35 | [0.08711622, 0.93259929], 36 | [1.70580611, -0.11219234], 37 | ] 38 | ) 39 | y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) 40 | return X, y 41 | 42 | 43 | def test_svm_smote(data): 44 | svm_smote = SVMSMOTE(random_state=42) 45 | svm_smote_nn = SVMSMOTE( 46 | random_state=42, 47 | k_neighbors=NearestNeighbors(n_neighbors=6), 48 | m_neighbors=NearestNeighbors(n_neighbors=11), 49 | svm_estimator=SVC(gamma="scale", random_state=42), 50 | ) 51 | 52 | X_res_1, y_res_1 = svm_smote.fit_resample(*data) 53 | X_res_2, y_res_2 = svm_smote_nn.fit_resample(*data) 54 | 55 | assert_allclose(X_res_1, X_res_2) 56 | assert_array_equal(y_res_1, y_res_2) 57 | 58 | 59 | def test_svm_smote_sample_weight(data): 60 | svm_smote = SVMSMOTE(random_state=42) 61 | svm_smote_nn = SVMSMOTE( 62 | random_state=42, 63 | k_neighbors=NearestNeighbors(n_neighbors=6), 64 | m_neighbors=NearestNeighbors(n_neighbors=11), 65 | svm_estimator=SVC(gamma="scale", random_state=42), 66 | ) 67 | X, y = data 68 | sample_weight = np.ones_like(y) 69 | 70 | X_res_1, y_res_1, w_res_1 = svm_smote.fit_resample(X, y, sample_weight=sample_weight) 71 | X_res_2, y_res_2, w_res_2 = svm_smote_nn.fit_resample(X, y, sample_weight=sample_weight) 72 | 73 | assert_allclose(X_res_1, X_res_2) 74 | assert_array_equal(y_res_1, y_res_2) 75 | assert_array_equal(w_res_1, w_res_2) -------------------------------------------------------------------------------- /imbens/datasets/tests/test_imbalance.py: -------------------------------------------------------------------------------- 1 | """Test the module easy ensemble.""" 2 | # Authors: Guillaume Lemaitre 3 | # Christos Aridas 4 | # License: MIT 5 | 6 | from collections import Counter 7 | 8 | import numpy as np 9 | import pytest 10 | from sklearn.datasets import fetch_openml, load_iris 11 | 12 | from imbens.datasets import make_imbalance 13 | 14 | 15 | @pytest.fixture 16 | def iris(): 17 | return load_iris(return_X_y=True) 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "sampling_strategy, err_msg", 22 | [ 23 | ({0: -100, 1: 50, 2: 50}, "in a class cannot be negative"), 24 | ({0: 10, 1: 70}, "should be less or equal to the original"), 25 | ("random-string", "has to be a dictionary or a function"), 26 | ], 27 | ) 28 | def test_make_imbalance_error(iris, sampling_strategy, err_msg): 29 | # we are reusing part of utils.check_sampling_strategy, however this is not 30 | # cover in the common tests so we will repeat it here 31 | X, y = iris 32 | with pytest.raises(ValueError, match=err_msg): 33 | make_imbalance(X, y, sampling_strategy=sampling_strategy) 34 | 35 | 36 | def test_make_imbalance_error_single_class(iris): 37 | X, y = iris 38 | y = np.zeros_like(y) 39 | with pytest.raises(ValueError, match="needs to have more than 1 class."): 40 | make_imbalance(X, y, sampling_strategy={0: 10}) 41 | 42 | 43 | @pytest.mark.parametrize( 44 | "sampling_strategy, expected_counts", 45 | [ 46 | ({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}), 47 | ({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50}), 48 | ], 49 | ) 50 | def test_make_imbalance_dict(iris, sampling_strategy, expected_counts): 51 | X, y = iris 52 | _, y_ = make_imbalance(X, y, sampling_strategy=sampling_strategy) 53 | assert Counter(y_) == expected_counts 54 | 55 | 56 | @pytest.mark.parametrize("as_frame", [True, False], ids=["dataframe", "array"]) 57 | @pytest.mark.parametrize( 58 | "sampling_strategy, expected_counts", 59 | [ 60 | ( 61 | {"Iris-setosa": 10, "Iris-versicolor": 20, "Iris-virginica": 30}, 62 | {"Iris-setosa": 10, "Iris-versicolor": 20, "Iris-virginica": 30}, 63 | ), 64 | ( 65 | {"Iris-setosa": 10, "Iris-versicolor": 20}, 66 | {"Iris-setosa": 10, "Iris-versicolor": 20, "Iris-virginica": 50}, 67 | ), 68 | ], 69 | ) 70 | def test_make_imbalanced_iris(as_frame, sampling_strategy, expected_counts): 71 | pytest.importorskip("pandas") 72 | X, y = fetch_openml( 73 | "iris", parser='auto', version=1, return_X_y=True, as_frame=as_frame 74 | ) 75 | X_res, y_res = make_imbalance(X, y, sampling_strategy=sampling_strategy) 76 | if as_frame: 77 | assert hasattr(X_res, "loc") 78 | assert Counter(y_res) == expected_counts 79 | -------------------------------------------------------------------------------- /examples/basic/plot_basic_visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================= 3 | Visualize an ensemble classifier 4 | ========================================================= 5 | 6 | This example illustrates how to quickly visualize an 7 | :mod:`imbens.ensemble` classifier with 8 | the :mod:`imbens.visualizer` module. 9 | 10 | This example uses: 11 | 12 | - :class:`imbens.ensemble.SelfPacedEnsembleClassifier` 13 | - :class:`imbens.visualizer.ImbalancedEnsembleVisualizer` 14 | """ 15 | 16 | # Authors: Zhining Liu 17 | # License: MIT 18 | 19 | 20 | # %% 21 | print(__doc__) 22 | 23 | # Import imbalanced-ensemble 24 | import imbens 25 | 26 | # Import utilities 27 | import sklearn 28 | from sklearn.datasets import make_classification 29 | from sklearn.model_selection import train_test_split 30 | 31 | RANDOM_STATE = 42 32 | 33 | # sphinx_gallery_thumbnail_number = 2 34 | 35 | # %% [markdown] 36 | # Prepare data 37 | # ------------ 38 | # Make a toy 3-class imbalanced classification task. 39 | 40 | # make dataset 41 | X, y = make_classification( 42 | n_classes=3, 43 | class_sep=2, 44 | weights=[0.1, 0.3, 0.6], 45 | n_informative=3, 46 | n_redundant=1, 47 | flip_y=0, 48 | n_features=20, 49 | n_clusters_per_class=2, 50 | n_samples=2000, 51 | random_state=0, 52 | ) 53 | 54 | # train valid split 55 | X_train, X_valid, y_train, y_valid = train_test_split( 56 | X, y, test_size=0.5, stratify=y, random_state=RANDOM_STATE 57 | ) 58 | 59 | 60 | # %% [markdown] 61 | # Train an ensemble classifier 62 | # ---------------------------- 63 | # Take ``SelfPacedEnsembleClassifier`` as example 64 | 65 | # Initialize and train an SPE classifier 66 | clf = imbens.ensemble.SelfPacedEnsembleClassifier(random_state=RANDOM_STATE).fit( 67 | X_train, y_train 68 | ) 69 | 70 | # Store the fitted SelfPacedEnsembleClassifier 71 | fitted_ensembles = {'SPE': clf} 72 | 73 | 74 | # %% [markdown] 75 | # Fit an ImbalancedEnsembleVisualizer 76 | # ----------------------------------------------------- 77 | 78 | # Initialize visualizer 79 | visualizer = imbens.visualizer.ImbalancedEnsembleVisualizer( 80 | eval_datasets={ 81 | 'training': (X_train, y_train), 82 | 'validation': (X_valid, y_valid), 83 | }, 84 | ) 85 | 86 | # Fit visualizer 87 | visualizer.fit(fitted_ensembles) 88 | 89 | 90 | # %% [markdown] 91 | # Plot performance curve 92 | # ---------------------- 93 | # **performance w.r.t. number of base estimators** 94 | 95 | fig, axes = visualizer.performance_lineplot() 96 | 97 | # %% [markdown] 98 | # Plot confusion matrix 99 | # --------------------- 100 | 101 | fig, axes = visualizer.confusion_matrix_heatmap( 102 | on_datasets=['validation'], # only on validation set 103 | sup_title=False, 104 | ) 105 | -------------------------------------------------------------------------------- /examples/datasets/plot_make_imbalance.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | Make a dataset class-imbalanced 4 | =============================== 5 | 6 | An illustration of the :func:`~imbens.datasets.make_imbalance` function to 7 | create an imbalanced dataset from a balanced dataset. We show the ability of 8 | :func:`~imbens.datasets.make_imbalance` of dealing with Pandas DataFrame. 9 | """ 10 | 11 | # Adapted from imbalanced-learn 12 | # Authors: Dayvid Oliveira 13 | # Christos Aridas 14 | # Guillaume Lemaitre 15 | # Zhining Liu 16 | # License: MIT 17 | 18 | # %% 19 | print(__doc__) 20 | 21 | import matplotlib.pyplot as plt 22 | import seaborn as sns 23 | 24 | sns.set_context("poster") 25 | 26 | # %% [markdown] 27 | # Generate the dataset 28 | # -------------------- 29 | # 30 | # First, we will generate a dataset and convert it to a 31 | # :class:`~pandas.DataFrame` with arbitrary column names. We will plot the 32 | # original dataset. 33 | 34 | # %% 35 | import pandas as pd 36 | from sklearn.datasets import make_moons 37 | 38 | X, y = make_moons(n_samples=200, shuffle=True, noise=0.25, random_state=10) 39 | X = pd.DataFrame(X, columns=["feature 1", "feature 2"]) 40 | 41 | fig = plt.figure(figsize=(6, 5)) 42 | ax = sns.scatterplot( 43 | data=X, 44 | x="feature 1", 45 | y="feature 2", 46 | hue=y, 47 | style=y, 48 | ) 49 | 50 | # %% [markdown] 51 | # Make a dataset imbalanced 52 | # ------------------------- 53 | # 54 | # Now, we will show the helpers :func:`~imbens.datasets.make_imbalance` 55 | # that is useful to random select a subset of samples. It will impact the 56 | # class distribution as specified by the parameters. 57 | 58 | # %% 59 | from collections import Counter 60 | 61 | 62 | def ratio_func(y, multiplier, minority_class): 63 | target_stats = Counter(y) 64 | return {minority_class: int(multiplier * target_stats[minority_class])} 65 | 66 | 67 | # %% 68 | from imbens.datasets import make_imbalance 69 | 70 | fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(15, 10)) 71 | 72 | sns.scatterplot( 73 | data=X, 74 | x="feature 1", 75 | y="feature 2", 76 | hue=y, 77 | style=y, 78 | ax=axs[0, 0], 79 | ) 80 | axs[0, 0].set_title("Original set") 81 | 82 | multipliers = [0.9, 0.75, 0.5, 0.25, 0.1] 83 | for ax, multiplier in zip(axs.ravel()[1:], multipliers): 84 | X_resampled, y_resampled = make_imbalance( 85 | X, 86 | y, 87 | sampling_strategy=ratio_func, 88 | **{"multiplier": multiplier, "minority_class": 1}, 89 | ) 90 | 91 | sns.scatterplot( 92 | data=X_resampled, 93 | x="feature 1", 94 | y="feature 2", 95 | hue=y_resampled, 96 | style=y_resampled, 97 | ax=ax, 98 | ) 99 | ax.set_title(f"Sampling ratio = {multiplier}") 100 | 101 | plt.tight_layout() 102 | plt.show() 103 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py: -------------------------------------------------------------------------------- 1 | """Test the module neighbourhood cleaning rule.""" 2 | # Authors: Guillaume Lemaitre 3 | # Christos Aridas 4 | # License: MIT 5 | 6 | import numpy as np 7 | import pytest 8 | from sklearn.utils._testing import assert_array_equal 9 | 10 | from imbens.sampler._under_sampling import NeighbourhoodCleaningRule 11 | 12 | X = np.array( 13 | [ 14 | [1.57737838, 0.1997882], 15 | [0.8960075, 0.46130762], 16 | [0.34096173, 0.50947647], 17 | [-0.91735824, 0.93110278], 18 | [-0.14619583, 1.33009918], 19 | [-0.20413357, 0.64628718], 20 | [0.85713638, 0.91069295], 21 | [0.35967591, 2.61186964], 22 | [0.43142011, 0.52323596], 23 | [0.90701028, -0.57636928], 24 | [-1.20809175, -1.49917302], 25 | [-0.60497017, -0.66630228], 26 | [1.39272351, -0.51631728], 27 | [-1.55581933, 1.09609604], 28 | [1.55157493, -1.6981518], 29 | ] 30 | ) 31 | Y = np.array([1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 0, 0, 2, 1, 2]) 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "ncr_params, err_msg", 36 | [ 37 | ({"threshold_cleaning": -10}, "value between 0 and 1"), 38 | ({"threshold_cleaning": 10}, "value between 0 and 1"), 39 | ({"n_neighbors": "rnd"}, "has to be one of"), 40 | ], 41 | ) 42 | def test_ncr_error(ncr_params, err_msg): 43 | ncr = NeighbourhoodCleaningRule(**ncr_params) 44 | with pytest.raises(ValueError, match=err_msg): 45 | ncr.fit_resample(X, Y) 46 | 47 | 48 | def test_ncr_fit_resample(): 49 | ncr = NeighbourhoodCleaningRule() 50 | X_resampled, y_resampled = ncr.fit_resample(X, Y) 51 | 52 | X_gt = np.array( 53 | [ 54 | [0.34096173, 0.50947647], 55 | [-0.91735824, 0.93110278], 56 | [-0.20413357, 0.64628718], 57 | [0.35967591, 2.61186964], 58 | [0.90701028, -0.57636928], 59 | [-1.20809175, -1.49917302], 60 | [-0.60497017, -0.66630228], 61 | [1.39272351, -0.51631728], 62 | [-1.55581933, 1.09609604], 63 | [1.55157493, -1.6981518], 64 | ] 65 | ) 66 | y_gt = np.array([1, 1, 1, 2, 2, 0, 0, 2, 1, 2]) 67 | assert_array_equal(X_resampled, X_gt) 68 | assert_array_equal(y_resampled, y_gt) 69 | 70 | 71 | def test_ncr_fit_resample_mode(): 72 | ncr = NeighbourhoodCleaningRule(kind_sel="mode") 73 | X_resampled, y_resampled = ncr.fit_resample(X, Y) 74 | 75 | X_gt = np.array( 76 | [ 77 | [0.34096173, 0.50947647], 78 | [-0.91735824, 0.93110278], 79 | [-0.20413357, 0.64628718], 80 | [0.35967591, 2.61186964], 81 | [0.90701028, -0.57636928], 82 | [-1.20809175, -1.49917302], 83 | [-0.60497017, -0.66630228], 84 | [1.39272351, -0.51631728], 85 | [-1.55581933, 1.09609604], 86 | [1.55157493, -1.6981518], 87 | ] 88 | ) 89 | y_gt = np.array([1, 1, 1, 2, 2, 0, 0, 2, 1, 2]) 90 | assert_array_equal(X_resampled, X_gt) 91 | assert_array_equal(y_resampled, y_gt) 92 | -------------------------------------------------------------------------------- /imbens/datasets/tests/test_zenodo.py: -------------------------------------------------------------------------------- 1 | """Test the datasets loader. 2 | 3 | Skipped if datasets is not already downloaded to data_home. 4 | """ 5 | 6 | # Authors: Guillaume Lemaitre 7 | # Christos Aridas 8 | # License: MIT 9 | 10 | import pytest 11 | from sklearn.utils._testing import SkipTest 12 | 13 | from imbens.datasets import fetch_zenodo_datasets 14 | 15 | DATASET_SHAPE = { 16 | "ecoli": (336, 7), 17 | "optical_digits": (5620, 64), 18 | "satimage": (6435, 36), 19 | "pen_digits": (10992, 16), 20 | "abalone": (4177, 10), 21 | "sick_euthyroid": (3163, 42), 22 | "spectrometer": (531, 93), 23 | "car_eval_34": (1728, 21), 24 | "isolet": (7797, 617), 25 | "us_crime": (1994, 100), 26 | "yeast_ml8": (2417, 103), 27 | "scene": (2407, 294), 28 | "libras_move": (360, 90), 29 | "thyroid_sick": (3772, 52), 30 | "coil_2000": (9822, 85), 31 | "arrhythmia": (452, 278), 32 | "solar_flare_m0": (1389, 32), 33 | "oil": (937, 49), 34 | "car_eval_4": (1728, 21), 35 | "wine_quality": (4898, 11), 36 | "letter_img": (20000, 16), 37 | "yeast_me2": (1484, 8), 38 | "webpage": (34780, 300), 39 | "ozone_level": (2536, 72), 40 | "mammography": (11183, 6), 41 | "protein_homo": (145751, 74), 42 | "abalone_19": (4177, 10), 43 | } 44 | 45 | 46 | def fetch(*args, **kwargs): 47 | return fetch_zenodo_datasets(*args, download_if_missing=True, **kwargs) 48 | 49 | 50 | @pytest.mark.xfail 51 | def test_fetch(): 52 | try: 53 | datasets1 = fetch(shuffle=True, random_state=42) 54 | except IOError: 55 | raise SkipTest("Zenodo dataset can not be loaded.") 56 | 57 | datasets2 = fetch(shuffle=True, random_state=37) 58 | 59 | for k in DATASET_SHAPE.keys(): 60 | 61 | X1, X2 = datasets1[k].data, datasets2[k].data 62 | assert DATASET_SHAPE[k] == X1.shape 63 | assert X1.shape == X2.shape 64 | 65 | y1, y2 = datasets1[k].target, datasets2[k].target 66 | assert (X1.shape[0],) == y1.shape 67 | assert (X1.shape[0],) == y2.shape 68 | 69 | 70 | def test_fetch_filter(): 71 | try: 72 | datasets1 = fetch(filter_data=tuple([1]), shuffle=True, random_state=42) 73 | except IOError: 74 | raise SkipTest("Zenodo dataset can not be loaded.") 75 | 76 | datasets2 = fetch(filter_data=tuple(["ecoli"]), shuffle=True, random_state=37) 77 | 78 | X1, X2 = datasets1["ecoli"].data, datasets2["ecoli"].data 79 | assert DATASET_SHAPE["ecoli"] == X1.shape 80 | assert X1.shape == X2.shape 81 | 82 | assert X1.sum() == pytest.approx(X2.sum()) 83 | 84 | y1, y2 = datasets1["ecoli"].target, datasets2["ecoli"].target 85 | assert (X1.shape[0],) == y1.shape 86 | assert (X1.shape[0],) == y2.shape 87 | 88 | 89 | @pytest.mark.parametrize( 90 | "filter_data, err_msg", 91 | [ 92 | (("rnf",), "is not a dataset available"), 93 | ((-1,), "dataset with the ID="), 94 | ((100,), "dataset with the ID="), 95 | ((1.00,), "value in the tuple"), 96 | ], 97 | ) 98 | def test_fetch_error(filter_data, err_msg): 99 | with pytest.raises(ValueError, match=err_msg): 100 | fetch_zenodo_datasets(filter_data=filter_data) 101 | -------------------------------------------------------------------------------- /imbens/ensemble/_compatible/tests/test_adabost_compatible.py: -------------------------------------------------------------------------------- 1 | """Test CompatibleAdaBoostClassifier.""" 2 | 3 | # Authors: Guillaume Lemaitre 4 | # Christos Aridas 5 | # Zhining Liu 6 | # License: MIT 7 | 8 | import numpy as np 9 | import pytest 10 | import sklearn 11 | from sklearn.datasets import load_iris, make_classification 12 | from sklearn.model_selection import train_test_split 13 | from sklearn.utils._testing import assert_array_equal 14 | from sklearn.utils.fixes import parse_version 15 | 16 | from imbens.ensemble import CompatibleAdaBoostClassifier 17 | 18 | sklearn_version = parse_version(sklearn.__version__) 19 | 20 | 21 | @pytest.fixture 22 | def imbalanced_dataset(): 23 | return make_classification( 24 | n_samples=10000, 25 | n_features=3, 26 | n_informative=2, 27 | n_redundant=0, 28 | n_repeated=0, 29 | n_classes=3, 30 | n_clusters_per_class=1, 31 | weights=[0.01, 0.05, 0.94], 32 | class_sep=0.8, 33 | random_state=0, 34 | ) 35 | 36 | 37 | @pytest.mark.parametrize("algorithm", ["SAMME"]) 38 | def test_adaboost(imbalanced_dataset, algorithm): 39 | X, y = imbalanced_dataset 40 | X_train, X_test, y_train, y_test = train_test_split( 41 | X, y, stratify=y, random_state=1 42 | ) 43 | classes = np.unique(y) 44 | 45 | n_estimators = 500 46 | adaboost = CompatibleAdaBoostClassifier( 47 | n_estimators=n_estimators, algorithm=algorithm, random_state=0 48 | ) 49 | adaboost.fit(X_train, y_train) 50 | assert_array_equal(classes, adaboost.classes_) 51 | 52 | # check that we have an ensemble of estimators with a 53 | # consistent size 54 | assert len(adaboost.estimators_) > 1 55 | 56 | # each estimator in the ensemble should have different random state 57 | assert len({est.random_state for est in adaboost.estimators_}) == len( 58 | adaboost.estimators_ 59 | ) 60 | 61 | # check the consistency of the feature importances 62 | assert len(adaboost.feature_importances_) == imbalanced_dataset[0].shape[1] 63 | 64 | # check the consistency of the prediction outpus 65 | y_pred = adaboost.predict_proba(X_test) 66 | assert y_pred.shape[1] == len(classes) 67 | assert adaboost.decision_function(X_test).shape[1] == len(classes) 68 | 69 | score = adaboost.score(X_test, y_test) 70 | assert score > 0.6, f"Failed with algorithm {algorithm} and score {score}" 71 | 72 | y_pred = adaboost.predict(X_test) 73 | assert y_pred.shape == y_test.shape 74 | 75 | 76 | @pytest.mark.parametrize("algorithm", ["SAMME"]) 77 | def test_adaboost_sample_weight(imbalanced_dataset, algorithm): 78 | X, y = imbalanced_dataset 79 | sample_weight = np.ones_like(y) 80 | adaboost = CompatibleAdaBoostClassifier(algorithm=algorithm, random_state=0) 81 | 82 | # Predictions should be the same when sample_weight are all ones 83 | y_pred_sample_weight = adaboost.fit(X, y, sample_weight=sample_weight).predict(X) 84 | y_pred_no_sample_weight = adaboost.fit(X, y).predict(X) 85 | 86 | assert_array_equal(y_pred_sample_weight, y_pred_no_sample_weight) 87 | 88 | rng = np.random.RandomState(42) 89 | sample_weight = rng.rand(y.shape[0]) 90 | y_pred_sample_weight = adaboost.fit(X, y, sample_weight=sample_weight).predict(X) 91 | 92 | with pytest.raises(AssertionError): 93 | assert_array_equal(y_pred_no_sample_weight, y_pred_sample_weight) 94 | -------------------------------------------------------------------------------- /examples/evaluation/plot_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================================= 3 | Metrics specific to imbalanced learning 4 | ======================================= 5 | 6 | Specific metrics have been developed to evaluate classifier which 7 | has been trained using imbalanced data. :mod:`imbens` provides mainly 8 | two additional metrics which are not implemented in :mod:`sklearn`: (i) 9 | geometric mean (:func:`imbens.metrics.geometric_mean_score`) 10 | and (ii) index balanced accuracy (:func:`imbens.metrics.make_index_balanced_accuracy`). 11 | """ 12 | 13 | # Adapted from imbalanced-learn 14 | # Authors: Guillaume Lemaitre 15 | # License: MIT 16 | 17 | # %% 18 | print(__doc__) 19 | 20 | RANDOM_STATE = 42 21 | 22 | # sphinx_gallery_thumbnail_path = '../../docs/source/_static/thumbnail.png' 23 | 24 | # %% [markdown] 25 | # First, we will generate some imbalanced dataset. 26 | 27 | # %% 28 | from sklearn.datasets import make_classification 29 | 30 | X, y = make_classification( 31 | n_classes=3, 32 | class_sep=2, 33 | weights=[0.1, 0.9], 34 | n_informative=10, 35 | n_redundant=1, 36 | flip_y=0, 37 | n_features=20, 38 | n_clusters_per_class=4, 39 | n_samples=5000, 40 | random_state=RANDOM_STATE, 41 | ) 42 | 43 | # %% [markdown] 44 | # We will split the data into a training and testing set. 45 | 46 | # %% 47 | from sklearn.model_selection import train_test_split 48 | 49 | X_train, X_test, y_train, y_test = train_test_split( 50 | X, y, stratify=y, random_state=RANDOM_STATE 51 | ) 52 | 53 | # %% [markdown] 54 | # We will create a pipeline made of a :class:`~imbens.sampler.SMOTE` 55 | # over-sampler followed by a :class:`~sklearn.svm.LinearSVC` classifier. 56 | 57 | # %% 58 | from imbens.pipeline import make_pipeline 59 | from imbens.sampler import SMOTE 60 | from sklearn.svm import LinearSVC 61 | 62 | model = make_pipeline( 63 | SMOTE(random_state=RANDOM_STATE), LinearSVC(random_state=RANDOM_STATE) 64 | ) 65 | 66 | # %% [markdown] 67 | # Now, we will train the model on the training set and get the prediction 68 | # associated with the testing set. Be aware that the resampling will happen 69 | # only when calling `fit`: the number of samples in `y_pred` is the same than 70 | # in `y_test`. 71 | 72 | # %% 73 | model.fit(X_train, y_train) 74 | y_pred = model.predict(X_test) 75 | 76 | # %% [markdown] 77 | # The geometric mean corresponds to the square root of the product of the 78 | # sensitivity and specificity. Combining the two metrics should account for 79 | # the balancing of the dataset. 80 | 81 | # %% 82 | from imbens.metrics import geometric_mean_score 83 | 84 | print(f"The geometric mean is {geometric_mean_score(y_test, y_pred):.3f}") 85 | 86 | # %% [markdown] 87 | # The index balanced accuracy can transform any metric to be used in 88 | # imbalanced learning problems. 89 | 90 | # %% 91 | from imbens.metrics import make_index_balanced_accuracy 92 | 93 | alpha = 0.1 94 | geo_mean = make_index_balanced_accuracy(alpha=alpha, squared=True)(geometric_mean_score) 95 | 96 | print( 97 | f"The IBA using alpha={alpha} and the geometric mean: " 98 | f"{geo_mean(y_test, y_pred):.3f}" 99 | ) 100 | 101 | # %% 102 | alpha = 0.5 103 | geo_mean = make_index_balanced_accuracy(alpha=alpha, squared=True)(geometric_mean_score) 104 | 105 | print( 106 | f"The IBA using alpha={alpha} and the geometric mean: " 107 | f"{geo_mean(y_test, y_pred):.3f}" 108 | ) 109 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py: -------------------------------------------------------------------------------- 1 | """Test the module .""" 2 | # Authors: Guillaume Lemaitre 3 | # Christos Aridas 4 | # License: MIT 5 | 6 | import numpy as np 7 | import pytest 8 | from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier 9 | from sklearn.naive_bayes import GaussianNB as NB 10 | from sklearn.utils._testing import assert_array_equal 11 | 12 | from imbens.sampler._under_sampling import InstanceHardnessThreshold 13 | 14 | RND_SEED = 0 15 | X = np.array( 16 | [ 17 | [-0.3879569, 0.6894251], 18 | [-0.09322739, 1.28177189], 19 | [-0.77740357, 0.74097941], 20 | [0.91542919, -0.65453327], 21 | [-0.03852113, 0.40910479], 22 | [-0.43877303, 1.07366684], 23 | [-0.85795321, 0.82980738], 24 | [-0.18430329, 0.52328473], 25 | [-0.30126957, -0.66268378], 26 | [-0.65571327, 0.42412021], 27 | [-0.28305528, 0.30284991], 28 | [0.20246714, -0.34727125], 29 | [1.06446472, -1.09279772], 30 | [0.30543283, -0.02589502], 31 | [-0.00717161, 0.00318087], 32 | ] 33 | ) 34 | Y = np.array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0]) 35 | ESTIMATOR = GradientBoostingClassifier(random_state=RND_SEED) 36 | 37 | 38 | def test_iht_init(): 39 | sampling_strategy = "auto" 40 | iht = InstanceHardnessThreshold( 41 | estimator=ESTIMATOR, 42 | sampling_strategy=sampling_strategy, 43 | random_state=RND_SEED, 44 | ) 45 | 46 | assert iht.sampling_strategy == sampling_strategy 47 | assert iht.random_state == RND_SEED 48 | 49 | 50 | def test_iht_fit_resample(): 51 | iht = InstanceHardnessThreshold(estimator=ESTIMATOR, random_state=RND_SEED) 52 | X_resampled, y_resampled = iht.fit_resample(X, Y) 53 | assert X_resampled.shape == (12, 2) 54 | assert y_resampled.shape == (12,) 55 | 56 | 57 | def test_iht_fit_resample_half(): 58 | sampling_strategy = {0: 3, 1: 3} 59 | iht = InstanceHardnessThreshold( 60 | estimator=NB(), 61 | sampling_strategy=sampling_strategy, 62 | random_state=RND_SEED, 63 | ) 64 | X_resampled, y_resampled = iht.fit_resample(X, Y) 65 | assert X_resampled.shape == (6, 2) 66 | assert y_resampled.shape == (6,) 67 | 68 | 69 | def test_iht_fit_resample_class_obj(): 70 | est = GradientBoostingClassifier(random_state=RND_SEED) 71 | iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED) 72 | X_resampled, y_resampled = iht.fit_resample(X, Y) 73 | assert X_resampled.shape == (12, 2) 74 | assert y_resampled.shape == (12,) 75 | 76 | 77 | def test_iht_fit_resample_wrong_class_obj(): 78 | from sklearn.cluster import KMeans 79 | 80 | est = KMeans(n_init='auto') 81 | iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED) 82 | with pytest.raises(ValueError, match="Invalid parameter `estimator`"): 83 | iht.fit_resample(X, Y) 84 | 85 | 86 | def test_iht_reproducibility(): 87 | from sklearn.datasets import load_digits 88 | 89 | X_digits, y_digits = load_digits(return_X_y=True) 90 | idx_sampled = [] 91 | for seed in range(5): 92 | est = RandomForestClassifier(n_estimators=10, random_state=seed) 93 | iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED) 94 | iht.fit_resample(X_digits, y_digits) 95 | idx_sampled.append(iht.sample_indices_.copy()) 96 | for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]): 97 | assert_array_equal(idx_1, idx_2) 98 | -------------------------------------------------------------------------------- /imbens/ensemble/_over_sampling/tests/test_over_boost.py: -------------------------------------------------------------------------------- 1 | """Test OverBoostClassifier.""" 2 | 3 | # Authors: Zhining Liu 4 | # License: MIT 5 | 6 | import numpy as np 7 | import pytest 8 | import sklearn 9 | from sklearn.datasets import load_iris, make_classification 10 | from sklearn.model_selection import train_test_split 11 | from sklearn.utils._testing import assert_array_equal 12 | from sklearn.utils.fixes import parse_version 13 | 14 | from imbens.ensemble import OverBoostClassifier 15 | 16 | sklearn_version = parse_version(sklearn.__version__) 17 | 18 | 19 | @pytest.fixture 20 | def imbalanced_dataset(): 21 | return make_classification( 22 | n_samples=10000, 23 | n_features=3, 24 | n_informative=2, 25 | n_redundant=0, 26 | n_repeated=0, 27 | n_classes=3, 28 | n_clusters_per_class=1, 29 | weights=[0.01, 0.05, 0.94], 30 | class_sep=0.8, 31 | random_state=0, 32 | ) 33 | 34 | 35 | @pytest.mark.parametrize("algorithm", ["SAMME"]) 36 | def test_overboost(imbalanced_dataset, algorithm): 37 | X, y = imbalanced_dataset 38 | X_train, X_test, y_train, y_test = train_test_split( 39 | X, y, stratify=y, random_state=1 40 | ) 41 | classes = np.unique(y) 42 | 43 | n_estimators = 100 44 | overboost = OverBoostClassifier( 45 | n_estimators=n_estimators, algorithm=algorithm, random_state=0 46 | ) 47 | overboost.fit(X_train, y_train) 48 | assert_array_equal(classes, overboost.classes_) 49 | 50 | # check that we have an ensemble of samplers and estimators with a 51 | # consistent size 52 | assert len(overboost.estimators_) > 1 53 | assert len(overboost.estimators_) == len(overboost.samplers_) 54 | 55 | # each sampler in the ensemble should have different random state 56 | assert len({sampler.random_state for sampler in overboost.samplers_}) == len( 57 | overboost.samplers_ 58 | ) 59 | # each estimator in the ensemble should have different random state 60 | assert len({est.random_state for est in overboost.estimators_}) == len( 61 | overboost.estimators_ 62 | ) 63 | 64 | # check the consistency of the feature importances 65 | assert len(overboost.feature_importances_) == imbalanced_dataset[0].shape[1] 66 | 67 | # check the consistency of the prediction outpus 68 | y_pred = overboost.predict_proba(X_test) 69 | assert y_pred.shape[1] == len(classes) 70 | assert overboost.decision_function(X_test).shape[1] == len(classes) 71 | 72 | score = overboost.score(X_test, y_test) 73 | assert score > 0.6, f"Failed with algorithm {algorithm} and score {score}" 74 | 75 | y_pred = overboost.predict(X_test) 76 | assert y_pred.shape == y_test.shape 77 | 78 | 79 | @pytest.mark.parametrize("algorithm", ["SAMME"]) 80 | def test_overboost_sample_weight(imbalanced_dataset, algorithm): 81 | X, y = imbalanced_dataset 82 | sample_weight = np.ones_like(y) 83 | overboost = OverBoostClassifier(algorithm=algorithm, random_state=0) 84 | 85 | # Predictions should be the same when sample_weight are all ones 86 | y_pred_sample_weight = overboost.fit(X, y, sample_weight=sample_weight).predict(X) 87 | y_pred_no_sample_weight = overboost.fit(X, y).predict(X) 88 | 89 | assert_array_equal(y_pred_sample_weight, y_pred_no_sample_weight) 90 | 91 | rng = np.random.RandomState(42) 92 | sample_weight = rng.rand(y.shape[0]) 93 | y_pred_sample_weight = overboost.fit(X, y, sample_weight=sample_weight).predict(X) 94 | 95 | with pytest.raises(AssertionError): 96 | assert_array_equal(y_pred_no_sample_weight, y_pred_sample_weight) 97 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """Toolbox for ensemble learning on class-imbalanced dataset.""" 3 | 4 | # import codecs 5 | 6 | import io 7 | import os 8 | 9 | from setuptools import find_packages, setup, Command 10 | 11 | # get __version__ from _version.py 12 | ver_file = os.path.join("imbens", "_version.py") 13 | with open(ver_file) as f: 14 | exec(f.read()) 15 | 16 | with open("requirements.txt") as f: 17 | requirements = f.read().splitlines() 18 | 19 | DISTNAME = "imbalanced-ensemble" 20 | DESCRIPTION = "Toolbox for ensemble learning on class-imbalanced dataset." 21 | 22 | # with codecs.open("README.rst", encoding="utf-8-sig") as f: 23 | # LONG_DESCRIPTION = f.read() 24 | 25 | # Import the README and use it as the long-description. 26 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 27 | here = os.path.abspath(os.path.dirname(__file__)) 28 | try: 29 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 30 | LONG_DESCRIPTION = "\n" + f.read() 31 | except FileNotFoundError: 32 | LONG_DESCRIPTION = DESCRIPTION 33 | 34 | AUTHOR = "Zhining Liu" 35 | AUTHOR_EMAIL = "zhining.liu@outlook.com" 36 | MAINTAINER = "Zhining Liu" 37 | MAINTAINER_EMAIL = "zhining.liu@outlook.com" 38 | URL = "https://github.com/ZhiningLiu1998/imbalanced-ensemble" 39 | PROJECT_URLS = { 40 | "Documentation": "https://imbalanced-ensemble.readthedocs.io/", 41 | "Source": "https://github.com/ZhiningLiu1998/imbalanced-ensemble", 42 | "Tracker": "https://github.com/ZhiningLiu1998/imbalanced-ensemble/issues", 43 | "Changelog": "https://imbalanced-ensemble.readthedocs.io/en/latest/release_history.html", 44 | "Download": "https://pypi.org/project/imbalanced-ensemble/#files", 45 | } 46 | LICENSE = "MIT" 47 | VERSION = __version__ 48 | CLASSIFIERS = [ 49 | "Intended Audience :: Science/Research", 50 | "Intended Audience :: Developers", 51 | "License :: OSI Approved :: MIT License", 52 | "Programming Language :: C", 53 | "Programming Language :: Python", 54 | "Topic :: Software Development", 55 | "Topic :: Scientific/Engineering", 56 | "Operating System :: Microsoft :: Windows", 57 | "Operating System :: POSIX", 58 | "Operating System :: Unix", 59 | "Operating System :: MacOS", 60 | "Programming Language :: Python :: 3.9", 61 | "Programming Language :: Python :: 3.10", 62 | "Programming Language :: Python :: 3.11", 63 | "Programming Language :: Python :: 3 :: Only", 64 | ] 65 | INSTALL_REQUIRES = requirements 66 | EXTRAS_REQUIRE = { 67 | "dev": [ 68 | "black", 69 | "flake8", 70 | ], 71 | "test": [ 72 | "pytest", 73 | "pytest-cov", 74 | ], 75 | "doc": [ 76 | "sphinx", 77 | "sphinx-gallery", 78 | "sphinx_rtd_theme", 79 | "pydata-sphinx-theme", 80 | "numpydoc", 81 | "sphinxcontrib-bibtex", 82 | "torch", 83 | "pytest", 84 | ], 85 | } 86 | 87 | setup( 88 | name=DISTNAME, 89 | author=AUTHOR, 90 | author_email=AUTHOR_EMAIL, 91 | maintainer=MAINTAINER, 92 | maintainer_email=MAINTAINER_EMAIL, 93 | description=DESCRIPTION, 94 | license=LICENSE, 95 | url=URL, 96 | version=VERSION, 97 | project_urls=PROJECT_URLS, 98 | long_description=LONG_DESCRIPTION, 99 | long_description_content_type="text/markdown", 100 | zip_safe=False, # the package can run out of an .egg file 101 | classifiers=CLASSIFIERS, 102 | packages=find_packages(), 103 | install_requires=INSTALL_REQUIRES, 104 | extras_require=EXTRAS_REQUIRE, 105 | ) 106 | -------------------------------------------------------------------------------- /docs/source/sphinxext/github_link.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | from operator import attrgetter 4 | import inspect 5 | import subprocess 6 | import os 7 | import sys 8 | from functools import partial 9 | 10 | REVISION_CMD = "git rev-parse --short HEAD" 11 | 12 | # %% 13 | 14 | 15 | def _get_git_revision(): 16 | try: 17 | revision = subprocess.check_output(REVISION_CMD.split()).strip() 18 | except (subprocess.CalledProcessError, OSError): 19 | print("Failed to execute git to get revision") 20 | return None 21 | return revision.decode("utf-8") 22 | 23 | 24 | def _linkcode_resolve(domain, info, package, url_fmt, revision): 25 | """Determine a link to online source for a class/method/function 26 | 27 | This is called by sphinx.ext.linkcode 28 | 29 | An example with a long-untouched module that everyone has 30 | >>> _linkcode_resolve('py', {'module': 'tty', 31 | ... 'fullname': 'setraw'}, 32 | ... package='tty', 33 | ... url_fmt='http://hg.python.org/cpython/file/' 34 | ... '{revision}/Lib/{package}/{path}#L{lineno}', 35 | ... revision='xxxx') 36 | 'http://hg.python.org/cpython/file/xxxx/Lib/tty/tty.py#L18' 37 | """ 38 | 39 | if revision is None: 40 | return 41 | if domain not in ("py", "pyx"): 42 | return 43 | if not info.get("module") or not info.get("fullname"): 44 | return 45 | 46 | class_name = info["fullname"].split(".")[0] 47 | if type(class_name) != str: 48 | # Python 2 only 49 | class_name = class_name.encode("utf-8") 50 | module = __import__(info["module"], fromlist=[class_name]) 51 | obj = attrgetter(info["fullname"])(module) 52 | 53 | # try: 54 | # fn = inspect.getsourcefile(obj) 55 | # except Exception: 56 | # fn = None 57 | # if not fn: 58 | # try: 59 | # fn = inspect.getsourcefile(sys.modules[obj.__module__]) 60 | # except Exception: 61 | # fn = None 62 | try: 63 | fn = inspect.getsourcefile(sys.modules[obj.__module__]) 64 | except Exception: 65 | fn = None 66 | if not fn: 67 | return 68 | 69 | fn = os.path.relpath(fn, start=os.path.dirname(__import__(package).__file__)) 70 | try: 71 | src_code_lines, lineno = inspect.getsourcelines(obj) 72 | i = 0 73 | for l in src_code_lines: 74 | if 'def' in l or 'class' in l: 75 | break 76 | i += 1 77 | lineno += i 78 | except Exception: 79 | lineno = "" 80 | return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno) 81 | 82 | 83 | def make_linkcode_resolve(package, url_fmt): 84 | """Returns a linkcode_resolve function for the given URL format 85 | 86 | revision is a git commit reference (hash or name) 87 | 88 | package is the name of the root module of the package 89 | 90 | url_fmt is along the lines of ('https://github.com/USER/PROJECT/' 91 | 'blob/{revision}/{package}/' 92 | '{path}#L{lineno}') 93 | """ 94 | revision = _get_git_revision() 95 | return partial( 96 | _linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt 97 | ) 98 | 99 | 100 | # %% 101 | 102 | linkcode_resolve = make_linkcode_resolve( 103 | "imbens", 104 | "https://github.com/ZhiningLiu1998/" 105 | "imbalanced-ensemble/blob/{revision}/" 106 | "{package}/{path}#L{lineno}", 107 | ) 108 | 109 | linkcode_resolve 110 | -------------------------------------------------------------------------------- /imbens/ensemble/_under_sampling/tests/test_rus_boost.py: -------------------------------------------------------------------------------- 1 | """Test RUSBoostClassifier.""" 2 | 3 | # Authors: Guillaume Lemaitre 4 | # Christos Aridas 5 | # Zhining Liu 6 | # License: MIT 7 | 8 | import numpy as np 9 | import pytest 10 | import sklearn 11 | from sklearn.datasets import load_iris, make_classification 12 | from sklearn.model_selection import train_test_split 13 | from sklearn.utils._testing import assert_array_equal 14 | from sklearn.utils.fixes import parse_version 15 | 16 | from imbens.ensemble import RUSBoostClassifier 17 | 18 | sklearn_version = parse_version(sklearn.__version__) 19 | 20 | 21 | @pytest.fixture 22 | def imbalanced_dataset(): 23 | return make_classification( 24 | n_samples=10000, 25 | n_features=3, 26 | n_informative=2, 27 | n_redundant=0, 28 | n_repeated=0, 29 | n_classes=3, 30 | n_clusters_per_class=1, 31 | weights=[0.01, 0.05, 0.94], 32 | class_sep=0.8, 33 | random_state=0, 34 | ) 35 | 36 | 37 | @pytest.mark.parametrize("algorithm", ["SAMME"]) 38 | def test_rusboost(imbalanced_dataset, algorithm): 39 | X, y = imbalanced_dataset 40 | X_train, X_test, y_train, y_test = train_test_split( 41 | X, y, stratify=y, random_state=1 42 | ) 43 | classes = np.unique(y) 44 | 45 | n_estimators = 500 46 | rusboost = RUSBoostClassifier( 47 | n_estimators=n_estimators, algorithm=algorithm, random_state=0 48 | ) 49 | rusboost.fit(X_train, y_train) 50 | assert_array_equal(classes, rusboost.classes_) 51 | 52 | # check that we have an ensemble of samplers and estimators with a 53 | # consistent size 54 | assert len(rusboost.estimators_) > 1 55 | assert len(rusboost.estimators_) == len(rusboost.samplers_) 56 | 57 | # each sampler in the ensemble should have different random state 58 | assert len({sampler.random_state for sampler in rusboost.samplers_}) == len( 59 | rusboost.samplers_ 60 | ) 61 | # each estimator in the ensemble should have different random state 62 | assert len({est.random_state for est in rusboost.estimators_}) == len( 63 | rusboost.estimators_ 64 | ) 65 | 66 | # check the consistency of the feature importances 67 | assert len(rusboost.feature_importances_) == imbalanced_dataset[0].shape[1] 68 | 69 | # check the consistency of the prediction outpus 70 | y_pred = rusboost.predict_proba(X_test) 71 | assert y_pred.shape[1] == len(classes) 72 | assert rusboost.decision_function(X_test).shape[1] == len(classes) 73 | 74 | score = rusboost.score(X_test, y_test) 75 | assert score > 0.6, f"Failed with algorithm {algorithm} and score {score}" 76 | 77 | y_pred = rusboost.predict(X_test) 78 | assert y_pred.shape == y_test.shape 79 | 80 | 81 | @pytest.mark.parametrize("algorithm", ["SAMME"]) 82 | def test_rusboost_sample_weight(imbalanced_dataset, algorithm): 83 | X, y = imbalanced_dataset 84 | sample_weight = np.ones_like(y) 85 | rusboost = RUSBoostClassifier(algorithm=algorithm, random_state=0) 86 | 87 | # Predictions should be the same when sample_weight are all ones 88 | y_pred_sample_weight = rusboost.fit(X, y, sample_weight=sample_weight).predict(X) 89 | y_pred_no_sample_weight = rusboost.fit(X, y).predict(X) 90 | 91 | assert_array_equal(y_pred_sample_weight, y_pred_no_sample_weight) 92 | 93 | rng = np.random.RandomState(42) 94 | sample_weight = rng.rand(y.shape[0]) 95 | y_pred_sample_weight = rusboost.fit(X, y, sample_weight=sample_weight).predict(X) 96 | 97 | with pytest.raises(AssertionError): 98 | assert_array_equal(y_pred_no_sample_weight, y_pred_sample_weight) 99 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_selection/tests/test_one_sided_selection.py: -------------------------------------------------------------------------------- 1 | """Test the module one-sided selection.""" 2 | # Authors: Guillaume Lemaitre 3 | # Christos Aridas 4 | # License: MIT 5 | 6 | import numpy as np 7 | import pytest 8 | from sklearn.neighbors import KNeighborsClassifier 9 | from sklearn.utils._testing import assert_array_equal 10 | 11 | from imbens.sampler._under_sampling import OneSidedSelection 12 | 13 | RND_SEED = 0 14 | X = np.array( 15 | [ 16 | [-0.3879569, 0.6894251], 17 | [-0.09322739, 1.28177189], 18 | [-0.77740357, 0.74097941], 19 | [0.91542919, -0.65453327], 20 | [-0.03852113, 0.40910479], 21 | [-0.43877303, 1.07366684], 22 | [-0.85795321, 0.82980738], 23 | [-0.18430329, 0.52328473], 24 | [-0.30126957, -0.66268378], 25 | [-0.65571327, 0.42412021], 26 | [-0.28305528, 0.30284991], 27 | [0.20246714, -0.34727125], 28 | [1.06446472, -1.09279772], 29 | [0.30543283, -0.02589502], 30 | [-0.00717161, 0.00318087], 31 | ] 32 | ) 33 | Y = np.array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0]) 34 | 35 | 36 | def test_oss_init(): 37 | oss = OneSidedSelection(random_state=RND_SEED) 38 | 39 | assert oss.n_seeds_S == 1 40 | assert oss.n_jobs is None 41 | assert oss.random_state == RND_SEED 42 | 43 | 44 | def test_oss_fit_resample(): 45 | oss = OneSidedSelection(random_state=RND_SEED) 46 | X_resampled, y_resampled = oss.fit_resample(X, Y) 47 | 48 | X_gt = np.array( 49 | [ 50 | [-0.3879569, 0.6894251], 51 | [0.91542919, -0.65453327], 52 | [-0.65571327, 0.42412021], 53 | [1.06446472, -1.09279772], 54 | [0.30543283, -0.02589502], 55 | [-0.00717161, 0.00318087], 56 | [-0.09322739, 1.28177189], 57 | [-0.77740357, 0.74097941], 58 | [-0.43877303, 1.07366684], 59 | [-0.85795321, 0.82980738], 60 | [-0.30126957, -0.66268378], 61 | [0.20246714, -0.34727125], 62 | ] 63 | ) 64 | y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) 65 | assert_array_equal(X_resampled, X_gt) 66 | assert_array_equal(y_resampled, y_gt) 67 | 68 | 69 | def test_oss_with_object(): 70 | knn = KNeighborsClassifier(n_neighbors=1) 71 | oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn) 72 | X_resampled, y_resampled = oss.fit_resample(X, Y) 73 | 74 | X_gt = np.array( 75 | [ 76 | [-0.3879569, 0.6894251], 77 | [0.91542919, -0.65453327], 78 | [-0.65571327, 0.42412021], 79 | [1.06446472, -1.09279772], 80 | [0.30543283, -0.02589502], 81 | [-0.00717161, 0.00318087], 82 | [-0.09322739, 1.28177189], 83 | [-0.77740357, 0.74097941], 84 | [-0.43877303, 1.07366684], 85 | [-0.85795321, 0.82980738], 86 | [-0.30126957, -0.66268378], 87 | [0.20246714, -0.34727125], 88 | ] 89 | ) 90 | y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) 91 | assert_array_equal(X_resampled, X_gt) 92 | assert_array_equal(y_resampled, y_gt) 93 | knn = 1 94 | oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn) 95 | X_resampled, y_resampled = oss.fit_resample(X, Y) 96 | assert_array_equal(X_resampled, X_gt) 97 | assert_array_equal(y_resampled, y_gt) 98 | 99 | 100 | def test_oss_with_wrong_object(): 101 | knn = "rnd" 102 | oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn) 103 | with pytest.raises(ValueError, match="has to be a int"): 104 | oss.fit_resample(X, Y) 105 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py: -------------------------------------------------------------------------------- 1 | """Test the module condensed nearest neighbour.""" 2 | # Authors: Guillaume Lemaitre 3 | # Christos Aridas 4 | # License: MIT 5 | 6 | import numpy as np 7 | import pytest 8 | from sklearn.neighbors import KNeighborsClassifier 9 | from sklearn.utils._testing import assert_array_equal 10 | 11 | from imbens.sampler._under_sampling import CondensedNearestNeighbour 12 | 13 | RND_SEED = 0 14 | X = np.array( 15 | [ 16 | [2.59928271, 0.93323465], 17 | [0.25738379, 0.95564169], 18 | [1.42772181, 0.526027], 19 | [1.92365863, 0.82718767], 20 | [-0.10903849, -0.12085181], 21 | [-0.284881, -0.62730973], 22 | [0.57062627, 1.19528323], 23 | [0.03394306, 0.03986753], 24 | [0.78318102, 2.59153329], 25 | [0.35831463, 1.33483198], 26 | [-0.14313184, -1.0412815], 27 | [0.01936241, 0.17799828], 28 | [-1.25020462, -0.40402054], 29 | [-0.09816301, -0.74662486], 30 | [-0.01252787, 0.34102657], 31 | [0.52726792, -0.38735648], 32 | [0.2821046, -0.07862747], 33 | [0.05230552, 0.09043907], 34 | [0.15198585, 0.12512646], 35 | [0.70524765, 0.39816382], 36 | ] 37 | ) 38 | Y = np.array([1, 2, 1, 1, 0, 2, 2, 2, 2, 2, 2, 0, 1, 2, 2, 2, 2, 1, 2, 1]) 39 | 40 | 41 | def test_cnn_init(): 42 | cnn = CondensedNearestNeighbour(random_state=RND_SEED) 43 | 44 | assert cnn.n_seeds_S == 1 45 | assert cnn.n_jobs is None 46 | 47 | 48 | def test_cnn_fit_resample(): 49 | cnn = CondensedNearestNeighbour(random_state=RND_SEED) 50 | X_resampled, y_resampled = cnn.fit_resample(X, Y) 51 | 52 | X_gt = np.array( 53 | [ 54 | [-0.10903849, -0.12085181], 55 | [0.01936241, 0.17799828], 56 | [0.05230552, 0.09043907], 57 | [-1.25020462, -0.40402054], 58 | [0.70524765, 0.39816382], 59 | [0.35831463, 1.33483198], 60 | [-0.284881, -0.62730973], 61 | [0.03394306, 0.03986753], 62 | [-0.01252787, 0.34102657], 63 | [0.15198585, 0.12512646], 64 | ] 65 | ) 66 | y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2]) 67 | assert_array_equal(X_resampled, X_gt) 68 | assert_array_equal(y_resampled, y_gt) 69 | 70 | 71 | def test_cnn_fit_resample_with_object(): 72 | knn = KNeighborsClassifier(n_neighbors=1) 73 | cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=knn) 74 | X_resampled, y_resampled = cnn.fit_resample(X, Y) 75 | 76 | X_gt = np.array( 77 | [ 78 | [-0.10903849, -0.12085181], 79 | [0.01936241, 0.17799828], 80 | [0.05230552, 0.09043907], 81 | [-1.25020462, -0.40402054], 82 | [0.70524765, 0.39816382], 83 | [0.35831463, 1.33483198], 84 | [-0.284881, -0.62730973], 85 | [0.03394306, 0.03986753], 86 | [-0.01252787, 0.34102657], 87 | [0.15198585, 0.12512646], 88 | ] 89 | ) 90 | y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2]) 91 | assert_array_equal(X_resampled, X_gt) 92 | assert_array_equal(y_resampled, y_gt) 93 | 94 | cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=1) 95 | X_resampled, y_resampled = cnn.fit_resample(X, Y) 96 | assert_array_equal(X_resampled, X_gt) 97 | assert_array_equal(y_resampled, y_gt) 98 | 99 | 100 | def test_cnn_fit_resample_with_wrong_object(): 101 | knn = "rnd" 102 | cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=knn) 103 | with pytest.raises(ValueError, match="has to be a int or an "): 104 | cnn.fit_resample(X, Y) 105 | -------------------------------------------------------------------------------- /imbens/sampler/_under_sampling/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for the under-sampling method. 3 | """ 4 | # Adapted from imbalanced-learn 5 | 6 | # Authors: Guillaume Lemaitre 7 | # License: MIT 8 | 9 | # %% 10 | LOCAL_DEBUG = False 11 | 12 | if not LOCAL_DEBUG: 13 | from ..base import BaseSampler 14 | else: # pragma: no cover 15 | import sys # For local test 16 | 17 | sys.path.append("../..") 18 | from sampler.base import BaseSampler 19 | 20 | 21 | class BaseUnderSampler(BaseSampler): 22 | """Base class for under-sampling algorithms. 23 | 24 | Warning: This class should not be used directly. Use the derive classes 25 | instead. 26 | """ 27 | 28 | _sampling_type = "under-sampling" 29 | 30 | _sampling_strategy_docstring = """sampling_strategy : float, str, dict, callable, default='auto' 31 | Sampling information to sample the data set. 32 | 33 | - When ``float``, it corresponds to the desired ratio of the number of 34 | samples in the minority class over the number of samples in the 35 | majority class after resampling. Therefore, the ratio is expressed as 36 | :math:`\\alpha_{us} = N_{m} / N_{rM}` where :math:`N_{m}` is the 37 | number of samples in the minority class and 38 | :math:`N_{rM}` is the number of samples in the majority class 39 | after resampling. 40 | 41 | .. warning:: 42 | ``float`` is only available for **binary** classification. An 43 | error is raised for multi-class classification. 44 | 45 | - When ``str``, specify the class targeted by the resampling. The 46 | number of samples in the different classes will be equalized. 47 | Possible choices are: 48 | 49 | ``'majority'``: resample only the majority class; 50 | 51 | ``'not minority'``: resample all classes but the minority class; 52 | 53 | ``'not majority'``: resample all classes but the majority class; 54 | 55 | ``'all'``: resample all classes; 56 | 57 | ``'auto'``: equivalent to ``'not minority'``. 58 | 59 | - When ``dict``, the keys correspond to the targeted classes. The 60 | values correspond to the desired number of samples for each targeted 61 | class. 62 | 63 | - When callable, function taking ``y`` and returns a ``dict``. The keys 64 | correspond to the targeted classes. The values correspond to the 65 | desired number of samples for each class. 66 | """.rstrip() 67 | 68 | 69 | class BaseCleaningSampler(BaseSampler): 70 | """Base class for under-sampling algorithms. 71 | 72 | Warning: This class should not be used directly. Use the derive classes 73 | instead. 74 | """ 75 | 76 | _sampling_type = "clean-sampling" 77 | 78 | _sampling_strategy_docstring = """sampling_strategy : str, list or callable 79 | Sampling information to sample the data set. 80 | 81 | - When ``str``, specify the class targeted by the resampling. Note the 82 | the number of samples will not be equal in each. Possible choices 83 | are: 84 | 85 | ``'majority'``: resample only the majority class; 86 | 87 | ``'not minority'``: resample all classes but the minority class; 88 | 89 | ``'not majority'``: resample all classes but the majority class; 90 | 91 | ``'all'``: resample all classes; 92 | 93 | ``'auto'``: equivalent to ``'not minority'``. 94 | 95 | - When ``list``, the list contains the classes targeted by the 96 | resampling. 97 | 98 | - When callable, function taking ``y`` and returns a ``dict``. The keys 99 | correspond to the targeted classes. The values correspond to the 100 | desired number of samples for each class. 101 | """.rstrip() 102 | 103 | 104 | # %% 105 | -------------------------------------------------------------------------------- /examples/datasets/plot_make_imbalance_digits.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================= 3 | Make digits dataset class-imbalanced 4 | ========================================================= 5 | 6 | An illustration of the :func:`~imbens.datasets.make_imbalance` 7 | function to create an imbalanced version of the digits dataset. 8 | """ 9 | 10 | # Authors: Zhining Liu 11 | # License: MIT 12 | 13 | # %% 14 | print(__doc__) 15 | 16 | # Import imbalanced-ensemble 17 | import imbens 18 | 19 | # Import utilities 20 | import sklearn 21 | from imbens.datasets import make_imbalance 22 | from imbens.utils._plot import plot_2Dprojection_and_cardinality, plot_scatter 23 | import matplotlib.pyplot as plt 24 | import seaborn as sns 25 | 26 | RANDOM_STATE = 42 27 | 28 | # sphinx_gallery_thumbnail_number = -1 29 | 30 | # %% [markdown] 31 | # Digits dataset 32 | # -------------- 33 | # The digits dataset consists of 8x8 pixel images of digits. The images attribute of the dataset stores 8x8 arrays of grayscale values for each image. We will use these arrays to visualize the first 4 images. The target attribute of the dataset stores the digit each image represents and this is included in the title of the 4 plots below. 34 | 35 | digits = sklearn.datasets.load_digits() 36 | 37 | # flatten the images 38 | n_samples = len(digits.images) 39 | X, y = digits.images.reshape((n_samples, -1)), digits.target 40 | 41 | _, axes = plt.subplots(nrows=3, ncols=4, figsize=(10, 9)) 42 | for ax, image, label in zip(axes.flatten(), digits.images, digits.target): 43 | ax.set_axis_off() 44 | ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') 45 | ax.set_title('Training: %i' % label) 46 | 47 | 48 | # %% [markdown] 49 | # **The original digits dataset** 50 | 51 | fig = plot_2Dprojection_and_cardinality(X, y, figsize=(8, 4)) 52 | 53 | 54 | # %% [markdown] 55 | # **Make class-imbalanced digits dataset** 56 | 57 | imbalance_distr = { 58 | 0: 178, 59 | 1: 120, 60 | 2: 80, 61 | 3: 60, 62 | 4: 50, 63 | 5: 44, 64 | 6: 40, 65 | 7: 40, 66 | 8: 40, 67 | 9: 40, 68 | } 69 | 70 | X_imb, y_imb = make_imbalance( 71 | X, y, sampling_strategy=imbalance_distr, random_state=RANDOM_STATE 72 | ) 73 | 74 | fig = plot_2Dprojection_and_cardinality(X_imb, y_imb, figsize=(8, 4)) 75 | 76 | 77 | # %% [markdown] 78 | # Use TSNE to compare the original & imbalanced Digits datasets 79 | # ------------------------------------------------------------- 80 | # We can observe that it is more difficult to distinguish the tail classes from each other in the imbalanced Digits dataset. 81 | # These tailed classes are not well represented, thus it is harder for a learning model to learn their patterns. 82 | 83 | sns.set_context('talk') 84 | 85 | tsne = sklearn.manifold.TSNE( 86 | n_components=2, perplexity=100, n_iter=500, random_state=RANDOM_STATE 87 | ) 88 | 89 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) 90 | 91 | # Plot original digits data 92 | plot_scatter( 93 | tsne.fit_transform(X), 94 | y, 95 | title='Original Digits Data', 96 | weights=100, 97 | vis_params={'edgecolor': 'black', 'alpha': 0.8}, 98 | ax=ax1, 99 | ) 100 | ax1.legend( 101 | ncol=2, 102 | loc=2, 103 | columnspacing=0.01, 104 | borderaxespad=0.1, 105 | handletextpad=0.01, 106 | labelspacing=0.01, 107 | handlelength=None, 108 | ) 109 | 110 | # Plot imbalanced digits data 111 | plot_scatter( 112 | tsne.fit_transform(X_imb), 113 | y_imb, 114 | title='Imbalanced Digits Data', 115 | weights=100, 116 | vis_params={'edgecolor': 'black', 'alpha': 0.8}, 117 | ax=ax2, 118 | ) 119 | ax2.legend( 120 | ncol=2, 121 | loc=2, 122 | columnspacing=0.01, 123 | borderaxespad=0.1, 124 | handletextpad=0.01, 125 | labelspacing=0.01, 126 | handlelength=None, 127 | ) 128 | 129 | fig.tight_layout() 130 | -------------------------------------------------------------------------------- /examples/basic/plot_training_log.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================= 3 | Customize ensemble training log 4 | ========================================================= 5 | 6 | This example illustrates how to enable and customize the training 7 | log when training an :mod:`imbens.ensemble` classifier. 8 | 9 | This example uses: 10 | 11 | - :class:`imbens.ensemble.SelfPacedEnsembleClassifier` 12 | """ 13 | 14 | # Authors: Zhining Liu 15 | # License: MIT 16 | 17 | 18 | # %% 19 | print(__doc__) 20 | 21 | # Import imbalanced-ensemble 22 | import imbens 23 | 24 | # Import utilities 25 | import sklearn 26 | from sklearn.datasets import make_classification 27 | from sklearn.model_selection import train_test_split 28 | 29 | RANDOM_STATE = 42 30 | 31 | # sphinx_gallery_thumbnail_path = '../../docs/source/_static/training_log_thumbnail.png' 32 | 33 | # %% [markdown] 34 | # Prepare data 35 | # ---------------------------- 36 | # Make a toy 3-class imbalanced classification task. 37 | 38 | # make dataset 39 | X, y = make_classification( 40 | n_classes=3, 41 | class_sep=2, 42 | weights=[0.1, 0.3, 0.6], 43 | n_informative=3, 44 | n_redundant=1, 45 | flip_y=0, 46 | n_features=20, 47 | n_clusters_per_class=2, 48 | n_samples=2000, 49 | random_state=0, 50 | ) 51 | 52 | # train valid split 53 | X_train, X_valid, y_train, y_valid = train_test_split( 54 | X, y, test_size=0.5, stratify=y, random_state=RANDOM_STATE 55 | ) 56 | 57 | # %% [markdown] 58 | # Customize training log 59 | # --------------------------------------------------------------------------- 60 | # Take ``SelfPacedEnsembleClassifier`` as example, training log is controlled by 3 parameters of the ``fit()`` method: 61 | # 62 | # - ``eval_datasets``: Dataset(s) used for evaluation during the ensemble training. 63 | # - ``eval_metrics``: Metric(s) used for evaluation during the ensemble training. 64 | # - ``train_verbose``: Controls the granularity and content of the training log. 65 | 66 | clf = imbens.ensemble.SelfPacedEnsembleClassifier(random_state=RANDOM_STATE) 67 | 68 | # %% [markdown] 69 | # Set training log format 70 | # ----------------------- 71 | # (``fit()`` parameter: ``train_verbose``: bool, int or dict) 72 | 73 | # %% [markdown] 74 | # **Enable auto training log** 75 | 76 | clf.fit( 77 | X_train, 78 | y_train, 79 | train_verbose=True, 80 | ) 81 | 82 | 83 | # %% [markdown] 84 | # **Customize training log granularity** 85 | 86 | clf.fit( 87 | X_train, 88 | y_train, 89 | train_verbose={ 90 | 'granularity': 10, 91 | }, 92 | ) 93 | 94 | 95 | # %% [markdown] 96 | # **Customize training log content column** 97 | 98 | clf.fit( 99 | X_train, 100 | y_train, 101 | train_verbose={ 102 | 'granularity': 10, 103 | 'print_distribution': False, 104 | 'print_metrics': True, 105 | }, 106 | ) 107 | 108 | 109 | # %% [markdown] 110 | # Add additional evaluation dataset(s) 111 | # ------------------------------------ 112 | # (``fit()`` parameter: ``eval_datasets``: dict) 113 | 114 | clf.fit( 115 | X_train, 116 | y_train, 117 | eval_datasets={ 118 | 'valid': (X_valid, y_valid), # add validation data 119 | }, 120 | train_verbose={ 121 | 'granularity': 10, 122 | }, 123 | ) 124 | 125 | 126 | # %% [markdown] 127 | # Specify evaluation metric(s) 128 | # ---------------------------- 129 | # (``fit()`` parameter: ``eval_metrics``: dict) 130 | 131 | clf.fit( 132 | X_train, 133 | y_train, 134 | eval_datasets={ 135 | 'valid': (X_valid, y_valid), 136 | }, 137 | eval_metrics={ 138 | 'weighted_f1': ( 139 | sklearn.metrics.f1_score, 140 | {'average': 'weighted'}, 141 | ), # use weighted_f1 142 | }, 143 | train_verbose={ 144 | 'granularity': 10, 145 | }, 146 | ) 147 | 148 | # %% 149 | -------------------------------------------------------------------------------- /docs/source/sg_execution_times.rst: -------------------------------------------------------------------------------- 1 | 2 | :orphan: 3 | 4 | .. _sphx_glr_sg_execution_times: 5 | 6 | 7 | Computation times 8 | ================= 9 | **00:02.812** total execution time for 18 files **from all galleries**: 10 | 11 | .. container:: 12 | 13 | .. raw:: html 14 | 15 | 19 | 20 | 21 | 22 | 27 | 28 | .. list-table:: 29 | :header-rows: 1 30 | :class: table table-striped sg-datatable 31 | 32 | * - Example 33 | - Time 34 | - Mem (MB) 35 | * - :ref:`sphx_glr_auto_examples_classification_plot_torch.py` (``..\..\examples\classification\plot_torch.py``) 36 | - 00:02.812 37 | - 0.0 38 | * - :ref:`sphx_glr_auto_examples_basic_plot_basic_example.py` (``..\..\examples\basic\plot_basic_example.py``) 39 | - 00:00.000 40 | - 0.0 41 | * - :ref:`sphx_glr_auto_examples_basic_plot_basic_visualize.py` (``..\..\examples\basic\plot_basic_visualize.py``) 42 | - 00:00.000 43 | - 0.0 44 | * - :ref:`sphx_glr_auto_examples_basic_plot_training_log.py` (``..\..\examples\basic\plot_training_log.py``) 45 | - 00:00.000 46 | - 0.0 47 | * - :ref:`sphx_glr_auto_examples_classification_plot_classifier_comparison.py` (``..\..\examples\classification\plot_classifier_comparison.py``) 48 | - 00:00.000 49 | - 0.0 50 | * - :ref:`sphx_glr_auto_examples_classification_plot_cost_matrix.py` (``..\..\examples\classification\plot_cost_matrix.py``) 51 | - 00:00.000 52 | - 0.0 53 | * - :ref:`sphx_glr_auto_examples_classification_plot_digits.py` (``..\..\examples\classification\plot_digits.py``) 54 | - 00:00.000 55 | - 0.0 56 | * - :ref:`sphx_glr_auto_examples_classification_plot_probability.py` (``..\..\examples\classification\plot_probability.py``) 57 | - 00:00.000 58 | - 0.0 59 | * - :ref:`sphx_glr_auto_examples_classification_plot_resampling_target.py` (``..\..\examples\classification\plot_resampling_target.py``) 60 | - 00:00.000 61 | - 0.0 62 | * - :ref:`sphx_glr_auto_examples_classification_plot_sampling_schedule.py` (``..\..\examples\classification\plot_sampling_schedule.py``) 63 | - 00:00.000 64 | - 0.0 65 | * - :ref:`sphx_glr_auto_examples_datasets_plot_generate_imbalance.py` (``..\..\examples\datasets\plot_generate_imbalance.py``) 66 | - 00:00.000 67 | - 0.0 68 | * - :ref:`sphx_glr_auto_examples_datasets_plot_make_imbalance.py` (``..\..\examples\datasets\plot_make_imbalance.py``) 69 | - 00:00.000 70 | - 0.0 71 | * - :ref:`sphx_glr_auto_examples_datasets_plot_make_imbalance_digits.py` (``..\..\examples\datasets\plot_make_imbalance_digits.py``) 72 | - 00:00.000 73 | - 0.0 74 | * - :ref:`sphx_glr_auto_examples_evaluation_plot_classification_report.py` (``..\..\examples\evaluation\plot_classification_report.py``) 75 | - 00:00.000 76 | - 0.0 77 | * - :ref:`sphx_glr_auto_examples_evaluation_plot_metrics.py` (``..\..\examples\evaluation\plot_metrics.py``) 78 | - 00:00.000 79 | - 0.0 80 | * - :ref:`sphx_glr_auto_examples_pipeline_plot_pipeline_classification.py` (``..\..\examples\pipeline\plot_pipeline_classification.py``) 81 | - 00:00.000 82 | - 0.0 83 | * - :ref:`sphx_glr_auto_examples_visualizer_plot_confusion_matrix.py` (``..\..\examples\visualizer\plot_confusion_matrix.py``) 84 | - 00:00.000 85 | - 0.0 86 | * - :ref:`sphx_glr_auto_examples_visualizer_plot_performance_curve.py` (``..\..\examples\visualizer\plot_performance_curve.py``) 87 | - 00:00.000 88 | - 0.0 89 | -------------------------------------------------------------------------------- /examples/classification/plot_probability.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================================================= 3 | Plot probabilities with different base classifiers 4 | ================================================================= 5 | 6 | Plot the classification probability for ensemble models with different base classifiers. 7 | 8 | We use a 3-class imbalanced dataset, and we classify it with a ``SelfPacedEnsembleClassifier`` (ensemble size = 5). 9 | We use Decision Tree, Support Vector Machine (rbf kernel), and Gaussian process classifier as the base classifier. 10 | 11 | This example uses: 12 | 13 | - :class:`imbens.ensemble.SelfPacedEnsembleClassifier` 14 | """ 15 | 16 | # Adapted from sklearn 17 | # Author: Zhining Liu 18 | # Alexandre Gramfort 19 | # License: BSD 3 clause 20 | 21 | # %% 22 | print(__doc__) 23 | 24 | # Import imbalanced-ensemble 25 | import imbens 26 | 27 | # Import utilities 28 | import numpy as np 29 | from collections import Counter 30 | import sklearn 31 | from imbens.datasets import make_imbalance 32 | from imbens.ensemble.base import sort_dict_by_key 33 | 34 | RANDOM_STATE = 42 35 | 36 | # %% [markdown] 37 | # Preparation 38 | # ----------- 39 | # **Make 3 imbalanced iris classification tasks.** 40 | 41 | iris = sklearn.datasets.load_iris() 42 | X = iris.data[:, 0:2] # we only take the first two features for visualization 43 | y = iris.target 44 | 45 | X, y = make_imbalance( 46 | X, y, sampling_strategy={0: 50, 1: 30, 2: 10}, random_state=RANDOM_STATE 47 | ) 48 | print( 49 | 'Class distribution of imbalanced iris dataset: \n%s' % sort_dict_by_key(Counter(y)) 50 | ) 51 | 52 | 53 | # %% [markdown] 54 | # **Create SPE (ensemble size = 5) with different base classifiers.** 55 | 56 | from sklearn.svm import SVC 57 | from sklearn.tree import DecisionTreeClassifier 58 | from sklearn.gaussian_process import GaussianProcessClassifier 59 | from sklearn.gaussian_process.kernels import RBF 60 | 61 | classifiers = { 62 | 'SPE-DT': imbens.ensemble.SelfPacedEnsembleClassifier( 63 | n_estimators=5, 64 | estimator=DecisionTreeClassifier(), 65 | ), 66 | 'SPE-SVM-rbf': imbens.ensemble.SelfPacedEnsembleClassifier( 67 | n_estimators=5, 68 | estimator=SVC(kernel='rbf', probability=True), 69 | ), 70 | 'SPE-GPC': imbens.ensemble.SelfPacedEnsembleClassifier( 71 | n_estimators=5, 72 | estimator=GaussianProcessClassifier(1.0 * RBF([1.0, 1.0])), 73 | ), 74 | } 75 | 76 | n_classifiers = len(classifiers) 77 | 78 | 79 | # %% [markdown] 80 | # Plot classification probabilities 81 | # --------------------------------- 82 | 83 | import matplotlib.pyplot as plt 84 | 85 | n_features = X.shape[1] 86 | 87 | plt.figure(figsize=(3 * 2, n_classifiers * 2)) 88 | plt.subplots_adjust(bottom=0.2, top=0.95) 89 | 90 | xx = np.linspace(3, 9, 100) 91 | yy = np.linspace(1, 5, 100).T 92 | xx, yy = np.meshgrid(xx, yy) 93 | Xfull = np.c_[xx.ravel(), yy.ravel()] 94 | 95 | for index, (name, classifier) in enumerate(classifiers.items()): 96 | classifier.fit(X, y) 97 | 98 | y_pred = classifier.predict(X) 99 | accuracy = sklearn.metrics.balanced_accuracy_score(y, y_pred) 100 | print("Balanced Accuracy (train) for %s: %0.1f%% " % (name, accuracy * 100)) 101 | 102 | # View probabilities: 103 | probas = classifier.predict_proba(Xfull) 104 | n_classes = np.unique(y_pred).size 105 | for k in range(n_classes): 106 | plt.subplot(n_classifiers, n_classes, index * n_classes + k + 1) 107 | plt.title("Class %d" % k) 108 | if k == 0: 109 | plt.ylabel(name) 110 | imshow_handle = plt.imshow( 111 | probas[:, k].reshape((100, 100)), extent=(3, 9, 1, 5), origin='lower' 112 | ) 113 | plt.xticks(()) 114 | plt.yticks(()) 115 | idx = y_pred == k 116 | if idx.any(): 117 | plt.scatter(X[idx, 0], X[idx, 1], marker='o', c='w', edgecolor='k') 118 | 119 | ax = plt.axes([0.15, 0.04, 0.7, 0.05]) 120 | plt.title("Probability") 121 | plt.colorbar(imshow_handle, cax=ax, orientation='horizontal') 122 | plt.show() 123 | -------------------------------------------------------------------------------- /imbens/utils/_plot.py: -------------------------------------------------------------------------------- 1 | """Utilities for data visualization.""" 2 | 3 | # Authors: Zhining Liu 4 | # License: MIT 5 | 6 | from collections import Counter 7 | from copy import copy 8 | 9 | import matplotlib.image as mpimg 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import pandas as pd 13 | import seaborn as sns 14 | from sklearn.decomposition import KernelPCA 15 | 16 | DEFAULT_VIS_KWARGS = { 17 | # 'cmap': plt.cm.rainbow, 18 | 'edgecolor': 'black', 19 | 'alpha': 0.6, 20 | } 21 | 22 | 23 | def set_ax_border(ax, border_color='black', border_width=2): 24 | '''Set border color and width''' 25 | for _, spine in ax.spines.items(): 26 | spine.set_color(border_color) 27 | spine.set_linewidth(border_width) 28 | 29 | return ax 30 | 31 | 32 | def plot_scatter( 33 | X, y, ax=None, weights=None, title='', projection=None, vis_params=None 34 | ): 35 | '''Plot scatter with given projection''' 36 | 37 | if ax is None: 38 | ax = plt.axes() 39 | if projection is None: 40 | projection = KernelPCA(n_components=2).fit(X, y) 41 | if vis_params is None: 42 | vis_params = copy(DEFAULT_VIS_KWARGS) 43 | 44 | if X.shape[1] > 2: 45 | X_vis = projection.transform(X) 46 | title += ' (2D projection by {})'.format( 47 | str(projection.__class__).split('.')[-1][:-2] 48 | ) 49 | else: 50 | X_vis = X 51 | 52 | size = 50 if weights is None else weights 53 | if np.unique(y).shape[0] > 2: 54 | vis_params['palette'] = plt.cm.rainbow 55 | sns.scatterplot( 56 | x=X_vis[:, 0], 57 | y=X_vis[:, 1], 58 | hue=y, 59 | style=y, 60 | s=size, 61 | **vis_params, 62 | legend='full', 63 | ax=ax 64 | ) 65 | 66 | ax.set_title(title) 67 | ax = set_ax_border(ax, border_color='black', border_width=2) 68 | ax.grid(color='black', linestyle='-.', alpha=0.5) 69 | 70 | return ax 71 | 72 | 73 | def plot_class_distribution(y, ax=None, title='', sort_values=True, plot_average=True): 74 | '''Plot class distribution of a given dataset''' 75 | count = pd.DataFrame(list(Counter(y).items()), columns=['Class', 'Frequency']) 76 | if sort_values: 77 | count = count.sort_values(by='Frequency', ascending=False) 78 | if ax is None: 79 | ax = plt.axes() 80 | count.plot.bar(x='Class', y='Frequency', title=title, ax=ax) 81 | 82 | ax.set_title(title) 83 | ax = set_ax_border(ax, border_color='black', border_width=2) 84 | ax.grid(color='black', linestyle='-.', alpha=0.5, axis='y') 85 | 86 | if plot_average: 87 | ax.axhline(y=count['Frequency'].mean(), ls="dashdot", c="red") 88 | xlim_min, xlim_max, ylim_min, ylim_max = ax.axis() 89 | ax.text( 90 | x=xlim_min + (xlim_max - xlim_min) * 0.82, 91 | y=count['Frequency'].mean() + (ylim_max - ylim_min) * 0.03, 92 | c="red", 93 | s='Average', 94 | ) 95 | 96 | return ax 97 | 98 | 99 | def plot_2Dprojection_and_cardinality( 100 | X, 101 | y, 102 | figsize=(10, 4), 103 | vis_params=None, 104 | projection=None, 105 | weights=None, 106 | plot_average=True, 107 | title1='Dataset', 108 | title2='Class Distribution', 109 | ): 110 | '''Plot the distribution of a given dataset''' 111 | 112 | if vis_params is None: 113 | vis_params = copy(DEFAULT_VIS_KWARGS) 114 | 115 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) 116 | 117 | ax1 = plot_scatter( 118 | X, 119 | y, 120 | ax=ax1, 121 | weights=weights, 122 | title=title1, 123 | projection=projection, 124 | vis_params=vis_params, 125 | ) 126 | ax2 = plot_class_distribution( 127 | y, ax=ax2, title=title2, sort_values=True, plot_average=plot_average 128 | ) 129 | plt.tight_layout() 130 | 131 | return fig, (ax1, ax2) 132 | 133 | 134 | def plot_online_figure(url: str = None): # pragma: no cover 135 | '''Plot an online figure''' 136 | figure = mpimg.imread(url) 137 | plt.axis('off') 138 | plt.imshow(figure) 139 | plt.tight_layout() 140 | -------------------------------------------------------------------------------- /examples/basic/plot_basic_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================================= 3 | Train and predict with an ensemble classifier 4 | ========================================================= 5 | 6 | This example shows the basic usage of an 7 | :mod:`imbens.ensemble` classifier. 8 | 9 | This example uses: 10 | 11 | - :class:`imbens.ensemble.SelfPacedEnsembleClassifier` 12 | """ 13 | 14 | # Authors: Zhining Liu 15 | # License: MIT 16 | 17 | 18 | # %% 19 | print(__doc__) 20 | 21 | # Import imbalanced-ensemble 22 | import imbens 23 | 24 | # Import utilities 25 | from collections import Counter 26 | import sklearn 27 | from sklearn.datasets import make_classification 28 | from sklearn.model_selection import train_test_split 29 | from imbens.ensemble.base import sort_dict_by_key 30 | 31 | # Import plot utilities 32 | import matplotlib.pyplot as plt 33 | from imbens.utils._plot import plot_2Dprojection_and_cardinality 34 | 35 | RANDOM_STATE = 42 36 | 37 | # %% [markdown] 38 | # Prepare & visualize the data 39 | # ---------------------------- 40 | # Make a toy 3-class imbalanced classification task. 41 | 42 | # Generate and split a synthetic dataset 43 | X, y = make_classification( 44 | n_classes=3, 45 | n_samples=2000, 46 | class_sep=2, 47 | weights=[0.1, 0.3, 0.6], 48 | n_informative=3, 49 | n_redundant=1, 50 | flip_y=0, 51 | n_features=20, 52 | n_clusters_per_class=2, 53 | random_state=RANDOM_STATE, 54 | ) 55 | X_train, X_valid, y_train, y_valid = train_test_split( 56 | X, y, test_size=0.5, stratify=y, random_state=RANDOM_STATE 57 | ) 58 | 59 | # Visualize the training dataset 60 | fig = plot_2Dprojection_and_cardinality(X_train, y_train, figsize=(8, 4)) 61 | plt.show() 62 | 63 | # Print class distribution 64 | print('Training dataset distribution %s' % sort_dict_by_key(Counter(y_train))) 65 | print('Validation dataset distribution %s' % sort_dict_by_key(Counter(y_valid))) 66 | 67 | # %% [markdown] 68 | # Using ensemble classifiers in ``imbens`` 69 | # ----------------------------------------------------- 70 | # Take ``SelfPacedEnsembleClassifier`` as example 71 | 72 | # Initialize an SelfPacedEnsembleClassifier 73 | clf = imbens.ensemble.SelfPacedEnsembleClassifier(random_state=RANDOM_STATE) 74 | 75 | # Train an SelfPacedEnsembleClassifier 76 | clf.fit(X_train, y_train) 77 | 78 | # Make predictions 79 | y_pred_proba = clf.predict_proba(X_valid) 80 | y_pred = clf.predict(X_valid) 81 | 82 | # Evaluate 83 | balanced_acc_score = sklearn.metrics.balanced_accuracy_score(y_valid, y_pred) 84 | print(f'SPE: ensemble of {clf.n_estimators} {clf.estimator_}') 85 | print('Validation Balanced Accuracy: {:.3f}'.format(balanced_acc_score)) 86 | 87 | 88 | # %% [markdown] 89 | # Set the ensemble size 90 | # --------------------- 91 | # (parameter ``n_estimators``: int) 92 | 93 | from imbens.ensemble import SelfPacedEnsembleClassifier as SPE 94 | from sklearn.metrics import balanced_accuracy_score 95 | 96 | clf = SPE( 97 | n_estimators=5, # Set ensemble size to 5 98 | random_state=RANDOM_STATE, 99 | ).fit(X_train, y_train) 100 | 101 | # Evaluate 102 | balanced_acc_score = balanced_accuracy_score(y_valid, clf.predict(X_valid)) 103 | print(f'SPE: ensemble of {clf.n_estimators} {clf.estimator_}') 104 | print('Validation Balanced Accuracy: {:.3f}'.format(balanced_acc_score)) 105 | 106 | 107 | # %% [markdown] 108 | # Use different base estimator 109 | # ---------------------------- 110 | # (parameter ``estimator``: estimator object) 111 | 112 | from sklearn.svm import SVC 113 | 114 | clf = SPE( 115 | n_estimators=5, 116 | estimator=SVC(probability=True), # Use SVM as the base estimator 117 | random_state=RANDOM_STATE, 118 | ).fit(X_train, y_train) 119 | 120 | # Evaluate 121 | balanced_acc_score = balanced_accuracy_score(y_valid, clf.predict(X_valid)) 122 | print(f'SPE: ensemble of {clf.n_estimators} {clf.estimator_}') 123 | print('Validation Balanced Accuracy: {:.3f}'.format(balanced_acc_score)) 124 | 125 | 126 | # %% [markdown] 127 | # Enable training log 128 | # ------------------- 129 | # (``fit()`` parameter ``train_verbose``: bool, int or dict) 130 | 131 | clf = SPE(random_state=RANDOM_STATE).fit( 132 | X_train, 133 | y_train, 134 | train_verbose=True, # Enable training log 135 | ) 136 | -------------------------------------------------------------------------------- /imbens/ensemble/tests/test_base_ensemble.py: -------------------------------------------------------------------------------- 1 | """Test SelfPacedEnsembleClassifier.""" 2 | 3 | # Authors: Zhining Liu 4 | # License: MIT 5 | 6 | import pytest 7 | import sklearn 8 | from sklearn.datasets import load_iris 9 | from sklearn.model_selection import train_test_split 10 | from sklearn.utils.fixes import parse_version 11 | 12 | from imbens.datasets import make_imbalance 13 | from imbens.utils.testing import all_estimators 14 | 15 | sklearn_version = parse_version(sklearn.__version__) 16 | iris = load_iris() 17 | all_ensembles = all_estimators(type_filter='ensemble') 18 | 19 | X, y = make_imbalance( 20 | iris.data, 21 | iris.target, 22 | sampling_strategy={0: 20, 1: 25, 2: 50}, 23 | random_state=0, 24 | ) 25 | X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0) 26 | init_param = {'random_state': 0, 'n_estimators': 20} 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "ensemble", 31 | all_ensembles, 32 | ) 33 | def test_evaluate(ensemble): 34 | """Check classification with dynamic logging.""" 35 | (ensemble_name, EnsembleCLass) = ensemble 36 | clf = EnsembleCLass(**init_param) 37 | clf.fit( 38 | X_train, 39 | y_train, 40 | train_verbose=True, 41 | ) 42 | clf._evaluate('train', return_value_dict=True) 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "ensemble", 47 | all_ensembles, 48 | ) 49 | def test_evaluate_verbose(ensemble): 50 | """Check classification with dynamic logging.""" 51 | (ensemble_name, EnsembleCLass) = ensemble 52 | clf = EnsembleCLass(**init_param) 53 | if clf._properties['training_type'] == 'parallel': 54 | with pytest.raises(TypeError, match="can only be of type `bool`"): 55 | clf.fit( 56 | X_train, 57 | y_train, 58 | train_verbose={ 59 | 'granularity': 10, 60 | }, 61 | ) 62 | else: 63 | clf.fit( 64 | X_train, 65 | y_train, 66 | train_verbose={ 67 | 'granularity': 10, 68 | }, 69 | ) 70 | clf.fit( 71 | X_train, 72 | y_train, 73 | train_verbose={ 74 | 'granularity': 10, 75 | 'print_distribution': False, 76 | 'print_metrics': True, 77 | }, 78 | ) 79 | 80 | 81 | @pytest.mark.parametrize( 82 | "ensemble", 83 | all_ensembles, 84 | ) 85 | def test_evaluate_eval_datasets(ensemble): 86 | """Check classification with dynamic logging.""" 87 | (ensemble_name, EnsembleCLass) = ensemble 88 | clf = EnsembleCLass(**init_param) 89 | if clf._properties['training_type'] == 'parallel': 90 | with pytest.raises(TypeError, match="can only be of type `bool`"): 91 | clf.fit( 92 | X_train, 93 | y_train, 94 | eval_datasets={ 95 | 'valid': (X_valid, y_valid), # add validation data 96 | }, 97 | train_verbose={ 98 | 'granularity': 10, 99 | }, 100 | ) 101 | else: 102 | clf.fit( 103 | X_train, 104 | y_train, 105 | eval_datasets={ 106 | 'valid': (X_valid, y_valid), # add validation data 107 | }, 108 | train_verbose={ 109 | 'granularity': 10, 110 | }, 111 | ) 112 | 113 | 114 | @pytest.mark.parametrize( 115 | "ensemble", 116 | all_ensembles, 117 | ) 118 | def test_evaluate_eval_metrics(ensemble): 119 | """Check classification with dynamic logging.""" 120 | (ensemble_name, EnsembleCLass) = ensemble 121 | clf = EnsembleCLass(**init_param) 122 | clf.fit( 123 | X_train, 124 | y_train, 125 | eval_datasets={ 126 | 'valid': (X_valid, y_valid), 127 | }, 128 | eval_metrics={ 129 | 'weighted_f1': ( 130 | sklearn.metrics.f1_score, 131 | {'average': 'weighted'}, 132 | ), # use weighted_f1 133 | 'roc': ( 134 | sklearn.metrics.roc_auc_score, 135 | {'multi_class': 'ovr', 'average': 'macro'}, 136 | ), # use roc_auc score 137 | }, 138 | train_verbose=True, 139 | ) 140 | -------------------------------------------------------------------------------- /imbens/utils/testing.py: -------------------------------------------------------------------------------- 1 | """Test utilities. 2 | """ 3 | # Adapted from imbalanced-learn 4 | 5 | # Adapted from scikit-learn 6 | # Authors: Guillaume Lemaitre 7 | # License: MIT 8 | 9 | import inspect 10 | import pkgutil 11 | from importlib import import_module 12 | from operator import itemgetter 13 | from pathlib import Path 14 | 15 | from sklearn.base import BaseEstimator 16 | from sklearn.utils._testing import ignore_warnings 17 | 18 | 19 | def all_estimators( 20 | type_filter=None, 21 | ): 22 | """Get a list of all estimators from imbens. 23 | 24 | This function crawls the module and gets all classes that inherit 25 | from BaseEstimator. Classes that are defined in test-modules are not 26 | included. 27 | By default meta_estimators are also not included. 28 | This function is adapted from sklearn. 29 | 30 | Parameters 31 | ---------- 32 | type_filter : string, list of string, or None, default=None 33 | Which kind of estimators should be returned. If None, no 34 | filter is applied and all estimators are returned. Possible 35 | values are 'sampler' or 'ensemble' to get estimators only of 36 | these specific types, or a list of these to get the estimators 37 | that fit at least one of the types. 38 | 39 | Returns 40 | ------- 41 | estimators : list of tuples 42 | List of (name, class), where ``name`` is the class name as string 43 | and ``class`` is the actual type of the class. 44 | 45 | """ 46 | from ..ensemble.base import ImbalancedEnsembleClassifierMixin 47 | from ..sampler.base import SamplerMixin 48 | 49 | def is_abstract(c): 50 | if not (hasattr(c, "__abstractmethods__")): 51 | return False 52 | if not len(c.__abstractmethods__): 53 | return False 54 | return True 55 | 56 | all_classes = [] 57 | modules_to_ignore = {"tests"} 58 | root = str(Path(__file__).parent.parent) 59 | # Ignore deprecation warnings triggered at import time and from walking 60 | # packages 61 | with ignore_warnings(category=FutureWarning): 62 | for importer, modname, ispkg in pkgutil.walk_packages( 63 | path=[root], prefix="imbens." 64 | ): 65 | mod_parts = modname.split(".") 66 | if any(part in modules_to_ignore for part in mod_parts) or "._" in modname: 67 | continue 68 | module = import_module(modname) 69 | classes = inspect.getmembers(module, inspect.isclass) 70 | classes = [ 71 | (name, est_cls) for name, est_cls in classes if not name.startswith("_") 72 | ] 73 | 74 | all_classes.extend(classes) 75 | 76 | all_classes = set(all_classes) 77 | 78 | estimators = [ 79 | c 80 | for c in all_classes 81 | if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") 82 | ] 83 | # get rid of abstract base classes 84 | estimators = [c for c in estimators if not is_abstract(c[1])] 85 | 86 | # get rid of sklearn estimators which have been imported in some classes 87 | estimators = [c for c in estimators if "sklearn" not in c[1].__module__] 88 | 89 | if type_filter is not None: 90 | if not isinstance(type_filter, list): 91 | type_filter = [type_filter] 92 | else: 93 | type_filter = list(type_filter) # copy 94 | filtered_estimators = [] 95 | filters = { 96 | "sampler": SamplerMixin, 97 | "ensemble": ImbalancedEnsembleClassifierMixin, 98 | } 99 | for name, mixin in filters.items(): 100 | if name in type_filter: 101 | type_filter.remove(name) 102 | filtered_estimators.extend( 103 | [est for est in estimators if issubclass(est[1], mixin)] 104 | ) 105 | estimators = filtered_estimators 106 | if type_filter: 107 | raise ValueError( 108 | "Parameter type_filter must be 'sampler', 'ensemble' or " 109 | "None, got" 110 | " %s." % repr(type_filter) 111 | ) 112 | 113 | # drop duplicates, sort for reproducibility 114 | # itemgetter is used to ensure the sort does not extend to the 2nd item of 115 | # the tuple 116 | return sorted(set(estimators), key=itemgetter(0)) 117 | --------------------------------------------------------------------------------