├── imbens
├── datasets
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_imbalance.py
│ │ └── test_zenodo.py
│ └── __init__.py
├── ensemble
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_base_ensemble.py
│ ├── _compatible
│ │ ├── tests
│ │ │ ├── __init__.py
│ │ │ └── test_adabost_compatible.py
│ │ └── __init__.py
│ ├── _over_sampling
│ │ ├── tests
│ │ │ ├── __init__.py
│ │ │ └── test_over_boost.py
│ │ └── __init__.py
│ ├── _reweighting
│ │ ├── tests
│ │ │ └── __init__.py
│ │ └── __init__.py
│ ├── _under_sampling
│ │ ├── tests
│ │ │ ├── __init__.py
│ │ │ └── test_rus_boost.py
│ │ └── __init__.py
│ └── __init__.py
├── metrics
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_score_objects.py
│ └── __init__.py
├── sampler
│ ├── tests
│ │ └── __init__.py
│ ├── _over_sampling
│ │ ├── tests
│ │ │ └── __init__.py
│ │ ├── _smote
│ │ │ ├── tests
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_borderline_smote.py
│ │ │ │ └── test_svm_smote.py
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ └── base.py
│ ├── _under_sampling
│ │ ├── _prototype_generation
│ │ │ ├── tests
│ │ │ │ └── __init__.py
│ │ │ └── __init__.py
│ │ ├── _prototype_selection
│ │ │ ├── tests
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_tomek_links.py
│ │ │ │ ├── test_neighbourhood_cleaning_rule.py
│ │ │ │ ├── test_instance_hardness_threshold.py
│ │ │ │ ├── test_one_sided_selection.py
│ │ │ │ └── test_condensed_nearest_neighbour.py
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ └── base.py
│ └── __init__.py
├── utils
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_deprecation.py
│ │ ├── test_testing.py
│ │ ├── test_plot.py
│ │ ├── test_docstring.py
│ │ └── test_show_versions.py
│ ├── __init__.py
│ ├── deprecation.py
│ ├── _show_versions.py
│ ├── _plot.py
│ └── testing.py
├── visualizer
│ ├── tests
│ │ └── __init__.py
│ └── __init__.py
├── base.py
├── exceptions.py
├── _version.py
└── __init__.py
├── docs
├── source
│ ├── sphinxext
│ │ ├── MANIFEST.in
│ │ ├── README.txt
│ │ └── github_link.py
│ ├── _static
│ │ ├── css
│ │ │ └── my_theme.css
│ │ ├── thumbnail.png
│ │ └── training_log_thumbnail.png
│ ├── api
│ │ ├── datasets
│ │ │ ├── api.rst
│ │ │ ├── _autosummary
│ │ │ │ ├── imbens.datasets.make_imbalance.rst
│ │ │ │ ├── imbens.datasets.fetch_openml_datasets.rst
│ │ │ │ ├── imbens.datasets.fetch_zenodo_datasets.rst
│ │ │ │ └── imbens.datasets.generate_imbalance_data.rst
│ │ │ └── datasets.rst
│ │ ├── pipeline
│ │ │ ├── api.rst
│ │ │ ├── _autosummary
│ │ │ │ ├── imbens.pipeline.make_pipeline.rst
│ │ │ │ └── imbens.pipeline.Pipeline.rst
│ │ │ └── pipeline.rst
│ │ ├── visualizer
│ │ │ ├── api.rst
│ │ │ ├── visualizer.rst
│ │ │ └── _autosummary
│ │ │ │ └── imbens.visualizer.ImbalancedEnsembleVisualizer.rst
│ │ ├── metrics
│ │ │ ├── api.rst
│ │ │ ├── pairwise.rst
│ │ │ ├── _autosummary
│ │ │ │ ├── imbens.metrics.sensitivity_score.rst
│ │ │ │ ├── imbens.metrics.specificity_score.rst
│ │ │ │ ├── imbens.metrics.geometric_mean_score.rst
│ │ │ │ ├── imbens.metrics.make_index_balanced_accuracy.rst
│ │ │ │ ├── imbens.metrics.sensitivity_specificity_support.rst
│ │ │ │ ├── imbens.metrics.classification_report_imbalanced.rst
│ │ │ │ ├── imbens.metrics.macro_averaged_mean_absolute_error.rst
│ │ │ │ └── imbens.metrics.ValueDifferenceMetric.rst
│ │ │ └── classification.rst
│ │ ├── sampler
│ │ │ ├── api.rst
│ │ │ ├── over-samplers.rst
│ │ │ ├── _autosummary
│ │ │ │ ├── imbens.sampler.SMOTE.rst
│ │ │ │ ├── imbens.sampler.ADASYN.rst
│ │ │ │ ├── imbens.sampler.AllKNN.rst
│ │ │ │ ├── imbens.sampler.NearMiss.rst
│ │ │ │ ├── imbens.sampler.SVMSMOTE.rst
│ │ │ │ ├── imbens.sampler.KMeansSMOTE.rst
│ │ │ │ ├── imbens.sampler.BorderlineSMOTE.rst
│ │ │ │ ├── imbens.sampler.ClusterCentroids.rst
│ │ │ │ ├── imbens.sampler.TomekLinks.rst
│ │ │ │ ├── imbens.sampler.OneSidedSelection.rst
│ │ │ │ ├── imbens.sampler.RandomOverSampler.rst
│ │ │ │ ├── imbens.sampler.RandomUnderSampler.rst
│ │ │ │ ├── imbens.sampler.SelfPacedUnderSampler.rst
│ │ │ │ ├── imbens.sampler.EditedNearestNeighbours.rst
│ │ │ │ ├── imbens.sampler.CondensedNearestNeighbour.rst
│ │ │ │ ├── imbens.sampler.InstanceHardnessThreshold.rst
│ │ │ │ ├── imbens.sampler.NeighbourhoodCleaningRule.rst
│ │ │ │ ├── imbens.sampler.BalanceCascadeUnderSampler.rst
│ │ │ │ └── imbens.sampler.RepeatedEditedNearestNeighbours.rst
│ │ │ └── under-samplers.rst
│ │ ├── utils
│ │ │ ├── api.rst
│ │ │ ├── _autosummary
│ │ │ │ ├── imbens.utils.evaluate_print.rst
│ │ │ │ ├── imbens.utils.check_target_type.rst
│ │ │ │ ├── imbens.utils.check_eval_metrics.rst
│ │ │ │ ├── imbens.utils.check_eval_datasets.rst
│ │ │ │ ├── imbens.utils.check_neighbors_object.rst
│ │ │ │ ├── imbens.utils.check_sampling_strategy.rst
│ │ │ │ ├── imbens.utils.check_balancing_schedule.rst
│ │ │ │ └── imbens.utils.check_target_label_and_n_target_samples.rst
│ │ │ ├── evaluation.rst
│ │ │ ├── validation_sampler.rst
│ │ │ └── validation_ensemble.rst
│ │ └── ensemble
│ │ │ ├── api.rst
│ │ │ ├── compatible.rst
│ │ │ ├── reweighting.rst
│ │ │ ├── over-sampling.rst
│ │ │ ├── under-sampling.rst
│ │ │ └── _autosummary
│ │ │ ├── imbens.ensemble.BalanceCascadeClassifier.rst
│ │ │ ├── imbens.ensemble.OverBaggingClassifier.rst
│ │ │ ├── imbens.ensemble.EasyEnsembleClassifier.rst
│ │ │ ├── imbens.ensemble.SMOTEBaggingClassifier.rst
│ │ │ ├── imbens.ensemble.SelfPacedEnsembleClassifier.rst
│ │ │ ├── imbens.ensemble.UnderBaggingClassifier.rst
│ │ │ ├── imbens.ensemble.CompatibleBaggingClassifier.rst
│ │ │ ├── imbens.ensemble.BalancedRandomForestClassifier.rst
│ │ │ ├── imbens.ensemble.AdaCostClassifier.rst
│ │ │ ├── imbens.ensemble.RUSBoostClassifier.rst
│ │ │ ├── imbens.ensemble.AdaUBoostClassifier.rst
│ │ │ ├── imbens.ensemble.AsymBoostClassifier.rst
│ │ │ ├── imbens.ensemble.OverBoostClassifier.rst
│ │ │ ├── imbens.ensemble.SMOTEBoostClassifier.rst
│ │ │ ├── imbens.ensemble.KmeansSMOTEBoostClassifier.rst
│ │ │ └── imbens.ensemble.CompatibleAdaBoostClassifier.rst
│ ├── _templates
│ │ ├── function.rst
│ │ ├── numpydoc_docstring.rst
│ │ ├── ensemble_class.rst
│ │ └── class.rst
│ ├── doc_requirements.txt
│ ├── install.rst
│ ├── get_start.rst
│ └── sg_execution_times.rst
├── requirements.txt
├── Makefile
└── make.bat
├── MANIFEST.in
├── examples
├── basic
│ ├── README.txt
│ ├── plot_basic_visualize.py
│ ├── plot_training_log.py
│ └── plot_basic_example.py
├── datasets
│ ├── README.txt
│ ├── plot_generate_imbalance.py
│ ├── plot_make_imbalance.py
│ └── plot_make_imbalance_digits.py
├── visualizer
│ └── README.txt
├── evaluation
│ ├── README.txt
│ ├── plot_classification_report.py
│ └── plot_metrics.py
├── pipeline
│ ├── README.txt
│ └── plot_pipeline_classification.py
├── classification
│ ├── README.txt
│ └── plot_probability.py
└── README.txt
├── requirements.txt
├── readthedocs.yml
├── .circleci
└── config.yml
├── LICENSE
├── .all-contributorsrc
├── .gitignore
└── setup.py
/imbens/datasets/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/ensemble/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/metrics/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/sampler/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/utils/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/visualizer/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/ensemble/_compatible/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/ensemble/_over_sampling/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/ensemble/_reweighting/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/sampler/_over_sampling/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/ensemble/_under_sampling/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/sampler/_over_sampling/_smote/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_generation/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_selection/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/sphinxext/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include tests *.py
2 | include *.txt
3 |
--------------------------------------------------------------------------------
/docs/source/_static/css/my_theme.css:
--------------------------------------------------------------------------------
1 | .wy-nav-content {
2 | max-width: 1000px !important;
3 | }
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include docs *
2 | recursive-include examples *
3 | include LICENSE
4 | include README.md
--------------------------------------------------------------------------------
/docs/source/_static/thumbnail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhiningLiu1998/imbalanced-ensemble/HEAD/docs/source/_static/thumbnail.png
--------------------------------------------------------------------------------
/examples/basic/README.txt:
--------------------------------------------------------------------------------
1 | .. _basic_examples:
2 |
3 | Basic usage examples
4 | --------------------
5 |
6 | Quick start with :mod:`imbens`.
7 |
--------------------------------------------------------------------------------
/docs/source/_static/training_log_thumbnail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhiningLiu1998/imbalanced-ensemble/HEAD/docs/source/_static/training_log_thumbnail.png
--------------------------------------------------------------------------------
/examples/datasets/README.txt:
--------------------------------------------------------------------------------
1 | .. _dataset_examples:
2 |
3 | Dataset examples
4 | ----------------
5 |
6 | Examples concerning the :mod:`imbens.datasets` module.
7 |
--------------------------------------------------------------------------------
/examples/visualizer/README.txt:
--------------------------------------------------------------------------------
1 | .. _visualizer_examples:
2 |
3 | Visualizer examples
4 | -------------------
5 |
6 | Examples concerning the :mod:`imbens.visualizer`. module.
7 |
--------------------------------------------------------------------------------
/examples/evaluation/README.txt:
--------------------------------------------------------------------------------
1 | .. _evaluation_examples:
2 |
3 | Evaluation examples
4 | -------------------
5 |
6 | Examples illustrating how classification using imbalanced dataset can be done.
7 |
--------------------------------------------------------------------------------
/examples/pipeline/README.txt:
--------------------------------------------------------------------------------
1 | .. _pipeline_examples:
2 |
3 | Pipeline examples
4 | -----------------
5 |
6 | Example of how to use the a pipeline to include under-sampling with `scikit-learn` estimators.
--------------------------------------------------------------------------------
/examples/classification/README.txt:
--------------------------------------------------------------------------------
1 | .. _classification_examples:
2 |
3 | Classification examples
4 | -----------------------
5 |
6 | Examples about using classification algorithms in :mod:`imbens.ensemble` module.
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.16.0
2 | scipy>=1.9.1
3 | pandas>=2.1.1
4 | joblib>=0.11
5 | scikit-learn==1.6.0
6 | matplotlib>=3.3.2
7 | seaborn>=0.13.2
8 | tqdm>=4.50.2
9 | openml>=0.14.0
10 | platformdirs>=3.0.0
--------------------------------------------------------------------------------
/docs/source/api/datasets/api.rst:
--------------------------------------------------------------------------------
1 | ``imbens.datasets``
2 | *********************************
3 |
4 | This is the full API documentation of the `imbens.datasets` module.
5 |
6 | .. toctree::
7 | :maxdepth: 3
8 |
9 | datasets
10 |
--------------------------------------------------------------------------------
/docs/source/api/pipeline/api.rst:
--------------------------------------------------------------------------------
1 | ``imbens.pipeline``
2 | *********************************
3 |
4 | This is the full API documentation of the `imbens.pipeline` module.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | pipeline
10 |
--------------------------------------------------------------------------------
/docs/source/api/visualizer/api.rst:
--------------------------------------------------------------------------------
1 | ``imbens.visualizer``
2 | ***********************************
3 |
4 | This is the full API documentation of the `imbens.visualizer` module.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | visualizer
10 |
--------------------------------------------------------------------------------
/docs/source/api/metrics/api.rst:
--------------------------------------------------------------------------------
1 | ``imbens.metrics``
2 | *********************************
3 |
4 | This is the full API documentation of the `imbens.metrics` module.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | classification
10 | pairwise
11 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/api.rst:
--------------------------------------------------------------------------------
1 | ``imbens.sampler``
2 | *********************************
3 |
4 | This is the full API documentation of the `imbens.sampler` module.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | under-samplers
10 | over-samplers
11 |
--------------------------------------------------------------------------------
/examples/README.txt:
--------------------------------------------------------------------------------
1 | .. _examples-index:
2 |
3 | **General-purpose and introductory examples for the imbalanced-ensemble toolbox.**
4 |
5 | **The examples gallery is still under construction.** Please refer to APIs for more detailed guidelines of how to use
6 | ``imbens``.
--------------------------------------------------------------------------------
/docs/source/api/utils/api.rst:
--------------------------------------------------------------------------------
1 | ``imbens.utils``
2 | *********************************
3 |
4 | This is the full API documentation of the `imbens.utils` module.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | evaluation
10 | validation_ensemble
11 | validation_sampler
--------------------------------------------------------------------------------
/imbens/visualizer/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.visualizer` module include a visualizer
3 | for visualizing ensemble estimators.
4 | """
5 |
6 | from .visualizer import ImbalancedEnsembleVisualizer
7 |
8 | __all__ = [
9 | "ImbalancedEnsembleVisualizer",
10 | ]
11 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/api.rst:
--------------------------------------------------------------------------------
1 | ``imbens.ensemble``
2 | *********************************
3 |
4 | This is the full API documentation of the `imbens.ensemble` module.
5 |
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | under-sampling
10 | over-sampling
11 | reweighting
12 | compatible
--------------------------------------------------------------------------------
/imbens/sampler/_over_sampling/_smote/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import SMOTE
2 |
3 | from .cluster import KMeansSMOTE
4 |
5 | from .filter import BorderlineSMOTE
6 | from .filter import SVMSMOTE
7 |
8 | __all__ = [
9 | "SMOTE",
10 | "KMeansSMOTE",
11 | "BorderlineSMOTE",
12 | "SVMSMOTE",
13 | ]
--------------------------------------------------------------------------------
/docs/source/_templates/function.rst:
--------------------------------------------------------------------------------
1 | {{objname}}
2 | {{ underline }}====================
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autofunction:: {{ objname }}
7 |
8 | .. include:: ../../../back_references/{{module}}.{{objname}}.examples
9 |
10 | .. raw:: html
11 |
12 |
13 |
--------------------------------------------------------------------------------
/docs/source/_templates/numpydoc_docstring.rst:
--------------------------------------------------------------------------------
1 | {{index}}
2 | {{summary}}
3 | {{extended_summary}}
4 | {{parameters}}
5 | {{returns}}
6 | {{yields}}
7 | {{other_parameters}}
8 | {{attributes}}
9 | {{raises}}
10 | {{warns}}
11 | {{warnings}}
12 | {{see_also}}
13 | {{notes}}
14 | {{references}}
15 | {{examples}}
16 | {{methods}}
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_generation/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.sampler._under_sampling.prototype_generation`
3 | submodule contains methods that generate new samples in order to balance
4 | the dataset.
5 | """
6 |
7 | from ._cluster_centroids import ClusterCentroids
8 |
9 | __all__ = ["ClusterCentroids"]
10 |
--------------------------------------------------------------------------------
/docs/source/doc_requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.16.0
2 | scipy>=1.9.1
3 | pandas>=2.1.1
4 | joblib>=0.11
5 | scikit-learn>=1.5.0
6 | matplotlib>=3.3.2
7 | seaborn>=0.13.2
8 | tqdm>=4.50.2
9 | openml>=0.14.0
10 | platformdirs
11 | sphinx
12 | sphinx-gallery
13 | sphinx_rtd_theme
14 | pydata-sphinx-theme
15 | numpydoc
16 | sphinxcontrib-bibtex
17 | torch
18 | pytest
--------------------------------------------------------------------------------
/docs/source/api/utils/_autosummary/imbens.utils.evaluate_print.rst:
--------------------------------------------------------------------------------
1 | evaluate_print
2 | ===============================================
3 |
4 | .. currentmodule:: imbens.utils
5 |
6 | .. autofunction:: evaluate_print
7 |
8 | .. include:: ../../../back_references/imbens.utils.evaluate_print.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/metrics/pairwise.rst:
--------------------------------------------------------------------------------
1 | .. _metrics_api:
2 |
3 | Pairwise Metrics
4 | ================================
5 |
6 | .. automodule:: imbens.metrics
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.metrics
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | ValueDifferenceMetric
--------------------------------------------------------------------------------
/docs/source/api/pipeline/_autosummary/imbens.pipeline.make_pipeline.rst:
--------------------------------------------------------------------------------
1 | make_pipeline
2 | =================================================
3 |
4 | .. currentmodule:: imbens.pipeline
5 |
6 | .. autofunction:: make_pipeline
7 |
8 | .. include:: ../../../back_references/imbens.pipeline.make_pipeline.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/datasets/_autosummary/imbens.datasets.make_imbalance.rst:
--------------------------------------------------------------------------------
1 | make_imbalance
2 | ==================================================
3 |
4 | .. currentmodule:: imbens.datasets
5 |
6 | .. autofunction:: make_imbalance
7 |
8 | .. include:: ../../../back_references/imbens.datasets.make_imbalance.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/utils/_autosummary/imbens.utils.check_target_type.rst:
--------------------------------------------------------------------------------
1 | check_target_type
2 | ==================================================
3 |
4 | .. currentmodule:: imbens.utils
5 |
6 | .. autofunction:: check_target_type
7 |
8 | .. include:: ../../../back_references/imbens.utils.check_target_type.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/utils/evaluation.rst:
--------------------------------------------------------------------------------
1 | .. _evaluation_api:
2 |
3 | Utilities for evaluation
4 | ===================================
5 |
6 | .. automodule:: imbens.utils
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.utils
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: function.rst
15 |
16 | evaluate_print
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.16.0
2 | scipy>=0.19.1
3 | pandas>=1.1.3
4 | joblib>=0.11
5 | scikit-learn>=0.24
6 | imbalanced-learn>=0.7.0
7 | matplotlib>=3.3.2
8 | seaborn>=0.11.0
9 | tqdm>=4.50.2
10 | openml>=0.14.0
11 | platformdirs
12 | sphinx
13 | sphinx-gallery
14 | numpydoc
15 | pydata-sphinx-theme
16 | sphinxcontrib-bibtex
17 | torch
18 | torchvision
19 | pytest
20 | pytest-cov
--------------------------------------------------------------------------------
/docs/source/api/utils/_autosummary/imbens.utils.check_eval_metrics.rst:
--------------------------------------------------------------------------------
1 | check_eval_metrics
2 | ===================================================
3 |
4 | .. currentmodule:: imbens.utils
5 |
6 | .. autofunction:: check_eval_metrics
7 |
8 | .. include:: ../../../back_references/imbens.utils.check_eval_metrics.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | #conda:
2 | # file: docs/environment.yml
3 |
4 | version: 2
5 |
6 | # Set the OS, Python version and other tools you might need
7 | build:
8 | os: ubuntu-22.04
9 | tools:
10 | python: "3.11"
11 |
12 | sphinx:
13 | configuration: docs/source/conf.py
14 |
15 | python:
16 | install:
17 | - requirements: docs/source/doc_requirements.txt
18 |
--------------------------------------------------------------------------------
/docs/source/api/metrics/_autosummary/imbens.metrics.sensitivity_score.rst:
--------------------------------------------------------------------------------
1 | sensitivity_score
2 | ====================================================
3 |
4 | .. currentmodule:: imbens.metrics
5 |
6 | .. autofunction:: sensitivity_score
7 |
8 | .. include:: ../../../back_references/imbens.metrics.sensitivity_score.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/metrics/_autosummary/imbens.metrics.specificity_score.rst:
--------------------------------------------------------------------------------
1 | specificity_score
2 | ====================================================
3 |
4 | .. currentmodule:: imbens.metrics
5 |
6 | .. autofunction:: specificity_score
7 |
8 | .. include:: ../../../back_references/imbens.metrics.specificity_score.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/utils/_autosummary/imbens.utils.check_eval_datasets.rst:
--------------------------------------------------------------------------------
1 | check_eval_datasets
2 | ====================================================
3 |
4 | .. currentmodule:: imbens.utils
5 |
6 | .. autofunction:: check_eval_datasets
7 |
8 | .. include:: ../../../back_references/imbens.utils.check_eval_datasets.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/visualizer/visualizer.rst:
--------------------------------------------------------------------------------
1 | .. _visualizer_api:
2 |
3 | Visualizer
4 | ================================
5 |
6 | .. automodule:: imbens.visualizer
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.visualizer
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | ImbalancedEnsembleVisualizer
--------------------------------------------------------------------------------
/docs/source/api/metrics/_autosummary/imbens.metrics.geometric_mean_score.rst:
--------------------------------------------------------------------------------
1 | geometric_mean_score
2 | =======================================================
3 |
4 | .. currentmodule:: imbens.metrics
5 |
6 | .. autofunction:: geometric_mean_score
7 |
8 | .. include:: ../../../back_references/imbens.metrics.geometric_mean_score.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/utils/_autosummary/imbens.utils.check_neighbors_object.rst:
--------------------------------------------------------------------------------
1 | check_neighbors_object
2 | =======================================================
3 |
4 | .. currentmodule:: imbens.utils
5 |
6 | .. autofunction:: check_neighbors_object
7 |
8 | .. include:: ../../../back_references/imbens.utils.check_neighbors_object.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/utils/_autosummary/imbens.utils.check_sampling_strategy.rst:
--------------------------------------------------------------------------------
1 | check_sampling_strategy
2 | ========================================================
3 |
4 | .. currentmodule:: imbens.utils
5 |
6 | .. autofunction:: check_sampling_strategy
7 |
8 | .. include:: ../../../back_references/imbens.utils.check_sampling_strategy.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/datasets/_autosummary/imbens.datasets.fetch_openml_datasets.rst:
--------------------------------------------------------------------------------
1 | fetch_openml_datasets
2 | =========================================================
3 |
4 | .. currentmodule:: imbens.datasets
5 |
6 | .. autofunction:: fetch_openml_datasets
7 |
8 | .. include:: ../../../back_references/imbens.datasets.fetch_openml_datasets.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/datasets/_autosummary/imbens.datasets.fetch_zenodo_datasets.rst:
--------------------------------------------------------------------------------
1 | fetch_zenodo_datasets
2 | =========================================================
3 |
4 | .. currentmodule:: imbens.datasets
5 |
6 | .. autofunction:: fetch_zenodo_datasets
7 |
8 | .. include:: ../../../back_references/imbens.datasets.fetch_zenodo_datasets.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/utils/_autosummary/imbens.utils.check_balancing_schedule.rst:
--------------------------------------------------------------------------------
1 | check_balancing_schedule
2 | =========================================================
3 |
4 | .. currentmodule:: imbens.utils
5 |
6 | .. autofunction:: check_balancing_schedule
7 |
8 | .. include:: ../../../back_references/imbens.utils.check_balancing_schedule.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/datasets/_autosummary/imbens.datasets.generate_imbalance_data.rst:
--------------------------------------------------------------------------------
1 | generate_imbalance_data
2 | ===========================================================
3 |
4 | .. currentmodule:: imbens.datasets
5 |
6 | .. autofunction:: generate_imbalance_data
7 |
8 | .. include:: ../../../back_references/imbens.datasets.generate_imbalance_data.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/compatible.rst:
--------------------------------------------------------------------------------
1 | .. _compatible_api:
2 |
3 | Compatible ensembles
4 | ======================
5 |
6 | .. automodule:: imbens.ensemble
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.ensemble
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | CompatibleAdaBoostClassifier
17 | CompatibleBaggingClassifier
18 |
--------------------------------------------------------------------------------
/imbens/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Includes all possible values of an imbens classifier's
3 | properties.
4 | """
5 |
6 | # Authors: Zhining Liu
7 | # License: MIT
8 |
9 |
10 | ENSEMBLE_TYPES = ('boosting', 'bagging', 'random-forest', 'general')
11 |
12 | TRAINING_TYPES = ('iterative', 'parallel')
13 |
14 | SOLUTION_TYPES = ('resampling', 'reweighting')
15 |
16 | SAMPLING_TYPES = ('under-sampling', 'over-sampling')
17 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/reweighting.rst:
--------------------------------------------------------------------------------
1 | .. _reweighting_api:
2 |
3 | Reweighting-based ensembles
4 | ================================
5 |
6 | .. automodule:: imbens.ensemble
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.ensemble
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | AdaCostClassifier
17 | AdaUBoostClassifier
18 | AsymBoostClassifier
--------------------------------------------------------------------------------
/docs/source/api/metrics/_autosummary/imbens.metrics.make_index_balanced_accuracy.rst:
--------------------------------------------------------------------------------
1 | make_index_balanced_accuracy
2 | ===============================================================
3 |
4 | .. currentmodule:: imbens.metrics
5 |
6 | .. autofunction:: make_index_balanced_accuracy
7 |
8 | .. include:: ../../../back_references/imbens.metrics.make_index_balanced_accuracy.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/datasets/datasets.rst:
--------------------------------------------------------------------------------
1 | .. _datasets_api:
2 |
3 | Datasets
4 | ================================
5 |
6 | .. automodule:: imbens.datasets
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.datasets
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: function.rst
15 |
16 | make_imbalance
17 | generate_imbalance_data
18 | fetch_zenodo_datasets
19 | fetch_openml_datasets
--------------------------------------------------------------------------------
/docs/source/api/metrics/_autosummary/imbens.metrics.sensitivity_specificity_support.rst:
--------------------------------------------------------------------------------
1 | sensitivity_specificity_support
2 | ==================================================================
3 |
4 | .. currentmodule:: imbens.metrics
5 |
6 | .. autofunction:: sensitivity_specificity_support
7 |
8 | .. include:: ../../../back_references/imbens.metrics.sensitivity_specificity_support.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/imbens/ensemble/_compatible/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.ensemble.compatible` submodule contains
3 | a set of `sklearn.ensemble` learning methods that were re-implemented in
4 | `imbens` style.
5 | """
6 |
7 | from .adaboost_compatible import CompatibleAdaBoostClassifier
8 | from .bagging_compatible import CompatibleBaggingClassifier
9 |
10 | __all__ = [
11 | "CompatibleAdaBoostClassifier",
12 | "CompatibleBaggingClassifier",
13 | ]
14 |
--------------------------------------------------------------------------------
/imbens/ensemble/_reweighting/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.ensemble._reweighting` submodule contains
3 | a set of reweighting-based ensemble imbalanced learning methods.
4 | """
5 |
6 | from .adacost import AdaCostClassifier
7 | from .adauboost import AdaUBoostClassifier
8 | from .asymmetric_boost import AsymBoostClassifier
9 |
10 | __all__ = [
11 | "AdaCostClassifier",
12 | "AdaUBoostClassifier",
13 | "AsymBoostClassifier",
14 | ]
15 |
--------------------------------------------------------------------------------
/docs/source/api/metrics/_autosummary/imbens.metrics.classification_report_imbalanced.rst:
--------------------------------------------------------------------------------
1 | classification_report_imbalanced
2 | ===================================================================
3 |
4 | .. currentmodule:: imbens.metrics
5 |
6 | .. autofunction:: classification_report_imbalanced
7 |
8 | .. include:: ../../../back_references/imbens.metrics.classification_report_imbalanced.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/utils/validation_sampler.rst:
--------------------------------------------------------------------------------
1 | .. _utils_sampler_api:
2 |
3 | Validation checks used in samplers
4 | ==================================
5 |
6 | .. automodule:: imbens.utils
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.utils
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: function.rst
15 |
16 | check_neighbors_object
17 | check_sampling_strategy
18 | check_target_type
--------------------------------------------------------------------------------
/docs/source/api/metrics/_autosummary/imbens.metrics.macro_averaged_mean_absolute_error.rst:
--------------------------------------------------------------------------------
1 | macro_averaged_mean_absolute_error
2 | =====================================================================
3 |
4 | .. currentmodule:: imbens.metrics
5 |
6 | .. autofunction:: macro_averaged_mean_absolute_error
7 |
8 | .. include:: ../../../back_references/imbens.metrics.macro_averaged_mean_absolute_error.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/docs/source/api/pipeline/pipeline.rst:
--------------------------------------------------------------------------------
1 | .. _pipeline_api:
2 |
3 | Pipeline
4 | ================================
5 |
6 | .. automodule:: imbens.pipeline
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.pipeline
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | Pipeline
17 |
18 | .. autosummary::
19 | :toctree: _autosummary
20 | :template: function.rst
21 |
22 | make_pipeline
--------------------------------------------------------------------------------
/docs/source/api/sampler/over-samplers.rst:
--------------------------------------------------------------------------------
1 | .. _over_sampling_sampler_api:
2 |
3 | Over-sampling Samplers
4 | ================================
5 |
6 | .. automodule:: imbens.sampler
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.sampler
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | ADASYN
17 | RandomOverSampler
18 | KMeansSMOTE
19 | SMOTE
20 | BorderlineSMOTE
21 | SVMSMOTE
--------------------------------------------------------------------------------
/imbens/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.datasets` provides methods to generate
3 | imbalanced data.
4 | """
5 |
6 | from ._imbalance import make_imbalance
7 | from ._imbalance import generate_imbalance_data
8 |
9 | from ._zenodo import fetch_zenodo_datasets
10 | from ._openml import fetch_openml_datasets
11 |
12 | __all__ = [
13 | "make_imbalance",
14 | "generate_imbalance_data",
15 | "fetch_zenodo_datasets",
16 | "fetch_openml_datasets",
17 | ]
18 |
--------------------------------------------------------------------------------
/docs/source/api/utils/_autosummary/imbens.utils.check_target_label_and_n_target_samples.rst:
--------------------------------------------------------------------------------
1 | check_target_label_and_n_target_samples
2 | ========================================================================
3 |
4 | .. currentmodule:: imbens.utils
5 |
6 | .. autofunction:: check_target_label_and_n_target_samples
7 |
8 | .. include:: ../../../back_references/imbens.utils.check_target_label_and_n_target_samples.examples
9 |
10 | .. raw:: html
11 |
12 |
--------------------------------------------------------------------------------
/imbens/exceptions.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.exceptions` module includes all custom warnings and error
3 | classes and functions used across imbalanced-learn.
4 | """
5 | # Adapted from imbalanced-learn
6 |
7 | # Authors: Guillaume Lemaitre
8 | # License: MIT
9 |
10 |
11 | def raise_isinstance_error(variable_name, possible_type, variable):
12 | raise ValueError(
13 | f"{variable_name} has to be one of {possible_type}. "
14 | f"Got {type(variable)} instead."
15 | )
16 |
--------------------------------------------------------------------------------
/docs/source/api/utils/validation_ensemble.rst:
--------------------------------------------------------------------------------
1 | .. _utils_ensemble_api:
2 |
3 | Validation checks used in ensembles
4 | ===================================
5 |
6 | .. automodule:: imbens.utils
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.utils
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: function.rst
15 |
16 | check_eval_datasets
17 | check_eval_metrics
18 | check_target_label_and_n_target_samples
19 | check_balancing_schedule
--------------------------------------------------------------------------------
/docs/source/api/ensemble/over-sampling.rst:
--------------------------------------------------------------------------------
1 | .. _over_sampling_api:
2 |
3 | Over-sampling-based ensembles
4 | ================================
5 |
6 | .. automodule:: imbens.ensemble
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.ensemble
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | OverBoostClassifier
17 | SMOTEBoostClassifier
18 | KmeansSMOTEBoostClassifier
19 | OverBaggingClassifier
20 | SMOTEBaggingClassifier
--------------------------------------------------------------------------------
/docs/source/_templates/ensemble_class.rst:
--------------------------------------------------------------------------------
1 | {{objname}}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | {% block methods %}
9 |
10 | {% if methods %}
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 | {% for item in methods %}
15 | {% if '__init__' not in item %}
16 | ~{{ name }}.{{ item }}
17 | {% endif %}
18 | {%- endfor %}
19 | {% endif %}
20 | {% endblock %}
21 |
22 | .. raw:: html
23 |
24 |
25 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/under-sampling.rst:
--------------------------------------------------------------------------------
1 | .. _under_sampling_api:
2 |
3 | Under-sampling-based ensembles
4 | ================================
5 |
6 | .. automodule:: imbens.ensemble
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.ensemble
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | SelfPacedEnsembleClassifier
17 | BalanceCascadeClassifier
18 | BalancedRandomForestClassifier
19 | EasyEnsembleClassifier
20 | RUSBoostClassifier
21 | UnderBaggingClassifier
--------------------------------------------------------------------------------
/docs/source/api/metrics/classification.rst:
--------------------------------------------------------------------------------
1 | .. _metrics_api:
2 |
3 | Classification Metrics
4 | ================================
5 |
6 | .. automodule:: imbens.metrics
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.metrics
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: function.rst
15 |
16 | sensitivity_specificity_support
17 | sensitivity_score
18 | specificity_score
19 | geometric_mean_score
20 | make_index_balanced_accuracy
21 | classification_report_imbalanced
22 | macro_averaged_mean_absolute_error
--------------------------------------------------------------------------------
/docs/source/_templates/class.rst:
--------------------------------------------------------------------------------
1 | {{objname}}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | {% block methods %}
9 |
10 | {% if methods %}
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 | {% for item in methods %}
15 | {% if '__init__' not in item %}
16 | ~{{ name }}.{{ item }}
17 | {% endif %}
18 | {%- endfor %}
19 | {% endif %}
20 | {% endblock %}
21 |
22 | .. include:: ../../../back_references/{{module}}.{{objname}}.examples
23 |
24 | .. raw:: html
25 |
26 |
27 |
--------------------------------------------------------------------------------
/imbens/ensemble/_over_sampling/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.ensemble._over_sampling` submodule contains
3 | a set of over-sampling-based ensemble imbalanced learning methods.
4 | """
5 |
6 | from .over_boost import OverBoostClassifier
7 | from .smote_boost import SMOTEBoostClassifier
8 | from .kmeans_smote_boost import KmeansSMOTEBoostClassifier
9 | from .smote_bagging import SMOTEBaggingClassifier
10 | from .over_bagging import OverBaggingClassifier
11 |
12 | __all__ = [
13 | "OverBoostClassifier",
14 | "SMOTEBoostClassifier",
15 | "KmeansSMOTEBoostClassifier",
16 | "OverBaggingClassifier",
17 | "SMOTEBaggingClassifier",
18 | ]
19 |
--------------------------------------------------------------------------------
/imbens/utils/tests/test_deprecation.py:
--------------------------------------------------------------------------------
1 | """Test for the deprecation helper"""
2 |
3 | # Authors: Guillaume Lemaitre
4 | # License: MIT
5 |
6 | import pytest
7 |
8 | from imbens.utils.deprecation import deprecate_parameter
9 |
10 |
11 | class Sampler:
12 | def __init__(self):
13 | self.a = "something"
14 | self.b = "something"
15 |
16 |
17 | def test_deprecate_parameter():
18 | with pytest.warns(DeprecationWarning, match="is deprecated from"):
19 | deprecate_parameter(Sampler(), "0.2", "a")
20 | with pytest.warns(DeprecationWarning, match="Use 'b' instead."):
21 | deprecate_parameter(Sampler(), "0.2", "a", "b")
22 |
--------------------------------------------------------------------------------
/imbens/sampler/_over_sampling/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.sampler._over_sampling` submodule provides a
3 | set of methods to perform over-sampling.
4 | """
5 |
6 | from ._adasyn import ADASYN
7 | from ._random_over_sampler import RandomOverSampler
8 | from ._smote import SMOTE
9 | from ._smote import BorderlineSMOTE
10 | from ._smote import KMeansSMOTE
11 | from ._smote import SVMSMOTE
12 | # from ._smote import SMOTENC
13 | # from ._smote import SMOTEN
14 |
15 | __all__ = [
16 | "ADASYN",
17 | "RandomOverSampler",
18 | "KMeansSMOTE",
19 | "SMOTE",
20 | "BorderlineSMOTE",
21 | "SVMSMOTE",
22 | # "SMOTENC",
23 | # "SMOTEN",
24 | ]
25 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.SMOTE.rst:
--------------------------------------------------------------------------------
1 | SMOTE
2 | ==================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: SMOTE
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~SMOTE.fit
18 |
19 |
20 | ~SMOTE.fit_resample
21 |
22 |
23 | ~SMOTE.get_metadata_routing
24 |
25 |
26 | ~SMOTE.get_params
27 |
28 |
29 | ~SMOTE.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.SMOTE.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.ADASYN.rst:
--------------------------------------------------------------------------------
1 | ADASYN
2 | ===================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: ADASYN
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~ADASYN.fit
18 |
19 |
20 | ~ADASYN.fit_resample
21 |
22 |
23 | ~ADASYN.get_metadata_routing
24 |
25 |
26 | ~ADASYN.get_params
27 |
28 |
29 | ~ADASYN.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.ADASYN.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.AllKNN.rst:
--------------------------------------------------------------------------------
1 | AllKNN
2 | ===================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: AllKNN
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~AllKNN.fit
18 |
19 |
20 | ~AllKNN.fit_resample
21 |
22 |
23 | ~AllKNN.get_metadata_routing
24 |
25 |
26 | ~AllKNN.get_params
27 |
28 |
29 | ~AllKNN.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.AllKNN.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/under-samplers.rst:
--------------------------------------------------------------------------------
1 | .. _under_sampling_sampler_api:
2 |
3 | Under-sampling Samplers
4 | ================================
5 |
6 | .. automodule:: imbens.sampler
7 | :no-members:
8 | :no-inherited-members:
9 |
10 | .. currentmodule:: imbens.sampler
11 |
12 | .. autosummary::
13 | :toctree: _autosummary
14 | :template: class.rst
15 |
16 | ClusterCentroids
17 | RandomUnderSampler
18 | InstanceHardnessThreshold
19 | NearMiss
20 | TomekLinks
21 | EditedNearestNeighbours
22 | RepeatedEditedNearestNeighbours
23 | AllKNN
24 | OneSidedSelection
25 | CondensedNearestNeighbour
26 | NeighbourhoodCleaningRule
27 | BalanceCascadeUnderSampler
28 | SelfPacedUnderSampler
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.NearMiss.rst:
--------------------------------------------------------------------------------
1 | NearMiss
2 | =====================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: NearMiss
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~NearMiss.fit
18 |
19 |
20 | ~NearMiss.fit_resample
21 |
22 |
23 | ~NearMiss.get_metadata_routing
24 |
25 |
26 | ~NearMiss.get_params
27 |
28 |
29 | ~NearMiss.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.NearMiss.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.SVMSMOTE.rst:
--------------------------------------------------------------------------------
1 | SVMSMOTE
2 | =====================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: SVMSMOTE
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~SVMSMOTE.fit
18 |
19 |
20 | ~SVMSMOTE.fit_resample
21 |
22 |
23 | ~SVMSMOTE.get_metadata_routing
24 |
25 |
26 | ~SVMSMOTE.get_params
27 |
28 |
29 | ~SVMSMOTE.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.SVMSMOTE.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.KMeansSMOTE.rst:
--------------------------------------------------------------------------------
1 | KMeansSMOTE
2 | ========================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: KMeansSMOTE
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~KMeansSMOTE.fit
18 |
19 |
20 | ~KMeansSMOTE.fit_resample
21 |
22 |
23 | ~KMeansSMOTE.get_metadata_routing
24 |
25 |
26 | ~KMeansSMOTE.get_params
27 |
28 |
29 | ~KMeansSMOTE.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.KMeansSMOTE.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/imbens/ensemble/_under_sampling/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.ensemble._under_sampling` submodule contains
3 | a set of under-sampling-based ensemble imbalanced learning methods.
4 | """
5 |
6 | from .self_paced_ensemble import SelfPacedEnsembleClassifier
7 | from .balance_cascade import BalanceCascadeClassifier
8 | from .balanced_random_forest import BalancedRandomForestClassifier
9 | from .easy_ensemble import EasyEnsembleClassifier
10 | from .rus_boost import RUSBoostClassifier
11 | from .under_bagging import UnderBaggingClassifier
12 |
13 | __all__ = [
14 | "SelfPacedEnsembleClassifier",
15 | "BalanceCascadeClassifier",
16 | "BalancedRandomForestClassifier",
17 | "EasyEnsembleClassifier",
18 | "RUSBoostClassifier",
19 | "UnderBaggingClassifier",
20 | ]
21 |
--------------------------------------------------------------------------------
/imbens/_version.py:
--------------------------------------------------------------------------------
1 | """
2 | ``imbalanced-ensemble`` is a set of python-based ensemble learning methods for
3 | dealing with class-imbalanced classification problems in machine learning.
4 | """
5 |
6 | # Based on NiLearn, imblearn package
7 | # License: simplified BSD, MIT
8 |
9 | # PEP0440 compatible formatted version, see:
10 | # https://www.python.org/dev/peps/pep-0440/
11 | #
12 | # Generic release markers:
13 | # X.Y
14 | # X.Y.Z # For bugfix releases
15 | #
16 | # Admissible pre-release markers:
17 | # X.YaN # Alpha release
18 | # X.YbN # Beta release
19 | # X.YrcN # Release Candidate
20 | # X.Y # Final release
21 | #
22 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
23 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev'
24 | #
25 |
26 | __version__ = "0.2.3"
27 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.BorderlineSMOTE.rst:
--------------------------------------------------------------------------------
1 | BorderlineSMOTE
2 | ============================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: BorderlineSMOTE
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~BorderlineSMOTE.fit
18 |
19 |
20 | ~BorderlineSMOTE.fit_resample
21 |
22 |
23 | ~BorderlineSMOTE.get_metadata_routing
24 |
25 |
26 | ~BorderlineSMOTE.get_params
27 |
28 |
29 | ~BorderlineSMOTE.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.BorderlineSMOTE.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/visualizer/_autosummary/imbens.visualizer.ImbalancedEnsembleVisualizer.rst:
--------------------------------------------------------------------------------
1 | ImbalancedEnsembleVisualizer
2 | ============================================================
3 |
4 | .. currentmodule:: imbens.visualizer
5 |
6 | .. autoclass:: ImbalancedEnsembleVisualizer
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~ImbalancedEnsembleVisualizer.confusion_matrix_heatmap
18 |
19 |
20 | ~ImbalancedEnsembleVisualizer.fit
21 |
22 |
23 | ~ImbalancedEnsembleVisualizer.performance_lineplot
24 |
25 |
26 |
27 |
28 | .. include:: ../../../back_references/imbens.visualizer.ImbalancedEnsembleVisualizer.examples
29 |
30 | .. raw:: html
31 |
32 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.ClusterCentroids.rst:
--------------------------------------------------------------------------------
1 | ClusterCentroids
2 | =============================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: ClusterCentroids
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~ClusterCentroids.fit
18 |
19 |
20 | ~ClusterCentroids.fit_resample
21 |
22 |
23 | ~ClusterCentroids.get_metadata_routing
24 |
25 |
26 | ~ClusterCentroids.get_params
27 |
28 |
29 | ~ClusterCentroids.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.ClusterCentroids.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.TomekLinks.rst:
--------------------------------------------------------------------------------
1 | TomekLinks
2 | =======================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: TomekLinks
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~TomekLinks.fit
18 |
19 |
20 | ~TomekLinks.fit_resample
21 |
22 |
23 | ~TomekLinks.get_metadata_routing
24 |
25 |
26 | ~TomekLinks.get_params
27 |
28 |
29 | ~TomekLinks.is_tomek
30 |
31 |
32 | ~TomekLinks.set_params
33 |
34 |
35 |
36 |
37 | .. include:: ../../../back_references/imbens.sampler.TomekLinks.examples
38 |
39 | .. raw:: html
40 |
41 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.OneSidedSelection.rst:
--------------------------------------------------------------------------------
1 | OneSidedSelection
2 | ==============================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: OneSidedSelection
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~OneSidedSelection.fit
18 |
19 |
20 | ~OneSidedSelection.fit_resample
21 |
22 |
23 | ~OneSidedSelection.get_metadata_routing
24 |
25 |
26 | ~OneSidedSelection.get_params
27 |
28 |
29 | ~OneSidedSelection.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.OneSidedSelection.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.RandomOverSampler.rst:
--------------------------------------------------------------------------------
1 | RandomOverSampler
2 | ==============================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: RandomOverSampler
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~RandomOverSampler.fit
18 |
19 |
20 | ~RandomOverSampler.fit_resample
21 |
22 |
23 | ~RandomOverSampler.get_metadata_routing
24 |
25 |
26 | ~RandomOverSampler.get_params
27 |
28 |
29 | ~RandomOverSampler.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.RandomOverSampler.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.RandomUnderSampler.rst:
--------------------------------------------------------------------------------
1 | RandomUnderSampler
2 | ===============================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: RandomUnderSampler
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~RandomUnderSampler.fit
18 |
19 |
20 | ~RandomUnderSampler.fit_resample
21 |
22 |
23 | ~RandomUnderSampler.get_metadata_routing
24 |
25 |
26 | ~RandomUnderSampler.get_params
27 |
28 |
29 | ~RandomUnderSampler.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.RandomUnderSampler.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/metrics/_autosummary/imbens.metrics.ValueDifferenceMetric.rst:
--------------------------------------------------------------------------------
1 | ValueDifferenceMetric
2 | ==================================================
3 |
4 | .. currentmodule:: imbens.metrics
5 |
6 | .. autoclass:: ValueDifferenceMetric
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~ValueDifferenceMetric.fit
18 |
19 |
20 | ~ValueDifferenceMetric.get_metadata_routing
21 |
22 |
23 | ~ValueDifferenceMetric.get_params
24 |
25 |
26 | ~ValueDifferenceMetric.pairwise
27 |
28 |
29 | ~ValueDifferenceMetric.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.metrics.ValueDifferenceMetric.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.SelfPacedUnderSampler.rst:
--------------------------------------------------------------------------------
1 | SelfPacedUnderSampler
2 | ==================================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: SelfPacedUnderSampler
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~SelfPacedUnderSampler.fit
18 |
19 |
20 | ~SelfPacedUnderSampler.fit_resample
21 |
22 |
23 | ~SelfPacedUnderSampler.get_metadata_routing
24 |
25 |
26 | ~SelfPacedUnderSampler.get_params
27 |
28 |
29 | ~SelfPacedUnderSampler.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.SelfPacedUnderSampler.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.EditedNearestNeighbours.rst:
--------------------------------------------------------------------------------
1 | EditedNearestNeighbours
2 | ====================================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: EditedNearestNeighbours
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~EditedNearestNeighbours.fit
18 |
19 |
20 | ~EditedNearestNeighbours.fit_resample
21 |
22 |
23 | ~EditedNearestNeighbours.get_metadata_routing
24 |
25 |
26 | ~EditedNearestNeighbours.get_params
27 |
28 |
29 | ~EditedNearestNeighbours.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.EditedNearestNeighbours.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.CondensedNearestNeighbour.rst:
--------------------------------------------------------------------------------
1 | CondensedNearestNeighbour
2 | ======================================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: CondensedNearestNeighbour
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~CondensedNearestNeighbour.fit
18 |
19 |
20 | ~CondensedNearestNeighbour.fit_resample
21 |
22 |
23 | ~CondensedNearestNeighbour.get_metadata_routing
24 |
25 |
26 | ~CondensedNearestNeighbour.get_params
27 |
28 |
29 | ~CondensedNearestNeighbour.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.CondensedNearestNeighbour.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.InstanceHardnessThreshold.rst:
--------------------------------------------------------------------------------
1 | InstanceHardnessThreshold
2 | ======================================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: InstanceHardnessThreshold
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~InstanceHardnessThreshold.fit
18 |
19 |
20 | ~InstanceHardnessThreshold.fit_resample
21 |
22 |
23 | ~InstanceHardnessThreshold.get_metadata_routing
24 |
25 |
26 | ~InstanceHardnessThreshold.get_params
27 |
28 |
29 | ~InstanceHardnessThreshold.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.InstanceHardnessThreshold.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.NeighbourhoodCleaningRule.rst:
--------------------------------------------------------------------------------
1 | NeighbourhoodCleaningRule
2 | ======================================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: NeighbourhoodCleaningRule
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~NeighbourhoodCleaningRule.fit
18 |
19 |
20 | ~NeighbourhoodCleaningRule.fit_resample
21 |
22 |
23 | ~NeighbourhoodCleaningRule.get_metadata_routing
24 |
25 |
26 | ~NeighbourhoodCleaningRule.get_params
27 |
28 |
29 | ~NeighbourhoodCleaningRule.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.NeighbourhoodCleaningRule.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.BalanceCascadeUnderSampler.rst:
--------------------------------------------------------------------------------
1 | BalanceCascadeUnderSampler
2 | =======================================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: BalanceCascadeUnderSampler
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~BalanceCascadeUnderSampler.fit
18 |
19 |
20 | ~BalanceCascadeUnderSampler.fit_resample
21 |
22 |
23 | ~BalanceCascadeUnderSampler.get_metadata_routing
24 |
25 |
26 | ~BalanceCascadeUnderSampler.get_params
27 |
28 |
29 | ~BalanceCascadeUnderSampler.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.BalanceCascadeUnderSampler.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/imbens/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.utils` module includes various utilities.
3 | """
4 |
5 | from ._docstring import Substitution
6 |
7 | from ._evaluate import evaluate_print
8 |
9 | from ._validation import check_neighbors_object
10 | from ._validation import check_target_type
11 | from ._validation import check_sampling_strategy
12 |
13 | from ._validation_data import check_eval_datasets
14 |
15 | from ._validation_param import check_eval_metrics
16 | from ._validation_param import check_target_label_and_n_target_samples
17 | from ._validation_param import check_balancing_schedule
18 |
19 | __all__ = [
20 | "evaluate_print",
21 | "check_neighbors_object",
22 | "check_sampling_strategy",
23 | "check_target_type",
24 | "check_eval_datasets",
25 | "check_eval_metrics",
26 | "check_target_label_and_n_target_samples",
27 | "check_balancing_schedule",
28 | "Substitution",
29 | ]
30 |
--------------------------------------------------------------------------------
/imbens/utils/tests/test_testing.py:
--------------------------------------------------------------------------------
1 | """Test for the testing module"""
2 | # Authors: Guillaume Lemaitre
3 | # Christos Aridas
4 | # License: MIT
5 |
6 | import pytest
7 |
8 | from imbens.sampler.base import SamplerMixin
9 | from imbens.utils.testing import all_estimators
10 |
11 |
12 | def test_all_estimators():
13 | # check if the filtering is working with a list or a single string
14 | type_filter = "sampler"
15 | all_estimators(type_filter=type_filter)
16 | type_filter = ["sampler"]
17 | estimators = all_estimators(type_filter=type_filter)
18 | for estimator in estimators:
19 | # check that all estimators are sampler
20 | assert issubclass(estimator[1], SamplerMixin)
21 |
22 | # check that an error is raised when the type is unknown
23 | type_filter = "rnd"
24 | with pytest.raises(ValueError, match="Parameter type_filter must be 'sampler'"):
25 | all_estimators(type_filter=type_filter)
26 |
--------------------------------------------------------------------------------
/docs/source/api/sampler/_autosummary/imbens.sampler.RepeatedEditedNearestNeighbours.rst:
--------------------------------------------------------------------------------
1 | RepeatedEditedNearestNeighbours
2 | ============================================================
3 |
4 | .. currentmodule:: imbens.sampler
5 |
6 | .. autoclass:: RepeatedEditedNearestNeighbours
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~RepeatedEditedNearestNeighbours.fit
18 |
19 |
20 | ~RepeatedEditedNearestNeighbours.fit_resample
21 |
22 |
23 | ~RepeatedEditedNearestNeighbours.get_metadata_routing
24 |
25 |
26 | ~RepeatedEditedNearestNeighbours.get_params
27 |
28 |
29 | ~RepeatedEditedNearestNeighbours.set_params
30 |
31 |
32 |
33 |
34 | .. include:: ../../../back_references/imbens.sampler.RepeatedEditedNearestNeighbours.examples
35 |
36 | .. raw:: html
37 |
38 |
--------------------------------------------------------------------------------
/imbens/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.metrics` module includes score functions, performance
3 | metrics and pairwise metrics and distance computations.
4 | """
5 |
6 | from ._classification import sensitivity_specificity_support
7 | from ._classification import sensitivity_score
8 | from ._classification import specificity_score
9 | from ._classification import geometric_mean_score
10 | from ._classification import make_index_balanced_accuracy
11 | from ._classification import classification_report_imbalanced
12 | from ._classification import macro_averaged_mean_absolute_error
13 |
14 | from .pairwise import ValueDifferenceMetric
15 |
16 | __all__ = [
17 | "sensitivity_specificity_support",
18 | "sensitivity_score",
19 | "specificity_score",
20 | "geometric_mean_score",
21 | "make_index_balanced_accuracy",
22 | "classification_report_imbalanced",
23 | "macro_averaged_mean_absolute_error",
24 | "ValueDifferenceMetric",
25 | ]
26 |
--------------------------------------------------------------------------------
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 |
3 | orbs:
4 | codecov: codecov/codecov@3.2.4
5 |
6 | jobs:
7 | test39: &test-template
8 | docker:
9 | - image: cimg/python:3.9
10 | # parallelism: 10
11 | steps:
12 | - checkout
13 | - run:
14 | name: Install dependencies
15 | command: |
16 | pip install -r requirements.txt
17 | pip install pytest pytest-cov
18 | pip install pytest-circleci-parallelized
19 | - run:
20 | name: Run tests with coverage
21 | command: |
22 | pytest --cov=imbens --cov-report=xml
23 | - codecov/upload:
24 | file: coverage.xml
25 |
26 | test310:
27 | <<: *test-template
28 | docker:
29 | - image: cimg/python:3.10
30 |
31 | test311:
32 | <<: *test-template
33 | docker:
34 | - image: cimg/python:3.11
35 |
36 | workflows:
37 | version: 2
38 | build-test-deploy:
39 | jobs:
40 | - test39
41 | - test310
42 | - test311
--------------------------------------------------------------------------------
/docs/source/install.rst:
--------------------------------------------------------------------------------
1 | Install imbalanced-ensemble
2 | ***************************
3 |
4 | Prerequisites
5 | =============
6 |
7 | The following packages are requirements:
8 |
9 | * ``numpy``
10 | * ``scipy``
11 | * ``pandas``
12 | * ``joblib``
13 | * ``sklearn``
14 | * ``matplotlib``
15 | * ``seaborn``
16 |
17 | Installation
18 | ============
19 |
20 | Install from PyPI
21 | ^^^^^^^^^^^^^^^^^
22 |
23 | You can install imbalanced-ensemble from
24 | `PyPI `__ by running:
25 |
26 | .. code-block:: bash
27 |
28 | > pip install imbalanced-ensemble
29 |
30 | Please make sure the latest version is installed to avoid potential problems:
31 |
32 | .. code-block:: bash
33 |
34 | > pip install --upgrade imbalanced-ensemble
35 |
36 | Clone from GitHub
37 | ^^^^^^^^^^^^^^^^^
38 |
39 | Or you can install imbalanced-ensemble locally:
40 |
41 | .. code-block:: bash
42 |
43 | > git clone https://github.com/ZhiningLiu1998/imbalanced-ensemble.git
44 | > cd imbalanced-ensemble
45 | > pip install .
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Zhining Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/examples/datasets/plot_generate_imbalance.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | Generate an imbalanced dataset
4 | ===============================
5 |
6 | An illustration of using the
7 | :func:`~imbens.datasets.generate_imbalance_data`
8 | function to create an imbalanced dataset.
9 | """
10 |
11 | # Authors: Zhining Liu
12 | # License: MIT
13 |
14 | # %%
15 | print(__doc__)
16 |
17 | from imbens.datasets import generate_imbalance_data
18 | from imbens.utils._plot import plot_2Dprojection_and_cardinality
19 | from collections import Counter
20 |
21 | # %% [markdown]
22 | # Generate the dataset
23 | # --------------------
24 | #
25 |
26 | # %%
27 | X_train, X_test, y_train, y_test = generate_imbalance_data(
28 | n_samples=1000,
29 | weights=[0.7, 0.2, 0.1],
30 | test_size=0.5,
31 | kwargs={'n_informative': 3},
32 | )
33 |
34 | print("Train class distribution: ", Counter(y_train))
35 | print("Test class distribution: ", Counter(y_test))
36 |
37 | # %% [markdown]
38 | # Plot the generated (training) data
39 | # ----------------------------------
40 | #
41 |
42 | plot_2Dprojection_and_cardinality(X_train, y_train)
43 |
--------------------------------------------------------------------------------
/imbens/utils/tests/test_plot.py:
--------------------------------------------------------------------------------
1 | """Test utilities for plot."""
2 |
3 | # Authors: Zhining Liu
4 | # License: MIT
5 |
6 |
7 | import numpy as np
8 |
9 | from imbens.utils._plot import *
10 |
11 | X = np.array(
12 | [
13 | [2.45166, 1.86760],
14 | [1.34450, -1.30331],
15 | [1.02989, 2.89408],
16 | [-1.94577, -1.75057],
17 | [1.21726, 1.90146],
18 | [2.00194, 1.25316],
19 | [2.31968, 2.33574],
20 | [1.14769, 1.41303],
21 | [1.32018, 2.17595],
22 | [-1.74686, -1.66665],
23 | [-2.17373, -1.91466],
24 | [2.41436, 1.83542],
25 | [1.97295, 2.55534],
26 | [-2.12126, -2.43786],
27 | [1.20494, 3.20696],
28 | [-2.30158, -2.39903],
29 | [1.76006, 1.94323],
30 | [2.35825, 1.77962],
31 | [-2.06578, -2.07671],
32 | [0.00245, -0.99528],
33 | ]
34 | )
35 | y = np.array([2, 0, 2, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 1, 2, 2, 1, 0])
36 |
37 |
38 | def test_plot_scatter():
39 | plot_scatter(X, y)
40 |
41 |
42 | def test_plot_class_distribution():
43 | plot_class_distribution(y)
44 |
45 |
46 | def test_plot_2Dprojection_and_cardinality():
47 | plot_2Dprojection_and_cardinality(X, y)
48 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.BalanceCascadeClassifier.rst:
--------------------------------------------------------------------------------
1 | BalanceCascadeClassifier
2 | ======================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: BalanceCascadeClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~BalanceCascadeClassifier.decision_function
18 |
19 |
20 | ~BalanceCascadeClassifier.fit
21 |
22 |
23 | ~BalanceCascadeClassifier.get_metadata_routing
24 |
25 |
26 | ~BalanceCascadeClassifier.get_params
27 |
28 |
29 | ~BalanceCascadeClassifier.predict
30 |
31 |
32 | ~BalanceCascadeClassifier.predict_proba
33 |
34 |
35 | ~BalanceCascadeClassifier.score
36 |
37 |
38 | ~BalanceCascadeClassifier.set_fit_request
39 |
40 |
41 | ~BalanceCascadeClassifier.set_params
42 |
43 |
44 | ~BalanceCascadeClassifier.set_score_request
45 |
46 |
47 |
48 |
49 | .. include:: ../../../back_references/imbens.ensemble.BalanceCascadeClassifier.examples
50 |
51 | .. raw:: html
52 |
53 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.OverBaggingClassifier.rst:
--------------------------------------------------------------------------------
1 | OverBaggingClassifier
2 | ===================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: OverBaggingClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~OverBaggingClassifier.decision_function
18 |
19 |
20 | ~OverBaggingClassifier.fit
21 |
22 |
23 | ~OverBaggingClassifier.get_metadata_routing
24 |
25 |
26 | ~OverBaggingClassifier.get_params
27 |
28 |
29 | ~OverBaggingClassifier.predict
30 |
31 |
32 | ~OverBaggingClassifier.predict_log_proba
33 |
34 |
35 | ~OverBaggingClassifier.predict_proba
36 |
37 |
38 | ~OverBaggingClassifier.score
39 |
40 |
41 | ~OverBaggingClassifier.set_fit_request
42 |
43 |
44 | ~OverBaggingClassifier.set_params
45 |
46 |
47 | ~OverBaggingClassifier.set_score_request
48 |
49 |
50 |
51 |
52 | .. include:: ../../../back_references/imbens.ensemble.OverBaggingClassifier.examples
53 |
54 | .. raw:: html
55 |
56 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.EasyEnsembleClassifier.rst:
--------------------------------------------------------------------------------
1 | EasyEnsembleClassifier
2 | ====================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: EasyEnsembleClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~EasyEnsembleClassifier.decision_function
18 |
19 |
20 | ~EasyEnsembleClassifier.fit
21 |
22 |
23 | ~EasyEnsembleClassifier.get_metadata_routing
24 |
25 |
26 | ~EasyEnsembleClassifier.get_params
27 |
28 |
29 | ~EasyEnsembleClassifier.predict
30 |
31 |
32 | ~EasyEnsembleClassifier.predict_log_proba
33 |
34 |
35 | ~EasyEnsembleClassifier.predict_proba
36 |
37 |
38 | ~EasyEnsembleClassifier.score
39 |
40 |
41 | ~EasyEnsembleClassifier.set_fit_request
42 |
43 |
44 | ~EasyEnsembleClassifier.set_params
45 |
46 |
47 | ~EasyEnsembleClassifier.set_score_request
48 |
49 |
50 |
51 |
52 | .. include:: ../../../back_references/imbens.ensemble.EasyEnsembleClassifier.examples
53 |
54 | .. raw:: html
55 |
56 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.SMOTEBaggingClassifier.rst:
--------------------------------------------------------------------------------
1 | SMOTEBaggingClassifier
2 | ====================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: SMOTEBaggingClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~SMOTEBaggingClassifier.decision_function
18 |
19 |
20 | ~SMOTEBaggingClassifier.fit
21 |
22 |
23 | ~SMOTEBaggingClassifier.get_metadata_routing
24 |
25 |
26 | ~SMOTEBaggingClassifier.get_params
27 |
28 |
29 | ~SMOTEBaggingClassifier.predict
30 |
31 |
32 | ~SMOTEBaggingClassifier.predict_log_proba
33 |
34 |
35 | ~SMOTEBaggingClassifier.predict_proba
36 |
37 |
38 | ~SMOTEBaggingClassifier.score
39 |
40 |
41 | ~SMOTEBaggingClassifier.set_fit_request
42 |
43 |
44 | ~SMOTEBaggingClassifier.set_params
45 |
46 |
47 | ~SMOTEBaggingClassifier.set_score_request
48 |
49 |
50 |
51 |
52 | .. include:: ../../../back_references/imbens.ensemble.SMOTEBaggingClassifier.examples
53 |
54 | .. raw:: html
55 |
56 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.SelfPacedEnsembleClassifier.rst:
--------------------------------------------------------------------------------
1 | SelfPacedEnsembleClassifier
2 | =========================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: SelfPacedEnsembleClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~SelfPacedEnsembleClassifier.decision_function
18 |
19 |
20 | ~SelfPacedEnsembleClassifier.fit
21 |
22 |
23 | ~SelfPacedEnsembleClassifier.get_metadata_routing
24 |
25 |
26 | ~SelfPacedEnsembleClassifier.get_params
27 |
28 |
29 | ~SelfPacedEnsembleClassifier.predict
30 |
31 |
32 | ~SelfPacedEnsembleClassifier.predict_proba
33 |
34 |
35 | ~SelfPacedEnsembleClassifier.score
36 |
37 |
38 | ~SelfPacedEnsembleClassifier.set_fit_request
39 |
40 |
41 | ~SelfPacedEnsembleClassifier.set_params
42 |
43 |
44 | ~SelfPacedEnsembleClassifier.set_score_request
45 |
46 |
47 |
48 |
49 | .. include:: ../../../back_references/imbens.ensemble.SelfPacedEnsembleClassifier.examples
50 |
51 | .. raw:: html
52 |
53 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.UnderBaggingClassifier.rst:
--------------------------------------------------------------------------------
1 | UnderBaggingClassifier
2 | ====================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: UnderBaggingClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~UnderBaggingClassifier.decision_function
18 |
19 |
20 | ~UnderBaggingClassifier.fit
21 |
22 |
23 | ~UnderBaggingClassifier.get_metadata_routing
24 |
25 |
26 | ~UnderBaggingClassifier.get_params
27 |
28 |
29 | ~UnderBaggingClassifier.predict
30 |
31 |
32 | ~UnderBaggingClassifier.predict_log_proba
33 |
34 |
35 | ~UnderBaggingClassifier.predict_proba
36 |
37 |
38 | ~UnderBaggingClassifier.score
39 |
40 |
41 | ~UnderBaggingClassifier.set_fit_request
42 |
43 |
44 | ~UnderBaggingClassifier.set_params
45 |
46 |
47 | ~UnderBaggingClassifier.set_score_request
48 |
49 |
50 |
51 |
52 | .. include:: ../../../back_references/imbens.ensemble.UnderBaggingClassifier.examples
53 |
54 | .. raw:: html
55 |
56 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.sampler._under_sampling` submodule contains
3 | methods to under-sample a dataset.
4 | """
5 |
6 | from ._prototype_generation import ClusterCentroids
7 |
8 | from ._prototype_selection import RandomUnderSampler
9 | from ._prototype_selection import TomekLinks
10 | from ._prototype_selection import NearMiss
11 | from ._prototype_selection import CondensedNearestNeighbour
12 | from ._prototype_selection import OneSidedSelection
13 | from ._prototype_selection import NeighbourhoodCleaningRule
14 | from ._prototype_selection import EditedNearestNeighbours
15 | from ._prototype_selection import RepeatedEditedNearestNeighbours
16 | from ._prototype_selection import AllKNN
17 | from ._prototype_selection import InstanceHardnessThreshold
18 | from ._prototype_selection import BalanceCascadeUnderSampler
19 | from ._prototype_selection import SelfPacedUnderSampler
20 |
21 | __all__ = [
22 | "ClusterCentroids",
23 | "RandomUnderSampler",
24 | "InstanceHardnessThreshold",
25 | "NearMiss",
26 | "TomekLinks",
27 | "EditedNearestNeighbours",
28 | "RepeatedEditedNearestNeighbours",
29 | "AllKNN",
30 | "OneSidedSelection",
31 | "CondensedNearestNeighbour",
32 | "NeighbourhoodCleaningRule",
33 | "BalanceCascadeUnderSampler",
34 | "SelfPacedUnderSampler",
35 | ]
36 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_selection/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.sampler._under_sampling.prototype_selection`
3 | submodule contains methods that select samples in order to balance the dataset.
4 | """
5 |
6 | from ._random_under_sampler import RandomUnderSampler
7 | from ._tomek_links import TomekLinks
8 | from ._nearmiss import NearMiss
9 | from ._condensed_nearest_neighbour import CondensedNearestNeighbour
10 | from ._one_sided_selection import OneSidedSelection
11 | from ._neighbourhood_cleaning_rule import NeighbourhoodCleaningRule
12 | from ._edited_nearest_neighbours import EditedNearestNeighbours
13 | from ._edited_nearest_neighbours import RepeatedEditedNearestNeighbours
14 | from ._edited_nearest_neighbours import AllKNN
15 | from ._instance_hardness_threshold import InstanceHardnessThreshold
16 | from ._balance_cascade_under_sampler import BalanceCascadeUnderSampler
17 | from ._self_paced_under_sampler import SelfPacedUnderSampler
18 |
19 | __all__ = [
20 | "RandomUnderSampler",
21 | "InstanceHardnessThreshold",
22 | "NearMiss",
23 | "TomekLinks",
24 | "EditedNearestNeighbours",
25 | "RepeatedEditedNearestNeighbours",
26 | "AllKNN",
27 | "OneSidedSelection",
28 | "CondensedNearestNeighbour",
29 | "NeighbourhoodCleaningRule",
30 | "BalanceCascadeUnderSampler",
31 | "SelfPacedUnderSampler",
32 | ]
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.CompatibleBaggingClassifier.rst:
--------------------------------------------------------------------------------
1 | CompatibleBaggingClassifier
2 | =========================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: CompatibleBaggingClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~CompatibleBaggingClassifier.decision_function
18 |
19 |
20 | ~CompatibleBaggingClassifier.fit
21 |
22 |
23 | ~CompatibleBaggingClassifier.get_metadata_routing
24 |
25 |
26 | ~CompatibleBaggingClassifier.get_params
27 |
28 |
29 | ~CompatibleBaggingClassifier.predict
30 |
31 |
32 | ~CompatibleBaggingClassifier.predict_log_proba
33 |
34 |
35 | ~CompatibleBaggingClassifier.predict_proba
36 |
37 |
38 | ~CompatibleBaggingClassifier.score
39 |
40 |
41 | ~CompatibleBaggingClassifier.set_fit_request
42 |
43 |
44 | ~CompatibleBaggingClassifier.set_params
45 |
46 |
47 | ~CompatibleBaggingClassifier.set_score_request
48 |
49 |
50 |
51 |
52 | .. include:: ../../../back_references/imbens.ensemble.CompatibleBaggingClassifier.examples
53 |
54 | .. raw:: html
55 |
56 |
--------------------------------------------------------------------------------
/imbens/utils/tests/test_docstring.py:
--------------------------------------------------------------------------------
1 | """Test utilities for docstring."""
2 |
3 | # Authors: Guillaume Lemaitre
4 | # License: MIT
5 |
6 | import pytest
7 |
8 | from imbens.utils import Substitution
9 | from imbens.utils._docstring import _n_jobs_docstring, _random_state_docstring
10 |
11 | func_docstring = """A function.
12 |
13 | Parameters
14 | ----------
15 | xxx
16 |
17 | yyy
18 | """
19 |
20 |
21 | def func(param_1, param_2):
22 | """A function.
23 |
24 | Parameters
25 | ----------
26 | {param_1}
27 |
28 | {param_2}
29 | """
30 | return param_1, param_2
31 |
32 |
33 | cls_docstring = """A class.
34 |
35 | Parameters
36 | ----------
37 | xxx
38 |
39 | yyy
40 | """
41 |
42 |
43 | class cls:
44 | """A class.
45 |
46 | Parameters
47 | ----------
48 | {param_1}
49 |
50 | {param_2}
51 | """
52 |
53 | def __init__(self, param_1, param_2):
54 | self.param_1 = param_1
55 | self.param_2 = param_2
56 |
57 |
58 | @pytest.mark.parametrize(
59 | "obj, obj_docstring", [(func, func_docstring), (cls, cls_docstring)]
60 | )
61 | def test_docstring_inject(obj, obj_docstring):
62 | obj_injected_docstring = Substitution(param_1="xxx", param_2="yyy")(obj)
63 | assert obj_injected_docstring.__doc__ == obj_docstring
64 |
65 |
66 | def test_docstring_template():
67 | assert "random_state" in _random_state_docstring
68 | assert "n_jobs" in _n_jobs_docstring
69 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.BalancedRandomForestClassifier.rst:
--------------------------------------------------------------------------------
1 | BalancedRandomForestClassifier
2 | ============================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: BalancedRandomForestClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~BalancedRandomForestClassifier.apply
18 |
19 |
20 | ~BalancedRandomForestClassifier.decision_path
21 |
22 |
23 | ~BalancedRandomForestClassifier.fit
24 |
25 |
26 | ~BalancedRandomForestClassifier.get_metadata_routing
27 |
28 |
29 | ~BalancedRandomForestClassifier.get_params
30 |
31 |
32 | ~BalancedRandomForestClassifier.predict
33 |
34 |
35 | ~BalancedRandomForestClassifier.predict_log_proba
36 |
37 |
38 | ~BalancedRandomForestClassifier.predict_proba
39 |
40 |
41 | ~BalancedRandomForestClassifier.score
42 |
43 |
44 | ~BalancedRandomForestClassifier.set_fit_request
45 |
46 |
47 | ~BalancedRandomForestClassifier.set_params
48 |
49 |
50 | ~BalancedRandomForestClassifier.set_score_request
51 |
52 |
53 |
54 |
55 | .. include:: ../../../back_references/imbens.ensemble.BalancedRandomForestClassifier.examples
56 |
57 | .. raw:: html
58 |
59 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.AdaCostClassifier.rst:
--------------------------------------------------------------------------------
1 | AdaCostClassifier
2 | ===============================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: AdaCostClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~AdaCostClassifier.decision_function
18 |
19 |
20 | ~AdaCostClassifier.fit
21 |
22 |
23 | ~AdaCostClassifier.get_metadata_routing
24 |
25 |
26 | ~AdaCostClassifier.get_params
27 |
28 |
29 | ~AdaCostClassifier.predict
30 |
31 |
32 | ~AdaCostClassifier.predict_log_proba
33 |
34 |
35 | ~AdaCostClassifier.predict_proba
36 |
37 |
38 | ~AdaCostClassifier.score
39 |
40 |
41 | ~AdaCostClassifier.set_fit_request
42 |
43 |
44 | ~AdaCostClassifier.set_params
45 |
46 |
47 | ~AdaCostClassifier.set_score_request
48 |
49 |
50 | ~AdaCostClassifier.staged_decision_function
51 |
52 |
53 | ~AdaCostClassifier.staged_predict
54 |
55 |
56 | ~AdaCostClassifier.staged_predict_proba
57 |
58 |
59 | ~AdaCostClassifier.staged_score
60 |
61 |
62 |
63 |
64 | .. include:: ../../../back_references/imbens.ensemble.AdaCostClassifier.examples
65 |
66 | .. raw:: html
67 |
68 |
--------------------------------------------------------------------------------
/imbens/__init__.py:
--------------------------------------------------------------------------------
1 | """Toolbox for ensemble learning on class-imbalanced dataset.
2 |
3 | ``imbalanced-ensemble`` is a set of python-based ensemble learning methods for
4 | dealing with class-imbalanced classification problems in machine learning.
5 |
6 | Subpackages
7 | -----------
8 | ensemble
9 | Module which provides ensemble imbalanced learning methods.
10 | sampler
11 | Module which provides samplers for resampling class-imbalanced data.
12 | visualizer
13 | Module which provides a visualizer for convenient visualization of
14 | ensemble learning process and results.
15 | metrics
16 | Module which provides metrics to quantified the classification performance
17 | with imbalanced dataset.
18 | utils
19 | Module including various utilities.
20 | exceptions
21 | Module including custom warnings and error classes used across
22 | imbalanced-learn.
23 | pipeline
24 | Module which allowing to create pipeline with scikit-learn estimators.
25 | """
26 |
27 | from . import ensemble
28 | from . import sampler
29 | from . import visualizer
30 | from . import metrics
31 | from . import utils
32 | from . import exceptions
33 | from . import pipeline
34 | from . import datasets
35 |
36 | from .sampler.base import FunctionSampler
37 |
38 | from ._version import __version__
39 |
40 | __all__ = [
41 | "ensemble",
42 | "sampler",
43 | "visualizer",
44 | "metrics",
45 | "utils",
46 | "exceptions",
47 | "pipeline",
48 | "datasets",
49 | "FunctionSampler",
50 | "__version__",
51 | ]
52 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.RUSBoostClassifier.rst:
--------------------------------------------------------------------------------
1 | RUSBoostClassifier
2 | ================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: RUSBoostClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~RUSBoostClassifier.decision_function
18 |
19 |
20 | ~RUSBoostClassifier.fit
21 |
22 |
23 | ~RUSBoostClassifier.get_metadata_routing
24 |
25 |
26 | ~RUSBoostClassifier.get_params
27 |
28 |
29 | ~RUSBoostClassifier.predict
30 |
31 |
32 | ~RUSBoostClassifier.predict_log_proba
33 |
34 |
35 | ~RUSBoostClassifier.predict_proba
36 |
37 |
38 | ~RUSBoostClassifier.score
39 |
40 |
41 | ~RUSBoostClassifier.set_fit_request
42 |
43 |
44 | ~RUSBoostClassifier.set_params
45 |
46 |
47 | ~RUSBoostClassifier.set_score_request
48 |
49 |
50 | ~RUSBoostClassifier.staged_decision_function
51 |
52 |
53 | ~RUSBoostClassifier.staged_predict
54 |
55 |
56 | ~RUSBoostClassifier.staged_predict_proba
57 |
58 |
59 | ~RUSBoostClassifier.staged_score
60 |
61 |
62 |
63 |
64 | .. include:: ../../../back_references/imbens.ensemble.RUSBoostClassifier.examples
65 |
66 | .. raw:: html
67 |
68 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.AdaUBoostClassifier.rst:
--------------------------------------------------------------------------------
1 | AdaUBoostClassifier
2 | =================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: AdaUBoostClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~AdaUBoostClassifier.decision_function
18 |
19 |
20 | ~AdaUBoostClassifier.fit
21 |
22 |
23 | ~AdaUBoostClassifier.get_metadata_routing
24 |
25 |
26 | ~AdaUBoostClassifier.get_params
27 |
28 |
29 | ~AdaUBoostClassifier.predict
30 |
31 |
32 | ~AdaUBoostClassifier.predict_log_proba
33 |
34 |
35 | ~AdaUBoostClassifier.predict_proba
36 |
37 |
38 | ~AdaUBoostClassifier.score
39 |
40 |
41 | ~AdaUBoostClassifier.set_fit_request
42 |
43 |
44 | ~AdaUBoostClassifier.set_params
45 |
46 |
47 | ~AdaUBoostClassifier.set_score_request
48 |
49 |
50 | ~AdaUBoostClassifier.staged_decision_function
51 |
52 |
53 | ~AdaUBoostClassifier.staged_predict
54 |
55 |
56 | ~AdaUBoostClassifier.staged_predict_proba
57 |
58 |
59 | ~AdaUBoostClassifier.staged_score
60 |
61 |
62 |
63 |
64 | .. include:: ../../../back_references/imbens.ensemble.AdaUBoostClassifier.examples
65 |
66 | .. raw:: html
67 |
68 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.AsymBoostClassifier.rst:
--------------------------------------------------------------------------------
1 | AsymBoostClassifier
2 | =================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: AsymBoostClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~AsymBoostClassifier.decision_function
18 |
19 |
20 | ~AsymBoostClassifier.fit
21 |
22 |
23 | ~AsymBoostClassifier.get_metadata_routing
24 |
25 |
26 | ~AsymBoostClassifier.get_params
27 |
28 |
29 | ~AsymBoostClassifier.predict
30 |
31 |
32 | ~AsymBoostClassifier.predict_log_proba
33 |
34 |
35 | ~AsymBoostClassifier.predict_proba
36 |
37 |
38 | ~AsymBoostClassifier.score
39 |
40 |
41 | ~AsymBoostClassifier.set_fit_request
42 |
43 |
44 | ~AsymBoostClassifier.set_params
45 |
46 |
47 | ~AsymBoostClassifier.set_score_request
48 |
49 |
50 | ~AsymBoostClassifier.staged_decision_function
51 |
52 |
53 | ~AsymBoostClassifier.staged_predict
54 |
55 |
56 | ~AsymBoostClassifier.staged_predict_proba
57 |
58 |
59 | ~AsymBoostClassifier.staged_score
60 |
61 |
62 |
63 |
64 | .. include:: ../../../back_references/imbens.ensemble.AsymBoostClassifier.examples
65 |
66 | .. raw:: html
67 |
68 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.OverBoostClassifier.rst:
--------------------------------------------------------------------------------
1 | OverBoostClassifier
2 | =================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: OverBoostClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~OverBoostClassifier.decision_function
18 |
19 |
20 | ~OverBoostClassifier.fit
21 |
22 |
23 | ~OverBoostClassifier.get_metadata_routing
24 |
25 |
26 | ~OverBoostClassifier.get_params
27 |
28 |
29 | ~OverBoostClassifier.predict
30 |
31 |
32 | ~OverBoostClassifier.predict_log_proba
33 |
34 |
35 | ~OverBoostClassifier.predict_proba
36 |
37 |
38 | ~OverBoostClassifier.score
39 |
40 |
41 | ~OverBoostClassifier.set_fit_request
42 |
43 |
44 | ~OverBoostClassifier.set_params
45 |
46 |
47 | ~OverBoostClassifier.set_score_request
48 |
49 |
50 | ~OverBoostClassifier.staged_decision_function
51 |
52 |
53 | ~OverBoostClassifier.staged_predict
54 |
55 |
56 | ~OverBoostClassifier.staged_predict_proba
57 |
58 |
59 | ~OverBoostClassifier.staged_score
60 |
61 |
62 |
63 |
64 | .. include:: ../../../back_references/imbens.ensemble.OverBoostClassifier.examples
65 |
66 | .. raw:: html
67 |
68 |
--------------------------------------------------------------------------------
/docs/source/api/pipeline/_autosummary/imbens.pipeline.Pipeline.rst:
--------------------------------------------------------------------------------
1 | Pipeline
2 | ======================================
3 |
4 | .. currentmodule:: imbens.pipeline
5 |
6 | .. autoclass:: Pipeline
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~Pipeline.decision_function
18 |
19 |
20 | ~Pipeline.fit
21 |
22 |
23 | ~Pipeline.fit_predict
24 |
25 |
26 | ~Pipeline.fit_resample
27 |
28 |
29 | ~Pipeline.fit_transform
30 |
31 |
32 | ~Pipeline.get_feature_names_out
33 |
34 |
35 | ~Pipeline.get_metadata_routing
36 |
37 |
38 | ~Pipeline.get_params
39 |
40 |
41 | ~Pipeline.inverse_transform
42 |
43 |
44 | ~Pipeline.predict
45 |
46 |
47 | ~Pipeline.predict_log_proba
48 |
49 |
50 | ~Pipeline.predict_proba
51 |
52 |
53 | ~Pipeline.score
54 |
55 |
56 | ~Pipeline.score_samples
57 |
58 |
59 | ~Pipeline.set_fit_request
60 |
61 |
62 | ~Pipeline.set_output
63 |
64 |
65 | ~Pipeline.set_params
66 |
67 |
68 | ~Pipeline.set_score_request
69 |
70 |
71 | ~Pipeline.transform
72 |
73 |
74 |
75 |
76 | .. include:: ../../../back_references/imbens.pipeline.Pipeline.examples
77 |
78 | .. raw:: html
79 |
80 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.SMOTEBoostClassifier.rst:
--------------------------------------------------------------------------------
1 | SMOTEBoostClassifier
2 | ==================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: SMOTEBoostClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~SMOTEBoostClassifier.decision_function
18 |
19 |
20 | ~SMOTEBoostClassifier.fit
21 |
22 |
23 | ~SMOTEBoostClassifier.get_metadata_routing
24 |
25 |
26 | ~SMOTEBoostClassifier.get_params
27 |
28 |
29 | ~SMOTEBoostClassifier.predict
30 |
31 |
32 | ~SMOTEBoostClassifier.predict_log_proba
33 |
34 |
35 | ~SMOTEBoostClassifier.predict_proba
36 |
37 |
38 | ~SMOTEBoostClassifier.score
39 |
40 |
41 | ~SMOTEBoostClassifier.set_fit_request
42 |
43 |
44 | ~SMOTEBoostClassifier.set_params
45 |
46 |
47 | ~SMOTEBoostClassifier.set_score_request
48 |
49 |
50 | ~SMOTEBoostClassifier.staged_decision_function
51 |
52 |
53 | ~SMOTEBoostClassifier.staged_predict
54 |
55 |
56 | ~SMOTEBoostClassifier.staged_predict_proba
57 |
58 |
59 | ~SMOTEBoostClassifier.staged_score
60 |
61 |
62 |
63 |
64 | .. include:: ../../../back_references/imbens.ensemble.SMOTEBoostClassifier.examples
65 |
66 | .. raw:: html
67 |
68 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.KmeansSMOTEBoostClassifier.rst:
--------------------------------------------------------------------------------
1 | KmeansSMOTEBoostClassifier
2 | ========================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: KmeansSMOTEBoostClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~KmeansSMOTEBoostClassifier.decision_function
18 |
19 |
20 | ~KmeansSMOTEBoostClassifier.fit
21 |
22 |
23 | ~KmeansSMOTEBoostClassifier.get_metadata_routing
24 |
25 |
26 | ~KmeansSMOTEBoostClassifier.get_params
27 |
28 |
29 | ~KmeansSMOTEBoostClassifier.predict
30 |
31 |
32 | ~KmeansSMOTEBoostClassifier.predict_log_proba
33 |
34 |
35 | ~KmeansSMOTEBoostClassifier.predict_proba
36 |
37 |
38 | ~KmeansSMOTEBoostClassifier.score
39 |
40 |
41 | ~KmeansSMOTEBoostClassifier.set_fit_request
42 |
43 |
44 | ~KmeansSMOTEBoostClassifier.set_params
45 |
46 |
47 | ~KmeansSMOTEBoostClassifier.set_score_request
48 |
49 |
50 | ~KmeansSMOTEBoostClassifier.staged_decision_function
51 |
52 |
53 | ~KmeansSMOTEBoostClassifier.staged_predict
54 |
55 |
56 | ~KmeansSMOTEBoostClassifier.staged_predict_proba
57 |
58 |
59 | ~KmeansSMOTEBoostClassifier.staged_score
60 |
61 |
62 |
63 |
64 | .. include:: ../../../back_references/imbens.ensemble.KmeansSMOTEBoostClassifier.examples
65 |
66 | .. raw:: html
67 |
68 |
--------------------------------------------------------------------------------
/docs/source/api/ensemble/_autosummary/imbens.ensemble.CompatibleAdaBoostClassifier.rst:
--------------------------------------------------------------------------------
1 | CompatibleAdaBoostClassifier
2 | ==========================================================
3 |
4 | .. currentmodule:: imbens.ensemble
5 |
6 | .. autoclass:: CompatibleAdaBoostClassifier
7 |
8 |
9 |
10 |
11 | .. rubric:: Methods
12 |
13 | .. autosummary::
14 |
15 |
16 |
17 | ~CompatibleAdaBoostClassifier.decision_function
18 |
19 |
20 | ~CompatibleAdaBoostClassifier.fit
21 |
22 |
23 | ~CompatibleAdaBoostClassifier.get_metadata_routing
24 |
25 |
26 | ~CompatibleAdaBoostClassifier.get_params
27 |
28 |
29 | ~CompatibleAdaBoostClassifier.predict
30 |
31 |
32 | ~CompatibleAdaBoostClassifier.predict_log_proba
33 |
34 |
35 | ~CompatibleAdaBoostClassifier.predict_proba
36 |
37 |
38 | ~CompatibleAdaBoostClassifier.score
39 |
40 |
41 | ~CompatibleAdaBoostClassifier.set_fit_request
42 |
43 |
44 | ~CompatibleAdaBoostClassifier.set_params
45 |
46 |
47 | ~CompatibleAdaBoostClassifier.set_score_request
48 |
49 |
50 | ~CompatibleAdaBoostClassifier.staged_decision_function
51 |
52 |
53 | ~CompatibleAdaBoostClassifier.staged_predict
54 |
55 |
56 | ~CompatibleAdaBoostClassifier.staged_predict_proba
57 |
58 |
59 | ~CompatibleAdaBoostClassifier.staged_score
60 |
61 |
62 |
63 |
64 | .. include:: ../../../back_references/imbens.ensemble.CompatibleAdaBoostClassifier.examples
65 |
66 | .. raw:: html
67 |
68 |
--------------------------------------------------------------------------------
/imbens/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.sampler` submodule provides a
3 | set of methods to perform resampling.
4 | """
5 |
6 | from . import _under_sampling
7 | from . import _over_sampling
8 |
9 | from ._under_sampling import ClusterCentroids
10 | from ._under_sampling import RandomUnderSampler
11 | from ._under_sampling import TomekLinks
12 | from ._under_sampling import NearMiss
13 | from ._under_sampling import CondensedNearestNeighbour
14 | from ._under_sampling import OneSidedSelection
15 | from ._under_sampling import NeighbourhoodCleaningRule
16 | from ._under_sampling import EditedNearestNeighbours
17 | from ._under_sampling import RepeatedEditedNearestNeighbours
18 | from ._under_sampling import AllKNN
19 | from ._under_sampling import InstanceHardnessThreshold
20 | from ._under_sampling import BalanceCascadeUnderSampler
21 | from ._under_sampling import SelfPacedUnderSampler
22 |
23 | from ._over_sampling import ADASYN
24 | from ._over_sampling import RandomOverSampler
25 | from ._over_sampling import SMOTE
26 | from ._over_sampling import BorderlineSMOTE
27 | from ._over_sampling import KMeansSMOTE
28 | from ._over_sampling import SVMSMOTE
29 |
30 |
31 | __all__ = [
32 | "_under_sampling",
33 | "_over_sampling",
34 |
35 | "ClusterCentroids",
36 | "RandomUnderSampler",
37 | "InstanceHardnessThreshold",
38 | "NearMiss",
39 | "TomekLinks",
40 | "EditedNearestNeighbours",
41 | "RepeatedEditedNearestNeighbours",
42 | "AllKNN",
43 | "OneSidedSelection",
44 | "CondensedNearestNeighbour",
45 | "NeighbourhoodCleaningRule",
46 | "BalanceCascadeUnderSampler",
47 | "SelfPacedUnderSampler",
48 |
49 | "ADASYN",
50 | "RandomOverSampler",
51 | "KMeansSMOTE",
52 | "SMOTE",
53 | "BorderlineSMOTE",
54 | "SVMSMOTE",
55 | ]
56 |
--------------------------------------------------------------------------------
/imbens/ensemble/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The :mod:`imbens.ensemble` module contains a set of
3 | ensemble imbalanced learning methods.
4 | """
5 |
6 | from . import _under_sampling
7 | from . import _over_sampling
8 | from . import _reweighting
9 | from . import _compatible
10 |
11 | from ._under_sampling import SelfPacedEnsembleClassifier
12 | from ._under_sampling import BalanceCascadeClassifier
13 | from ._under_sampling import BalancedRandomForestClassifier
14 | from ._under_sampling import EasyEnsembleClassifier
15 | from ._under_sampling import RUSBoostClassifier
16 | from ._under_sampling import UnderBaggingClassifier
17 |
18 | from ._over_sampling import OverBoostClassifier
19 | from ._over_sampling import SMOTEBoostClassifier
20 | from ._over_sampling import KmeansSMOTEBoostClassifier
21 | from ._over_sampling import SMOTEBaggingClassifier
22 | from ._over_sampling import OverBaggingClassifier
23 |
24 | from ._reweighting import AdaCostClassifier
25 | from ._reweighting import AdaUBoostClassifier
26 | from ._reweighting import AsymBoostClassifier
27 |
28 | from ._compatible import CompatibleAdaBoostClassifier
29 | from ._compatible import CompatibleBaggingClassifier
30 |
31 | __all__ = [
32 | "_under_sampling",
33 | "_over_sampling",
34 | "_reweighting",
35 | "_compatible",
36 |
37 | "SelfPacedEnsembleClassifier",
38 | "BalanceCascadeClassifier",
39 | "BalancedRandomForestClassifier",
40 | "EasyEnsembleClassifier",
41 | "RUSBoostClassifier",
42 | "UnderBaggingClassifier",
43 |
44 | "OverBoostClassifier",
45 | "SMOTEBoostClassifier",
46 | "KmeansSMOTEBoostClassifier",
47 | "OverBaggingClassifier",
48 | "SMOTEBaggingClassifier",
49 |
50 | "AdaCostClassifier",
51 | "AdaUBoostClassifier",
52 | "AsymBoostClassifier",
53 |
54 | "CompatibleAdaBoostClassifier",
55 | "CompatibleBaggingClassifier",
56 | ]
--------------------------------------------------------------------------------
/examples/evaluation/plot_classification_report.py:
--------------------------------------------------------------------------------
1 | """
2 | =============================================
3 | Evaluate classification by compiling a report
4 | =============================================
5 |
6 | Specific metrics have been developed to evaluate classifier which has been
7 | trained using imbalanced data. "mod:`imbens` provides a classification report
8 | (:func:`imbens.metrics.classification_report_imbalanced`)
9 | similar to :mod:`sklearn`, with additional metrics specific to imbalanced
10 | learning problem.
11 | """
12 |
13 | # Adapted from imbalanced-learn
14 | # Authors: Guillaume Lemaitre
15 | # License: MIT
16 |
17 | from sklearn import datasets
18 | from sklearn.svm import LinearSVC
19 | from sklearn.model_selection import train_test_split
20 |
21 | from imbens.sampler import SMOTE
22 | from imbens import pipeline as pl
23 | from imbens.metrics import classification_report_imbalanced
24 |
25 | print(__doc__)
26 |
27 | RANDOM_STATE = 42
28 |
29 | # sphinx_gallery_thumbnail_path = '../../docs/source/_static/thumbnail.png'
30 |
31 | # Generate a dataset
32 | X, y = datasets.make_classification(
33 | n_classes=2,
34 | class_sep=2,
35 | weights=[0.1, 0.9],
36 | n_informative=10,
37 | n_redundant=1,
38 | flip_y=0,
39 | n_features=20,
40 | n_clusters_per_class=4,
41 | n_samples=5000,
42 | random_state=RANDOM_STATE,
43 | )
44 |
45 | pipeline = pl.make_pipeline(
46 | SMOTE(random_state=RANDOM_STATE), LinearSVC(random_state=RANDOM_STATE)
47 | )
48 |
49 | # Split the data
50 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=RANDOM_STATE)
51 |
52 | # Train the classifier with balancing
53 | pipeline.fit(X_train, y_train)
54 |
55 | # Test the classifier and get the prediction
56 | y_pred_bal = pipeline.predict(X_test)
57 |
58 | # Show the classification report
59 | print(classification_report_imbalanced(y_test, y_pred_bal))
60 |
--------------------------------------------------------------------------------
/imbens/utils/deprecation.py:
--------------------------------------------------------------------------------
1 | """Utilities for deprecation
2 | """
3 | # Adapted from imbalanced-learn
4 |
5 | # Authors: Guillaume Lemaitre
6 | # License: MIT
7 |
8 | import warnings
9 |
10 |
11 | def deprecate_parameter(sampler, version_deprecation, param_deprecated, new_param=None):
12 | """Helper to deprecate a parameter by another one.
13 |
14 | Parameters
15 | ----------
16 | sampler : sampler object,
17 | The object which will be inspected.
18 |
19 | version_deprecation : str,
20 | The version from which the parameter will be deprecated. The format
21 | should be ``'x.y'``
22 |
23 | param_deprecated : str,
24 | The parameter being deprecated.
25 |
26 | new_param : str,
27 | The parameter used instead of the deprecated parameter. By default, no
28 | parameter is expected.
29 |
30 | Returns
31 | -------
32 | None
33 |
34 | """
35 | x, y = version_deprecation.split(".")
36 | version_removed = x + "." + str(int(y) + 2)
37 | if new_param is None:
38 | if getattr(sampler, param_deprecated) is not None:
39 | warnings.warn(
40 | f"'{param_deprecated}' is deprecated from {version_deprecation} and "
41 | f" will be removed in {version_removed} for the estimator "
42 | f"{sampler.__class__}.",
43 | category=DeprecationWarning,
44 | )
45 | else:
46 | if getattr(sampler, param_deprecated) is not None:
47 | warnings.warn(
48 | f"'{param_deprecated}' is deprecated from {version_deprecation} and "
49 | f"will be removed in {version_removed} for the estimator "
50 | f"{sampler.__class__}. Use '{new_param}' instead.",
51 | category=DeprecationWarning,
52 | )
53 | setattr(sampler, new_param, getattr(sampler, param_deprecated))
54 |
--------------------------------------------------------------------------------
/docs/source/sphinxext/README.txt:
--------------------------------------------------------------------------------
1 | =====================================
2 | numpydoc -- Numpy's Sphinx extensions
3 | =====================================
4 |
5 | Numpy's documentation uses several custom extensions to Sphinx. These
6 | are shipped in this ``numpydoc`` package, in case you want to make use
7 | of them in third-party projects.
8 |
9 | The following extensions are available:
10 |
11 | - ``numpydoc``: support for the Numpy docstring format in Sphinx, and add
12 | the code description directives ``np-function``, ``np-cfunction``, etc.
13 | that support the Numpy docstring syntax.
14 |
15 | - ``numpydoc.traitsdoc``: For gathering documentation about Traits attributes.
16 |
17 | - ``numpydoc.plot_directives``: Adaptation of Matplotlib's ``plot::``
18 | directive. Note that this implementation may still undergo severe
19 | changes or eventually be deprecated.
20 |
21 | - ``numpydoc.only_directives``: (DEPRECATED)
22 |
23 | - ``numpydoc.autosummary``: (DEPRECATED) An ``autosummary::`` directive.
24 | Available in Sphinx 0.6.2 and (to-be) 1.0 as ``sphinx.ext.autosummary``,
25 | and it the Sphinx 1.0 version is recommended over that included in
26 | Numpydoc.
27 |
28 |
29 | numpydoc
30 | ========
31 |
32 | Numpydoc inserts a hook into Sphinx's autodoc that converts docstrings
33 | following the Numpy/Scipy format to a form palatable to Sphinx.
34 |
35 | Options
36 | -------
37 |
38 | The following options can be set in conf.py:
39 |
40 | - numpydoc_use_plots: bool
41 |
42 | Whether to produce ``plot::`` directives for Examples sections that
43 | contain ``import matplotlib``.
44 |
45 | - numpydoc_show_class_members: bool
46 |
47 | Whether to show all members of a class in the Methods and Attributes
48 | sections automatically.
49 |
50 | - numpydoc_edit_link: bool (DEPRECATED -- edit your HTML template instead)
51 |
52 | Whether to insert an edit link after docstrings.
53 |
--------------------------------------------------------------------------------
/docs/source/get_start.rst:
--------------------------------------------------------------------------------
1 | Getting Started
2 | ***************
3 |
4 | Background
5 | ====================================
6 |
7 | Class-imbalance (also known as the long-tail problem) is the fact that the
8 | classes are not represented equally in a classification problem, which is
9 | quite common in practice. For instance, fraud detection, prediction of
10 | rare adverse drug reactions and prediction gene families. Failure to account
11 | for the class imbalance often causes inaccurate and decreased predictive
12 | performance of many classification algorithms.
13 |
14 | Imbalanced learning (IL) aims
15 | to tackle the class imbalance problem to learn an unbiased model from
16 | imbalanced data. This is usually achieved by changing the training data
17 | distribution by resampling or reweighting. However, naive resampling or
18 | reweighting may introduce bias/variance to the training data, especially
19 | when the data has class-overlapping or contains noise.
20 |
21 | Ensemble imbalanced learning (EIL) is known to effectively improve typical
22 | IL solutions by combining the outputs of multiple classifiers, thereby
23 | reducing the variance introduce by resampling/reweighting.
24 |
25 | About ``imbens``
26 | ====================================
27 |
28 | ``imbens`` aims to provide users with easy-to-use EIL methods
29 | and related utilities, so that everyone can quickly deploy EIL algorithms
30 | to their tasks. The EIL methods implemented in this package have
31 | unified APIs and are compatible with other popular Python machine-learning
32 | packages such as `scikit-learn `__
33 | and `imbalanced-learn `__.
34 |
35 | ``imbens`` is an early version software and is under development.
36 | Any kinds of contributions are welcome!
37 |
38 | > Note: *many resampling algorithms and utilities are adapted from*
39 | `imbalanced-learn `__, *which is an amazing
40 | project!*
--------------------------------------------------------------------------------
/imbens/utils/tests/test_show_versions.py:
--------------------------------------------------------------------------------
1 | """Test for the show_versions helper. Based on the sklearn tests."""
2 | # Author: Alexander L. Hayes
3 | # License: MIT
4 |
5 | # %%
6 |
7 | from imbens.utils._show_versions import _get_deps_info, show_versions
8 |
9 |
10 | def test_get_deps_info():
11 | _deps_info = _get_deps_info()
12 | assert "pip" in _deps_info
13 | assert "setuptools" in _deps_info
14 | assert "imblearn" in _deps_info
15 | assert "imbens" in _deps_info
16 | assert "sklearn" in _deps_info
17 | assert "numpy" in _deps_info
18 | assert "scipy" in _deps_info
19 | assert "Cython" in _deps_info
20 | assert "pandas" in _deps_info
21 | assert "joblib" in _deps_info
22 |
23 |
24 | def test_show_versions_default(capsys):
25 | show_versions()
26 | out, err = capsys.readouterr()
27 | assert "python" in out
28 | assert "executable" in out
29 | assert "machine" in out
30 | assert "pip" in out
31 | assert "setuptools" in out
32 | assert "imblearn" in out
33 | assert "imbens" in out
34 | assert "sklearn" in out
35 | assert "numpy" in out
36 | assert "scipy" in out
37 | assert "Cython" in out
38 | assert "pandas" in out
39 | assert "joblib" in out
40 |
41 |
42 | def test_show_versions_github(capsys):
43 | show_versions(github=True)
44 | out, err = capsys.readouterr()
45 | assert "System, Dependency Information
" in out
46 | assert "**System Information**" in out
47 | assert "* python" in out
48 | assert "* executable" in out
49 | assert "* machine" in out
50 | assert "**Python Dependencies**" in out
51 | assert "* pip" in out
52 | assert "* setuptools" in out
53 | assert "* imblearn" in out
54 | assert "* imbens" in out
55 | assert "* sklearn" in out
56 | assert "* numpy" in out
57 | assert "* scipy" in out
58 | assert "* Cython" in out
59 | assert "* pandas" in out
60 | assert "* joblib" in out
61 | assert " " in out
62 |
63 |
64 | # %%
65 |
--------------------------------------------------------------------------------
/.all-contributorsrc:
--------------------------------------------------------------------------------
1 | {
2 | "files": [
3 | "README.md"
4 | ],
5 | "imageSize": 100,
6 | "commit": false,
7 | "badgeTemplate": "
-orange.svg\">",
8 | "contributors": [
9 | {
10 | "login": "ZhiningLiu1998",
11 | "name": "Zhining Liu",
12 | "avatar_url": "https://avatars.githubusercontent.com/u/26108487?v=4",
13 | "profile": "http://zhiningliu.com",
14 | "contributions": [
15 | "code",
16 | "ideas",
17 | "maintenance",
18 | "bug",
19 | "doc"
20 | ]
21 | },
22 | {
23 | "login": "leaphan",
24 | "name": "leaphan",
25 | "avatar_url": "https://avatars.githubusercontent.com/u/35593707?v=4",
26 | "profile": "https://github.com/leaphan",
27 | "contributions": [
28 | "bug"
29 | ]
30 | },
31 | {
32 | "login": "hannanhtang",
33 | "name": "hannanhtang",
34 | "avatar_url": "https://avatars.githubusercontent.com/u/23587399?v=4",
35 | "profile": "https://github.com/hannanhtang",
36 | "contributions": [
37 | "bug"
38 | ]
39 | },
40 | {
41 | "login": "huajuanren",
42 | "name": "H.J.Ren",
43 | "avatar_url": "https://avatars.githubusercontent.com/u/37321841?v=4",
44 | "profile": "https://github.com/huajuanren",
45 | "contributions": [
46 | "bug"
47 | ]
48 | },
49 | {
50 | "login": "MarcSkovMadsen",
51 | "name": "Marc Skov Madsen",
52 | "avatar_url": "https://avatars.githubusercontent.com/u/42288570?v=4",
53 | "profile": "http://datamodelsanalytics.com",
54 | "contributions": [
55 | "bug"
56 | ]
57 | }
58 | ],
59 | "contributorsPerLine": 7,
60 | "projectName": "imbalanced-ensemble",
61 | "projectOwner": "ZhiningLiu1998",
62 | "repoType": "github",
63 | "repoHost": "https://github.com",
64 | "skipCi": true,
65 | "commitConvention": "angular"
66 | }
67 |
--------------------------------------------------------------------------------
/imbens/sampler/_over_sampling/_smote/tests/test_borderline_smote.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from sklearn.neighbors import NearestNeighbors
5 | from sklearn.utils._testing import assert_allclose
6 | from sklearn.utils._testing import assert_array_equal
7 |
8 | from imbens.sampler._over_sampling import BorderlineSMOTE
9 |
10 |
11 | @pytest.fixture
12 | def data():
13 | X = np.array(
14 | [
15 | [0.11622591, -0.0317206],
16 | [0.77481731, 0.60935141],
17 | [1.25192108, -0.22367336],
18 | [0.53366841, -0.30312976],
19 | [1.52091956, -0.49283504],
20 | [-0.28162401, -2.10400981],
21 | [0.83680821, 1.72827342],
22 | [0.3084254, 0.33299982],
23 | [0.70472253, -0.73309052],
24 | [0.28893132, -0.38761769],
25 | [1.15514042, 0.0129463],
26 | [0.88407872, 0.35454207],
27 | [1.31301027, -0.92648734],
28 | [-1.11515198, -0.93689695],
29 | [-0.18410027, -0.45194484],
30 | [0.9281014, 0.53085498],
31 | [-0.14374509, 0.27370049],
32 | [-0.41635887, -0.38299653],
33 | [0.08711622, 0.93259929],
34 | [1.70580611, -0.11219234],
35 | ]
36 | )
37 | y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
38 | return X, y
39 |
40 |
41 | def test_borderline_smote_wrong_kind(data):
42 | bsmote = BorderlineSMOTE(kind="rand")
43 | with pytest.raises(ValueError, match='The possible "kind" of algorithm'):
44 | bsmote.fit_resample(*data)
45 |
46 |
47 | @pytest.mark.parametrize("kind", ["borderline-1", "borderline-2"])
48 | def test_borderline_smote(kind, data):
49 | bsmote = BorderlineSMOTE(kind=kind, random_state=42)
50 | bsmote_nn = BorderlineSMOTE(
51 | kind=kind,
52 | random_state=42,
53 | k_neighbors=NearestNeighbors(n_neighbors=6),
54 | m_neighbors=NearestNeighbors(n_neighbors=11),
55 | )
56 |
57 | X_res_1, y_res_1 = bsmote.fit_resample(*data)
58 | X_res_2, y_res_2 = bsmote_nn.fit_resample(*data)
59 |
60 | assert_allclose(X_res_1, X_res_2)
61 | assert_array_equal(y_res_1, y_res_2)
62 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_selection/tests/test_tomek_links.py:
--------------------------------------------------------------------------------
1 | """Test the module Tomek's links."""
2 | # Authors: Guillaume Lemaitre
3 | # Christos Aridas
4 | # License: MIT
5 |
6 | import numpy as np
7 | from sklearn.utils._testing import assert_array_equal
8 |
9 | from imbens.sampler._under_sampling import TomekLinks
10 |
11 | X = np.array(
12 | [
13 | [0.31230513, 0.1216318],
14 | [0.68481731, 0.51935141],
15 | [1.34192108, -0.13367336],
16 | [0.62366841, -0.21312976],
17 | [1.61091956, -0.40283504],
18 | [-0.37162401, -2.19400981],
19 | [0.74680821, 1.63827342],
20 | [0.2184254, 0.24299982],
21 | [0.61472253, -0.82309052],
22 | [0.19893132, -0.47761769],
23 | [1.06514042, -0.0770537],
24 | [0.97407872, 0.44454207],
25 | [1.40301027, -0.83648734],
26 | [-1.20515198, -1.02689695],
27 | [-0.27410027, -0.54194484],
28 | [0.8381014, 0.44085498],
29 | [-0.23374509, 0.18370049],
30 | [-0.32635887, -0.29299653],
31 | [-0.00288378, 0.84259929],
32 | [1.79580611, -0.02219234],
33 | ]
34 | )
35 | Y = np.array([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
36 |
37 |
38 | def test_tl_init():
39 | tl = TomekLinks()
40 | assert tl.n_jobs is None
41 |
42 |
43 | def test_tl_fit_resample():
44 | tl = TomekLinks()
45 | X_resampled, y_resampled = tl.fit_resample(X, Y)
46 |
47 | X_gt = np.array(
48 | [
49 | [0.31230513, 0.1216318],
50 | [0.68481731, 0.51935141],
51 | [1.34192108, -0.13367336],
52 | [0.62366841, -0.21312976],
53 | [1.61091956, -0.40283504],
54 | [-0.37162401, -2.19400981],
55 | [0.74680821, 1.63827342],
56 | [0.2184254, 0.24299982],
57 | [0.61472253, -0.82309052],
58 | [0.19893132, -0.47761769],
59 | [0.97407872, 0.44454207],
60 | [1.40301027, -0.83648734],
61 | [-1.20515198, -1.02689695],
62 | [-0.23374509, 0.18370049],
63 | [-0.32635887, -0.29299653],
64 | [-0.00288378, 0.84259929],
65 | [1.79580611, -0.02219234],
66 | ]
67 | )
68 | y_gt = np.array([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0])
69 | assert_array_equal(X_resampled, X_gt)
70 | assert_array_equal(y_resampled, y_gt)
71 |
--------------------------------------------------------------------------------
/imbens/utils/_show_versions.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility method which prints system info to help with debugging,
3 | and filing issues on GitHub.
4 | Adapted from :func:`sklearn.show_versions`,
5 | which was adapted from :func:`pandas.show_versions`
6 | """
7 | # Adapted from imbalanced-learn
8 |
9 | # Author: Alexander L. Hayes
10 | # License: MIT
11 |
12 | from .. import __version__
13 |
14 |
15 | def _get_deps_info():
16 | """Overview of the installed version of main dependencies
17 | Returns
18 | -------
19 | deps_info: dict
20 | version information on relevant Python libraries
21 | """
22 | deps = [
23 | "pip",
24 | "setuptools",
25 | "imblearn",
26 | "imbens",
27 | "sklearn",
28 | "numpy",
29 | "scipy",
30 | "Cython",
31 | "pandas",
32 | "joblib",
33 | ]
34 |
35 | deps_info = {
36 | "imbalanced-ensemble": __version__,
37 | }
38 |
39 | from importlib.metadata import PackageNotFoundError, version
40 |
41 | for modname in deps:
42 | try:
43 | deps_info[modname] = version(modname)
44 | except PackageNotFoundError:
45 | deps_info[modname] = None
46 | return deps_info
47 |
48 |
49 | def show_versions(github=False):
50 | """Print debugging information.
51 |
52 | .. versionadded:: 0.5
53 |
54 | Parameters
55 | ----------
56 | github : bool,
57 | If true, wrap system info with GitHub markup.
58 | """
59 |
60 | from sklearn.utils._show_versions import _get_sys_info
61 |
62 | _sys_info = _get_sys_info()
63 | _deps_info = _get_deps_info()
64 | _github_markup = (
65 | ""
66 | "System, Dependency Information
\n\n"
67 | "**System Information**\n\n"
68 | "{0}\n"
69 | "**Python Dependencies**\n\n"
70 | "{1}\n"
71 | " "
72 | )
73 |
74 | if github:
75 | _sys_markup = ""
76 | _deps_markup = ""
77 |
78 | for k, stat in _sys_info.items():
79 | _sys_markup += f"* {k:<10}: `{stat}`\n"
80 | for k, stat in _deps_info.items():
81 | _deps_markup += f"* {k:<10}: `{stat}`\n"
82 |
83 | print(_github_markup.format(_sys_markup, _deps_markup))
84 |
85 | else:
86 | print("\nSystem:")
87 | for k, stat in _sys_info.items():
88 | print(f"{k:>11}: {stat}")
89 |
90 | print("\nPython dependencies:")
91 | for k, stat in _deps_info.items():
92 | print(f"{k:>11}: {stat}")
93 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 | docs/source/auto_examples/
74 | docs/source/back_references/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # vscode
134 | .vscode/
--------------------------------------------------------------------------------
/imbens/sampler/_over_sampling/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class for the over-sampling method.
3 | """
4 | # Adapted from imbalanced-learn
5 |
6 | # Authors: Guillaume Lemaitre
7 | # Christos Aridas
8 | # Zhining Liu
9 | # License: MIT
10 |
11 | LOCAL_DEBUG = False
12 |
13 | if not LOCAL_DEBUG:
14 | from ..base import BaseSampler
15 | else: # pragma: no cover
16 | import sys # For local test
17 | sys.path.append("../..")
18 | from sampler.base import BaseSampler
19 |
20 |
21 | class BaseOverSampler(BaseSampler):
22 | """Base class for over-sampling algorithms.
23 |
24 | Warning: This class should not be used directly. Use the derive classes
25 | instead.
26 | """
27 |
28 | _sampling_type = "over-sampling"
29 |
30 | _sampling_strategy_docstring = """sampling_strategy : float, str, dict or callable, default='auto'
31 | Sampling information to resample the data set.
32 |
33 | - When ``float``, it corresponds to the desired ratio of the number of
34 | samples in the minority class over the number of samples in the
35 | majority class after resampling. Therefore, the ratio is expressed as
36 | :math:`\\alpha_{os} = N_{rm} / N_{M}` where :math:`N_{rm}` is the
37 | number of samples in the minority class after resampling and
38 | :math:`N_{M}` is the number of samples in the majority class.
39 |
40 | .. warning::
41 | ``float`` is only available for **binary** classification. An
42 | error is raised for multi-class classification.
43 |
44 | - When ``str``, specify the class targeted by the resampling. The
45 | number of samples in the different classes will be equalized.
46 | Possible choices are:
47 |
48 | ``'minority'``: resample only the minority class;
49 |
50 | ``'not minority'``: resample all classes but the minority class;
51 |
52 | ``'not majority'``: resample all classes but the majority class;
53 |
54 | ``'all'``: resample all classes;
55 |
56 | ``'auto'``: equivalent to ``'not majority'``.
57 |
58 | - When ``dict``, the keys correspond to the targeted classes. The
59 | values correspond to the desired number of samples for each targeted
60 | class.
61 |
62 | - When callable, function taking ``y`` and returns a ``dict``. The keys
63 | correspond to the targeted classes. The values correspond to the
64 | desired number of samples for each class.
65 | """.strip()
66 |
--------------------------------------------------------------------------------
/imbens/metrics/tests/test_score_objects.py:
--------------------------------------------------------------------------------
1 | """Test for score"""
2 | # Authors: Guillaume Lemaitre
3 | # Christos Aridas
4 | # License: MIT
5 |
6 | import pytest
7 | from sklearn.datasets import make_blobs
8 | from sklearn.metrics import make_scorer
9 | from sklearn.model_selection import GridSearchCV, train_test_split
10 | from sklearn.svm import LinearSVC
11 |
12 | from imbens.metrics import (
13 | geometric_mean_score,
14 | make_index_balanced_accuracy,
15 | sensitivity_score,
16 | specificity_score,
17 | )
18 |
19 | R_TOL = 1e-2
20 |
21 |
22 | @pytest.fixture
23 | def data():
24 | X, y = make_blobs(random_state=0, centers=2)
25 | return train_test_split(X, y, random_state=0)
26 |
27 |
28 | @pytest.mark.filterwarnings("ignore:Liblinear failed to converge")
29 | @pytest.mark.parametrize(
30 | "score, expected_score",
31 | [
32 | (sensitivity_score, 0.92),
33 | (specificity_score, 0.92),
34 | (geometric_mean_score, 0.92),
35 | (make_index_balanced_accuracy()(geometric_mean_score), 0.85),
36 | ],
37 | )
38 | @pytest.mark.parametrize("average", ["macro", "weighted", "micro"])
39 | def test_scorer_common_average(data, score, expected_score, average):
40 | X_train, X_test, y_train, _ = data
41 |
42 | scorer = make_scorer(score, pos_label=None, average=average)
43 | grid = GridSearchCV(
44 | LinearSVC(dual="auto", random_state=0),
45 | param_grid={"C": [1, 10]},
46 | scoring=scorer,
47 | cv=3,
48 | )
49 | grid.fit(X_train, y_train).predict(X_test)
50 |
51 | assert grid.best_score_ == pytest.approx(expected_score, rel=R_TOL)
52 |
53 |
54 | @pytest.mark.filterwarnings("ignore:Liblinear failed to converge")
55 | @pytest.mark.parametrize(
56 | "score, average, expected_score",
57 | [
58 | (sensitivity_score, "binary", 0.92),
59 | (specificity_score, "binary", 0.95),
60 | (geometric_mean_score, "multiclass", 0.92),
61 | (
62 | make_index_balanced_accuracy()(geometric_mean_score),
63 | "multiclass",
64 | 0.84,
65 | ),
66 | ],
67 | )
68 | def test_scorer_default_average(data, score, average, expected_score):
69 | X_train, X_test, y_train, _ = data
70 |
71 | scorer = make_scorer(score, pos_label=1, average=average)
72 | grid = GridSearchCV(
73 | LinearSVC(dual="auto", random_state=0),
74 | param_grid={"C": [1, 10]},
75 | scoring=scorer,
76 | cv=3,
77 | )
78 | grid.fit(X_train, y_train).predict(X_test)
79 |
80 | assert grid.best_score_ == pytest.approx(expected_score, rel=R_TOL)
81 |
--------------------------------------------------------------------------------
/examples/pipeline/plot_pipeline_classification.py:
--------------------------------------------------------------------------------
1 | """
2 | ====================================
3 | Usage of pipeline embedding samplers
4 | ====================================
5 |
6 | An example of the :class:`~imbens.pipeline.Pipeline` object (or
7 | :func:`~imbens.pipeline.make_pipeline` helper function) working with
8 | transformers (:class:`~sklearn.decomposition.PCA`,
9 | :class:`~sklearn.neighbors.KNeighborsClassifier` from *scikit-learn*) and resamplers
10 | (:class:`~imbens.sampler.EditedNearestNeighbours`,
11 | :class:`~imbens.sampler.SMOTE`).
12 | """
13 |
14 | # Adapted from imbalanced-learn
15 | # Authors: Christos Aridas
16 | # Guillaume Lemaitre
17 | # License: MIT
18 |
19 | # %%
20 | print(__doc__)
21 |
22 | # sphinx_gallery_thumbnail_path = '../../docs/source/_static/thumbnail.png'
23 |
24 | # %% [markdown]
25 | # Let's first create an imbalanced dataset and split in to two sets.
26 |
27 | # %%
28 | from sklearn.datasets import make_classification
29 | from sklearn.model_selection import train_test_split
30 |
31 | X, y = make_classification(
32 | n_classes=2,
33 | class_sep=1.25,
34 | weights=[0.3, 0.7],
35 | n_informative=3,
36 | n_redundant=1,
37 | flip_y=0,
38 | n_features=5,
39 | n_clusters_per_class=1,
40 | n_samples=5000,
41 | random_state=10,
42 | )
43 |
44 | X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
45 |
46 | # %% [markdown]
47 | # Now, we will create each individual steps
48 | # that we would like later to combine
49 |
50 | # %%
51 | from sklearn.decomposition import PCA
52 | from sklearn.neighbors import KNeighborsClassifier
53 | from imbens.sampler import EditedNearestNeighbours
54 | from imbens.sampler import SMOTE
55 |
56 | pca = PCA(n_components=2)
57 | enn = EditedNearestNeighbours()
58 | smote = SMOTE(random_state=0)
59 | knn = KNeighborsClassifier(n_neighbors=1)
60 |
61 | # %% [markdown]
62 | # Now, we can finally create a pipeline to specify in which order the different
63 | # transformers and samplers should be executed before to provide the data to
64 | # the final classifier.
65 |
66 | # %%
67 | from imbens.pipeline import make_pipeline
68 |
69 | model = make_pipeline(pca, enn, smote, knn)
70 |
71 | # %% [markdown]
72 | # We can now use the pipeline created as a normal classifier where resampling
73 | # will happen when calling `fit` and disabled when calling `decision_function`,
74 | # `predict_proba`, or `predict`.
75 |
76 | # %%
77 | from sklearn.metrics import classification_report
78 |
79 | model.fit(X_train, y_train)
80 | y_pred = model.predict(X_test)
81 | print(classification_report(y_test, y_pred))
82 |
--------------------------------------------------------------------------------
/imbens/sampler/_over_sampling/_smote/tests/test_svm_smote.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 | from sklearn.neighbors import NearestNeighbors
5 | from sklearn.svm import SVC
6 |
7 | from sklearn.utils._testing import assert_allclose
8 | from sklearn.utils._testing import assert_array_equal
9 |
10 | from imbens.sampler._over_sampling import SVMSMOTE
11 |
12 |
13 | @pytest.fixture
14 | def data():
15 | X = np.array(
16 | [
17 | [0.11622591, -0.0317206],
18 | [0.77481731, 0.60935141],
19 | [1.25192108, -0.22367336],
20 | [0.53366841, -0.30312976],
21 | [1.52091956, -0.49283504],
22 | [-0.28162401, -2.10400981],
23 | [0.83680821, 1.72827342],
24 | [0.3084254, 0.33299982],
25 | [0.70472253, -0.73309052],
26 | [0.28893132, -0.38761769],
27 | [1.15514042, 0.0129463],
28 | [0.88407872, 0.35454207],
29 | [1.31301027, -0.92648734],
30 | [-1.11515198, -0.93689695],
31 | [-0.18410027, -0.45194484],
32 | [0.9281014, 0.53085498],
33 | [-0.14374509, 0.27370049],
34 | [-0.41635887, -0.38299653],
35 | [0.08711622, 0.93259929],
36 | [1.70580611, -0.11219234],
37 | ]
38 | )
39 | y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
40 | return X, y
41 |
42 |
43 | def test_svm_smote(data):
44 | svm_smote = SVMSMOTE(random_state=42)
45 | svm_smote_nn = SVMSMOTE(
46 | random_state=42,
47 | k_neighbors=NearestNeighbors(n_neighbors=6),
48 | m_neighbors=NearestNeighbors(n_neighbors=11),
49 | svm_estimator=SVC(gamma="scale", random_state=42),
50 | )
51 |
52 | X_res_1, y_res_1 = svm_smote.fit_resample(*data)
53 | X_res_2, y_res_2 = svm_smote_nn.fit_resample(*data)
54 |
55 | assert_allclose(X_res_1, X_res_2)
56 | assert_array_equal(y_res_1, y_res_2)
57 |
58 |
59 | def test_svm_smote_sample_weight(data):
60 | svm_smote = SVMSMOTE(random_state=42)
61 | svm_smote_nn = SVMSMOTE(
62 | random_state=42,
63 | k_neighbors=NearestNeighbors(n_neighbors=6),
64 | m_neighbors=NearestNeighbors(n_neighbors=11),
65 | svm_estimator=SVC(gamma="scale", random_state=42),
66 | )
67 | X, y = data
68 | sample_weight = np.ones_like(y)
69 |
70 | X_res_1, y_res_1, w_res_1 = svm_smote.fit_resample(X, y, sample_weight=sample_weight)
71 | X_res_2, y_res_2, w_res_2 = svm_smote_nn.fit_resample(X, y, sample_weight=sample_weight)
72 |
73 | assert_allclose(X_res_1, X_res_2)
74 | assert_array_equal(y_res_1, y_res_2)
75 | assert_array_equal(w_res_1, w_res_2)
--------------------------------------------------------------------------------
/imbens/datasets/tests/test_imbalance.py:
--------------------------------------------------------------------------------
1 | """Test the module easy ensemble."""
2 | # Authors: Guillaume Lemaitre
3 | # Christos Aridas
4 | # License: MIT
5 |
6 | from collections import Counter
7 |
8 | import numpy as np
9 | import pytest
10 | from sklearn.datasets import fetch_openml, load_iris
11 |
12 | from imbens.datasets import make_imbalance
13 |
14 |
15 | @pytest.fixture
16 | def iris():
17 | return load_iris(return_X_y=True)
18 |
19 |
20 | @pytest.mark.parametrize(
21 | "sampling_strategy, err_msg",
22 | [
23 | ({0: -100, 1: 50, 2: 50}, "in a class cannot be negative"),
24 | ({0: 10, 1: 70}, "should be less or equal to the original"),
25 | ("random-string", "has to be a dictionary or a function"),
26 | ],
27 | )
28 | def test_make_imbalance_error(iris, sampling_strategy, err_msg):
29 | # we are reusing part of utils.check_sampling_strategy, however this is not
30 | # cover in the common tests so we will repeat it here
31 | X, y = iris
32 | with pytest.raises(ValueError, match=err_msg):
33 | make_imbalance(X, y, sampling_strategy=sampling_strategy)
34 |
35 |
36 | def test_make_imbalance_error_single_class(iris):
37 | X, y = iris
38 | y = np.zeros_like(y)
39 | with pytest.raises(ValueError, match="needs to have more than 1 class."):
40 | make_imbalance(X, y, sampling_strategy={0: 10})
41 |
42 |
43 | @pytest.mark.parametrize(
44 | "sampling_strategy, expected_counts",
45 | [
46 | ({0: 10, 1: 20, 2: 30}, {0: 10, 1: 20, 2: 30}),
47 | ({0: 10, 1: 20}, {0: 10, 1: 20, 2: 50}),
48 | ],
49 | )
50 | def test_make_imbalance_dict(iris, sampling_strategy, expected_counts):
51 | X, y = iris
52 | _, y_ = make_imbalance(X, y, sampling_strategy=sampling_strategy)
53 | assert Counter(y_) == expected_counts
54 |
55 |
56 | @pytest.mark.parametrize("as_frame", [True, False], ids=["dataframe", "array"])
57 | @pytest.mark.parametrize(
58 | "sampling_strategy, expected_counts",
59 | [
60 | (
61 | {"Iris-setosa": 10, "Iris-versicolor": 20, "Iris-virginica": 30},
62 | {"Iris-setosa": 10, "Iris-versicolor": 20, "Iris-virginica": 30},
63 | ),
64 | (
65 | {"Iris-setosa": 10, "Iris-versicolor": 20},
66 | {"Iris-setosa": 10, "Iris-versicolor": 20, "Iris-virginica": 50},
67 | ),
68 | ],
69 | )
70 | def test_make_imbalanced_iris(as_frame, sampling_strategy, expected_counts):
71 | pytest.importorskip("pandas")
72 | X, y = fetch_openml(
73 | "iris", parser='auto', version=1, return_X_y=True, as_frame=as_frame
74 | )
75 | X_res, y_res = make_imbalance(X, y, sampling_strategy=sampling_strategy)
76 | if as_frame:
77 | assert hasattr(X_res, "loc")
78 | assert Counter(y_res) == expected_counts
79 |
--------------------------------------------------------------------------------
/examples/basic/plot_basic_visualize.py:
--------------------------------------------------------------------------------
1 | """
2 | =========================================================
3 | Visualize an ensemble classifier
4 | =========================================================
5 |
6 | This example illustrates how to quickly visualize an
7 | :mod:`imbens.ensemble` classifier with
8 | the :mod:`imbens.visualizer` module.
9 |
10 | This example uses:
11 |
12 | - :class:`imbens.ensemble.SelfPacedEnsembleClassifier`
13 | - :class:`imbens.visualizer.ImbalancedEnsembleVisualizer`
14 | """
15 |
16 | # Authors: Zhining Liu
17 | # License: MIT
18 |
19 |
20 | # %%
21 | print(__doc__)
22 |
23 | # Import imbalanced-ensemble
24 | import imbens
25 |
26 | # Import utilities
27 | import sklearn
28 | from sklearn.datasets import make_classification
29 | from sklearn.model_selection import train_test_split
30 |
31 | RANDOM_STATE = 42
32 |
33 | # sphinx_gallery_thumbnail_number = 2
34 |
35 | # %% [markdown]
36 | # Prepare data
37 | # ------------
38 | # Make a toy 3-class imbalanced classification task.
39 |
40 | # make dataset
41 | X, y = make_classification(
42 | n_classes=3,
43 | class_sep=2,
44 | weights=[0.1, 0.3, 0.6],
45 | n_informative=3,
46 | n_redundant=1,
47 | flip_y=0,
48 | n_features=20,
49 | n_clusters_per_class=2,
50 | n_samples=2000,
51 | random_state=0,
52 | )
53 |
54 | # train valid split
55 | X_train, X_valid, y_train, y_valid = train_test_split(
56 | X, y, test_size=0.5, stratify=y, random_state=RANDOM_STATE
57 | )
58 |
59 |
60 | # %% [markdown]
61 | # Train an ensemble classifier
62 | # ----------------------------
63 | # Take ``SelfPacedEnsembleClassifier`` as example
64 |
65 | # Initialize and train an SPE classifier
66 | clf = imbens.ensemble.SelfPacedEnsembleClassifier(random_state=RANDOM_STATE).fit(
67 | X_train, y_train
68 | )
69 |
70 | # Store the fitted SelfPacedEnsembleClassifier
71 | fitted_ensembles = {'SPE': clf}
72 |
73 |
74 | # %% [markdown]
75 | # Fit an ImbalancedEnsembleVisualizer
76 | # -----------------------------------------------------
77 |
78 | # Initialize visualizer
79 | visualizer = imbens.visualizer.ImbalancedEnsembleVisualizer(
80 | eval_datasets={
81 | 'training': (X_train, y_train),
82 | 'validation': (X_valid, y_valid),
83 | },
84 | )
85 |
86 | # Fit visualizer
87 | visualizer.fit(fitted_ensembles)
88 |
89 |
90 | # %% [markdown]
91 | # Plot performance curve
92 | # ----------------------
93 | # **performance w.r.t. number of base estimators**
94 |
95 | fig, axes = visualizer.performance_lineplot()
96 |
97 | # %% [markdown]
98 | # Plot confusion matrix
99 | # ---------------------
100 |
101 | fig, axes = visualizer.confusion_matrix_heatmap(
102 | on_datasets=['validation'], # only on validation set
103 | sup_title=False,
104 | )
105 |
--------------------------------------------------------------------------------
/examples/datasets/plot_make_imbalance.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | Make a dataset class-imbalanced
4 | ===============================
5 |
6 | An illustration of the :func:`~imbens.datasets.make_imbalance` function to
7 | create an imbalanced dataset from a balanced dataset. We show the ability of
8 | :func:`~imbens.datasets.make_imbalance` of dealing with Pandas DataFrame.
9 | """
10 |
11 | # Adapted from imbalanced-learn
12 | # Authors: Dayvid Oliveira
13 | # Christos Aridas
14 | # Guillaume Lemaitre
15 | # Zhining Liu
16 | # License: MIT
17 |
18 | # %%
19 | print(__doc__)
20 |
21 | import matplotlib.pyplot as plt
22 | import seaborn as sns
23 |
24 | sns.set_context("poster")
25 |
26 | # %% [markdown]
27 | # Generate the dataset
28 | # --------------------
29 | #
30 | # First, we will generate a dataset and convert it to a
31 | # :class:`~pandas.DataFrame` with arbitrary column names. We will plot the
32 | # original dataset.
33 |
34 | # %%
35 | import pandas as pd
36 | from sklearn.datasets import make_moons
37 |
38 | X, y = make_moons(n_samples=200, shuffle=True, noise=0.25, random_state=10)
39 | X = pd.DataFrame(X, columns=["feature 1", "feature 2"])
40 |
41 | fig = plt.figure(figsize=(6, 5))
42 | ax = sns.scatterplot(
43 | data=X,
44 | x="feature 1",
45 | y="feature 2",
46 | hue=y,
47 | style=y,
48 | )
49 |
50 | # %% [markdown]
51 | # Make a dataset imbalanced
52 | # -------------------------
53 | #
54 | # Now, we will show the helpers :func:`~imbens.datasets.make_imbalance`
55 | # that is useful to random select a subset of samples. It will impact the
56 | # class distribution as specified by the parameters.
57 |
58 | # %%
59 | from collections import Counter
60 |
61 |
62 | def ratio_func(y, multiplier, minority_class):
63 | target_stats = Counter(y)
64 | return {minority_class: int(multiplier * target_stats[minority_class])}
65 |
66 |
67 | # %%
68 | from imbens.datasets import make_imbalance
69 |
70 | fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(15, 10))
71 |
72 | sns.scatterplot(
73 | data=X,
74 | x="feature 1",
75 | y="feature 2",
76 | hue=y,
77 | style=y,
78 | ax=axs[0, 0],
79 | )
80 | axs[0, 0].set_title("Original set")
81 |
82 | multipliers = [0.9, 0.75, 0.5, 0.25, 0.1]
83 | for ax, multiplier in zip(axs.ravel()[1:], multipliers):
84 | X_resampled, y_resampled = make_imbalance(
85 | X,
86 | y,
87 | sampling_strategy=ratio_func,
88 | **{"multiplier": multiplier, "minority_class": 1},
89 | )
90 |
91 | sns.scatterplot(
92 | data=X_resampled,
93 | x="feature 1",
94 | y="feature 2",
95 | hue=y_resampled,
96 | style=y_resampled,
97 | ax=ax,
98 | )
99 | ax.set_title(f"Sampling ratio = {multiplier}")
100 |
101 | plt.tight_layout()
102 | plt.show()
103 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py:
--------------------------------------------------------------------------------
1 | """Test the module neighbourhood cleaning rule."""
2 | # Authors: Guillaume Lemaitre
3 | # Christos Aridas
4 | # License: MIT
5 |
6 | import numpy as np
7 | import pytest
8 | from sklearn.utils._testing import assert_array_equal
9 |
10 | from imbens.sampler._under_sampling import NeighbourhoodCleaningRule
11 |
12 | X = np.array(
13 | [
14 | [1.57737838, 0.1997882],
15 | [0.8960075, 0.46130762],
16 | [0.34096173, 0.50947647],
17 | [-0.91735824, 0.93110278],
18 | [-0.14619583, 1.33009918],
19 | [-0.20413357, 0.64628718],
20 | [0.85713638, 0.91069295],
21 | [0.35967591, 2.61186964],
22 | [0.43142011, 0.52323596],
23 | [0.90701028, -0.57636928],
24 | [-1.20809175, -1.49917302],
25 | [-0.60497017, -0.66630228],
26 | [1.39272351, -0.51631728],
27 | [-1.55581933, 1.09609604],
28 | [1.55157493, -1.6981518],
29 | ]
30 | )
31 | Y = np.array([1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 0, 0, 2, 1, 2])
32 |
33 |
34 | @pytest.mark.parametrize(
35 | "ncr_params, err_msg",
36 | [
37 | ({"threshold_cleaning": -10}, "value between 0 and 1"),
38 | ({"threshold_cleaning": 10}, "value between 0 and 1"),
39 | ({"n_neighbors": "rnd"}, "has to be one of"),
40 | ],
41 | )
42 | def test_ncr_error(ncr_params, err_msg):
43 | ncr = NeighbourhoodCleaningRule(**ncr_params)
44 | with pytest.raises(ValueError, match=err_msg):
45 | ncr.fit_resample(X, Y)
46 |
47 |
48 | def test_ncr_fit_resample():
49 | ncr = NeighbourhoodCleaningRule()
50 | X_resampled, y_resampled = ncr.fit_resample(X, Y)
51 |
52 | X_gt = np.array(
53 | [
54 | [0.34096173, 0.50947647],
55 | [-0.91735824, 0.93110278],
56 | [-0.20413357, 0.64628718],
57 | [0.35967591, 2.61186964],
58 | [0.90701028, -0.57636928],
59 | [-1.20809175, -1.49917302],
60 | [-0.60497017, -0.66630228],
61 | [1.39272351, -0.51631728],
62 | [-1.55581933, 1.09609604],
63 | [1.55157493, -1.6981518],
64 | ]
65 | )
66 | y_gt = np.array([1, 1, 1, 2, 2, 0, 0, 2, 1, 2])
67 | assert_array_equal(X_resampled, X_gt)
68 | assert_array_equal(y_resampled, y_gt)
69 |
70 |
71 | def test_ncr_fit_resample_mode():
72 | ncr = NeighbourhoodCleaningRule(kind_sel="mode")
73 | X_resampled, y_resampled = ncr.fit_resample(X, Y)
74 |
75 | X_gt = np.array(
76 | [
77 | [0.34096173, 0.50947647],
78 | [-0.91735824, 0.93110278],
79 | [-0.20413357, 0.64628718],
80 | [0.35967591, 2.61186964],
81 | [0.90701028, -0.57636928],
82 | [-1.20809175, -1.49917302],
83 | [-0.60497017, -0.66630228],
84 | [1.39272351, -0.51631728],
85 | [-1.55581933, 1.09609604],
86 | [1.55157493, -1.6981518],
87 | ]
88 | )
89 | y_gt = np.array([1, 1, 1, 2, 2, 0, 0, 2, 1, 2])
90 | assert_array_equal(X_resampled, X_gt)
91 | assert_array_equal(y_resampled, y_gt)
92 |
--------------------------------------------------------------------------------
/imbens/datasets/tests/test_zenodo.py:
--------------------------------------------------------------------------------
1 | """Test the datasets loader.
2 |
3 | Skipped if datasets is not already downloaded to data_home.
4 | """
5 |
6 | # Authors: Guillaume Lemaitre
7 | # Christos Aridas
8 | # License: MIT
9 |
10 | import pytest
11 | from sklearn.utils._testing import SkipTest
12 |
13 | from imbens.datasets import fetch_zenodo_datasets
14 |
15 | DATASET_SHAPE = {
16 | "ecoli": (336, 7),
17 | "optical_digits": (5620, 64),
18 | "satimage": (6435, 36),
19 | "pen_digits": (10992, 16),
20 | "abalone": (4177, 10),
21 | "sick_euthyroid": (3163, 42),
22 | "spectrometer": (531, 93),
23 | "car_eval_34": (1728, 21),
24 | "isolet": (7797, 617),
25 | "us_crime": (1994, 100),
26 | "yeast_ml8": (2417, 103),
27 | "scene": (2407, 294),
28 | "libras_move": (360, 90),
29 | "thyroid_sick": (3772, 52),
30 | "coil_2000": (9822, 85),
31 | "arrhythmia": (452, 278),
32 | "solar_flare_m0": (1389, 32),
33 | "oil": (937, 49),
34 | "car_eval_4": (1728, 21),
35 | "wine_quality": (4898, 11),
36 | "letter_img": (20000, 16),
37 | "yeast_me2": (1484, 8),
38 | "webpage": (34780, 300),
39 | "ozone_level": (2536, 72),
40 | "mammography": (11183, 6),
41 | "protein_homo": (145751, 74),
42 | "abalone_19": (4177, 10),
43 | }
44 |
45 |
46 | def fetch(*args, **kwargs):
47 | return fetch_zenodo_datasets(*args, download_if_missing=True, **kwargs)
48 |
49 |
50 | @pytest.mark.xfail
51 | def test_fetch():
52 | try:
53 | datasets1 = fetch(shuffle=True, random_state=42)
54 | except IOError:
55 | raise SkipTest("Zenodo dataset can not be loaded.")
56 |
57 | datasets2 = fetch(shuffle=True, random_state=37)
58 |
59 | for k in DATASET_SHAPE.keys():
60 |
61 | X1, X2 = datasets1[k].data, datasets2[k].data
62 | assert DATASET_SHAPE[k] == X1.shape
63 | assert X1.shape == X2.shape
64 |
65 | y1, y2 = datasets1[k].target, datasets2[k].target
66 | assert (X1.shape[0],) == y1.shape
67 | assert (X1.shape[0],) == y2.shape
68 |
69 |
70 | def test_fetch_filter():
71 | try:
72 | datasets1 = fetch(filter_data=tuple([1]), shuffle=True, random_state=42)
73 | except IOError:
74 | raise SkipTest("Zenodo dataset can not be loaded.")
75 |
76 | datasets2 = fetch(filter_data=tuple(["ecoli"]), shuffle=True, random_state=37)
77 |
78 | X1, X2 = datasets1["ecoli"].data, datasets2["ecoli"].data
79 | assert DATASET_SHAPE["ecoli"] == X1.shape
80 | assert X1.shape == X2.shape
81 |
82 | assert X1.sum() == pytest.approx(X2.sum())
83 |
84 | y1, y2 = datasets1["ecoli"].target, datasets2["ecoli"].target
85 | assert (X1.shape[0],) == y1.shape
86 | assert (X1.shape[0],) == y2.shape
87 |
88 |
89 | @pytest.mark.parametrize(
90 | "filter_data, err_msg",
91 | [
92 | (("rnf",), "is not a dataset available"),
93 | ((-1,), "dataset with the ID="),
94 | ((100,), "dataset with the ID="),
95 | ((1.00,), "value in the tuple"),
96 | ],
97 | )
98 | def test_fetch_error(filter_data, err_msg):
99 | with pytest.raises(ValueError, match=err_msg):
100 | fetch_zenodo_datasets(filter_data=filter_data)
101 |
--------------------------------------------------------------------------------
/imbens/ensemble/_compatible/tests/test_adabost_compatible.py:
--------------------------------------------------------------------------------
1 | """Test CompatibleAdaBoostClassifier."""
2 |
3 | # Authors: Guillaume Lemaitre
4 | # Christos Aridas
5 | # Zhining Liu
6 | # License: MIT
7 |
8 | import numpy as np
9 | import pytest
10 | import sklearn
11 | from sklearn.datasets import load_iris, make_classification
12 | from sklearn.model_selection import train_test_split
13 | from sklearn.utils._testing import assert_array_equal
14 | from sklearn.utils.fixes import parse_version
15 |
16 | from imbens.ensemble import CompatibleAdaBoostClassifier
17 |
18 | sklearn_version = parse_version(sklearn.__version__)
19 |
20 |
21 | @pytest.fixture
22 | def imbalanced_dataset():
23 | return make_classification(
24 | n_samples=10000,
25 | n_features=3,
26 | n_informative=2,
27 | n_redundant=0,
28 | n_repeated=0,
29 | n_classes=3,
30 | n_clusters_per_class=1,
31 | weights=[0.01, 0.05, 0.94],
32 | class_sep=0.8,
33 | random_state=0,
34 | )
35 |
36 |
37 | @pytest.mark.parametrize("algorithm", ["SAMME"])
38 | def test_adaboost(imbalanced_dataset, algorithm):
39 | X, y = imbalanced_dataset
40 | X_train, X_test, y_train, y_test = train_test_split(
41 | X, y, stratify=y, random_state=1
42 | )
43 | classes = np.unique(y)
44 |
45 | n_estimators = 500
46 | adaboost = CompatibleAdaBoostClassifier(
47 | n_estimators=n_estimators, algorithm=algorithm, random_state=0
48 | )
49 | adaboost.fit(X_train, y_train)
50 | assert_array_equal(classes, adaboost.classes_)
51 |
52 | # check that we have an ensemble of estimators with a
53 | # consistent size
54 | assert len(adaboost.estimators_) > 1
55 |
56 | # each estimator in the ensemble should have different random state
57 | assert len({est.random_state for est in adaboost.estimators_}) == len(
58 | adaboost.estimators_
59 | )
60 |
61 | # check the consistency of the feature importances
62 | assert len(adaboost.feature_importances_) == imbalanced_dataset[0].shape[1]
63 |
64 | # check the consistency of the prediction outpus
65 | y_pred = adaboost.predict_proba(X_test)
66 | assert y_pred.shape[1] == len(classes)
67 | assert adaboost.decision_function(X_test).shape[1] == len(classes)
68 |
69 | score = adaboost.score(X_test, y_test)
70 | assert score > 0.6, f"Failed with algorithm {algorithm} and score {score}"
71 |
72 | y_pred = adaboost.predict(X_test)
73 | assert y_pred.shape == y_test.shape
74 |
75 |
76 | @pytest.mark.parametrize("algorithm", ["SAMME"])
77 | def test_adaboost_sample_weight(imbalanced_dataset, algorithm):
78 | X, y = imbalanced_dataset
79 | sample_weight = np.ones_like(y)
80 | adaboost = CompatibleAdaBoostClassifier(algorithm=algorithm, random_state=0)
81 |
82 | # Predictions should be the same when sample_weight are all ones
83 | y_pred_sample_weight = adaboost.fit(X, y, sample_weight=sample_weight).predict(X)
84 | y_pred_no_sample_weight = adaboost.fit(X, y).predict(X)
85 |
86 | assert_array_equal(y_pred_sample_weight, y_pred_no_sample_weight)
87 |
88 | rng = np.random.RandomState(42)
89 | sample_weight = rng.rand(y.shape[0])
90 | y_pred_sample_weight = adaboost.fit(X, y, sample_weight=sample_weight).predict(X)
91 |
92 | with pytest.raises(AssertionError):
93 | assert_array_equal(y_pred_no_sample_weight, y_pred_sample_weight)
94 |
--------------------------------------------------------------------------------
/examples/evaluation/plot_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | =======================================
3 | Metrics specific to imbalanced learning
4 | =======================================
5 |
6 | Specific metrics have been developed to evaluate classifier which
7 | has been trained using imbalanced data. :mod:`imbens` provides mainly
8 | two additional metrics which are not implemented in :mod:`sklearn`: (i)
9 | geometric mean (:func:`imbens.metrics.geometric_mean_score`)
10 | and (ii) index balanced accuracy (:func:`imbens.metrics.make_index_balanced_accuracy`).
11 | """
12 |
13 | # Adapted from imbalanced-learn
14 | # Authors: Guillaume Lemaitre
15 | # License: MIT
16 |
17 | # %%
18 | print(__doc__)
19 |
20 | RANDOM_STATE = 42
21 |
22 | # sphinx_gallery_thumbnail_path = '../../docs/source/_static/thumbnail.png'
23 |
24 | # %% [markdown]
25 | # First, we will generate some imbalanced dataset.
26 |
27 | # %%
28 | from sklearn.datasets import make_classification
29 |
30 | X, y = make_classification(
31 | n_classes=3,
32 | class_sep=2,
33 | weights=[0.1, 0.9],
34 | n_informative=10,
35 | n_redundant=1,
36 | flip_y=0,
37 | n_features=20,
38 | n_clusters_per_class=4,
39 | n_samples=5000,
40 | random_state=RANDOM_STATE,
41 | )
42 |
43 | # %% [markdown]
44 | # We will split the data into a training and testing set.
45 |
46 | # %%
47 | from sklearn.model_selection import train_test_split
48 |
49 | X_train, X_test, y_train, y_test = train_test_split(
50 | X, y, stratify=y, random_state=RANDOM_STATE
51 | )
52 |
53 | # %% [markdown]
54 | # We will create a pipeline made of a :class:`~imbens.sampler.SMOTE`
55 | # over-sampler followed by a :class:`~sklearn.svm.LinearSVC` classifier.
56 |
57 | # %%
58 | from imbens.pipeline import make_pipeline
59 | from imbens.sampler import SMOTE
60 | from sklearn.svm import LinearSVC
61 |
62 | model = make_pipeline(
63 | SMOTE(random_state=RANDOM_STATE), LinearSVC(random_state=RANDOM_STATE)
64 | )
65 |
66 | # %% [markdown]
67 | # Now, we will train the model on the training set and get the prediction
68 | # associated with the testing set. Be aware that the resampling will happen
69 | # only when calling `fit`: the number of samples in `y_pred` is the same than
70 | # in `y_test`.
71 |
72 | # %%
73 | model.fit(X_train, y_train)
74 | y_pred = model.predict(X_test)
75 |
76 | # %% [markdown]
77 | # The geometric mean corresponds to the square root of the product of the
78 | # sensitivity and specificity. Combining the two metrics should account for
79 | # the balancing of the dataset.
80 |
81 | # %%
82 | from imbens.metrics import geometric_mean_score
83 |
84 | print(f"The geometric mean is {geometric_mean_score(y_test, y_pred):.3f}")
85 |
86 | # %% [markdown]
87 | # The index balanced accuracy can transform any metric to be used in
88 | # imbalanced learning problems.
89 |
90 | # %%
91 | from imbens.metrics import make_index_balanced_accuracy
92 |
93 | alpha = 0.1
94 | geo_mean = make_index_balanced_accuracy(alpha=alpha, squared=True)(geometric_mean_score)
95 |
96 | print(
97 | f"The IBA using alpha={alpha} and the geometric mean: "
98 | f"{geo_mean(y_test, y_pred):.3f}"
99 | )
100 |
101 | # %%
102 | alpha = 0.5
103 | geo_mean = make_index_balanced_accuracy(alpha=alpha, squared=True)(geometric_mean_score)
104 |
105 | print(
106 | f"The IBA using alpha={alpha} and the geometric mean: "
107 | f"{geo_mean(y_test, y_pred):.3f}"
108 | )
109 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_selection/tests/test_instance_hardness_threshold.py:
--------------------------------------------------------------------------------
1 | """Test the module ."""
2 | # Authors: Guillaume Lemaitre
3 | # Christos Aridas
4 | # License: MIT
5 |
6 | import numpy as np
7 | import pytest
8 | from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
9 | from sklearn.naive_bayes import GaussianNB as NB
10 | from sklearn.utils._testing import assert_array_equal
11 |
12 | from imbens.sampler._under_sampling import InstanceHardnessThreshold
13 |
14 | RND_SEED = 0
15 | X = np.array(
16 | [
17 | [-0.3879569, 0.6894251],
18 | [-0.09322739, 1.28177189],
19 | [-0.77740357, 0.74097941],
20 | [0.91542919, -0.65453327],
21 | [-0.03852113, 0.40910479],
22 | [-0.43877303, 1.07366684],
23 | [-0.85795321, 0.82980738],
24 | [-0.18430329, 0.52328473],
25 | [-0.30126957, -0.66268378],
26 | [-0.65571327, 0.42412021],
27 | [-0.28305528, 0.30284991],
28 | [0.20246714, -0.34727125],
29 | [1.06446472, -1.09279772],
30 | [0.30543283, -0.02589502],
31 | [-0.00717161, 0.00318087],
32 | ]
33 | )
34 | Y = np.array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0])
35 | ESTIMATOR = GradientBoostingClassifier(random_state=RND_SEED)
36 |
37 |
38 | def test_iht_init():
39 | sampling_strategy = "auto"
40 | iht = InstanceHardnessThreshold(
41 | estimator=ESTIMATOR,
42 | sampling_strategy=sampling_strategy,
43 | random_state=RND_SEED,
44 | )
45 |
46 | assert iht.sampling_strategy == sampling_strategy
47 | assert iht.random_state == RND_SEED
48 |
49 |
50 | def test_iht_fit_resample():
51 | iht = InstanceHardnessThreshold(estimator=ESTIMATOR, random_state=RND_SEED)
52 | X_resampled, y_resampled = iht.fit_resample(X, Y)
53 | assert X_resampled.shape == (12, 2)
54 | assert y_resampled.shape == (12,)
55 |
56 |
57 | def test_iht_fit_resample_half():
58 | sampling_strategy = {0: 3, 1: 3}
59 | iht = InstanceHardnessThreshold(
60 | estimator=NB(),
61 | sampling_strategy=sampling_strategy,
62 | random_state=RND_SEED,
63 | )
64 | X_resampled, y_resampled = iht.fit_resample(X, Y)
65 | assert X_resampled.shape == (6, 2)
66 | assert y_resampled.shape == (6,)
67 |
68 |
69 | def test_iht_fit_resample_class_obj():
70 | est = GradientBoostingClassifier(random_state=RND_SEED)
71 | iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
72 | X_resampled, y_resampled = iht.fit_resample(X, Y)
73 | assert X_resampled.shape == (12, 2)
74 | assert y_resampled.shape == (12,)
75 |
76 |
77 | def test_iht_fit_resample_wrong_class_obj():
78 | from sklearn.cluster import KMeans
79 |
80 | est = KMeans(n_init='auto')
81 | iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
82 | with pytest.raises(ValueError, match="Invalid parameter `estimator`"):
83 | iht.fit_resample(X, Y)
84 |
85 |
86 | def test_iht_reproducibility():
87 | from sklearn.datasets import load_digits
88 |
89 | X_digits, y_digits = load_digits(return_X_y=True)
90 | idx_sampled = []
91 | for seed in range(5):
92 | est = RandomForestClassifier(n_estimators=10, random_state=seed)
93 | iht = InstanceHardnessThreshold(estimator=est, random_state=RND_SEED)
94 | iht.fit_resample(X_digits, y_digits)
95 | idx_sampled.append(iht.sample_indices_.copy())
96 | for idx_1, idx_2 in zip(idx_sampled, idx_sampled[1:]):
97 | assert_array_equal(idx_1, idx_2)
98 |
--------------------------------------------------------------------------------
/imbens/ensemble/_over_sampling/tests/test_over_boost.py:
--------------------------------------------------------------------------------
1 | """Test OverBoostClassifier."""
2 |
3 | # Authors: Zhining Liu
4 | # License: MIT
5 |
6 | import numpy as np
7 | import pytest
8 | import sklearn
9 | from sklearn.datasets import load_iris, make_classification
10 | from sklearn.model_selection import train_test_split
11 | from sklearn.utils._testing import assert_array_equal
12 | from sklearn.utils.fixes import parse_version
13 |
14 | from imbens.ensemble import OverBoostClassifier
15 |
16 | sklearn_version = parse_version(sklearn.__version__)
17 |
18 |
19 | @pytest.fixture
20 | def imbalanced_dataset():
21 | return make_classification(
22 | n_samples=10000,
23 | n_features=3,
24 | n_informative=2,
25 | n_redundant=0,
26 | n_repeated=0,
27 | n_classes=3,
28 | n_clusters_per_class=1,
29 | weights=[0.01, 0.05, 0.94],
30 | class_sep=0.8,
31 | random_state=0,
32 | )
33 |
34 |
35 | @pytest.mark.parametrize("algorithm", ["SAMME"])
36 | def test_overboost(imbalanced_dataset, algorithm):
37 | X, y = imbalanced_dataset
38 | X_train, X_test, y_train, y_test = train_test_split(
39 | X, y, stratify=y, random_state=1
40 | )
41 | classes = np.unique(y)
42 |
43 | n_estimators = 100
44 | overboost = OverBoostClassifier(
45 | n_estimators=n_estimators, algorithm=algorithm, random_state=0
46 | )
47 | overboost.fit(X_train, y_train)
48 | assert_array_equal(classes, overboost.classes_)
49 |
50 | # check that we have an ensemble of samplers and estimators with a
51 | # consistent size
52 | assert len(overboost.estimators_) > 1
53 | assert len(overboost.estimators_) == len(overboost.samplers_)
54 |
55 | # each sampler in the ensemble should have different random state
56 | assert len({sampler.random_state for sampler in overboost.samplers_}) == len(
57 | overboost.samplers_
58 | )
59 | # each estimator in the ensemble should have different random state
60 | assert len({est.random_state for est in overboost.estimators_}) == len(
61 | overboost.estimators_
62 | )
63 |
64 | # check the consistency of the feature importances
65 | assert len(overboost.feature_importances_) == imbalanced_dataset[0].shape[1]
66 |
67 | # check the consistency of the prediction outpus
68 | y_pred = overboost.predict_proba(X_test)
69 | assert y_pred.shape[1] == len(classes)
70 | assert overboost.decision_function(X_test).shape[1] == len(classes)
71 |
72 | score = overboost.score(X_test, y_test)
73 | assert score > 0.6, f"Failed with algorithm {algorithm} and score {score}"
74 |
75 | y_pred = overboost.predict(X_test)
76 | assert y_pred.shape == y_test.shape
77 |
78 |
79 | @pytest.mark.parametrize("algorithm", ["SAMME"])
80 | def test_overboost_sample_weight(imbalanced_dataset, algorithm):
81 | X, y = imbalanced_dataset
82 | sample_weight = np.ones_like(y)
83 | overboost = OverBoostClassifier(algorithm=algorithm, random_state=0)
84 |
85 | # Predictions should be the same when sample_weight are all ones
86 | y_pred_sample_weight = overboost.fit(X, y, sample_weight=sample_weight).predict(X)
87 | y_pred_no_sample_weight = overboost.fit(X, y).predict(X)
88 |
89 | assert_array_equal(y_pred_sample_weight, y_pred_no_sample_weight)
90 |
91 | rng = np.random.RandomState(42)
92 | sample_weight = rng.rand(y.shape[0])
93 | y_pred_sample_weight = overboost.fit(X, y, sample_weight=sample_weight).predict(X)
94 |
95 | with pytest.raises(AssertionError):
96 | assert_array_equal(y_pred_no_sample_weight, y_pred_sample_weight)
97 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 | """Toolbox for ensemble learning on class-imbalanced dataset."""
3 |
4 | # import codecs
5 |
6 | import io
7 | import os
8 |
9 | from setuptools import find_packages, setup, Command
10 |
11 | # get __version__ from _version.py
12 | ver_file = os.path.join("imbens", "_version.py")
13 | with open(ver_file) as f:
14 | exec(f.read())
15 |
16 | with open("requirements.txt") as f:
17 | requirements = f.read().splitlines()
18 |
19 | DISTNAME = "imbalanced-ensemble"
20 | DESCRIPTION = "Toolbox for ensemble learning on class-imbalanced dataset."
21 |
22 | # with codecs.open("README.rst", encoding="utf-8-sig") as f:
23 | # LONG_DESCRIPTION = f.read()
24 |
25 | # Import the README and use it as the long-description.
26 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
27 | here = os.path.abspath(os.path.dirname(__file__))
28 | try:
29 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
30 | LONG_DESCRIPTION = "\n" + f.read()
31 | except FileNotFoundError:
32 | LONG_DESCRIPTION = DESCRIPTION
33 |
34 | AUTHOR = "Zhining Liu"
35 | AUTHOR_EMAIL = "zhining.liu@outlook.com"
36 | MAINTAINER = "Zhining Liu"
37 | MAINTAINER_EMAIL = "zhining.liu@outlook.com"
38 | URL = "https://github.com/ZhiningLiu1998/imbalanced-ensemble"
39 | PROJECT_URLS = {
40 | "Documentation": "https://imbalanced-ensemble.readthedocs.io/",
41 | "Source": "https://github.com/ZhiningLiu1998/imbalanced-ensemble",
42 | "Tracker": "https://github.com/ZhiningLiu1998/imbalanced-ensemble/issues",
43 | "Changelog": "https://imbalanced-ensemble.readthedocs.io/en/latest/release_history.html",
44 | "Download": "https://pypi.org/project/imbalanced-ensemble/#files",
45 | }
46 | LICENSE = "MIT"
47 | VERSION = __version__
48 | CLASSIFIERS = [
49 | "Intended Audience :: Science/Research",
50 | "Intended Audience :: Developers",
51 | "License :: OSI Approved :: MIT License",
52 | "Programming Language :: C",
53 | "Programming Language :: Python",
54 | "Topic :: Software Development",
55 | "Topic :: Scientific/Engineering",
56 | "Operating System :: Microsoft :: Windows",
57 | "Operating System :: POSIX",
58 | "Operating System :: Unix",
59 | "Operating System :: MacOS",
60 | "Programming Language :: Python :: 3.9",
61 | "Programming Language :: Python :: 3.10",
62 | "Programming Language :: Python :: 3.11",
63 | "Programming Language :: Python :: 3 :: Only",
64 | ]
65 | INSTALL_REQUIRES = requirements
66 | EXTRAS_REQUIRE = {
67 | "dev": [
68 | "black",
69 | "flake8",
70 | ],
71 | "test": [
72 | "pytest",
73 | "pytest-cov",
74 | ],
75 | "doc": [
76 | "sphinx",
77 | "sphinx-gallery",
78 | "sphinx_rtd_theme",
79 | "pydata-sphinx-theme",
80 | "numpydoc",
81 | "sphinxcontrib-bibtex",
82 | "torch",
83 | "pytest",
84 | ],
85 | }
86 |
87 | setup(
88 | name=DISTNAME,
89 | author=AUTHOR,
90 | author_email=AUTHOR_EMAIL,
91 | maintainer=MAINTAINER,
92 | maintainer_email=MAINTAINER_EMAIL,
93 | description=DESCRIPTION,
94 | license=LICENSE,
95 | url=URL,
96 | version=VERSION,
97 | project_urls=PROJECT_URLS,
98 | long_description=LONG_DESCRIPTION,
99 | long_description_content_type="text/markdown",
100 | zip_safe=False, # the package can run out of an .egg file
101 | classifiers=CLASSIFIERS,
102 | packages=find_packages(),
103 | install_requires=INSTALL_REQUIRES,
104 | extras_require=EXTRAS_REQUIRE,
105 | )
106 |
--------------------------------------------------------------------------------
/docs/source/sphinxext/github_link.py:
--------------------------------------------------------------------------------
1 | # %%
2 |
3 | from operator import attrgetter
4 | import inspect
5 | import subprocess
6 | import os
7 | import sys
8 | from functools import partial
9 |
10 | REVISION_CMD = "git rev-parse --short HEAD"
11 |
12 | # %%
13 |
14 |
15 | def _get_git_revision():
16 | try:
17 | revision = subprocess.check_output(REVISION_CMD.split()).strip()
18 | except (subprocess.CalledProcessError, OSError):
19 | print("Failed to execute git to get revision")
20 | return None
21 | return revision.decode("utf-8")
22 |
23 |
24 | def _linkcode_resolve(domain, info, package, url_fmt, revision):
25 | """Determine a link to online source for a class/method/function
26 |
27 | This is called by sphinx.ext.linkcode
28 |
29 | An example with a long-untouched module that everyone has
30 | >>> _linkcode_resolve('py', {'module': 'tty',
31 | ... 'fullname': 'setraw'},
32 | ... package='tty',
33 | ... url_fmt='http://hg.python.org/cpython/file/'
34 | ... '{revision}/Lib/{package}/{path}#L{lineno}',
35 | ... revision='xxxx')
36 | 'http://hg.python.org/cpython/file/xxxx/Lib/tty/tty.py#L18'
37 | """
38 |
39 | if revision is None:
40 | return
41 | if domain not in ("py", "pyx"):
42 | return
43 | if not info.get("module") or not info.get("fullname"):
44 | return
45 |
46 | class_name = info["fullname"].split(".")[0]
47 | if type(class_name) != str:
48 | # Python 2 only
49 | class_name = class_name.encode("utf-8")
50 | module = __import__(info["module"], fromlist=[class_name])
51 | obj = attrgetter(info["fullname"])(module)
52 |
53 | # try:
54 | # fn = inspect.getsourcefile(obj)
55 | # except Exception:
56 | # fn = None
57 | # if not fn:
58 | # try:
59 | # fn = inspect.getsourcefile(sys.modules[obj.__module__])
60 | # except Exception:
61 | # fn = None
62 | try:
63 | fn = inspect.getsourcefile(sys.modules[obj.__module__])
64 | except Exception:
65 | fn = None
66 | if not fn:
67 | return
68 |
69 | fn = os.path.relpath(fn, start=os.path.dirname(__import__(package).__file__))
70 | try:
71 | src_code_lines, lineno = inspect.getsourcelines(obj)
72 | i = 0
73 | for l in src_code_lines:
74 | if 'def' in l or 'class' in l:
75 | break
76 | i += 1
77 | lineno += i
78 | except Exception:
79 | lineno = ""
80 | return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno)
81 |
82 |
83 | def make_linkcode_resolve(package, url_fmt):
84 | """Returns a linkcode_resolve function for the given URL format
85 |
86 | revision is a git commit reference (hash or name)
87 |
88 | package is the name of the root module of the package
89 |
90 | url_fmt is along the lines of ('https://github.com/USER/PROJECT/'
91 | 'blob/{revision}/{package}/'
92 | '{path}#L{lineno}')
93 | """
94 | revision = _get_git_revision()
95 | return partial(
96 | _linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt
97 | )
98 |
99 |
100 | # %%
101 |
102 | linkcode_resolve = make_linkcode_resolve(
103 | "imbens",
104 | "https://github.com/ZhiningLiu1998/"
105 | "imbalanced-ensemble/blob/{revision}/"
106 | "{package}/{path}#L{lineno}",
107 | )
108 |
109 | linkcode_resolve
110 |
--------------------------------------------------------------------------------
/imbens/ensemble/_under_sampling/tests/test_rus_boost.py:
--------------------------------------------------------------------------------
1 | """Test RUSBoostClassifier."""
2 |
3 | # Authors: Guillaume Lemaitre
4 | # Christos Aridas
5 | # Zhining Liu
6 | # License: MIT
7 |
8 | import numpy as np
9 | import pytest
10 | import sklearn
11 | from sklearn.datasets import load_iris, make_classification
12 | from sklearn.model_selection import train_test_split
13 | from sklearn.utils._testing import assert_array_equal
14 | from sklearn.utils.fixes import parse_version
15 |
16 | from imbens.ensemble import RUSBoostClassifier
17 |
18 | sklearn_version = parse_version(sklearn.__version__)
19 |
20 |
21 | @pytest.fixture
22 | def imbalanced_dataset():
23 | return make_classification(
24 | n_samples=10000,
25 | n_features=3,
26 | n_informative=2,
27 | n_redundant=0,
28 | n_repeated=0,
29 | n_classes=3,
30 | n_clusters_per_class=1,
31 | weights=[0.01, 0.05, 0.94],
32 | class_sep=0.8,
33 | random_state=0,
34 | )
35 |
36 |
37 | @pytest.mark.parametrize("algorithm", ["SAMME"])
38 | def test_rusboost(imbalanced_dataset, algorithm):
39 | X, y = imbalanced_dataset
40 | X_train, X_test, y_train, y_test = train_test_split(
41 | X, y, stratify=y, random_state=1
42 | )
43 | classes = np.unique(y)
44 |
45 | n_estimators = 500
46 | rusboost = RUSBoostClassifier(
47 | n_estimators=n_estimators, algorithm=algorithm, random_state=0
48 | )
49 | rusboost.fit(X_train, y_train)
50 | assert_array_equal(classes, rusboost.classes_)
51 |
52 | # check that we have an ensemble of samplers and estimators with a
53 | # consistent size
54 | assert len(rusboost.estimators_) > 1
55 | assert len(rusboost.estimators_) == len(rusboost.samplers_)
56 |
57 | # each sampler in the ensemble should have different random state
58 | assert len({sampler.random_state for sampler in rusboost.samplers_}) == len(
59 | rusboost.samplers_
60 | )
61 | # each estimator in the ensemble should have different random state
62 | assert len({est.random_state for est in rusboost.estimators_}) == len(
63 | rusboost.estimators_
64 | )
65 |
66 | # check the consistency of the feature importances
67 | assert len(rusboost.feature_importances_) == imbalanced_dataset[0].shape[1]
68 |
69 | # check the consistency of the prediction outpus
70 | y_pred = rusboost.predict_proba(X_test)
71 | assert y_pred.shape[1] == len(classes)
72 | assert rusboost.decision_function(X_test).shape[1] == len(classes)
73 |
74 | score = rusboost.score(X_test, y_test)
75 | assert score > 0.6, f"Failed with algorithm {algorithm} and score {score}"
76 |
77 | y_pred = rusboost.predict(X_test)
78 | assert y_pred.shape == y_test.shape
79 |
80 |
81 | @pytest.mark.parametrize("algorithm", ["SAMME"])
82 | def test_rusboost_sample_weight(imbalanced_dataset, algorithm):
83 | X, y = imbalanced_dataset
84 | sample_weight = np.ones_like(y)
85 | rusboost = RUSBoostClassifier(algorithm=algorithm, random_state=0)
86 |
87 | # Predictions should be the same when sample_weight are all ones
88 | y_pred_sample_weight = rusboost.fit(X, y, sample_weight=sample_weight).predict(X)
89 | y_pred_no_sample_weight = rusboost.fit(X, y).predict(X)
90 |
91 | assert_array_equal(y_pred_sample_weight, y_pred_no_sample_weight)
92 |
93 | rng = np.random.RandomState(42)
94 | sample_weight = rng.rand(y.shape[0])
95 | y_pred_sample_weight = rusboost.fit(X, y, sample_weight=sample_weight).predict(X)
96 |
97 | with pytest.raises(AssertionError):
98 | assert_array_equal(y_pred_no_sample_weight, y_pred_sample_weight)
99 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_selection/tests/test_one_sided_selection.py:
--------------------------------------------------------------------------------
1 | """Test the module one-sided selection."""
2 | # Authors: Guillaume Lemaitre
3 | # Christos Aridas
4 | # License: MIT
5 |
6 | import numpy as np
7 | import pytest
8 | from sklearn.neighbors import KNeighborsClassifier
9 | from sklearn.utils._testing import assert_array_equal
10 |
11 | from imbens.sampler._under_sampling import OneSidedSelection
12 |
13 | RND_SEED = 0
14 | X = np.array(
15 | [
16 | [-0.3879569, 0.6894251],
17 | [-0.09322739, 1.28177189],
18 | [-0.77740357, 0.74097941],
19 | [0.91542919, -0.65453327],
20 | [-0.03852113, 0.40910479],
21 | [-0.43877303, 1.07366684],
22 | [-0.85795321, 0.82980738],
23 | [-0.18430329, 0.52328473],
24 | [-0.30126957, -0.66268378],
25 | [-0.65571327, 0.42412021],
26 | [-0.28305528, 0.30284991],
27 | [0.20246714, -0.34727125],
28 | [1.06446472, -1.09279772],
29 | [0.30543283, -0.02589502],
30 | [-0.00717161, 0.00318087],
31 | ]
32 | )
33 | Y = np.array([0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0])
34 |
35 |
36 | def test_oss_init():
37 | oss = OneSidedSelection(random_state=RND_SEED)
38 |
39 | assert oss.n_seeds_S == 1
40 | assert oss.n_jobs is None
41 | assert oss.random_state == RND_SEED
42 |
43 |
44 | def test_oss_fit_resample():
45 | oss = OneSidedSelection(random_state=RND_SEED)
46 | X_resampled, y_resampled = oss.fit_resample(X, Y)
47 |
48 | X_gt = np.array(
49 | [
50 | [-0.3879569, 0.6894251],
51 | [0.91542919, -0.65453327],
52 | [-0.65571327, 0.42412021],
53 | [1.06446472, -1.09279772],
54 | [0.30543283, -0.02589502],
55 | [-0.00717161, 0.00318087],
56 | [-0.09322739, 1.28177189],
57 | [-0.77740357, 0.74097941],
58 | [-0.43877303, 1.07366684],
59 | [-0.85795321, 0.82980738],
60 | [-0.30126957, -0.66268378],
61 | [0.20246714, -0.34727125],
62 | ]
63 | )
64 | y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
65 | assert_array_equal(X_resampled, X_gt)
66 | assert_array_equal(y_resampled, y_gt)
67 |
68 |
69 | def test_oss_with_object():
70 | knn = KNeighborsClassifier(n_neighbors=1)
71 | oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn)
72 | X_resampled, y_resampled = oss.fit_resample(X, Y)
73 |
74 | X_gt = np.array(
75 | [
76 | [-0.3879569, 0.6894251],
77 | [0.91542919, -0.65453327],
78 | [-0.65571327, 0.42412021],
79 | [1.06446472, -1.09279772],
80 | [0.30543283, -0.02589502],
81 | [-0.00717161, 0.00318087],
82 | [-0.09322739, 1.28177189],
83 | [-0.77740357, 0.74097941],
84 | [-0.43877303, 1.07366684],
85 | [-0.85795321, 0.82980738],
86 | [-0.30126957, -0.66268378],
87 | [0.20246714, -0.34727125],
88 | ]
89 | )
90 | y_gt = np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
91 | assert_array_equal(X_resampled, X_gt)
92 | assert_array_equal(y_resampled, y_gt)
93 | knn = 1
94 | oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn)
95 | X_resampled, y_resampled = oss.fit_resample(X, Y)
96 | assert_array_equal(X_resampled, X_gt)
97 | assert_array_equal(y_resampled, y_gt)
98 |
99 |
100 | def test_oss_with_wrong_object():
101 | knn = "rnd"
102 | oss = OneSidedSelection(random_state=RND_SEED, n_neighbors=knn)
103 | with pytest.raises(ValueError, match="has to be a int"):
104 | oss.fit_resample(X, Y)
105 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/_prototype_selection/tests/test_condensed_nearest_neighbour.py:
--------------------------------------------------------------------------------
1 | """Test the module condensed nearest neighbour."""
2 | # Authors: Guillaume Lemaitre
3 | # Christos Aridas
4 | # License: MIT
5 |
6 | import numpy as np
7 | import pytest
8 | from sklearn.neighbors import KNeighborsClassifier
9 | from sklearn.utils._testing import assert_array_equal
10 |
11 | from imbens.sampler._under_sampling import CondensedNearestNeighbour
12 |
13 | RND_SEED = 0
14 | X = np.array(
15 | [
16 | [2.59928271, 0.93323465],
17 | [0.25738379, 0.95564169],
18 | [1.42772181, 0.526027],
19 | [1.92365863, 0.82718767],
20 | [-0.10903849, -0.12085181],
21 | [-0.284881, -0.62730973],
22 | [0.57062627, 1.19528323],
23 | [0.03394306, 0.03986753],
24 | [0.78318102, 2.59153329],
25 | [0.35831463, 1.33483198],
26 | [-0.14313184, -1.0412815],
27 | [0.01936241, 0.17799828],
28 | [-1.25020462, -0.40402054],
29 | [-0.09816301, -0.74662486],
30 | [-0.01252787, 0.34102657],
31 | [0.52726792, -0.38735648],
32 | [0.2821046, -0.07862747],
33 | [0.05230552, 0.09043907],
34 | [0.15198585, 0.12512646],
35 | [0.70524765, 0.39816382],
36 | ]
37 | )
38 | Y = np.array([1, 2, 1, 1, 0, 2, 2, 2, 2, 2, 2, 0, 1, 2, 2, 2, 2, 1, 2, 1])
39 |
40 |
41 | def test_cnn_init():
42 | cnn = CondensedNearestNeighbour(random_state=RND_SEED)
43 |
44 | assert cnn.n_seeds_S == 1
45 | assert cnn.n_jobs is None
46 |
47 |
48 | def test_cnn_fit_resample():
49 | cnn = CondensedNearestNeighbour(random_state=RND_SEED)
50 | X_resampled, y_resampled = cnn.fit_resample(X, Y)
51 |
52 | X_gt = np.array(
53 | [
54 | [-0.10903849, -0.12085181],
55 | [0.01936241, 0.17799828],
56 | [0.05230552, 0.09043907],
57 | [-1.25020462, -0.40402054],
58 | [0.70524765, 0.39816382],
59 | [0.35831463, 1.33483198],
60 | [-0.284881, -0.62730973],
61 | [0.03394306, 0.03986753],
62 | [-0.01252787, 0.34102657],
63 | [0.15198585, 0.12512646],
64 | ]
65 | )
66 | y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
67 | assert_array_equal(X_resampled, X_gt)
68 | assert_array_equal(y_resampled, y_gt)
69 |
70 |
71 | def test_cnn_fit_resample_with_object():
72 | knn = KNeighborsClassifier(n_neighbors=1)
73 | cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=knn)
74 | X_resampled, y_resampled = cnn.fit_resample(X, Y)
75 |
76 | X_gt = np.array(
77 | [
78 | [-0.10903849, -0.12085181],
79 | [0.01936241, 0.17799828],
80 | [0.05230552, 0.09043907],
81 | [-1.25020462, -0.40402054],
82 | [0.70524765, 0.39816382],
83 | [0.35831463, 1.33483198],
84 | [-0.284881, -0.62730973],
85 | [0.03394306, 0.03986753],
86 | [-0.01252787, 0.34102657],
87 | [0.15198585, 0.12512646],
88 | ]
89 | )
90 | y_gt = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
91 | assert_array_equal(X_resampled, X_gt)
92 | assert_array_equal(y_resampled, y_gt)
93 |
94 | cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=1)
95 | X_resampled, y_resampled = cnn.fit_resample(X, Y)
96 | assert_array_equal(X_resampled, X_gt)
97 | assert_array_equal(y_resampled, y_gt)
98 |
99 |
100 | def test_cnn_fit_resample_with_wrong_object():
101 | knn = "rnd"
102 | cnn = CondensedNearestNeighbour(random_state=RND_SEED, n_neighbors=knn)
103 | with pytest.raises(ValueError, match="has to be a int or an "):
104 | cnn.fit_resample(X, Y)
105 |
--------------------------------------------------------------------------------
/imbens/sampler/_under_sampling/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base class for the under-sampling method.
3 | """
4 | # Adapted from imbalanced-learn
5 |
6 | # Authors: Guillaume Lemaitre
7 | # License: MIT
8 |
9 | # %%
10 | LOCAL_DEBUG = False
11 |
12 | if not LOCAL_DEBUG:
13 | from ..base import BaseSampler
14 | else: # pragma: no cover
15 | import sys # For local test
16 |
17 | sys.path.append("../..")
18 | from sampler.base import BaseSampler
19 |
20 |
21 | class BaseUnderSampler(BaseSampler):
22 | """Base class for under-sampling algorithms.
23 |
24 | Warning: This class should not be used directly. Use the derive classes
25 | instead.
26 | """
27 |
28 | _sampling_type = "under-sampling"
29 |
30 | _sampling_strategy_docstring = """sampling_strategy : float, str, dict, callable, default='auto'
31 | Sampling information to sample the data set.
32 |
33 | - When ``float``, it corresponds to the desired ratio of the number of
34 | samples in the minority class over the number of samples in the
35 | majority class after resampling. Therefore, the ratio is expressed as
36 | :math:`\\alpha_{us} = N_{m} / N_{rM}` where :math:`N_{m}` is the
37 | number of samples in the minority class and
38 | :math:`N_{rM}` is the number of samples in the majority class
39 | after resampling.
40 |
41 | .. warning::
42 | ``float`` is only available for **binary** classification. An
43 | error is raised for multi-class classification.
44 |
45 | - When ``str``, specify the class targeted by the resampling. The
46 | number of samples in the different classes will be equalized.
47 | Possible choices are:
48 |
49 | ``'majority'``: resample only the majority class;
50 |
51 | ``'not minority'``: resample all classes but the minority class;
52 |
53 | ``'not majority'``: resample all classes but the majority class;
54 |
55 | ``'all'``: resample all classes;
56 |
57 | ``'auto'``: equivalent to ``'not minority'``.
58 |
59 | - When ``dict``, the keys correspond to the targeted classes. The
60 | values correspond to the desired number of samples for each targeted
61 | class.
62 |
63 | - When callable, function taking ``y`` and returns a ``dict``. The keys
64 | correspond to the targeted classes. The values correspond to the
65 | desired number of samples for each class.
66 | """.rstrip()
67 |
68 |
69 | class BaseCleaningSampler(BaseSampler):
70 | """Base class for under-sampling algorithms.
71 |
72 | Warning: This class should not be used directly. Use the derive classes
73 | instead.
74 | """
75 |
76 | _sampling_type = "clean-sampling"
77 |
78 | _sampling_strategy_docstring = """sampling_strategy : str, list or callable
79 | Sampling information to sample the data set.
80 |
81 | - When ``str``, specify the class targeted by the resampling. Note the
82 | the number of samples will not be equal in each. Possible choices
83 | are:
84 |
85 | ``'majority'``: resample only the majority class;
86 |
87 | ``'not minority'``: resample all classes but the minority class;
88 |
89 | ``'not majority'``: resample all classes but the majority class;
90 |
91 | ``'all'``: resample all classes;
92 |
93 | ``'auto'``: equivalent to ``'not minority'``.
94 |
95 | - When ``list``, the list contains the classes targeted by the
96 | resampling.
97 |
98 | - When callable, function taking ``y`` and returns a ``dict``. The keys
99 | correspond to the targeted classes. The values correspond to the
100 | desired number of samples for each class.
101 | """.rstrip()
102 |
103 |
104 | # %%
105 |
--------------------------------------------------------------------------------
/examples/datasets/plot_make_imbalance_digits.py:
--------------------------------------------------------------------------------
1 | """
2 | =========================================================
3 | Make digits dataset class-imbalanced
4 | =========================================================
5 |
6 | An illustration of the :func:`~imbens.datasets.make_imbalance`
7 | function to create an imbalanced version of the digits dataset.
8 | """
9 |
10 | # Authors: Zhining Liu
11 | # License: MIT
12 |
13 | # %%
14 | print(__doc__)
15 |
16 | # Import imbalanced-ensemble
17 | import imbens
18 |
19 | # Import utilities
20 | import sklearn
21 | from imbens.datasets import make_imbalance
22 | from imbens.utils._plot import plot_2Dprojection_and_cardinality, plot_scatter
23 | import matplotlib.pyplot as plt
24 | import seaborn as sns
25 |
26 | RANDOM_STATE = 42
27 |
28 | # sphinx_gallery_thumbnail_number = -1
29 |
30 | # %% [markdown]
31 | # Digits dataset
32 | # --------------
33 | # The digits dataset consists of 8x8 pixel images of digits. The images attribute of the dataset stores 8x8 arrays of grayscale values for each image. We will use these arrays to visualize the first 4 images. The target attribute of the dataset stores the digit each image represents and this is included in the title of the 4 plots below.
34 |
35 | digits = sklearn.datasets.load_digits()
36 |
37 | # flatten the images
38 | n_samples = len(digits.images)
39 | X, y = digits.images.reshape((n_samples, -1)), digits.target
40 |
41 | _, axes = plt.subplots(nrows=3, ncols=4, figsize=(10, 9))
42 | for ax, image, label in zip(axes.flatten(), digits.images, digits.target):
43 | ax.set_axis_off()
44 | ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
45 | ax.set_title('Training: %i' % label)
46 |
47 |
48 | # %% [markdown]
49 | # **The original digits dataset**
50 |
51 | fig = plot_2Dprojection_and_cardinality(X, y, figsize=(8, 4))
52 |
53 |
54 | # %% [markdown]
55 | # **Make class-imbalanced digits dataset**
56 |
57 | imbalance_distr = {
58 | 0: 178,
59 | 1: 120,
60 | 2: 80,
61 | 3: 60,
62 | 4: 50,
63 | 5: 44,
64 | 6: 40,
65 | 7: 40,
66 | 8: 40,
67 | 9: 40,
68 | }
69 |
70 | X_imb, y_imb = make_imbalance(
71 | X, y, sampling_strategy=imbalance_distr, random_state=RANDOM_STATE
72 | )
73 |
74 | fig = plot_2Dprojection_and_cardinality(X_imb, y_imb, figsize=(8, 4))
75 |
76 |
77 | # %% [markdown]
78 | # Use TSNE to compare the original & imbalanced Digits datasets
79 | # -------------------------------------------------------------
80 | # We can observe that it is more difficult to distinguish the tail classes from each other in the imbalanced Digits dataset.
81 | # These tailed classes are not well represented, thus it is harder for a learning model to learn their patterns.
82 |
83 | sns.set_context('talk')
84 |
85 | tsne = sklearn.manifold.TSNE(
86 | n_components=2, perplexity=100, n_iter=500, random_state=RANDOM_STATE
87 | )
88 |
89 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
90 |
91 | # Plot original digits data
92 | plot_scatter(
93 | tsne.fit_transform(X),
94 | y,
95 | title='Original Digits Data',
96 | weights=100,
97 | vis_params={'edgecolor': 'black', 'alpha': 0.8},
98 | ax=ax1,
99 | )
100 | ax1.legend(
101 | ncol=2,
102 | loc=2,
103 | columnspacing=0.01,
104 | borderaxespad=0.1,
105 | handletextpad=0.01,
106 | labelspacing=0.01,
107 | handlelength=None,
108 | )
109 |
110 | # Plot imbalanced digits data
111 | plot_scatter(
112 | tsne.fit_transform(X_imb),
113 | y_imb,
114 | title='Imbalanced Digits Data',
115 | weights=100,
116 | vis_params={'edgecolor': 'black', 'alpha': 0.8},
117 | ax=ax2,
118 | )
119 | ax2.legend(
120 | ncol=2,
121 | loc=2,
122 | columnspacing=0.01,
123 | borderaxespad=0.1,
124 | handletextpad=0.01,
125 | labelspacing=0.01,
126 | handlelength=None,
127 | )
128 |
129 | fig.tight_layout()
130 |
--------------------------------------------------------------------------------
/examples/basic/plot_training_log.py:
--------------------------------------------------------------------------------
1 | """
2 | =========================================================
3 | Customize ensemble training log
4 | =========================================================
5 |
6 | This example illustrates how to enable and customize the training
7 | log when training an :mod:`imbens.ensemble` classifier.
8 |
9 | This example uses:
10 |
11 | - :class:`imbens.ensemble.SelfPacedEnsembleClassifier`
12 | """
13 |
14 | # Authors: Zhining Liu
15 | # License: MIT
16 |
17 |
18 | # %%
19 | print(__doc__)
20 |
21 | # Import imbalanced-ensemble
22 | import imbens
23 |
24 | # Import utilities
25 | import sklearn
26 | from sklearn.datasets import make_classification
27 | from sklearn.model_selection import train_test_split
28 |
29 | RANDOM_STATE = 42
30 |
31 | # sphinx_gallery_thumbnail_path = '../../docs/source/_static/training_log_thumbnail.png'
32 |
33 | # %% [markdown]
34 | # Prepare data
35 | # ----------------------------
36 | # Make a toy 3-class imbalanced classification task.
37 |
38 | # make dataset
39 | X, y = make_classification(
40 | n_classes=3,
41 | class_sep=2,
42 | weights=[0.1, 0.3, 0.6],
43 | n_informative=3,
44 | n_redundant=1,
45 | flip_y=0,
46 | n_features=20,
47 | n_clusters_per_class=2,
48 | n_samples=2000,
49 | random_state=0,
50 | )
51 |
52 | # train valid split
53 | X_train, X_valid, y_train, y_valid = train_test_split(
54 | X, y, test_size=0.5, stratify=y, random_state=RANDOM_STATE
55 | )
56 |
57 | # %% [markdown]
58 | # Customize training log
59 | # ---------------------------------------------------------------------------
60 | # Take ``SelfPacedEnsembleClassifier`` as example, training log is controlled by 3 parameters of the ``fit()`` method:
61 | #
62 | # - ``eval_datasets``: Dataset(s) used for evaluation during the ensemble training.
63 | # - ``eval_metrics``: Metric(s) used for evaluation during the ensemble training.
64 | # - ``train_verbose``: Controls the granularity and content of the training log.
65 |
66 | clf = imbens.ensemble.SelfPacedEnsembleClassifier(random_state=RANDOM_STATE)
67 |
68 | # %% [markdown]
69 | # Set training log format
70 | # -----------------------
71 | # (``fit()`` parameter: ``train_verbose``: bool, int or dict)
72 |
73 | # %% [markdown]
74 | # **Enable auto training log**
75 |
76 | clf.fit(
77 | X_train,
78 | y_train,
79 | train_verbose=True,
80 | )
81 |
82 |
83 | # %% [markdown]
84 | # **Customize training log granularity**
85 |
86 | clf.fit(
87 | X_train,
88 | y_train,
89 | train_verbose={
90 | 'granularity': 10,
91 | },
92 | )
93 |
94 |
95 | # %% [markdown]
96 | # **Customize training log content column**
97 |
98 | clf.fit(
99 | X_train,
100 | y_train,
101 | train_verbose={
102 | 'granularity': 10,
103 | 'print_distribution': False,
104 | 'print_metrics': True,
105 | },
106 | )
107 |
108 |
109 | # %% [markdown]
110 | # Add additional evaluation dataset(s)
111 | # ------------------------------------
112 | # (``fit()`` parameter: ``eval_datasets``: dict)
113 |
114 | clf.fit(
115 | X_train,
116 | y_train,
117 | eval_datasets={
118 | 'valid': (X_valid, y_valid), # add validation data
119 | },
120 | train_verbose={
121 | 'granularity': 10,
122 | },
123 | )
124 |
125 |
126 | # %% [markdown]
127 | # Specify evaluation metric(s)
128 | # ----------------------------
129 | # (``fit()`` parameter: ``eval_metrics``: dict)
130 |
131 | clf.fit(
132 | X_train,
133 | y_train,
134 | eval_datasets={
135 | 'valid': (X_valid, y_valid),
136 | },
137 | eval_metrics={
138 | 'weighted_f1': (
139 | sklearn.metrics.f1_score,
140 | {'average': 'weighted'},
141 | ), # use weighted_f1
142 | },
143 | train_verbose={
144 | 'granularity': 10,
145 | },
146 | )
147 |
148 | # %%
149 |
--------------------------------------------------------------------------------
/docs/source/sg_execution_times.rst:
--------------------------------------------------------------------------------
1 |
2 | :orphan:
3 |
4 | .. _sphx_glr_sg_execution_times:
5 |
6 |
7 | Computation times
8 | =================
9 | **00:02.812** total execution time for 18 files **from all galleries**:
10 |
11 | .. container::
12 |
13 | .. raw:: html
14 |
15 |
19 |
20 |
21 |
22 |
27 |
28 | .. list-table::
29 | :header-rows: 1
30 | :class: table table-striped sg-datatable
31 |
32 | * - Example
33 | - Time
34 | - Mem (MB)
35 | * - :ref:`sphx_glr_auto_examples_classification_plot_torch.py` (``..\..\examples\classification\plot_torch.py``)
36 | - 00:02.812
37 | - 0.0
38 | * - :ref:`sphx_glr_auto_examples_basic_plot_basic_example.py` (``..\..\examples\basic\plot_basic_example.py``)
39 | - 00:00.000
40 | - 0.0
41 | * - :ref:`sphx_glr_auto_examples_basic_plot_basic_visualize.py` (``..\..\examples\basic\plot_basic_visualize.py``)
42 | - 00:00.000
43 | - 0.0
44 | * - :ref:`sphx_glr_auto_examples_basic_plot_training_log.py` (``..\..\examples\basic\plot_training_log.py``)
45 | - 00:00.000
46 | - 0.0
47 | * - :ref:`sphx_glr_auto_examples_classification_plot_classifier_comparison.py` (``..\..\examples\classification\plot_classifier_comparison.py``)
48 | - 00:00.000
49 | - 0.0
50 | * - :ref:`sphx_glr_auto_examples_classification_plot_cost_matrix.py` (``..\..\examples\classification\plot_cost_matrix.py``)
51 | - 00:00.000
52 | - 0.0
53 | * - :ref:`sphx_glr_auto_examples_classification_plot_digits.py` (``..\..\examples\classification\plot_digits.py``)
54 | - 00:00.000
55 | - 0.0
56 | * - :ref:`sphx_glr_auto_examples_classification_plot_probability.py` (``..\..\examples\classification\plot_probability.py``)
57 | - 00:00.000
58 | - 0.0
59 | * - :ref:`sphx_glr_auto_examples_classification_plot_resampling_target.py` (``..\..\examples\classification\plot_resampling_target.py``)
60 | - 00:00.000
61 | - 0.0
62 | * - :ref:`sphx_glr_auto_examples_classification_plot_sampling_schedule.py` (``..\..\examples\classification\plot_sampling_schedule.py``)
63 | - 00:00.000
64 | - 0.0
65 | * - :ref:`sphx_glr_auto_examples_datasets_plot_generate_imbalance.py` (``..\..\examples\datasets\plot_generate_imbalance.py``)
66 | - 00:00.000
67 | - 0.0
68 | * - :ref:`sphx_glr_auto_examples_datasets_plot_make_imbalance.py` (``..\..\examples\datasets\plot_make_imbalance.py``)
69 | - 00:00.000
70 | - 0.0
71 | * - :ref:`sphx_glr_auto_examples_datasets_plot_make_imbalance_digits.py` (``..\..\examples\datasets\plot_make_imbalance_digits.py``)
72 | - 00:00.000
73 | - 0.0
74 | * - :ref:`sphx_glr_auto_examples_evaluation_plot_classification_report.py` (``..\..\examples\evaluation\plot_classification_report.py``)
75 | - 00:00.000
76 | - 0.0
77 | * - :ref:`sphx_glr_auto_examples_evaluation_plot_metrics.py` (``..\..\examples\evaluation\plot_metrics.py``)
78 | - 00:00.000
79 | - 0.0
80 | * - :ref:`sphx_glr_auto_examples_pipeline_plot_pipeline_classification.py` (``..\..\examples\pipeline\plot_pipeline_classification.py``)
81 | - 00:00.000
82 | - 0.0
83 | * - :ref:`sphx_glr_auto_examples_visualizer_plot_confusion_matrix.py` (``..\..\examples\visualizer\plot_confusion_matrix.py``)
84 | - 00:00.000
85 | - 0.0
86 | * - :ref:`sphx_glr_auto_examples_visualizer_plot_performance_curve.py` (``..\..\examples\visualizer\plot_performance_curve.py``)
87 | - 00:00.000
88 | - 0.0
89 |
--------------------------------------------------------------------------------
/examples/classification/plot_probability.py:
--------------------------------------------------------------------------------
1 | """
2 | =================================================================
3 | Plot probabilities with different base classifiers
4 | =================================================================
5 |
6 | Plot the classification probability for ensemble models with different base classifiers.
7 |
8 | We use a 3-class imbalanced dataset, and we classify it with a ``SelfPacedEnsembleClassifier`` (ensemble size = 5).
9 | We use Decision Tree, Support Vector Machine (rbf kernel), and Gaussian process classifier as the base classifier.
10 |
11 | This example uses:
12 |
13 | - :class:`imbens.ensemble.SelfPacedEnsembleClassifier`
14 | """
15 |
16 | # Adapted from sklearn
17 | # Author: Zhining Liu
18 | # Alexandre Gramfort
19 | # License: BSD 3 clause
20 |
21 | # %%
22 | print(__doc__)
23 |
24 | # Import imbalanced-ensemble
25 | import imbens
26 |
27 | # Import utilities
28 | import numpy as np
29 | from collections import Counter
30 | import sklearn
31 | from imbens.datasets import make_imbalance
32 | from imbens.ensemble.base import sort_dict_by_key
33 |
34 | RANDOM_STATE = 42
35 |
36 | # %% [markdown]
37 | # Preparation
38 | # -----------
39 | # **Make 3 imbalanced iris classification tasks.**
40 |
41 | iris = sklearn.datasets.load_iris()
42 | X = iris.data[:, 0:2] # we only take the first two features for visualization
43 | y = iris.target
44 |
45 | X, y = make_imbalance(
46 | X, y, sampling_strategy={0: 50, 1: 30, 2: 10}, random_state=RANDOM_STATE
47 | )
48 | print(
49 | 'Class distribution of imbalanced iris dataset: \n%s' % sort_dict_by_key(Counter(y))
50 | )
51 |
52 |
53 | # %% [markdown]
54 | # **Create SPE (ensemble size = 5) with different base classifiers.**
55 |
56 | from sklearn.svm import SVC
57 | from sklearn.tree import DecisionTreeClassifier
58 | from sklearn.gaussian_process import GaussianProcessClassifier
59 | from sklearn.gaussian_process.kernels import RBF
60 |
61 | classifiers = {
62 | 'SPE-DT': imbens.ensemble.SelfPacedEnsembleClassifier(
63 | n_estimators=5,
64 | estimator=DecisionTreeClassifier(),
65 | ),
66 | 'SPE-SVM-rbf': imbens.ensemble.SelfPacedEnsembleClassifier(
67 | n_estimators=5,
68 | estimator=SVC(kernel='rbf', probability=True),
69 | ),
70 | 'SPE-GPC': imbens.ensemble.SelfPacedEnsembleClassifier(
71 | n_estimators=5,
72 | estimator=GaussianProcessClassifier(1.0 * RBF([1.0, 1.0])),
73 | ),
74 | }
75 |
76 | n_classifiers = len(classifiers)
77 |
78 |
79 | # %% [markdown]
80 | # Plot classification probabilities
81 | # ---------------------------------
82 |
83 | import matplotlib.pyplot as plt
84 |
85 | n_features = X.shape[1]
86 |
87 | plt.figure(figsize=(3 * 2, n_classifiers * 2))
88 | plt.subplots_adjust(bottom=0.2, top=0.95)
89 |
90 | xx = np.linspace(3, 9, 100)
91 | yy = np.linspace(1, 5, 100).T
92 | xx, yy = np.meshgrid(xx, yy)
93 | Xfull = np.c_[xx.ravel(), yy.ravel()]
94 |
95 | for index, (name, classifier) in enumerate(classifiers.items()):
96 | classifier.fit(X, y)
97 |
98 | y_pred = classifier.predict(X)
99 | accuracy = sklearn.metrics.balanced_accuracy_score(y, y_pred)
100 | print("Balanced Accuracy (train) for %s: %0.1f%% " % (name, accuracy * 100))
101 |
102 | # View probabilities:
103 | probas = classifier.predict_proba(Xfull)
104 | n_classes = np.unique(y_pred).size
105 | for k in range(n_classes):
106 | plt.subplot(n_classifiers, n_classes, index * n_classes + k + 1)
107 | plt.title("Class %d" % k)
108 | if k == 0:
109 | plt.ylabel(name)
110 | imshow_handle = plt.imshow(
111 | probas[:, k].reshape((100, 100)), extent=(3, 9, 1, 5), origin='lower'
112 | )
113 | plt.xticks(())
114 | plt.yticks(())
115 | idx = y_pred == k
116 | if idx.any():
117 | plt.scatter(X[idx, 0], X[idx, 1], marker='o', c='w', edgecolor='k')
118 |
119 | ax = plt.axes([0.15, 0.04, 0.7, 0.05])
120 | plt.title("Probability")
121 | plt.colorbar(imshow_handle, cax=ax, orientation='horizontal')
122 | plt.show()
123 |
--------------------------------------------------------------------------------
/imbens/utils/_plot.py:
--------------------------------------------------------------------------------
1 | """Utilities for data visualization."""
2 |
3 | # Authors: Zhining Liu
4 | # License: MIT
5 |
6 | from collections import Counter
7 | from copy import copy
8 |
9 | import matplotlib.image as mpimg
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import pandas as pd
13 | import seaborn as sns
14 | from sklearn.decomposition import KernelPCA
15 |
16 | DEFAULT_VIS_KWARGS = {
17 | # 'cmap': plt.cm.rainbow,
18 | 'edgecolor': 'black',
19 | 'alpha': 0.6,
20 | }
21 |
22 |
23 | def set_ax_border(ax, border_color='black', border_width=2):
24 | '''Set border color and width'''
25 | for _, spine in ax.spines.items():
26 | spine.set_color(border_color)
27 | spine.set_linewidth(border_width)
28 |
29 | return ax
30 |
31 |
32 | def plot_scatter(
33 | X, y, ax=None, weights=None, title='', projection=None, vis_params=None
34 | ):
35 | '''Plot scatter with given projection'''
36 |
37 | if ax is None:
38 | ax = plt.axes()
39 | if projection is None:
40 | projection = KernelPCA(n_components=2).fit(X, y)
41 | if vis_params is None:
42 | vis_params = copy(DEFAULT_VIS_KWARGS)
43 |
44 | if X.shape[1] > 2:
45 | X_vis = projection.transform(X)
46 | title += ' (2D projection by {})'.format(
47 | str(projection.__class__).split('.')[-1][:-2]
48 | )
49 | else:
50 | X_vis = X
51 |
52 | size = 50 if weights is None else weights
53 | if np.unique(y).shape[0] > 2:
54 | vis_params['palette'] = plt.cm.rainbow
55 | sns.scatterplot(
56 | x=X_vis[:, 0],
57 | y=X_vis[:, 1],
58 | hue=y,
59 | style=y,
60 | s=size,
61 | **vis_params,
62 | legend='full',
63 | ax=ax
64 | )
65 |
66 | ax.set_title(title)
67 | ax = set_ax_border(ax, border_color='black', border_width=2)
68 | ax.grid(color='black', linestyle='-.', alpha=0.5)
69 |
70 | return ax
71 |
72 |
73 | def plot_class_distribution(y, ax=None, title='', sort_values=True, plot_average=True):
74 | '''Plot class distribution of a given dataset'''
75 | count = pd.DataFrame(list(Counter(y).items()), columns=['Class', 'Frequency'])
76 | if sort_values:
77 | count = count.sort_values(by='Frequency', ascending=False)
78 | if ax is None:
79 | ax = plt.axes()
80 | count.plot.bar(x='Class', y='Frequency', title=title, ax=ax)
81 |
82 | ax.set_title(title)
83 | ax = set_ax_border(ax, border_color='black', border_width=2)
84 | ax.grid(color='black', linestyle='-.', alpha=0.5, axis='y')
85 |
86 | if plot_average:
87 | ax.axhline(y=count['Frequency'].mean(), ls="dashdot", c="red")
88 | xlim_min, xlim_max, ylim_min, ylim_max = ax.axis()
89 | ax.text(
90 | x=xlim_min + (xlim_max - xlim_min) * 0.82,
91 | y=count['Frequency'].mean() + (ylim_max - ylim_min) * 0.03,
92 | c="red",
93 | s='Average',
94 | )
95 |
96 | return ax
97 |
98 |
99 | def plot_2Dprojection_and_cardinality(
100 | X,
101 | y,
102 | figsize=(10, 4),
103 | vis_params=None,
104 | projection=None,
105 | weights=None,
106 | plot_average=True,
107 | title1='Dataset',
108 | title2='Class Distribution',
109 | ):
110 | '''Plot the distribution of a given dataset'''
111 |
112 | if vis_params is None:
113 | vis_params = copy(DEFAULT_VIS_KWARGS)
114 |
115 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
116 |
117 | ax1 = plot_scatter(
118 | X,
119 | y,
120 | ax=ax1,
121 | weights=weights,
122 | title=title1,
123 | projection=projection,
124 | vis_params=vis_params,
125 | )
126 | ax2 = plot_class_distribution(
127 | y, ax=ax2, title=title2, sort_values=True, plot_average=plot_average
128 | )
129 | plt.tight_layout()
130 |
131 | return fig, (ax1, ax2)
132 |
133 |
134 | def plot_online_figure(url: str = None): # pragma: no cover
135 | '''Plot an online figure'''
136 | figure = mpimg.imread(url)
137 | plt.axis('off')
138 | plt.imshow(figure)
139 | plt.tight_layout()
140 |
--------------------------------------------------------------------------------
/examples/basic/plot_basic_example.py:
--------------------------------------------------------------------------------
1 | """
2 | =========================================================
3 | Train and predict with an ensemble classifier
4 | =========================================================
5 |
6 | This example shows the basic usage of an
7 | :mod:`imbens.ensemble` classifier.
8 |
9 | This example uses:
10 |
11 | - :class:`imbens.ensemble.SelfPacedEnsembleClassifier`
12 | """
13 |
14 | # Authors: Zhining Liu
15 | # License: MIT
16 |
17 |
18 | # %%
19 | print(__doc__)
20 |
21 | # Import imbalanced-ensemble
22 | import imbens
23 |
24 | # Import utilities
25 | from collections import Counter
26 | import sklearn
27 | from sklearn.datasets import make_classification
28 | from sklearn.model_selection import train_test_split
29 | from imbens.ensemble.base import sort_dict_by_key
30 |
31 | # Import plot utilities
32 | import matplotlib.pyplot as plt
33 | from imbens.utils._plot import plot_2Dprojection_and_cardinality
34 |
35 | RANDOM_STATE = 42
36 |
37 | # %% [markdown]
38 | # Prepare & visualize the data
39 | # ----------------------------
40 | # Make a toy 3-class imbalanced classification task.
41 |
42 | # Generate and split a synthetic dataset
43 | X, y = make_classification(
44 | n_classes=3,
45 | n_samples=2000,
46 | class_sep=2,
47 | weights=[0.1, 0.3, 0.6],
48 | n_informative=3,
49 | n_redundant=1,
50 | flip_y=0,
51 | n_features=20,
52 | n_clusters_per_class=2,
53 | random_state=RANDOM_STATE,
54 | )
55 | X_train, X_valid, y_train, y_valid = train_test_split(
56 | X, y, test_size=0.5, stratify=y, random_state=RANDOM_STATE
57 | )
58 |
59 | # Visualize the training dataset
60 | fig = plot_2Dprojection_and_cardinality(X_train, y_train, figsize=(8, 4))
61 | plt.show()
62 |
63 | # Print class distribution
64 | print('Training dataset distribution %s' % sort_dict_by_key(Counter(y_train)))
65 | print('Validation dataset distribution %s' % sort_dict_by_key(Counter(y_valid)))
66 |
67 | # %% [markdown]
68 | # Using ensemble classifiers in ``imbens``
69 | # -----------------------------------------------------
70 | # Take ``SelfPacedEnsembleClassifier`` as example
71 |
72 | # Initialize an SelfPacedEnsembleClassifier
73 | clf = imbens.ensemble.SelfPacedEnsembleClassifier(random_state=RANDOM_STATE)
74 |
75 | # Train an SelfPacedEnsembleClassifier
76 | clf.fit(X_train, y_train)
77 |
78 | # Make predictions
79 | y_pred_proba = clf.predict_proba(X_valid)
80 | y_pred = clf.predict(X_valid)
81 |
82 | # Evaluate
83 | balanced_acc_score = sklearn.metrics.balanced_accuracy_score(y_valid, y_pred)
84 | print(f'SPE: ensemble of {clf.n_estimators} {clf.estimator_}')
85 | print('Validation Balanced Accuracy: {:.3f}'.format(balanced_acc_score))
86 |
87 |
88 | # %% [markdown]
89 | # Set the ensemble size
90 | # ---------------------
91 | # (parameter ``n_estimators``: int)
92 |
93 | from imbens.ensemble import SelfPacedEnsembleClassifier as SPE
94 | from sklearn.metrics import balanced_accuracy_score
95 |
96 | clf = SPE(
97 | n_estimators=5, # Set ensemble size to 5
98 | random_state=RANDOM_STATE,
99 | ).fit(X_train, y_train)
100 |
101 | # Evaluate
102 | balanced_acc_score = balanced_accuracy_score(y_valid, clf.predict(X_valid))
103 | print(f'SPE: ensemble of {clf.n_estimators} {clf.estimator_}')
104 | print('Validation Balanced Accuracy: {:.3f}'.format(balanced_acc_score))
105 |
106 |
107 | # %% [markdown]
108 | # Use different base estimator
109 | # ----------------------------
110 | # (parameter ``estimator``: estimator object)
111 |
112 | from sklearn.svm import SVC
113 |
114 | clf = SPE(
115 | n_estimators=5,
116 | estimator=SVC(probability=True), # Use SVM as the base estimator
117 | random_state=RANDOM_STATE,
118 | ).fit(X_train, y_train)
119 |
120 | # Evaluate
121 | balanced_acc_score = balanced_accuracy_score(y_valid, clf.predict(X_valid))
122 | print(f'SPE: ensemble of {clf.n_estimators} {clf.estimator_}')
123 | print('Validation Balanced Accuracy: {:.3f}'.format(balanced_acc_score))
124 |
125 |
126 | # %% [markdown]
127 | # Enable training log
128 | # -------------------
129 | # (``fit()`` parameter ``train_verbose``: bool, int or dict)
130 |
131 | clf = SPE(random_state=RANDOM_STATE).fit(
132 | X_train,
133 | y_train,
134 | train_verbose=True, # Enable training log
135 | )
136 |
--------------------------------------------------------------------------------
/imbens/ensemble/tests/test_base_ensemble.py:
--------------------------------------------------------------------------------
1 | """Test SelfPacedEnsembleClassifier."""
2 |
3 | # Authors: Zhining Liu
4 | # License: MIT
5 |
6 | import pytest
7 | import sklearn
8 | from sklearn.datasets import load_iris
9 | from sklearn.model_selection import train_test_split
10 | from sklearn.utils.fixes import parse_version
11 |
12 | from imbens.datasets import make_imbalance
13 | from imbens.utils.testing import all_estimators
14 |
15 | sklearn_version = parse_version(sklearn.__version__)
16 | iris = load_iris()
17 | all_ensembles = all_estimators(type_filter='ensemble')
18 |
19 | X, y = make_imbalance(
20 | iris.data,
21 | iris.target,
22 | sampling_strategy={0: 20, 1: 25, 2: 50},
23 | random_state=0,
24 | )
25 | X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)
26 | init_param = {'random_state': 0, 'n_estimators': 20}
27 |
28 |
29 | @pytest.mark.parametrize(
30 | "ensemble",
31 | all_ensembles,
32 | )
33 | def test_evaluate(ensemble):
34 | """Check classification with dynamic logging."""
35 | (ensemble_name, EnsembleCLass) = ensemble
36 | clf = EnsembleCLass(**init_param)
37 | clf.fit(
38 | X_train,
39 | y_train,
40 | train_verbose=True,
41 | )
42 | clf._evaluate('train', return_value_dict=True)
43 |
44 |
45 | @pytest.mark.parametrize(
46 | "ensemble",
47 | all_ensembles,
48 | )
49 | def test_evaluate_verbose(ensemble):
50 | """Check classification with dynamic logging."""
51 | (ensemble_name, EnsembleCLass) = ensemble
52 | clf = EnsembleCLass(**init_param)
53 | if clf._properties['training_type'] == 'parallel':
54 | with pytest.raises(TypeError, match="can only be of type `bool`"):
55 | clf.fit(
56 | X_train,
57 | y_train,
58 | train_verbose={
59 | 'granularity': 10,
60 | },
61 | )
62 | else:
63 | clf.fit(
64 | X_train,
65 | y_train,
66 | train_verbose={
67 | 'granularity': 10,
68 | },
69 | )
70 | clf.fit(
71 | X_train,
72 | y_train,
73 | train_verbose={
74 | 'granularity': 10,
75 | 'print_distribution': False,
76 | 'print_metrics': True,
77 | },
78 | )
79 |
80 |
81 | @pytest.mark.parametrize(
82 | "ensemble",
83 | all_ensembles,
84 | )
85 | def test_evaluate_eval_datasets(ensemble):
86 | """Check classification with dynamic logging."""
87 | (ensemble_name, EnsembleCLass) = ensemble
88 | clf = EnsembleCLass(**init_param)
89 | if clf._properties['training_type'] == 'parallel':
90 | with pytest.raises(TypeError, match="can only be of type `bool`"):
91 | clf.fit(
92 | X_train,
93 | y_train,
94 | eval_datasets={
95 | 'valid': (X_valid, y_valid), # add validation data
96 | },
97 | train_verbose={
98 | 'granularity': 10,
99 | },
100 | )
101 | else:
102 | clf.fit(
103 | X_train,
104 | y_train,
105 | eval_datasets={
106 | 'valid': (X_valid, y_valid), # add validation data
107 | },
108 | train_verbose={
109 | 'granularity': 10,
110 | },
111 | )
112 |
113 |
114 | @pytest.mark.parametrize(
115 | "ensemble",
116 | all_ensembles,
117 | )
118 | def test_evaluate_eval_metrics(ensemble):
119 | """Check classification with dynamic logging."""
120 | (ensemble_name, EnsembleCLass) = ensemble
121 | clf = EnsembleCLass(**init_param)
122 | clf.fit(
123 | X_train,
124 | y_train,
125 | eval_datasets={
126 | 'valid': (X_valid, y_valid),
127 | },
128 | eval_metrics={
129 | 'weighted_f1': (
130 | sklearn.metrics.f1_score,
131 | {'average': 'weighted'},
132 | ), # use weighted_f1
133 | 'roc': (
134 | sklearn.metrics.roc_auc_score,
135 | {'multi_class': 'ovr', 'average': 'macro'},
136 | ), # use roc_auc score
137 | },
138 | train_verbose=True,
139 | )
140 |
--------------------------------------------------------------------------------
/imbens/utils/testing.py:
--------------------------------------------------------------------------------
1 | """Test utilities.
2 | """
3 | # Adapted from imbalanced-learn
4 |
5 | # Adapted from scikit-learn
6 | # Authors: Guillaume Lemaitre
7 | # License: MIT
8 |
9 | import inspect
10 | import pkgutil
11 | from importlib import import_module
12 | from operator import itemgetter
13 | from pathlib import Path
14 |
15 | from sklearn.base import BaseEstimator
16 | from sklearn.utils._testing import ignore_warnings
17 |
18 |
19 | def all_estimators(
20 | type_filter=None,
21 | ):
22 | """Get a list of all estimators from imbens.
23 |
24 | This function crawls the module and gets all classes that inherit
25 | from BaseEstimator. Classes that are defined in test-modules are not
26 | included.
27 | By default meta_estimators are also not included.
28 | This function is adapted from sklearn.
29 |
30 | Parameters
31 | ----------
32 | type_filter : string, list of string, or None, default=None
33 | Which kind of estimators should be returned. If None, no
34 | filter is applied and all estimators are returned. Possible
35 | values are 'sampler' or 'ensemble' to get estimators only of
36 | these specific types, or a list of these to get the estimators
37 | that fit at least one of the types.
38 |
39 | Returns
40 | -------
41 | estimators : list of tuples
42 | List of (name, class), where ``name`` is the class name as string
43 | and ``class`` is the actual type of the class.
44 |
45 | """
46 | from ..ensemble.base import ImbalancedEnsembleClassifierMixin
47 | from ..sampler.base import SamplerMixin
48 |
49 | def is_abstract(c):
50 | if not (hasattr(c, "__abstractmethods__")):
51 | return False
52 | if not len(c.__abstractmethods__):
53 | return False
54 | return True
55 |
56 | all_classes = []
57 | modules_to_ignore = {"tests"}
58 | root = str(Path(__file__).parent.parent)
59 | # Ignore deprecation warnings triggered at import time and from walking
60 | # packages
61 | with ignore_warnings(category=FutureWarning):
62 | for importer, modname, ispkg in pkgutil.walk_packages(
63 | path=[root], prefix="imbens."
64 | ):
65 | mod_parts = modname.split(".")
66 | if any(part in modules_to_ignore for part in mod_parts) or "._" in modname:
67 | continue
68 | module = import_module(modname)
69 | classes = inspect.getmembers(module, inspect.isclass)
70 | classes = [
71 | (name, est_cls) for name, est_cls in classes if not name.startswith("_")
72 | ]
73 |
74 | all_classes.extend(classes)
75 |
76 | all_classes = set(all_classes)
77 |
78 | estimators = [
79 | c
80 | for c in all_classes
81 | if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
82 | ]
83 | # get rid of abstract base classes
84 | estimators = [c for c in estimators if not is_abstract(c[1])]
85 |
86 | # get rid of sklearn estimators which have been imported in some classes
87 | estimators = [c for c in estimators if "sklearn" not in c[1].__module__]
88 |
89 | if type_filter is not None:
90 | if not isinstance(type_filter, list):
91 | type_filter = [type_filter]
92 | else:
93 | type_filter = list(type_filter) # copy
94 | filtered_estimators = []
95 | filters = {
96 | "sampler": SamplerMixin,
97 | "ensemble": ImbalancedEnsembleClassifierMixin,
98 | }
99 | for name, mixin in filters.items():
100 | if name in type_filter:
101 | type_filter.remove(name)
102 | filtered_estimators.extend(
103 | [est for est in estimators if issubclass(est[1], mixin)]
104 | )
105 | estimators = filtered_estimators
106 | if type_filter:
107 | raise ValueError(
108 | "Parameter type_filter must be 'sampler', 'ensemble' or "
109 | "None, got"
110 | " %s." % repr(type_filter)
111 | )
112 |
113 | # drop duplicates, sort for reproducibility
114 | # itemgetter is used to ensure the sort does not extend to the 2nd item of
115 | # the tuple
116 | return sorted(set(estimators), key=itemgetter(0))
117 |
--------------------------------------------------------------------------------