├── .all-contributorsrc ├── .codecov.yml ├── .coveragerc ├── .github └── workflows │ ├── build-and-test.yml │ └── code-quality.yml ├── .gitignore ├── .readthedocs.yml ├── CHANGELOG.rst ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── CONTRIBUTORS.md ├── LICENSE ├── MINIFEST.in ├── README.rst ├── build_tools ├── cleanup.bat ├── linting.sh └── requirements.txt ├── docs ├── Makefile ├── _images │ ├── badge.svg │ ├── badge_small.png │ ├── fusion.png │ ├── gradient_boosting.png │ ├── lenet_cifar10.png │ ├── lenet_mnist.png │ ├── resnet_cifar10.png │ └── resnet_cifar100.png ├── _static │ └── custom.css ├── _templates │ ├── apidoc │ │ ├── module.rst_t │ │ ├── package.rst_t │ │ └── toc.rst_t │ ├── class.rst │ ├── class_with_call.rst │ ├── class_without_init.rst │ ├── function.rst │ ├── module.rst │ └── numpydoc_docstring.rst ├── changelog.rst ├── code_of_conduct.rst ├── conf.py ├── contributors.rst ├── experiment.rst ├── experiments.xlsx ├── guide.rst ├── index.rst ├── introduction.rst ├── make.bat ├── parameters.rst ├── plotting │ ├── lenet_cifar10.py │ ├── lenet_mnist.py │ ├── resnet_cifar10.py │ └── resnet_cifar100.py ├── quick_start.rst ├── requirements.txt └── roadmap.rst ├── examples ├── classification_cifar10_cnn.py ├── classification_mnist_tree_ensemble.py ├── fast_geometric_ensemble_cifar10_resnet18.py ├── regression_YearPredictionMSD_mlp.py └── snapshot_ensemble_cifar10_resnet18.py ├── pyproject.toml ├── requirements.txt ├── setup.py └── torchensemble ├── __init__.py ├── _base.py ├── _constants.py ├── adversarial_training.py ├── bagging.py ├── fast_geometric.py ├── fusion.py ├── gradient_boosting.py ├── snapshot_ensemble.py ├── soft_gradient_boosting.py ├── tests ├── test_adversarial_training.py ├── test_all_models.py ├── test_all_models_multi_input.py ├── test_fixed_dataloder.py ├── test_logging.py ├── test_neural_tree_ensemble.py ├── test_operator.py ├── test_set_optimizer.py ├── test_set_scheduler.py ├── test_tb_logging.py └── test_training_params.py ├── utils ├── __init__.py ├── dataloder.py ├── io.py ├── logging.py ├── operator.py └── set_module.py └── voting.py /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "projectName": "Ensemble-Pytorch", 3 | "projectOwner": "TorchEnsemble-Community", 4 | "repoType": "github", 5 | "repoHost": "https://github.com", 6 | "files": [ 7 | "CONTRIBUTORS.md" 8 | ], 9 | "imageSize": 100, 10 | "commit": false, 11 | "commitConvention": "none", 12 | "contributorsSortAlphabetically": true, 13 | "contributors": [ 14 | { 15 | "login": "xuyxu", 16 | "name": "Yi-Xuan Xu", 17 | "avatar_url": "https://avatars.githubusercontent.com/u/22359569?v=4", 18 | "profile": "https://github.com/xuyxu", 19 | "contributions": [ 20 | "code", 21 | "doc", 22 | "test", 23 | "example" 24 | ] 25 | }, 26 | { 27 | "login": "zzzzwj", 28 | "name": "Wenjie Zhang", 29 | "avatar_url": "https://avatars.githubusercontent.com/u/23235538?v=4", 30 | "profile": "https://github.com/zzzzwj", 31 | "contributions": [ 32 | "code", 33 | "test" 34 | ] 35 | }, 36 | { 37 | "login": "mttgdd", 38 | "name": "Matt Gadd", 39 | "avatar_url": "https://avatars.githubusercontent.com/u/3154919?v=4", 40 | "profile": "https://github.com/mttgdd", 41 | "contributions": [ 42 | "code" 43 | ] 44 | }, 45 | { 46 | "login": "Xiaohui9607", 47 | "name": "Xiaohui Chen", 48 | "avatar_url": "https://avatars.githubusercontent.com/u/37996225?v=4", 49 | "profile": "https://github.com/Xiaohui9607", 50 | "contributions": [ 51 | "bug" 52 | ] 53 | }, 54 | { 55 | "login": "cspsampedro", 56 | "name": "cspsampedro", 57 | "avatar_url": "https://avatars.githubusercontent.com/u/7384605?v=4", 58 | "profile": "https://github.com/cspsampedro", 59 | "contributions": [ 60 | "ideas", 61 | "code" 62 | ] 63 | }, 64 | { 65 | "login": "nolaurence", 66 | "name": "nolaurence", 67 | "avatar_url": "https://avatars.githubusercontent.com/u/53215736?v=4", 68 | "profile": "https://github.com/nolaurence", 69 | "contributions": [ 70 | "code" 71 | ] 72 | }, 73 | { 74 | "login": "by256", 75 | "name": "Batuhan Yildirim", 76 | "avatar_url": "https://avatars.githubusercontent.com/u/44163664?v=4", 77 | "profile": "https://by256.github.io/", 78 | "contributions": [ 79 | "code" 80 | ] 81 | }, 82 | { 83 | "login": "kiranchari", 84 | "name": "kiranchari", 85 | "avatar_url": "https://avatars.githubusercontent.com/u/1838910?v=4", 86 | "profile": "https://github.com/kiranchari", 87 | "contributions": [ 88 | "doc" 89 | ] 90 | }, 91 | { 92 | "login": "Arcify", 93 | "name": "Arcify", 94 | "avatar_url": "https://avatars.githubusercontent.com/u/96500824?v=4", 95 | "profile": "https://github.com/Arcify", 96 | "contributions": [ 97 | "doc" 98 | ] 99 | }, 100 | { 101 | "login": "SunHaozhe", 102 | "name": "Sun Haozhe", 103 | "avatar_url": "https://avatars.githubusercontent.com/u/26926814?v=4", 104 | "profile": "https://github.com/SunHaozhe", 105 | "contributions": [ 106 | "code" 107 | ] 108 | }, 109 | { 110 | "login": "FedericoV", 111 | "name": "Federico Vaggi", 112 | "avatar_url": "https://avatars.githubusercontent.com/u/630134?v=4", 113 | "profile": "https://github.com/FedericoV", 114 | "contributions": [ 115 | "code" 116 | ] 117 | }, 118 | { 119 | "login": "LukasGardberg", 120 | "name": "LukasGardberg", 121 | "avatar_url": "https://avatars.githubusercontent.com/u/52111220?v=4", 122 | "profile": "https://github.com/LukasGardberg", 123 | "contributions": [ 124 | "code", 125 | "test" 126 | ] 127 | }, 128 | { 129 | "login": "SarthakJariwala", 130 | "name": "SarthakJariwala", 131 | "avatar_url": "https://avatars.githubusercontent.com/u/35085572?v=4", 132 | "profile": "http://www.sarthakjariwala.com", 133 | "contributions": [ 134 | "code", 135 | "test" 136 | ] 137 | }, 138 | { 139 | "login": "lorenzozanisi", 140 | "name": "Lorenzo Zanisi", 141 | "avatar_url": "https://avatars.githubusercontent.com/u/33519471?v=4", 142 | "profile": "https://github.com/lorenzozanisi", 143 | "contributions": [ 144 | "bug" 145 | ] 146 | }, 147 | { 148 | "login": "AtiqurRahmanAni", 149 | "name": "AtiqurRahmanAni", 150 | "avatar_url": "https://avatars.githubusercontent.com/u/56642339?v=4", 151 | "profile": "https://github.com/AtiqurRahmanAni", 152 | "contributions": [ 153 | "code" 154 | ] 155 | }, 156 | { 157 | "login": "Malephilosopher", 158 | "name": "Malephilosopher", 159 | "avatar_url": "https://avatars.githubusercontent.com/u/68840672?v=4", 160 | "profile": "https://github.com/Malephilosopher", 161 | "contributions": [ 162 | "code" 163 | ] 164 | }, 165 | { 166 | "login": "h2soheili", 167 | "name": "Hamed Soheili", 168 | "avatar_url": "https://avatars.githubusercontent.com/u/48531597?v=4", 169 | "profile": "https://github.com/h2soheili", 170 | "contributions": [ 171 | "code" 172 | ] 173 | } 174 | ], 175 | "contributorsPerLine": 7, 176 | "skipCi": true, 177 | "commitType": "docs" 178 | } 179 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | # PR status check 2 | coverage: 3 | status: 4 | project: 5 | default: 6 | # threshold: 1% 7 | informational: true 8 | patch: 9 | default: 10 | informational: true 11 | 12 | # post coverage report as comment on PR 13 | comment: false 14 | 15 | # enable codecov to report to GitHub status checks 16 | github_checks: 17 | annotations: false -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | show_missing = True 3 | 4 | [run] 5 | include = */torchensemble/* 6 | omit = 7 | */setup.py 8 | */torchensemble/__init__.py 9 | */torchensemble/_constants.py 10 | */torchensemble/tests/* 11 | -------------------------------------------------------------------------------- /.github/workflows/build-and-test.yml: -------------------------------------------------------------------------------- 1 | name: torchensemble-CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, windows-latest] 15 | python-version: ["3.9", "3.10"] 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Display python version 23 | run: python -c "import sys; print(sys.version)" 24 | - name: Install package dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install -r requirements.txt 28 | pip install -r build_tools/requirements.txt 29 | - name: Install 30 | run: pip install --verbose --editable . 31 | - name: Run tests 32 | run: | 33 | pytest --cov-config=.coveragerc --cov-report=xml --cov=torchensemble torchensemble 34 | - name: Publish code coverage 35 | uses: codecov/codecov-action@v1 36 | with: 37 | token: ${{ secrets.CODECOV_TOKEN }} 38 | file: ./coverage.xml 39 | -------------------------------------------------------------------------------- /.github/workflows/code-quality.yml: -------------------------------------------------------------------------------- 1 | name: code-quality 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | code-quality: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest] 15 | python-version: [3.9] 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Display python version 23 | run: python -c "import sys; print(sys.version)" 24 | - name: Install package dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install -r build_tools/requirements.txt 28 | - name: Check code quality 29 | run: | 30 | black --skip-string-normalization --check --config pyproject.toml ./ 31 | chmod +x "${GITHUB_WORKSPACE}/build_tools/linting.sh" 32 | ./build_tools/linting.sh 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | *.swp 4 | *.vsdx 5 | *.pth 6 | *.xml 7 | .DS_Store 8 | .coverage 9 | package-lock.json 10 | 11 | .idea/ 12 | .vscode/ 13 | .pytest_cache/ 14 | build/ 15 | _build/ 16 | dist/ 17 | logs/ 18 | torchensemble.egg-info 19 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | formats: [] 4 | 5 | build: 6 | os: "ubuntu-22.04" 7 | tools: 8 | python: "3.10" 9 | 10 | sphinx: 11 | configuration: docs/conf.py 12 | 13 | python: 14 | install: 15 | - requirements: docs/requirements.txt 16 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | +---------------+-----------------------------------------------------------+ 5 | | Badge | Meaning | 6 | +===============+===========================================================+ 7 | | |Feature| | Add something that cannot be achieved before. | 8 | +---------------+-----------------------------------------------------------+ 9 | | |Efficiency| | Improve the efficiency on the computation or memory. | 10 | +---------------+-----------------------------------------------------------+ 11 | | |Enhancement| | Miscellaneous minor improvements. | 12 | +---------------+-----------------------------------------------------------+ 13 | | |Fix| | Fix up something that does not work as expected. | 14 | +---------------+-----------------------------------------------------------+ 15 | | |API| | You will need to change the code to have the same effect. | 16 | +---------------+-----------------------------------------------------------+ 17 | 18 | Ver 0.1.* 19 | --------- 20 | 21 | * |Fix| Fix the issue that model will be overwritten when using the validation data loder | `@xuyxu `__ 22 | * |Feature| Add internal :meth:`unsqueeze` operation in :meth:`forward` of all classifiers | `@xuyxu `__ 23 | * |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg `__ 24 | * |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe `__ 25 | * |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu `__ 26 | * |Fix| Relax check on input dataloader | `@xuyxu `__ 27 | * |Feature| |API| Support arbitrary training criteria for all ensembles except Gradient Boosting | `@by256 `__ and `@xuyxu `__ 28 | * |Fix| Fix missing functionality of ``save_model`` for :meth:`fit` of Soft Gradient Boosting | `@xuyxu `__ 29 | * |Feature| |API| Add :class:`SoftGradientBoostingClassifier` and :class:`SoftGradientBoostingRegressor` | `@xuyxu `__ 30 | * |Feature| |API| Support using dataloader with multiple input | `@xuyxu `__ 31 | * |Fix| Fix missing functionality of ``use_reduction_sum`` for :meth:`fit` of Gradient Boosting | `@xuyxu `__ 32 | * |Enhancement| Relax :mod:`tensorboard` as a soft dependency | `@xuyxu `__ 33 | * |Enhancement| |API| Simplify the training workflow of :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu `__ 34 | * |Feature| |API| Support TensorBoard logging in :meth:`set_logger` | `@zzzzwj `__ 35 | * |Enhancement| |API| Add ``use_reduction_sum`` parameter for :meth:`fit` of Gradient Boosting | `@xuyxu `__ 36 | * |Feature| |API| Improve the functionality of :meth:`evaluate` and :meth:`predict` | `@xuyxu `__ 37 | * |Feature| |API| Add :class:`FastGeometricClassifier` and :class:`FastGeometricRegressor` | `@xuyxu `__ 38 | * |Enhancement| Add flexible instantiation of optimizers and schedulers | `@cspsampedro `__ 39 | * |Feature| |API| Add support on accepting instantiated base estimators as valid input | `@xuyxu `__ 40 | * |Fix| Fix missing base estimators when calling :meth:`load()` for all ensembles | `@xuyxu `__ 41 | * |Feature| |API| Add methods on model deserialization :meth:`load()` for all ensembles | `@mttgdd `__ 42 | 43 | Beta 44 | ---- 45 | 46 | * |Feature| |API| Add :meth:`set_scheduler` for all ensembles | `@xuyxu `__ 47 | * |MajorFeature| Add :class:`AdversarialTrainingClassifier` and :class:`AdversarialTrainingRegressor` | `@xuyxu `__ 48 | * |MajorFeature| Add :class:`SnapshotEnsembleClassifier` and :class:`SnapshotEnsembleRegressor` | `@xuyxu `__ 49 | * |Feature| |API| Add model validation and serialization | `@ozanpkr `__ and `@xuyxu `__ 50 | * |Enhancement| Add CI and maintenance tools | `@xuyxu `__ 51 | * |Enhancement| Add the code coverage on codecov | `@xuyxu `__ 52 | * |Enhancement| Add the version numbers to requirements.txt | `@zackhardtoname `__ and `@xuyxu `__ 53 | * |Enhancement| Improve the logging module using :class:`logging` | `@zzzzwj `__ 54 | * |API| Remove the input argument ``output_dim`` from all methods | `@xuyxu `__ 55 | * |API| Refactor the setup on optimizer into :meth:`set_optimizer` | `@xuyxu `__ 56 | * |API| Refactor the codes on operating tensors into an independent module | `@zzzzwj `__ 57 | * |Fix| Fix the bug in logging module when using multi-processing | `@zzzzwj `__ 58 | * |Fix| Fix the binding problem on scheduler and optimizer when using parallelization | `@Alex-Medium `__ and `@xuyxu `__ 59 | 60 | .. role:: raw-html(raw) 61 | :format: html 62 | 63 | .. role:: raw-latex(raw) 64 | :format: latex 65 | 66 | .. |MajorFeature| replace:: :raw-html:`Major Feature` :raw-latex:`{\small\sc [Major Feature]}` 67 | .. |Feature| replace:: :raw-html:`Feature` :raw-latex:`{\small\sc [Feature]}` 68 | .. |Efficiency| replace:: :raw-html:`Efficiency` :raw-latex:`{\small\sc [Efficiency]}` 69 | .. |Enhancement| replace:: :raw-html:`Enhancement` :raw-latex:`{\small\sc [Enhancement]}` 70 | .. |Fix| replace:: :raw-html:`Fix` :raw-latex:`{\small\sc [Fix]}` 71 | .. |API| replace:: :raw-html:`API Change` :raw-latex:`{\small\sc [API Change]}` 72 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the current project maintainer at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [PyTorch][homepage], version 1.8.1, 71 | available at https://github.com/pytorch/pytorch/blob/master/CODE_OF_CONDUCT.md 72 | 73 | [homepage]: https://pytorch.org/ 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Welcome to the contributing guidelines of Ensemble-Pytorch! 2 | 3 | Ensemble-Pytorch is a community-driven project and your contributions are highly welcome. Feel free to [raise an issue](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/issues/new/choose) if you have any problem. Below is the table of contents in this contributing guidelines. 4 | 5 | - [Where to contribute](#where-to-contribute) 6 | - [Areas of contribution](#areas-of-contribution) 7 | - [Roadmap](#roadmap) 8 | - [Acknowledging contributions](#acknowledging-contributions) 9 | - [Installation](#installation) 10 | - [Reporting bugs](#reporting-bugs) 11 | - [Continuous integration](#continuous-integration) 12 | - [Unit testing](#unit-testing) 13 | - [Test coverage](#test-coverage) 14 | - [Coding style](#coding-style) 15 | - [API design](#api-design) 16 | - [Documentation](#documentation) 17 | - [Acknowledgement](#acknowledgement) 18 | 19 | Where to contribute 20 | ------------------- 21 | 22 | #### Areas of contribution 23 | 24 | We value all kinds of contributions - not just code. The following table gives an overview of key contribution areas. 25 | 26 | | Area | Description | 27 | |---------------|--------------------------------------------------------------------------------------------------------------| 28 | | algorithm | collect and report novel algorithms relevant to torchensemble, mainly from top-tier conferences and journals | 29 | | code | implement algorithms, improve or add functionality, fix bugs | 30 | | documentation | improve or add docstrings, user guide, introduction, and experiments | 31 | | testing | report bugs, improve or add unit tests, improve the coverage of unit tests | 32 | | maintenance | improve the development pipeline (continuous integration, Github bots), manage and view issues/pull-requests | 33 | | api design | design interfaces for estimators and other functionality | 34 | 35 | ### Roadmap 36 | 37 | For a more detailed overview of current and future work, check out our [development roadmap](https://ensemble-pytorch.readthedocs.io/en/stable/roadmap.html). 38 | 39 | Acknowledging contributions 40 | --------------------------- 41 | 42 | We follow the [all-contributors specification](https://allcontributors.org) and recognise various types of contributions. Take a look at our past and current [contributors](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/CONTRIBUTORS.md)! 43 | 44 | If you are a new contributor, please make sure we add you to our list of contributors. All contributions are recorded in [.all-contributorsrc](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/.all-contributorsrc). 45 | 46 | If we have missed anything, please [raise an issue](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/issues/new/choose) or create a pull request! 47 | 48 | Installation 49 | ------------ 50 | 51 | Please visit our [installation instructions](https://ensemble-pytorch.readthedocs.io/en/stable/quick_start.html#installation) to resolve any package issues and dependency errors. Feel free to [raise an issue](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/issues/new/choose) if the problem still exists. 52 | 53 | Reporting bugs 54 | -------------- 55 | 56 | We use GitHub issues to track all bugs and feature requests; feel free to open an issue if you have found a bug or wish to see a feature implemented. 57 | 58 | It is recommended to check that your issue complies with the following rules before submitting: 59 | 60 | - Verify that your issue is not being currently addressed by other [issues](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/issues) or [pull requests](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/pulls). 61 | - Please ensure all code snippets and error messages are formatted in appropriate code blocks. See [Creating and highlighting code blocks](https://help.github.com/articles/creating-and-highlighting-code-blocks). 62 | - Please be specific about what estimators and/or functions are involved and the shape of the data, as appropriate; please include a [reproducible](https://stackoverflow.com/help/mcve) code snippet. If an exception is raised, please provide the traceback. 63 | 64 | Continuous integration 65 | ---------------------- 66 | 67 | We use continuous integration services on GitHub to automatically check if new pull requests do not break anything and meet code quality standards. Please visit our [config files on continuous integration](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/tree/master/.github/workflows). 68 | 69 | ### Unit testing 70 | 71 | We use [pytest](https://docs.pytest.org/en/latest/) for unit testing. To check if your code passes all tests locally, you need to install the development version of torchensemble and all extra dependencies. 72 | 73 | 1. Install all extra requirements from the root directory of torchensemble: 74 | 75 | ```bash 76 | pip install -r build_tools/requirements.txt 77 | ``` 78 | 79 | 2. Install the development version of torchensemble: 80 | 81 | ```bash 82 | pip install -e . 83 | ``` 84 | 85 | 3. To run all unit tests, run the following commend from the root directory: 86 | 87 | ```bash 88 | pytest ./ 89 | ``` 90 | 91 | ### Test coverage 92 | 93 | We use [coverage](https://coverage.readthedocs.io/en/coverage-5.3/) via the [pytest-cov](https://github.com/pytest-dev/pytest-cov) plugin and [codecov](https://codecov.io) to measure and compare test coverage of our code. 94 | 95 | Coding style 96 | ------------ 97 | 98 | We follow the [PEP8](https://www.python.org/dev/peps/pep-0008/) coding guidelines. A good example can be found [here](https://gist.github.com/nateGeorge/5455d2c57fb33c1ae04706f2dc4fee01). 99 | 100 | We use the [pre-commit](#Code-quality-checks) workflow together with [black](https://black.readthedocs.io/en/stable/) and [flake8](https://flake8.pycqa.org/en/latest/) to automatically apply consistent formatting and check whether your contribution complies with the PEP8 style. Besides, if you are using GitHub desktop on Windows, the following code snippet allows you to format and check the coding style manually. 101 | 102 | ``` bash 103 | black --skip-string-normalization --config pyproject.toml ./ 104 | flake8 --filename=*.py torchensemble/ 105 | ``` 106 | 107 | API design 108 | ---------- 109 | 110 | The general API design we use in torchensemble is similar to [scikit-learn](https://scikit-learn.org/) and [skorch](https://skorch.readthedocs.io/en/latest/?badge=latest). 111 | 112 | For docstrings, we use the [numpy docstring standard](https://numpydoc.readthedocs.io/en/latest/format.html\#docstring-standard). 113 | 114 | Documentation 115 | ------------- 116 | 117 | We use [sphinx](https://www.sphinx-doc.org/en/master/) and [readthedocs](https://readthedocs.org/projects/ensemble-pytorch/) to build and deploy our online documentation. You can find our online documentation [here](https://ensemble-pytorch.readthedocs.io). 118 | 119 | The source files used to generate the online documentation can be found in [docs/](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/tree/master/docs). For example, the main configuration file for sphinx is [conf.py](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/docs/conf.py) and the main page is [index.rst](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/docs/index.rst). To add new pages, you need to add a new `.rst` file and include it in the `index.rst` file. 120 | 121 | To build the documentation locally, you need to install a few extra dependencies listed in [docs/requirements.txt](https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/docs/requirements.txt). 122 | 123 | 1. Install extra requirements from the root directory, run: 124 | 125 | ```bash 126 | pip install -r docs/requirements.txt 127 | ``` 128 | 129 | 2. To build the website locally, run: 130 | 131 | ```bash 132 | cd docs 133 | make html 134 | ``` 135 | 136 | You can find the generated files in the `Ensemble-Pytorch/docs/_build/` folder. To view the website, open `Ensemble-Pytorch/docs/_build/html/index.html` with your preferred web browser. 137 | 138 | Acknowledgement 139 | --------------- 140 | 141 | This CONTRIBUTING file is adapted from the [PyTorch](https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md) and [Sktime](https://github.com/alan-turing-institute/sktime/blob/main/CONTRIBUTING.md). -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | 2 | ## Contributors ✨ 3 | 4 | Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 |
Arcify
Arcify

📖
AtiqurRahmanAni
AtiqurRahmanAni

💻
Batuhan Yildirim
Batuhan Yildirim

💻
Federico Vaggi
Federico Vaggi

💻
Hamed Soheili
Hamed Soheili

💻
Lorenzo Zanisi
Lorenzo Zanisi

🐛
LukasGardberg
LukasGardberg

💻 ⚠️
Malephilosopher
Malephilosopher

💻
Matt Gadd
Matt Gadd

💻
SarthakJariwala
SarthakJariwala

💻 ⚠️
Sun Haozhe
Sun Haozhe

💻
Wenjie Zhang
Wenjie Zhang

💻 ⚠️
Xiaohui Chen
Xiaohui Chen

🐛
Yi-Xuan Xu
Yi-Xuan Xu

💻 📖 ⚠️ 💡
cspsampedro
cspsampedro

🤔 💻
kiranchari
kiranchari

📖
nolaurence
nolaurence

💻
36 | 37 | 38 | 39 | 40 | 41 | 42 | This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019 - 2024 The Ensemble-Pytorch developers. 4 | 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /MINIFEST.in: -------------------------------------------------------------------------------- 1 | include MANIFEST.in 2 | 3 | exclude examples 4 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: ./docs/_images/badge_small.png 2 | 3 | |github|_ |readthedocs|_ |codecov|_ |license|_ 4 | 5 | .. |github| image:: https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/workflows/torchensemble-CI/badge.svg 6 | .. _github: https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/actions 7 | 8 | .. |readthedocs| image:: https://readthedocs.org/projects/ensemble-pytorch/badge/?version=latest 9 | .. _readthedocs: https://ensemble-pytorch.readthedocs.io/en/latest/index.html 10 | 11 | .. |codecov| image:: https://codecov.io/gh/TorchEnsemble-Community/Ensemble-Pytorch/branch/master/graph/badge.svg?token=2FXCFRIDTV 12 | .. _codecov: https://codecov.io/gh/TorchEnsemble-Community/Ensemble-Pytorch 13 | 14 | .. |license| image:: https://img.shields.io/github/license/TorchEnsemble-Community/Ensemble-Pytorch 15 | .. _license: https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/LICENSE 16 | 17 | Ensemble PyTorch 18 | ================ 19 | 20 | A unified ensemble framework for pytorch_ to easily improve the performance and robustness of your deep learning model. Ensemble-PyTorch is part of the `pytorch ecosystem `__, which requires the project to be well maintained. 21 | 22 | * `Document `__ 23 | * `Experiment `__ 24 | 25 | Installation 26 | ------------ 27 | 28 | .. code:: bash 29 | 30 | pip install torchensemble 31 | 32 | Example 33 | ------- 34 | 35 | .. code:: python 36 | 37 | from torchensemble import VotingClassifier # voting is a classic ensemble strategy 38 | 39 | # Load data 40 | train_loader = DataLoader(...) 41 | test_loader = DataLoader(...) 42 | 43 | # Define the ensemble 44 | ensemble = VotingClassifier( 45 | estimator=base_estimator, # estimator is your pytorch model 46 | n_estimators=10, # number of base estimators 47 | ) 48 | 49 | # Set the optimizer 50 | ensemble.set_optimizer( 51 | "Adam", # type of parameter optimizer 52 | lr=learning_rate, # learning rate of parameter optimizer 53 | weight_decay=weight_decay, # weight decay of parameter optimizer 54 | ) 55 | 56 | # Set the learning rate scheduler 57 | ensemble.set_scheduler( 58 | "CosineAnnealingLR", # type of learning rate scheduler 59 | T_max=epochs, # additional arguments on the scheduler 60 | ) 61 | 62 | # Train the ensemble 63 | ensemble.fit( 64 | train_loader, 65 | epochs=epochs, # number of training epochs 66 | ) 67 | 68 | # Evaluate the ensemble 69 | acc = ensemble.evaluate(test_loader) # testing accuracy 70 | 71 | Supported Ensemble 72 | ------------------ 73 | 74 | +------------------------------+------------+---------------------------+-----------------------------+ 75 | | **Ensemble Name** | **Type** | **Source Code** | **Problem** | 76 | +==============================+============+===========================+=============================+ 77 | | Fusion | Mixed | fusion.py | Classification / Regression | 78 | +------------------------------+------------+---------------------------+-----------------------------+ 79 | | Voting [1]_ | Parallel | voting.py | Classification / Regression | 80 | +------------------------------+------------+---------------------------+-----------------------------+ 81 | | Neural Forest | Parallel | voting.py | Classification / Regression | 82 | +------------------------------+------------+---------------------------+-----------------------------+ 83 | | Bagging [2]_ | Parallel | bagging.py | Classification / Regression | 84 | +------------------------------+------------+---------------------------+-----------------------------+ 85 | | Gradient Boosting [3]_ | Sequential | gradient_boosting.py | Classification / Regression | 86 | +------------------------------+------------+---------------------------+-----------------------------+ 87 | | Snapshot Ensemble [4]_ | Sequential | snapshot_ensemble.py | Classification / Regression | 88 | +------------------------------+------------+---------------------------+-----------------------------+ 89 | | Adversarial Training [5]_ | Parallel | adversarial_training.py | Classification / Regression | 90 | +------------------------------+------------+---------------------------+-----------------------------+ 91 | | Fast Geometric Ensemble [6]_ | Sequential | fast_geometric.py | Classification / Regression | 92 | +------------------------------+------------+---------------------------+-----------------------------+ 93 | | Soft Gradient Boosting [7]_ | Parallel | soft_gradient_boosting.py | Classification / Regression | 94 | +------------------------------+------------+---------------------------+-----------------------------+ 95 | 96 | Dependencies 97 | ------------ 98 | 99 | - scikit-learn>=0.23.0 100 | - torch>=1.4.0 101 | - torchvision>=0.2.2 102 | 103 | Reference 104 | --------- 105 | 106 | .. [1] Zhou, Zhi-Hua. Ensemble Methods: Foundations and Algorithms. CRC press, 2012. 107 | 108 | .. [2] Breiman, Leo. Bagging Predictors. Machine Learning (1996): 123-140. 109 | 110 | .. [3] Friedman, Jerome H. Greedy Function Approximation: A Gradient Boosting Machine. Annals of Statistics (2001): 1189-1232. 111 | 112 | .. [4] Huang, Gao, et al. Snapshot Ensembles: Train 1, Get M For Free. ICLR, 2017. 113 | 114 | .. [5] Lakshminarayanan, Balaji, et al. Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles. NIPS, 2017. 115 | 116 | .. [6] Garipov, Timur, et al. Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs. NeurIPS, 2018. 117 | 118 | .. [7] Feng, Ji, et al. Soft Gradient Boosting Machine. ArXiv, 2020. 119 | 120 | .. _pytorch: https://pytorch.org/ 121 | 122 | .. _pypi: https://pypi.org/project/torchensemble/ 123 | 124 | Thanks to all our contributors 125 | ------------------------------ 126 | 127 | |contributors| 128 | 129 | .. |contributors| image:: https://contributors-img.web.app/image?repo=TorchEnsemble-Community/Ensemble-Pytorch 130 | .. _contributors: https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/graphs/contributors 131 | -------------------------------------------------------------------------------- /build_tools/cleanup.bat: -------------------------------------------------------------------------------- 1 | cd .. 2 | 3 | del /s /q /f *.pth 4 | del /s /q /f *.log 5 | -------------------------------------------------------------------------------- /build_tools/linting.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e -x 4 | set -o pipefail 5 | 6 | if ! flake8 --verbose --filename=*.py torchensemble/; then 7 | echo 'Failure on Code Quality Check.' 8 | exit 1 9 | fi -------------------------------------------------------------------------------- /build_tools/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest==7.1.1 2 | numpy==1.22.4 3 | flake8 4 | pytest-cov 5 | click==8.0.3 6 | black==20.8b1 7 | tensorboard==2.* 8 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_images/badge_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/docs/_images/badge_small.png -------------------------------------------------------------------------------- /docs/_images/fusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/docs/_images/fusion.png -------------------------------------------------------------------------------- /docs/_images/gradient_boosting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/docs/_images/gradient_boosting.png -------------------------------------------------------------------------------- /docs/_images/lenet_cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/docs/_images/lenet_cifar10.png -------------------------------------------------------------------------------- /docs/_images/lenet_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/docs/_images/lenet_mnist.png -------------------------------------------------------------------------------- /docs/_images/resnet_cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/docs/_images/resnet_cifar10.png -------------------------------------------------------------------------------- /docs/_images/resnet_cifar100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/docs/_images/resnet_cifar100.png -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | .center {text-align: center;} -------------------------------------------------------------------------------- /docs/_templates/apidoc/module.rst_t: -------------------------------------------------------------------------------- 1 | {%- if show_headings %} 2 | {{- basename | e | heading(2) }} 3 | 4 | {% endif -%} 5 | .. automodule:: {{ qualname }} 6 | {%- for option in automodule_options %} 7 | :{{ option }}: 8 | {%- endfor %} 9 | 10 | -------------------------------------------------------------------------------- /docs/_templates/apidoc/package.rst_t: -------------------------------------------------------------------------------- 1 | {%- macro automodule(modname, options) -%} 2 | .. automodule:: {{ modname }} 3 | {%- for option in options %} 4 | :{{ option }}: 5 | {%- endfor %} 6 | {%- endmacro %} 7 | 8 | {%- macro toctree(docnames) -%} 9 | .. toctree:: 10 | {% for docname in docnames %} 11 | {{ docname }} 12 | {%- endfor %} 13 | {%- endmacro %} 14 | 15 | {%- if is_namespace %} 16 | {{- [pkgname, "namespace"] | join(" ") | e | heading }} 17 | {% else %} 18 | {{- pkgname | e | heading(2) }} 19 | {% endif %} 20 | 21 | {%- if modulefirst and not is_namespace %} 22 | {{ automodule(pkgname, automodule_options) }} 23 | {% endif %} 24 | 25 | {%- if subpackages %} 26 | Subpackages 27 | ----------- 28 | 29 | {{ toctree(subpackages) }} 30 | {% endif %} 31 | 32 | {%- if submodules %} 33 | Submodules 34 | ---------- 35 | {% if separatemodules %} 36 | {{ toctree(submodules) }} 37 | {%- else %} 38 | {%- for submodule in submodules %} 39 | {% if show_headings %} 40 | {{- submodule | e | heading(2) }} 41 | {% endif %} 42 | {{ automodule(submodule, automodule_options) }} 43 | {% endfor %} 44 | {%- endif %} 45 | {% endif %} 46 | 47 | {%- if not modulefirst and not is_namespace %} 48 | Module contents 49 | --------------- 50 | 51 | {{ automodule(pkgname, automodule_options) }} 52 | {% endif %} 53 | -------------------------------------------------------------------------------- /docs/_templates/apidoc/toc.rst_t: -------------------------------------------------------------------------------- 1 | {{ header | heading }} 2 | 3 | .. toctree:: 4 | :maxdepth: {{ maxdepth }} 5 | {% for docname in docnames %} 6 | {{ docname }} 7 | {%- endfor %} 8 | 9 | -------------------------------------------------------------------------------- /docs/_templates/class.rst: -------------------------------------------------------------------------------- 1 | {{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | .. automethod:: __init__ 10 | {% endblock %} 11 | -------------------------------------------------------------------------------- /docs/_templates/class_with_call.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}=============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | .. automethod:: __init__ 10 | .. automethod:: __call__ 11 | {% endblock %} 12 | -------------------------------------------------------------------------------- /docs/_templates/class_without_init.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | -------------------------------------------------------------------------------- /docs/_templates/function.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}==================== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autofunction:: {{ objname }} 7 | -------------------------------------------------------------------------------- /docs/_templates/module.rst: -------------------------------------------------------------------------------- 1 | .. _mod-{{ fullname }}: 2 | 3 | {{ fullname | underline }} 4 | 5 | .. automodule:: {{ fullname }} 6 | 7 | {% block functions %} 8 | {% if functions %} 9 | .. rubric:: Functions 10 | 11 | .. autosummary:: 12 | :toctree: {{ objname }} 13 | :template: function.rst 14 | {% for item in functions %} 15 | {{ item }} 16 | {%- endfor %} 17 | {% endif %} 18 | {% endblock %} 19 | 20 | {% block classes %} 21 | {% if classes %} 22 | .. rubric:: Classes 23 | 24 | .. autosummary:: 25 | :toctree: {{ objname }} 26 | :template: class.rst 27 | {% for item in classes %} 28 | {{ item }} 29 | {%- endfor %} 30 | {% endif %} 31 | {% endblock %} 32 | 33 | {% block exceptions %} 34 | {% if exceptions %} 35 | .. rubric:: Exceptions 36 | 37 | .. autosummary:: 38 | {% for item in exceptions %} 39 | {{ item }} 40 | {%- endfor %} 41 | {% endif %} 42 | {% endblock %} 43 | -------------------------------------------------------------------------------- /docs/_templates/numpydoc_docstring.rst: -------------------------------------------------------------------------------- 1 | {{index}} 2 | {{summary}} 3 | {{extended_summary}} 4 | {{parameters}} 5 | {{returns}} 6 | {{yields}} 7 | {{other_parameters}} 8 | {{attributes}} 9 | {{raises}} 10 | {{warns}} 11 | {{warnings}} 12 | {{see_also}} 13 | {{notes}} 14 | {{references}} 15 | {{examples}} 16 | {{methods}} 17 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changelog: 2 | 3 | .. include:: ../CHANGELOG.rst -------------------------------------------------------------------------------- /docs/code_of_conduct.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../CODE_OF_CONDUCT.md -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("../")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | project = 'Ensemble-PyTorch' 21 | copyright = '2021, Ensemble-PyTorch Developers' 22 | author = 'Ensemble-PyTorch Developers' 23 | 24 | # The master toctree document. 25 | master_doc = 'index' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | autodoc_mock_imports = ["joblib", 30 | "scikit-learn", 31 | "torch", 32 | "torchvision"] 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | 'sphinx.ext.autodoc', 39 | 'sphinx.ext.autosummary', 40 | 'sphinx.ext.todo', 41 | 'sphinx.ext.napoleon', 42 | 'sphinx_copybutton', 43 | 'sphinx_panels', 44 | 'sphinxemoji.sphinxemoji', 45 | 'm2r2' 46 | ] 47 | 48 | autoapi_dirs = ['../torcensemble'] 49 | 50 | # Napoleon settings 51 | napoleon_google_docstring = False 52 | napoleon_numpy_docstring = True 53 | napoleon_include_init_with_doc = True 54 | napoleon_include_private_with_doc = False 55 | napoleon_include_special_with_doc = True 56 | napoleon_use_admonition_for_examples = False 57 | napoleon_use_admonition_for_notes = False 58 | napoleon_use_admonition_for_references = False 59 | napoleon_use_ivar = False 60 | napoleon_use_param = True 61 | napoleon_use_rtype = True 62 | napoleon_type_aliases = None 63 | 64 | autodoc_default_flags = ['members', 'inherited-members', 'show-inheritance'] 65 | autodoc_default_options = { 66 | "members": True, 67 | "inherited-members": False, 68 | "show-inheritance": False, 69 | "member-order": "bysource" 70 | } 71 | 72 | # Add any paths that contain templates here, relative to this directory. 73 | templates_path = ['_templates'] 74 | 75 | # List of patterns, relative to source directory, that match files and 76 | # directories to ignore when looking for source files. 77 | # This pattern also affects html_static_path and html_extra_path. 78 | exclude_patterns = [] 79 | 80 | # The name of the Pygments (syntax highlighting) style to use. 81 | pygments_style = "default" 82 | 83 | # -- Options for HTML output ------------------------------------------------- 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | # 88 | html_theme = 'sphinx_rtd_theme' 89 | 90 | html_sidebars = { 91 | '**': ['logo-text.html', 'globaltoc.html', 'searchbox.html'] 92 | } 93 | 94 | # Add any paths that contain custom static files (such as style sheets) here, 95 | # relative to this directory. They are copied after the builtin static files, 96 | # so a file named "default.css" will overwrite the builtin "default.css". 97 | html_static_path = ['_static'] 98 | 99 | html_css_files = [ 100 | 'custom.css', 101 | ] 102 | -------------------------------------------------------------------------------- /docs/contributors.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../CONTRIBUTORS.md -------------------------------------------------------------------------------- /docs/experiment.rst: -------------------------------------------------------------------------------- 1 | Experiments 2 | =========== 3 | 4 | Setup 5 | ~~~~~ 6 | 7 | Experiments here are designed to evaluate the performance of each ensemble implemented in Ensemble-PyTorch. We have collected four different configurations on dataset and base estimator, as shown in the table below. In addition, scripts on producing all figures below are available on `GitHub `__. 8 | 9 | .. table:: 10 | :align: center 11 | 12 | +------------------+----------------+-----------+-------------------+ 13 | | Config Name | Estimator | Dataset | n_estimators | 14 | +==================+================+===========+===================+ 15 | | LeNet\@MNIST | LeNet-5 | MNIST | 5, 10, 15, 20 | 16 | +------------------+----------------+-----------+-------------------+ 17 | | LeNet\@CIFAR-10 | LeNet-5 | CIFAR-10 | 5, 10, 15, 20 | 18 | +------------------+----------------+-----------+-------------------+ 19 | | ResNet\@CIFAR-10 | ResNet-18 | CIFAR-10 | 2, 5, 7, 10 | 20 | +------------------+----------------+-----------+-------------------+ 21 | |ResNet\@CIFAR-100 | ResNet-18 | CIFAR-100 | 2, 5, 7, 10 | 22 | +------------------+----------------+-----------+-------------------+ 23 | 24 | 1. Data augmentations were adopted on both **CIFAR-10** and **CIFAR-100** datasets. 25 | 2. For **LeNet-5**, the ``Adam`` optimizer with learning rate ``1e-3`` and weight decay ``5e-4`` was used. 26 | 3. For **ResNet-18**, the ``SGD`` optimizer with learning rate ``1e-1``, weight decay ``5e-4``, and momentum ``0.9`` was used. 27 | 4. Reference code: `ResNet-18 on CIFAR-10 `__ and `ResNet-18 on CIFAR-100 `__. 28 | 29 | .. tip:: 30 | 31 | For each experiment shown below, we have added some comments that may be worthy of your attention. Feel free to open an `issue `__ if you have any question on the results. 32 | 33 | LeNet\@MNIST 34 | ~~~~~~~~~~~~ 35 | 36 | .. image:: ./_images/lenet_mnist.png 37 | :align: center 38 | :width: 400 39 | 40 | * MNIST is a very easy dataset, and the testing acc of a single LeNet-5 estimator is over 99% 41 | * voting and bagging are the most effective ensemble in this case 42 | * bagging is even better than voting since the bootstrap sampling on training data ingests more diversity into the ensemble 43 | * fusion does not perform well in this case, possibly because the model complexity of a single LeNet-5 estimator is already sufficient for MNIST. Therefore, simply encapsulating several LeNet-5 estimators into a large model will only make the over-fitting problem more severe. 44 | 45 | LeNet\@CIFAR-10 46 | ~~~~~~~~~~~~~~~ 47 | 48 | .. image:: ./_images/lenet_cifar10.png 49 | :align: center 50 | :width: 400 51 | 52 | * CIFAR-10 is a hard dataset for LeNet-5, and the testing acc of a single LeNet-5 estimator is around 70% 53 | * gradient boosting is the most effective ensemble because it is able to improve the performance of weak estimators by a large margin as a bias-reduction ensemble method 54 | * bagging is worse than voting since less training data are available 55 | * snapshot ensemble, more precisely, the customized learning rate scheduler in snapshot ensemble, does not adapt well with LeNet-5 (more training epochs are needed) 56 | 57 | ResNet\@CIFAR-10 58 | ~~~~~~~~~~~~~~~~ 59 | 60 | .. image:: ./_images/resnet_cifar10.png 61 | :align: center 62 | :width: 400 63 | 64 | * CIFAR-10 is a relatively easy dataset for ResNet-18, and the testing acc of a single ResNet-18 estimator is between 94% and 95% 65 | * voting and snapshot ensemble are the most effective ensemble in this case 66 | * snapshot ensemble is even better when taking the training cost into consideration 67 | 68 | ResNet\@CIFAR-100 69 | ~~~~~~~~~~~~~~~~~ 70 | 71 | .. image:: ./_images/resnet_cifar100.png 72 | :align: center 73 | :width: 400 74 | 75 | * CIFAR-100 is a hard dataset for ResNet-18, and the testing acc of a single ResNet-18 estimator is around 76% 76 | * voting is the most effective ensemble in this case 77 | * fusion does not perform well in this case 78 | * the result of gradient boosting is omitted because of its prohibitively long training time 79 | 80 | Acknowledgement 81 | ~~~~~~~~~~~~~~~ 82 | 83 | We would like to thank the `LAMDA Group `__ from Nanjing University for providing us with the powerful V-100 GPU server. 84 | -------------------------------------------------------------------------------- /docs/experiments.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/docs/experiments.xlsx -------------------------------------------------------------------------------- /docs/guide.rst: -------------------------------------------------------------------------------- 1 | Guidance 2 | ======== 3 | 4 | This page provides useful instructions on how to choose the appropriate ensemble method for your deep learning model. 5 | 6 | Check Your Model 7 | ---------------- 8 | 9 | A good rule-of-thumb is to check the performance of your deep learning model first. Below are two important aspects that you should pay attention to: 10 | 11 | * What is the final performance of your model ? 12 | * Does your model suffer from the over-fitting problem ? 13 | 14 | To check these two aspects, it is recommended to evaluate the performance of your model on the :obj:`test_loader` after each training epoch. 15 | 16 | .. tip:: 17 | Using Ensemble-PyTorch, you can pass your model to the :class:`Fusion` or :class:`Voting` with the argument ``n_estimators`` set to ``1``. The behavior of the ensemble should be the same as a single model. 18 | 19 | If the performance of your model is relatively good, for example, the testing accuracy of your LeNet-5 CNN model on MNIST is over 99%, the conclusion on the first point is that it is not likely that your model suffers from the under-fitting problem. You could skip the section :ref:`under_fit`. 20 | 21 | .. _under_fit: 22 | 23 | Under-fit 24 | --------- 25 | 26 | If the performance of your model is unsatisfactory, you can try out the :class:`Gradient Boosting` related ensemble methods. :class:`Gradient Boosting` focuses on reducing the bias term from the perspective of `Bias-Variance Decomposition `__, it usually works well when your deep learning model is a weak learner on the dataset. 27 | 28 | Below is the pros and cons on using :class:`Gradient Boosting`: 29 | 30 | * Pros: 31 | - You can have a much higher improvements than using other ensemble methods 32 | * Cons: 33 | - Relatively longer training time 34 | - Suffer from the over-fitting problem if the value of ``n_estimators`` is large 35 | 36 | .. tip:: 37 | :class:`Gradient Boosting` in Ensemble-PyTorch supports the early stopping to alleviate the over-fitting. To use early stopping, you need to set the input argument ``test_loader`` and ``early_stopping_rounds`` when calling the :meth:`fit` function of :class:`Gradient Boosting`. In additional, using a small ``shrinkage_rate`` when declaring the model also helps to alleviate the over-fitting problem. 38 | 39 | .. _over_fit: 40 | 41 | Over-fit 42 | -------- 43 | 44 | Large Training Costs 45 | -------------------- 46 | 47 | Training an ensemble of large deep learning models could take prohibitively long time and easily run out of the memory. If you are suffering from large training costs when using Ensemble-PyTorch, the recommended ensemble method would be :class:`Snapshot Ensemble`. The training costs on :class:`Snapshot Ensemble` are approximately the same as that on training a single base estimator. Please refer to the related section in `Introduction <./introduction.html>`__ for details on :class:`Snapshot Ensemble`. 48 | 49 | However, :class:`Snapshot Ensemble` does not work well across all deep learning models. To reduce the costs on using other parallel ensemble methods (i.e., :class:`Voting`, :class:`Bagging`, :class:`Adversarial Training`), you can set ``n_jobs`` to ``None`` or ``1``, which disables the parallelization conducted internally. 50 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. image:: ./_images/badge.svg 2 | :align: center 3 | :width: 400 4 | 5 | Ensemble PyTorch Documentation 6 | ============================== 7 | 8 | Ensemble PyTorch is a unified ensemble framework for PyTorch to easily improve the performance and robustness of your deep learning model. It provides: 9 | 10 | * Easy ways to improve the performance and robustness of your deep learning model. 11 | * Easy-to-use APIs on training and evaluating the ensemble. 12 | * High training efficiency with parallelization. 13 | 14 | Guidepost 15 | --------- 16 | 17 | * To get started, please refer to `Quick Start <./quick_start.html>`__; 18 | * To learn more about ensemble methods supported, please refer to `Introduction <./introduction.html>`__; 19 | * If you are confused on which ensemble method to use, our `experiments <./experiment.html>`__ and the instructions in `guidance <./guide.html>`__ may be helpful. 20 | 21 | Example 22 | ------- 23 | 24 | .. code:: python 25 | 26 | from torchensemble import VotingClassifier # voting is a classic ensemble strategy 27 | 28 | # Load data 29 | train_loader = DataLoader(...) 30 | test_loader = DataLoader(...) 31 | 32 | # Define the ensemble 33 | ensemble = VotingClassifier( 34 | estimator=base_estimator, # here is your deep learning model 35 | n_estimators=10, # number of base estimators 36 | ) 37 | # Set the criterion 38 | criterion = nn.CrossEntropyLoss() # training objective 39 | ensemble.set_criterion(criterion) 40 | 41 | # Set the optimizer 42 | ensemble.set_optimizer( 43 | "Adam", # type of parameter optimizer 44 | lr=learning_rate, # learning rate of parameter optimizer 45 | weight_decay=weight_decay, # weight decay of parameter optimizer 46 | ) 47 | 48 | # Set the learning rate scheduler 49 | ensemble.set_scheduler( 50 | "CosineAnnealingLR", # type of learning rate scheduler 51 | T_max=epochs, # additional arguments on the scheduler 52 | ) 53 | 54 | # Train the ensemble 55 | ensemble.fit( 56 | train_loader, 57 | epochs=epochs, # number of training epochs 58 | ) 59 | 60 | # Evaluate the ensemble 61 | acc = ensemble.predict(test_loader) # testing accuracy 62 | 63 | Content 64 | ------- 65 | 66 | .. toctree:: 67 | :caption: For Users 68 | :maxdepth: 2 69 | :hidden: 70 | 71 | Quick Start 72 | Introduction 73 | Guidance 74 | Experiment 75 | API Reference 76 | 77 | .. toctree:: 78 | :caption: For Developers 79 | :maxdepth: 2 80 | :hidden: 81 | 82 | Changelog 83 | Roadmap 84 | Contributors 85 | Code of Conduct 86 | -------------------------------------------------------------------------------- /docs/introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | This page briefly introduces ensemble methods implemented in Ensemble-PyTorch. 5 | 6 | To begin with, below are some notations that will be used throughout the introduction. 7 | 8 | - :math:`\mathcal{B} = \{\mathbf{x}_i, y_i\}_{i=1}^B`: A batch of data with :math:`B` samples; 9 | - :math:`\{h^1, h^2, \cdots, h^m, \cdots, h^M\}`: A set of :math:`M` base estimators; 10 | - :math:`\mathbf{o}_i^m`: The output of the base estimator :math:`h^m` on sample :math:`\mathbf{x}_i`. For regression, it is a scalar or a real-valued vector; For classification, it is a class vector with its size the number of distinct classes; 11 | - :math:`\mathcal{L}(\mathbf{o}_i, y_i)`: Training loss computed on the output :math:`\mathbf{o}_i` and the ground-truth :math:`y_i`. For regression, it could be the mean squared error; For classification, it could be the cross-entropy loss for binary or multi-class classification. 12 | 13 | Fusion 14 | ------ 15 | 16 | The output of fusion is the averaged output from all base estimators. Formally, given a sample :math:`\mathbf{x}_i`, the output of fusion is :math:`\mathbf{o}_i = \frac{1}{M} \sum_{m=1}^M \mathbf{o}_i^m`. 17 | 18 | During the training stage, all base estimators in fusion are jointly trained with mini-batch gradient descent. Given the output of fusion on a data batch :math:`\mathcal{B}`, the training loss is: :math:`\frac{1}{B} \sum_{i=1}^B \mathcal{L}(\mathbf{o}_i, y_i)`. After then, parameters of all base estimator can be jointly updated with the auto-differentiation system in PyTorch and gradient descent. The figure below presents the data flow of fusion: 19 | 20 | .. image:: ./_images/fusion.png 21 | :align: center 22 | :width: 200 23 | 24 | Voting and Bagging 25 | ------------------ 26 | 27 | Voting and bagging are popularly used ensemble methods. Basically, voting and bagging fits $M$ base estimators independently, and the final prediction takes average over the predictions from all base estimators. 28 | 29 | Compared to voting, bagging further uses sampling with replacement on each batch of data. Notice that sub-sampling is not typically used when training neural networks, because the neural networks typically achieve better performance with more training data. 30 | 31 | Gradient Boosting [1]_ 32 | ---------------------- 33 | 34 | Gradient boosting trains all base estimators in a sequential fashion, as the learning target of a base estimator :math:`h^m` is associated with the outputs from base estimators fitted before, i.e., :math:`\{h^1, \cdots, h^{m-1}\}`. 35 | 36 | Given the :math:`M` fitted base estimators in gradient boosting, the output of the entire ensemble on a sample is :math:`\mathbf{o}_i = \sum_{m=1}^M \epsilon \mathbf{o}_i^m`, where :math:`\epsilon` is a pre-defined scalar in the range :math:`(0, 1]`, and known as the shrinkage rate or learning rate in gradient boosting. 37 | 38 | The training routine of the m-th base estimator in gradient boosting can be summarized as follow: 39 | 40 | - **Decide the learning target on each sample** :math:`\mathbf{r}_i^m`: Given the ground truth :math:`y_i` and the accumulated output from base estimators fitted before: :math:`\mathbf{o}_i^{[:m]}=\sum_{p=1}^{m-1} \epsilon \mathbf{o}_i^p`, the learning target is defined as :math:`\mathbf{r}_i^m = - \frac{\partial\ \mathcal{L}(\mathbf{o}_i^{[:m]},\ y_i)}{\partial\ \mathbf{o}_i^{[:m]}}`. Therefore, the learning target is simply the negative gradient of the training loss :math:`\mathcal{L}` with respect to the accumulated output from base estimators fitted before :math:`\mathbf{o}_i^{[:m]}`; 41 | - **Fit the m-th base estimator via least square regression**, that is, the training loss for the m-th base estimator is :math:`l^m = \frac{1}{B} \sum_{i=1}^B \|\mathbf{r}_i^m - \mathbf{o}_i^m\|_2^2`. Given :math:`l^m`, the parameters of :math:`h^m` then can be fitted using gradient descent; 42 | - **Update the accumulated output**: :math:`\mathbf{o}_i^{[:m+1]} = \mathbf{o}_i^{[:m]} + \epsilon \mathbf{o}_i^m`, and then move to the training routine of the (m+1)-th base estimator. 43 | 44 | For regression with the mean squared error, :math:`\mathbf{r}_i^m = \mathbf{y}_i - \mathbf{o}_i^{[:m]}`. For classification with the cross-entropy loss, :math:`\mathbf{r}_i^m = \mathbf{y}_i - \text{Softmax}(\mathbf{o}_i^{[:m]})`, where :math:`\mathbf{y}_i` is the one-hot encoded vector of the class label :math:`y_i`. 45 | 46 | The figure below presents the data flow of gradient boosting during the training and evaluating stage, respectively. Notice that the training stage runs sequentially from the left to right. 47 | 48 | .. image:: ./_images/gradient_boosting.png 49 | :align: center 50 | :width: 500 51 | 52 | Snapshot Ensemble [2]_ 53 | ---------------------- 54 | 55 | Unlike all methods above, where :math:`M` independent base estimators will be trained, snapshot ensemble generates the ensemble by enforcing a single base estimator to converge to different local minima :math:`M` times. At each minima, the parameters of this estimator are saved (i.e., a snapshot), serving as a base estimator in the ensemble. The output of snapshot ensemble also takes the average over the predictions from all snapshots. 56 | 57 | To obtain snapshots with good performance, snapshot ensemble uses **cyclic annealing schedule on learning rate** to train the base estimator. Suppose that the initial learning rate is :math:`\alpha_0`, the total number of training iterations is :math:`T`, the learning rate at iteration :math:`t` is: 58 | 59 | .. math:: 60 | \alpha_t = \frac{\alpha_0}{2} \left(\cos \left(\pi \frac{(t-1) \pmod{ \left \lceil T/M \right \rceil}}{\left \lceil T/M \right \rceil}\right) + 1\right). 61 | 62 | Notice that the iteration above indicates the loop on enumerating all batches within each epoch, instead of the loop on iterating over all training epochs. 63 | 64 | Adversarial Training [3]_ 65 | ------------------------- 66 | 67 | Adversarial samples can be used to improve the performance of base estimators, as validated by the authors in [2]. The implemented ``AdversarialTrainingClassifier`` and ``AdversarialTrainingRegressor`` contain :math:`M` independent base estimators, and each of them is fitted independently as in Voting and Bagging. 68 | 69 | During the training stage of each base estimator :math:`h^m`, an adversarial sample :math:`\mathbf{x}_i^{adv}` is first generated for each sample :math:`\mathbf{x}_i` in the current data batch, using the fast gradient sign method (FGSM). After then, parameters of the base estimator is optimized to minimize the training loss :math:`\mathcal{L}(\mathbf{o}_i, y_i) + \mathcal{L}(\mathbf{o}_i^{adv}, y_i)`, where :math:`\mathbf{o}_i^{adv}` is the output of :math:`h^m` on the adversarial sample :math:`\mathbf{x}_i^{adv}`. Clearly, this training loss encourages each base estimator to perform well on both original samples and adversarial samples. 70 | 71 | Same as Voting and Bagging, the output of ``AdversarialTrainingClassifier`` or ``AdversarialTrainingRegressor`` during the evaluating stage is the average over predictions from all base estimators. 72 | 73 | Fast Geometric Ensemble [4]_ 74 | ---------------------------- 75 | 76 | Motivated by geometric insights on the loss surface of deep neural networks, Fast Geometirc Ensembling (FGE) is an efficient ensemble that uses a customized learning rate scheduler to generate base estimators, similar to snapshot ensemble. 77 | 78 | **References** 79 | 80 | .. [1] Jerome H. Friedman., "Greedy Function Approximation: A Gradient Boosting Machine." The Annals of Statistics, 2001. 81 | .. [2] Huang Gao, Sharon Yixuan Li, Geoff Pleisset, et al., "Snapshot Ensembles: Train 1, Get M for Free." ICLR, 2017. 82 | .. [3] Balaji Lakshminarayanan, Alexander Pritzel, Charles Blundell., "Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles." NIPS 2017. 83 | .. [4] Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin et al., "Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs." NeurIPS, 2018. 84 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/parameters.rst: -------------------------------------------------------------------------------- 1 | Parameters 2 | ========== 3 | 4 | This page provides the API reference of :mod:`torchensemble`. Below is a list of functions supported by all ensembles. 5 | 6 | * :meth:`fit`: Training stage of the ensemble 7 | * :meth:`evaluate`: Evaluating stage of the ensemble 8 | * :meth:`predict`: Return the predictions of the ensemble 9 | * :meth:`forward`: Data forward process of the ensemble 10 | * :meth:`set_optimizer`: Set the parameter optimizer for training the ensemble 11 | * :meth:`set_scheduler`: Set the learning rate scheduler for training the ensemble 12 | 13 | Fusion 14 | ------ 15 | 16 | In fusion-based ensemble methods, the predictions from all base estimators 17 | are first aggregated as an average output. After then, the training loss is 18 | computed based on this average output and the ground-truth. The training loss 19 | is then back-propagated to all base estimators simultaneously. 20 | 21 | FusionClassifier 22 | **************** 23 | 24 | .. autoclass:: torchensemble.fusion.FusionClassifier 25 | :members: 26 | 27 | FusionRegressor 28 | *************** 29 | 30 | .. autoclass:: torchensemble.fusion.FusionRegressor 31 | :members: 32 | 33 | Voting 34 | ------ 35 | 36 | In voting-based ensemble methods, each base estimator is trained 37 | independently, and the final prediction takes the average over predictions 38 | from all base estimators. 39 | 40 | VotingClassifier 41 | **************** 42 | 43 | .. autoclass:: torchensemble.voting.VotingClassifier 44 | :members: 45 | 46 | VotingRegressor 47 | *************** 48 | 49 | .. autoclass:: torchensemble.voting.VotingRegressor 50 | :members: 51 | 52 | Bagging 53 | ------- 54 | 55 | In bagging-based ensemble methods, each base estimator is trained 56 | independently. In addition, sampling with replacement is conducted on the 57 | training data to further encourage the diversity between different base 58 | estimators in the ensemble. 59 | 60 | BaggingClassifier 61 | ***************** 62 | 63 | .. autoclass:: torchensemble.bagging.BaggingClassifier 64 | :members: 65 | 66 | BaggingRegressor 67 | **************** 68 | 69 | .. autoclass:: torchensemble.bagging.BaggingRegressor 70 | :members: 71 | 72 | Gradient Boosting 73 | ----------------- 74 | 75 | Gradient boosting is a classic sequential ensemble method. At each iteration, 76 | the learning target of a new base estimator is to fit the pseudo residual 77 | computed based on the ground truth and the output from base estimators 78 | fitted before, using ordinary least square. 79 | 80 | .. tip:: 81 | The input argument ``shrinkage_rate`` in :class:`gradient_boosting` is also known as learning rate in other gradient boosting libraries such as `XGBoost `__. However, its meaning is totally different from the meaning of learning rate in the context of parameter optimizer in deep learning. 82 | 83 | GradientBoostingClassifier 84 | ************************** 85 | 86 | .. autoclass:: torchensemble.gradient_boosting.GradientBoostingClassifier 87 | :members: 88 | 89 | GradientBoostingRegressor 90 | ************************* 91 | 92 | .. autoclass:: torchensemble.gradient_boosting.GradientBoostingRegressor 93 | :members: 94 | 95 | Neural Tree Ensemble 96 | -------------------- 97 | 98 | Neural tree ensemble are extensions of voting and gradient boosting, which 99 | uses the neural tree as the base estimator. Neural trees are differentiable 100 | trees that uses the logistic regression in internal nodes to split samples 101 | into child nodes with different probabilities. Model details are available at 102 | `Distilling a neural network into a soft decision tree 103 | `_. 104 | 105 | 106 | NeuralForestClassifier 107 | ********************** 108 | 109 | .. autoclass:: torchensemble.voting.NeuralForestClassifier 110 | :members: 111 | 112 | NeuralForestRegressor 113 | ********************** 114 | 115 | .. autoclass:: torchensemble.voting.NeuralForestRegressor 116 | :members: 117 | 118 | Snapshot Ensemble 119 | ----------------- 120 | 121 | Snapshot ensemble generates many base estimators by enforcing a base 122 | estimator to converge to its local minima many times and save the 123 | model parameters at that point as a snapshot. The final prediction takes 124 | the average over predictions from all snapshot models. 125 | 126 | Reference: 127 | G. Huang, Y.-X. Li, G. Pleiss et al., Snapshot Ensemble: Train 1, and 128 | M for free, ICLR, 2017. 129 | 130 | SnapshotEnsembleClassifier 131 | ************************** 132 | 133 | .. autoclass:: torchensemble.snapshot_ensemble.SnapshotEnsembleClassifier 134 | :members: 135 | 136 | SnapshotEnsembleRegressor 137 | ************************* 138 | 139 | .. autoclass:: torchensemble.snapshot_ensemble.SnapshotEnsembleRegressor 140 | :members: 141 | 142 | Adversarial Training 143 | -------------------- 144 | 145 | Adversarial training is able to improve the performance of an ensemble by 146 | treating adversarial samples as the augmented training data. The fast 147 | gradient sign method (FGSM) is used to generate adversarial samples. 148 | 149 | Reference: 150 | B. Lakshminarayanan, A. Pritzel, C. Blundell., Simple and Scalable 151 | Predictive Uncertainty Estimation using Deep Ensembles, NIPS 2017. 152 | 153 | .. warning:: 154 | When your base estimator is under-fit on the dataset, it is not recommended to use the :mod:`AdversarialTrainingClassifier` or :mod:`AdversarialTrainingRegressor`, because they may deteriorate the performance further. 155 | 156 | AdversarialTrainingClassifier 157 | ***************************** 158 | 159 | .. autoclass:: torchensemble.adversarial_training.AdversarialTrainingClassifier 160 | :members: 161 | 162 | AdversarialTrainingRegressor 163 | ***************************** 164 | 165 | .. autoclass:: torchensemble.adversarial_training.AdversarialTrainingRegressor 166 | :members: 167 | 168 | Fast Geometric Ensemble 169 | ----------------------- 170 | 171 | Motivated by geometric insights on the loss surface of deep neural networks, 172 | Fast Geometric Ensembling (FGE) is an efficient ensemble that uses a 173 | customized learning rate scheduler to generate base estimators, similar to 174 | snapshot ensemble. 175 | 176 | Reference: 177 | T. Garipov, P. Izmailov, D. Podoprikhin et al., Loss Surfaces, Mode 178 | Connectivity, and Fast Ensembling of DNNs, NeurIPS, 2018. 179 | 180 | FastGeometricClassifier 181 | *********************** 182 | 183 | .. autoclass:: torchensemble.fast_geometric.FastGeometricClassifier 184 | :members: 185 | 186 | FastGeometricRegressor 187 | *********************** 188 | 189 | .. autoclass:: torchensemble.fast_geometric.FastGeometricRegressor 190 | :members: 191 | 192 | Soft Gradient Boosting 193 | ---------------------- 194 | 195 | In soft gradient boosting, all base estimators could be simultaneously fitted, 196 | while achieving the similar boosting improvements as in gradient boosting. 197 | 198 | Reference: 199 | J. Feng, Y.-X. Xu et al., Soft Gradient Boosting Machine, ArXiv, 2020. 200 | 201 | SoftGradientBoostingClassifier 202 | ****************************** 203 | 204 | .. autoclass:: torchensemble.soft_gradient_boosting.SoftGradientBoostingClassifier 205 | :members: 206 | 207 | SoftGradientBoostingRegressor 208 | ***************************** 209 | 210 | .. autoclass:: torchensemble.soft_gradient_boosting.SoftGradientBoostingRegressor 211 | :members: 212 | -------------------------------------------------------------------------------- /docs/plotting/lenet_cifar10.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | plt.style.use("seaborn") 4 | 5 | n_estimators = [5, 10, 15, 20] 6 | fusion = [78.26, 78.79, 78.44, 77.20] 7 | voting = [77.71, 78.22, 78.13, 78.73] 8 | bagging = [77.15, 77.32, 77.52, 77.85] 9 | gradient_boosting = [77.28, 78.88, 79.19, 78.60] 10 | snapshot_ensemble = [70.83, 70.69, 70.06, 70.37] 11 | 12 | plt.plot(n_estimators, fusion, label="fusion", linewidth=5) 13 | plt.plot(n_estimators, voting, label="voting", linewidth=5) 14 | plt.plot(n_estimators, bagging, label="bagging", linewidth=5) 15 | plt.plot(n_estimators, gradient_boosting, label="gradient boosting", linewidth=5) 16 | plt.plot(n_estimators, snapshot_ensemble, label="snapshot ensemble", linewidth=5) 17 | 18 | plt.title("LeNet@CIFAR-10", fontdict={"size": 30}) 19 | plt.xlabel("n_estimators", fontdict={"size": 30}) 20 | plt.ylabel("testing acc", fontdict={"size": 30}) 21 | plt.xticks(n_estimators, fontsize=25) 22 | plt.yticks(fontsize=25) 23 | plt.legend(prop={"size": 30}) 24 | -------------------------------------------------------------------------------- /docs/plotting/lenet_mnist.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | plt.style.use("seaborn") 4 | 5 | n_estimators = [5, 10, 15, 20] 6 | fusion = [99.16, 99.24, 99.20, 99.23] 7 | voting = [99.47, 99.47, 99.51, 99.52] 8 | bagging = [99.47, 99.52, 99.52, 99.54] 9 | gradient_boosting = [99.13, 99.53, 99.51, 99.43] 10 | snapshot_ensemble = [99.22, 99.25, 99.28, 99.32] 11 | 12 | plt.plot(n_estimators, fusion, label="fusion", linewidth=5) 13 | plt.plot(n_estimators, voting, label="voting", linewidth=5) 14 | plt.plot(n_estimators, bagging, label="bagging", linewidth=5) 15 | plt.plot(n_estimators, gradient_boosting, label="gradient boosting", linewidth=5) 16 | plt.plot(n_estimators, snapshot_ensemble, label="snapshot ensemble", linewidth=5) 17 | 18 | plt.title("LeNet@MNIST", fontdict={"size": 30}) 19 | plt.xlabel("n_estimators", fontdict={"size": 30}) 20 | plt.ylabel("testing acc", fontdict={"size": 30}) 21 | plt.xticks(n_estimators, fontsize=25) 22 | plt.yticks(fontsize=25) 23 | plt.legend(prop={"size": 30}) 24 | -------------------------------------------------------------------------------- /docs/plotting/resnet_cifar10.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | plt.style.use("seaborn") 4 | 5 | n_estimators = [2, 5, 7, 10] 6 | fusion = [95.32, 95.34, 95.31, 95.29] 7 | voting = [96.00, 94.77, 96.41, 95.34] 8 | bagging = [95.71, 94.97, 96.16, 94.94] 9 | gradient_boosting = [95.63, 95.78, 95.77, 96.03] 10 | snapshot_ensemble = [95.57, 95.71, 96.34, 95.68] 11 | 12 | plt.plot(n_estimators, fusion, label="fusion", linewidth=5) 13 | plt.plot(n_estimators, voting, label="voting", linewidth=5) 14 | plt.plot(n_estimators, bagging, label="bagging", linewidth=5) 15 | plt.plot(n_estimators, gradient_boosting, label="gradient boosting", linewidth=5) 16 | plt.plot(n_estimators, snapshot_ensemble, label="snapshot ensemble", linewidth=5) 17 | 18 | plt.title("ResNet@CIFAR-10", fontdict={"size": 30}) 19 | plt.xlabel("n_estimators", fontdict={"size": 30}) 20 | plt.ylabel("testing acc", fontdict={"size": 30}) 21 | plt.xticks(n_estimators, fontsize=25) 22 | plt.yticks(fontsize=25) 23 | plt.legend(prop={"size": 30}) 24 | -------------------------------------------------------------------------------- /docs/plotting/resnet_cifar100.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | plt.style.use("seaborn") 4 | 5 | n_estimators = [2, 5, 7, 10] 6 | fusion = [77.14, 76.81, 76.52, 76.06] 7 | voting = [78.04, 79.39, 79.63, 79.78] 8 | bagging = [77.21, 78.94, 79.09, 79.47] 9 | snapshot_ensemble = [77.30, 77.39, 77.91, 77.58] 10 | 11 | plt.plot(n_estimators, fusion, label="fusion", linewidth=5) 12 | plt.plot(n_estimators, voting, label="voting", linewidth=5) 13 | plt.plot(n_estimators, bagging, label="bagging", linewidth=5) 14 | plt.plot(n_estimators, snapshot_ensemble, label="snapshot ensemble", linewidth=5) 15 | 16 | plt.title("ResNet@CIFAR-100", fontdict={"size": 30}) 17 | plt.xlabel("n_estimators", fontdict={"size": 30}) 18 | plt.ylabel("testing acc", fontdict={"size": 30}) 19 | plt.xticks(n_estimators, fontsize=25) 20 | plt.yticks(fontsize=25) 21 | plt.legend(prop={"size": 30}) 22 | -------------------------------------------------------------------------------- /docs/quick_start.rst: -------------------------------------------------------------------------------- 1 | Get started 2 | =========== 3 | 4 | Installation 5 | ------------ 6 | 7 | You can install the stable version of Ensemble-PyTorch with the following command: 8 | 9 | .. code-block:: bash 10 | 11 | $ pip install torchensemble 12 | 13 | Ensemble-PyTorch is designed to be portable and has very few package dependencies. It is recommended to use the package environment and PyTorch installed from `Anaconda `__. 14 | 15 | Define Your Base Estimator 16 | -------------------------- 17 | 18 | Since Ensemble-PyTorch uses different ensemble methods to improve the performance, a key input argument is your deep learning model, serving as the base estimator. Same as PyTorch, the class of your model should inherit from :mod:`torch.nn.Module`, and it should at least implement two methods: 19 | 20 | * :meth:`__init__`: Instantiate sub-modules for your model and assign them as the member variables. 21 | * :meth:`forward`: Define the data forward process for your model. 22 | 23 | For example, the code snippet below defines a multi-layered perceptron (MLP) of the structure ``Input(784) - 128 - 128 - Output(10)``: 24 | 25 | .. code-block:: python 26 | 27 | import torch.nn as nn 28 | from torch.nn import functional as F 29 | 30 | class MLP(nn.Module): 31 | 32 | def __init__(self): 33 | super(MLP, self).__init__() 34 | self.linear1 = nn.Linear(784, 128) 35 | self.linear2 = nn.Linear(128, 128) 36 | self.linear3 = nn.Linear(128, 10) 37 | 38 | def forward(self, data): 39 | data = data.view(data.size(0), -1) # flatten 40 | output = F.relu(self.linear1(data)) 41 | output = F.relu(self.linear2(output)) 42 | output = self.linear3(output) 43 | return output 44 | 45 | Set the Logger 46 | -------------- 47 | 48 | Ensemble-PyTorch uses a global logger to track and print the intermediate logging information. The code snippet below shows how to set up a logger: 49 | 50 | .. code-block:: python 51 | 52 | from torchensemble.utils.logging import set_logger 53 | 54 | logger = set_logger('classification_mnist_mlp') 55 | 56 | With this logger, all logging information will be printed on the command line and saved to the specified text file: ``classification_mnist_mlp``. 57 | 58 | In addition, when passing ``use_tb_logger=True`` into the method :meth:`set_logger`, you can use tensorboard to have a better visualization result on training and evaluating the ensemble. 59 | 60 | .. code-block:: bash 61 | 62 | tensorboard --logdir=logs/ 63 | 64 | Choose the Ensemble 65 | ------------------- 66 | 67 | After defining the base estimator, we can then wrap it using one of ensemble models available in Ensemble-PyTorch. Different models have very similar APIs, take the ``VotingClassifier`` as an example: 68 | 69 | .. code-block:: python 70 | 71 | from torchensemble import VotingClassifier 72 | 73 | model = VotingClassifier( 74 | estimator=MLP, 75 | n_estimators=10, 76 | cuda=True, 77 | ) 78 | 79 | The meaning of different arguments is listed as follow: 80 | 81 | * ``estimator``: The class of your model, used to instantiate base estimators in the ensemble. 82 | * ``n_estimators``: The number of base estimators in the ensemble. 83 | * ``cuda``: Specify whether to use GPU for training and evaluating the ensemble. 84 | 85 | Set the Criterion 86 | ----------------- 87 | 88 | The next step is to set the objective function. Since our ensemble model is a classifier, we 89 | will use cross-entropy: 90 | 91 | .. code-block:: python 92 | 93 | criterion = nn.CrossEntropyLoss() 94 | model.set_criterion(criterion) 95 | 96 | Set the Optimizer 97 | ----------------- 98 | 99 | After declaring the ensemble, another step before the training stage is to set the optimizer. Suppose that we are going to use the Adam optimizer with learning rate ``1e-3`` and weight decay ``5e-4`` to train the ensemble, this can be achieved by calling the ``set_optimizer`` method of the ensemble: 100 | 101 | .. code-block:: python 102 | 103 | model.set_optimizer('Adam', # parameter optimizer 104 | lr=1e-3, # learning rate of the optimizer 105 | weight_decay=5e-4) # weight decay of the optimizer 106 | 107 | Notice that all arguments after the optimizer name (i.e., ``Adam``) should be in the form of keyword arguments. They be will delivered to the :mod:`torch.optim.Optimizer` to instantiate the internal parameter optimizer. 108 | 109 | Setting the learning rate scheduler for the ensemble is also supported, please refer to the :meth:`set_scheduler` in `API Reference <./parameters.html>`__. 110 | 111 | Train and Evaluate 112 | ------------------ 113 | 114 | Given the ensemble with the optimizer already set, Ensemble-PyTorch provides Scikit-Learn APIs on the training and evaluating stage of the ensemble: 115 | 116 | .. code-block:: python 117 | 118 | # Training 119 | model.fit(train_loader=train_loader, # training data 120 | epochs=100) # the number of training epochs 121 | 122 | # Evaluating 123 | accuracy = model.predict(test_loader) 124 | 125 | In the code snippet above, ``train_loader`` and ``test_loader`` is the PyTorch :mod:`DataLoader` object that contains your data. In addition, ``epochs`` specifies the number of training epochs. Since ``VotingClassifier`` is used for the classification, the :meth:`predict` will return the classification accuracy on the ``test_loader``. 126 | 127 | Notice that the ``test_loader`` can also be passed to :meth:`fit`, in this case, the ensemble will treat it as validation data, and evaluate the ensemble on the ``test_loader`` after each training epoch. 128 | 129 | Save and Reload 130 | --------------- 131 | 132 | By setting the ``save_model`` to ``True`` and ``save_dir`` to a directory in the :meth:`fit`, model parameters will be automatically saved to the path ``save_dir`` (By default, it will be saved in the same folder as the running script). After then, you can use the following code snippet to load the saved ensemble. 133 | 134 | .. code-block:: python 135 | 136 | from torchensemble.utils import io 137 | 138 | io.load(new_ensemble, save_dir) # reload 139 | 140 | where :obj:`new_ensemble` is an ensemble instantiated in the same way as the original ensemble. 141 | 142 | Example on MNIST 143 | ---------------- 144 | 145 | The script below shows an example on using VotingClassifier with 10 MLPs for classification on the MNIST dataset. 146 | 147 | .. code-block:: python 148 | 149 | import torch 150 | import torch.nn as nn 151 | from torch.nn import functional as F 152 | from torchvision import datasets, transforms 153 | 154 | from torchensemble import VotingClassifier 155 | from torchensemble.utils.logging import set_logger 156 | 157 | # Define Your Base Estimator 158 | class MLP(nn.Module): 159 | 160 | def __init__(self): 161 | super(MLP, self).__init__() 162 | self.linear1 = nn.Linear(784, 128) 163 | self.linear2 = nn.Linear(128, 128) 164 | self.linear3 = nn.Linear(128, 10) 165 | 166 | def forward(self, data): 167 | data = data.view(data.size(0), -1) 168 | output = F.relu(self.linear1(data)) 169 | output = F.relu(self.linear2(output)) 170 | output = self.linear3(output) 171 | return output 172 | 173 | # Load MNIST dataset 174 | transform=transforms.Compose([ 175 | transforms.ToTensor(), 176 | transforms.Normalize((0.1307,), (0.3081,)) 177 | ]) 178 | 179 | train = datasets.MNIST('../Dataset', train=True, download=True, transform=transform) 180 | test = datasets.MNIST('../Dataset', train=False, transform=transform) 181 | train_loader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True) 182 | test_loader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=True) 183 | 184 | # Set the Logger 185 | logger = set_logger('classification_mnist_mlp') 186 | 187 | # Define the ensemble 188 | model = VotingClassifier( 189 | estimator=MLP, 190 | n_estimators=10, 191 | cuda=True, 192 | ) 193 | 194 | # Set the criterion 195 | criterion = nn.CrossEntropyLoss() 196 | model.set_criterion(criterion) 197 | 198 | # Set the optimizer 199 | model.set_optimizer('Adam', lr=1e-3, weight_decay=5e-4) 200 | 201 | # Train and Evaluate 202 | model.fit( 203 | train_loader, 204 | epochs=50, 205 | test_loader=test_loader, 206 | ) 207 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==5.0 2 | sphinx-panels 3 | sphinxemoji==0.1.8 4 | sphinx_rtd_theme 5 | sphinx-copybutton 6 | m2r2 7 | mistune==0.8.4 8 | Jinja2<3.1 9 | Numpy -------------------------------------------------------------------------------- /docs/roadmap.rst: -------------------------------------------------------------------------------- 1 | Roadmap 2 | ======= 3 | 4 | Anyone is welcomed to work on any feature or enhancement listed below. Your contributions would be clearly posted in `Contributors <./contributors.html>`__ and `Changelog <./changelog.html>`__. 5 | 6 | Long-term Objective 7 | ------------------- 8 | 9 | * |MajorFeature| |API| Verify the effectiveness of new ensembles that could be included in torchensemble 10 | * |Efficiency| Improve the training and evaluating efficiency of existing ensembles 11 | 12 | New Functionality 13 | ----------------- 14 | 15 | * |MajorFeature| |API| Add the partial-fit mode for all ensembles 16 | * |MajorFeature| |API| Support manually-specified evaluation metrics when calling :meth:`predict` 17 | * |MajorFeature| |API| Integration with :mod:`torch.utils.tensorboard` for better visualization 18 | * |MajorFeature| |API| Support arbitrary training criteria for existing ensembles 19 | 20 | Maintenance Related 21 | ------------------- 22 | 23 | * |Enhancement| Refactor existing code to reduce the redundancy 24 | * |Enhancement| Reduce the randomness in the experiment part by running each experiment multiple times with different random seeds 25 | 26 | .. role:: raw-html(raw) 27 | :format: html 28 | 29 | .. role:: raw-latex(raw) 30 | :format: latex 31 | 32 | .. |MajorFeature| replace:: :raw-html:`Major Feature` :raw-latex:`{\small\sc [Major Feature]}` 33 | .. |Feature| replace:: :raw-html:`Feature` :raw-latex:`{\small\sc [Feature]}` 34 | .. |Efficiency| replace:: :raw-html:`Efficiency` :raw-latex:`{\small\sc [Efficiency]}` 35 | .. |Enhancement| replace:: :raw-html:`Enhancement` :raw-latex:`{\small\sc [Enhancement]}` 36 | .. |Fix| replace:: :raw-html:`Fix` :raw-latex:`{\small\sc [Fix]}` 37 | .. |API| replace:: :raw-html:`API Change` :raw-latex:`{\small\sc [API Change]}` -------------------------------------------------------------------------------- /examples/classification_cifar10_cnn.py: -------------------------------------------------------------------------------- 1 | """Example on classification using CIFAR-10.""" 2 | 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | from torchvision import datasets, transforms 9 | 10 | from torchensemble.fusion import FusionClassifier 11 | from torchensemble.voting import VotingClassifier 12 | from torchensemble.bagging import BaggingClassifier 13 | from torchensemble.gradient_boosting import GradientBoostingClassifier 14 | from torchensemble.snapshot_ensemble import SnapshotEnsembleClassifier 15 | from torchensemble.soft_gradient_boosting import SoftGradientBoostingClassifier 16 | 17 | from torchensemble.utils.logging import set_logger 18 | 19 | 20 | def display_records(records, logger): 21 | msg = ( 22 | "{:<28} | Testing Acc: {:.2f} % | Training Time: {:.2f} s |" 23 | " Evaluating Time: {:.2f} s" 24 | ) 25 | 26 | print("\n") 27 | for method, training_time, evaluating_time, acc in records: 28 | logger.info(msg.format(method, acc, training_time, evaluating_time)) 29 | 30 | 31 | class LeNet5(nn.Module): 32 | def __init__(self): 33 | super(LeNet5, self).__init__() 34 | self.conv1 = nn.Conv2d(3, 6, 5) 35 | self.pool = nn.MaxPool2d(2, 2) 36 | self.conv2 = nn.Conv2d(6, 16, 5) 37 | self.fc1 = nn.Linear(400, 120) 38 | self.fc2 = nn.Linear(120, 84) 39 | self.fc3 = nn.Linear(84, 10) 40 | 41 | def forward(self, x): 42 | x = self.pool(F.relu(self.conv1(x))) 43 | x = self.pool(F.relu(self.conv2(x))) 44 | x = x.view(-1, 400) 45 | x = F.relu(self.fc1(x)) 46 | x = F.relu(self.fc2(x)) 47 | x = self.fc3(x) 48 | return x 49 | 50 | 51 | if __name__ == "__main__": 52 | 53 | # Hyper-parameters 54 | n_estimators = 10 55 | lr = 1e-3 56 | weight_decay = 5e-4 57 | epochs = 100 58 | 59 | # Utils 60 | batch_size = 128 61 | data_dir = "../../Dataset/cifar" # MODIFY THIS IF YOU WANT 62 | records = [] 63 | torch.manual_seed(0) 64 | 65 | # Load data 66 | train_transformer = transforms.Compose( 67 | [ 68 | transforms.RandomHorizontalFlip(), 69 | transforms.RandomCrop(32, 4), 70 | transforms.ToTensor(), 71 | transforms.Normalize( 72 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 73 | ), 74 | ] 75 | ) 76 | 77 | test_transformer = transforms.Compose( 78 | [ 79 | transforms.ToTensor(), 80 | transforms.Normalize( 81 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 82 | ), 83 | ] 84 | ) 85 | 86 | train_loader = DataLoader( 87 | datasets.CIFAR10( 88 | data_dir, train=True, download=True, transform=train_transformer 89 | ), 90 | batch_size=batch_size, 91 | shuffle=True, 92 | ) 93 | 94 | test_loader = DataLoader( 95 | datasets.CIFAR10(data_dir, train=False, transform=test_transformer), 96 | batch_size=batch_size, 97 | shuffle=True, 98 | ) 99 | 100 | logger = set_logger("classification_cifar10_cnn", use_tb_logger=True) 101 | 102 | # FusionClassifier 103 | model = FusionClassifier( 104 | estimator=LeNet5, n_estimators=n_estimators, cuda=True 105 | ) 106 | 107 | # Set the optimizer 108 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 109 | 110 | # Training 111 | tic = time.time() 112 | model.fit(train_loader, epochs=epochs) 113 | toc = time.time() 114 | training_time = toc - tic 115 | 116 | # Evaluating 117 | tic = time.time() 118 | testing_acc = model.evaluate(test_loader) 119 | toc = time.time() 120 | evaluating_time = toc - tic 121 | 122 | records.append( 123 | ("FusionClassifier", training_time, evaluating_time, testing_acc) 124 | ) 125 | 126 | # VotingClassifier 127 | model = VotingClassifier( 128 | estimator=LeNet5, n_estimators=n_estimators, cuda=True 129 | ) 130 | 131 | # Set the optimizer 132 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 133 | 134 | # Training 135 | tic = time.time() 136 | model.fit(train_loader, epochs=epochs) 137 | toc = time.time() 138 | training_time = toc - tic 139 | 140 | # Evaluating 141 | tic = time.time() 142 | testing_acc = model.evaluate(test_loader) 143 | toc = time.time() 144 | evaluating_time = toc - tic 145 | 146 | records.append( 147 | ("VotingClassifier", training_time, evaluating_time, testing_acc) 148 | ) 149 | 150 | # BaggingClassifier 151 | model = BaggingClassifier( 152 | estimator=LeNet5, n_estimators=n_estimators, cuda=True 153 | ) 154 | 155 | # Set the optimizer 156 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 157 | 158 | # Training 159 | tic = time.time() 160 | model.fit(train_loader, epochs=epochs) 161 | toc = time.time() 162 | training_time = toc - tic 163 | 164 | # Evaluating 165 | tic = time.time() 166 | testing_acc = model.evaluate(test_loader) 167 | toc = time.time() 168 | evaluating_time = toc - tic 169 | 170 | records.append( 171 | ("BaggingClassifier", training_time, evaluating_time, testing_acc) 172 | ) 173 | 174 | # GradientBoostingClassifier 175 | model = GradientBoostingClassifier( 176 | estimator=LeNet5, n_estimators=n_estimators, cuda=True 177 | ) 178 | 179 | # Set the optimizer 180 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 181 | 182 | # Training 183 | tic = time.time() 184 | model.fit(train_loader, epochs=epochs) 185 | toc = time.time() 186 | training_time = toc - tic 187 | 188 | # Evaluating 189 | tic = time.time() 190 | testing_acc = model.evaluate(test_loader) 191 | toc = time.time() 192 | evaluating_time = toc - tic 193 | 194 | records.append( 195 | ( 196 | "GradientBoostingClassifier", 197 | training_time, 198 | evaluating_time, 199 | testing_acc, 200 | ) 201 | ) 202 | 203 | # SnapshotEnsembleClassifier 204 | model = SnapshotEnsembleClassifier( 205 | estimator=LeNet5, n_estimators=n_estimators, cuda=True 206 | ) 207 | 208 | # Set the optimizer 209 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 210 | 211 | # Training 212 | tic = time.time() 213 | model.fit(train_loader, epochs=epochs) 214 | toc = time.time() 215 | training_time = toc - tic 216 | 217 | # Evaluating 218 | tic = time.time() 219 | testing_acc = model.evaluate(test_loader) 220 | toc = time.time() 221 | evaluating_time = toc - tic 222 | 223 | records.append( 224 | ( 225 | "SnapshotEnsembleClassifier", 226 | training_time, 227 | evaluating_time, 228 | testing_acc, 229 | ) 230 | ) 231 | 232 | # SoftGradientBoostingClassifier 233 | model = SoftGradientBoostingClassifier( 234 | estimator=LeNet5, n_estimators=n_estimators, cuda=True 235 | ) 236 | 237 | # Set the optimizer 238 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 239 | 240 | # Training 241 | tic = time.time() 242 | model.fit(train_loader, epochs=epochs) 243 | toc = time.time() 244 | training_time = toc - tic 245 | 246 | # Evaluating 247 | tic = time.time() 248 | testing_acc = model.evaluate(test_loader) 249 | toc = time.time() 250 | evaluating_time = toc - tic 251 | 252 | records.append( 253 | ( 254 | "SoftGradientBoostingClassifier", 255 | training_time, 256 | evaluating_time, 257 | testing_acc, 258 | ) 259 | ) 260 | 261 | # Print results on different ensemble methods 262 | display_records(records, logger) 263 | -------------------------------------------------------------------------------- /examples/classification_mnist_tree_ensemble.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torchvision import datasets, transforms 4 | 5 | from torchensemble import NeuralForestClassifier 6 | from torchensemble.utils.logging import set_logger 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | # Hyper-parameters 12 | n_estimators = 5 13 | depth = 5 14 | lamda = 1e-3 15 | lr = 1e-3 16 | weight_decay = 5e-4 17 | epochs = 50 18 | 19 | # Utils 20 | cuda = False 21 | n_jobs = 1 22 | batch_size = 128 23 | data_dir = "../../Dataset/mnist" # MODIFY THIS IF YOU WANT 24 | 25 | # Load data 26 | train_loader = torch.utils.data.DataLoader( 27 | datasets.MNIST( 28 | data_dir, 29 | train=True, 30 | download=True, 31 | transform=transforms.Compose( 32 | [ 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.1307,), (0.3081,)), 35 | ] 36 | ), 37 | ), 38 | batch_size=batch_size, 39 | shuffle=True, 40 | ) 41 | 42 | test_loader = torch.utils.data.DataLoader( 43 | datasets.MNIST( 44 | data_dir, 45 | train=False, 46 | download=True, 47 | transform=transforms.Compose( 48 | [ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.1307,), (0.3081,)), 51 | ] 52 | ), 53 | ), 54 | batch_size=batch_size, 55 | shuffle=True, 56 | ) 57 | 58 | logger = set_logger( 59 | "classification_mnist_tree_ensemble", use_tb_logger=False 60 | ) 61 | 62 | model = NeuralForestClassifier( 63 | n_estimators=n_estimators, 64 | depth=depth, 65 | lamda=lamda, 66 | cuda=cuda, 67 | n_jobs=-1, 68 | ) 69 | 70 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 71 | 72 | tic = time.time() 73 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 74 | toc = time.time() 75 | training_time = toc - tic 76 | -------------------------------------------------------------------------------- /examples/fast_geometric_ensemble_cifar10_resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets, transforms 6 | 7 | from torchensemble import FastGeometricClassifier 8 | from torchensemble.utils.logging import set_logger 9 | 10 | 11 | # The class `BasicBlock` and `ResNet` is modified from: 12 | # https://github.com/kuangliu/pytorch-cifar 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = nn.Conv2d( 19 | in_planes, 20 | planes, 21 | kernel_size=3, 22 | stride=stride, 23 | padding=1, 24 | bias=False, 25 | ) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv2 = nn.Conv2d( 28 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 29 | ) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != self.expansion * planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d( 36 | in_planes, 37 | self.expansion * planes, 38 | kernel_size=1, 39 | stride=stride, 40 | bias=False, 41 | ), 42 | nn.BatchNorm2d(self.expansion * planes), 43 | ) 44 | 45 | def forward(self, x): 46 | out = F.relu(self.bn1(self.conv1(x))) 47 | out = self.bn2(self.conv2(out)) 48 | out += self.shortcut(x) 49 | out = F.relu(out) 50 | return out 51 | 52 | 53 | class ResNet(nn.Module): 54 | def __init__(self, block, num_blocks, num_classes=10): 55 | super(ResNet, self).__init__() 56 | self.in_planes = 64 57 | 58 | self.conv1 = nn.Conv2d( 59 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 60 | ) 61 | self.bn1 = nn.BatchNorm2d(64) 62 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 63 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 64 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 65 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 66 | self.linear = nn.Linear(512 * block.expansion, num_classes) 67 | 68 | def _make_layer(self, block, planes, num_blocks, stride): 69 | strides = [stride] + [1] * (num_blocks - 1) 70 | layers = [] 71 | for stride in strides: 72 | layers.append(block(self.in_planes, planes, stride)) 73 | self.in_planes = planes * block.expansion 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | out = F.relu(self.bn1(self.conv1(x))) 78 | out = self.layer1(out) 79 | out = self.layer2(out) 80 | out = self.layer3(out) 81 | out = self.layer4(out) 82 | out = F.avg_pool2d(out, 4) 83 | out = out.view(out.size(0), -1) 84 | out = self.linear(out) 85 | return out 86 | 87 | 88 | if __name__ == "__main__": 89 | 90 | # Hyper-parameters 91 | n_estimators = 10 92 | lr = 1e-1 93 | weight_decay = 5e-4 94 | momentum = 0.9 95 | epochs = 200 96 | 97 | # Utils 98 | batch_size = 128 99 | data_dir = "../../Dataset/cifar" # MODIFY THIS IF YOU WANT 100 | torch.manual_seed(0) 101 | torch.cuda.set_device(0) 102 | 103 | # Load data 104 | train_transformer = transforms.Compose( 105 | [ 106 | transforms.RandomHorizontalFlip(), 107 | transforms.RandomCrop(32, 4), 108 | transforms.ToTensor(), 109 | transforms.Normalize( 110 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 111 | ), 112 | ] 113 | ) 114 | 115 | test_transformer = transforms.Compose( 116 | [ 117 | transforms.ToTensor(), 118 | transforms.Normalize( 119 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 120 | ), 121 | ] 122 | ) 123 | 124 | train_loader = DataLoader( 125 | datasets.CIFAR10( 126 | data_dir, train=True, download=True, transform=train_transformer 127 | ), 128 | batch_size=batch_size, 129 | shuffle=True, 130 | ) 131 | 132 | test_loader = DataLoader( 133 | datasets.CIFAR10(data_dir, train=False, transform=test_transformer), 134 | batch_size=batch_size, 135 | shuffle=True, 136 | ) 137 | 138 | # Set the Logger 139 | logger = set_logger( 140 | "FastGeometricClassifier_cifar10_resnet", use_tb_logger=True 141 | ) 142 | 143 | # Choose the Ensemble Method 144 | model = FastGeometricClassifier( 145 | estimator=ResNet, 146 | estimator_args={"block": BasicBlock, "num_blocks": [2, 2, 2, 2]}, 147 | n_estimators=n_estimators, 148 | cuda=True, 149 | ) 150 | 151 | # Set the Optimizer 152 | model.set_optimizer( 153 | "SGD", lr=lr, weight_decay=weight_decay, momentum=momentum 154 | ) 155 | 156 | # Set the Scheduler 157 | model.set_scheduler("CosineAnnealingLR", T_max=epochs) 158 | 159 | # Train 160 | estimator = model.fit( 161 | train_loader, 162 | epochs=epochs, 163 | test_loader=test_loader, 164 | cycle=4, 165 | lr_1=5e-2, 166 | lr_2=5e-4, 167 | ) 168 | 169 | # Evaluate 170 | acc = model.evaluate(test_loader) 171 | print("Testing Acc: {:.3f}".format(acc)) 172 | -------------------------------------------------------------------------------- /examples/regression_YearPredictionMSD_mlp.py: -------------------------------------------------------------------------------- 1 | """Example on regression using YearPredictionMSD.""" 2 | 3 | import time 4 | import torch 5 | import numbers 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from sklearn.preprocessing import scale 9 | from sklearn.datasets import load_svmlight_file 10 | from torch.utils.data import TensorDataset, DataLoader 11 | 12 | from torchensemble.fusion import FusionRegressor 13 | from torchensemble.voting import VotingRegressor 14 | from torchensemble.bagging import BaggingRegressor 15 | from torchensemble.gradient_boosting import GradientBoostingRegressor 16 | from torchensemble.snapshot_ensemble import SnapshotEnsembleRegressor 17 | 18 | from torchensemble.utils.logging import set_logger 19 | 20 | 21 | def load_data(batch_size): 22 | 23 | # The dataset can be downloaded from: 24 | # https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression.html#YearPredictionMSD 25 | 26 | if not isinstance(batch_size, numbers.Integral): 27 | msg = "`batch_size` should be an integer, but got {} instead." 28 | raise ValueError(msg.format(batch_size)) 29 | 30 | # MODIFY THE PATH IF YOU WANT 31 | train_path = "../../Dataset/LIBSVM/yearpredictionmsd_training" 32 | test_path = "../../Dataset/LIBSVM/yearpredictionmsd_testing" 33 | 34 | train = load_svmlight_file(train_path) 35 | test = load_svmlight_file(test_path) 36 | 37 | # Numpy array -> Tensor 38 | X_train, X_test = ( 39 | torch.FloatTensor(train[0].toarray()), 40 | torch.FloatTensor(test[0].toarray()), 41 | ) 42 | 43 | y_train, y_test = ( 44 | torch.FloatTensor(scale(train[1]).reshape(-1, 1)), 45 | torch.FloatTensor(scale(test[1]).reshape(-1, 1)), 46 | ) 47 | 48 | # Tensor -> Data loader 49 | train_data = TensorDataset(X_train, y_train) 50 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) 51 | 52 | test_data = TensorDataset(X_test, y_test) 53 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) 54 | 55 | return train_loader, test_loader 56 | 57 | 58 | def display_records(records, logger): 59 | msg = ( 60 | "{:<28} | Testing MSE: {:.2f} | Training Time: {:.2f} s |" 61 | " Evaluating Time: {:.2f} s" 62 | ) 63 | 64 | print("\n") 65 | for method, training_time, evaluating_time, mse in records: 66 | logger.info(msg.format(method, mse, training_time, evaluating_time)) 67 | 68 | 69 | class MLP(nn.Module): 70 | def __init__(self): 71 | super(MLP, self).__init__() 72 | self.linear1 = nn.Linear(90, 128) 73 | self.linear2 = nn.Linear(128, 128) 74 | self.linear3 = nn.Linear(128, 1) 75 | 76 | def forward(self, x): 77 | x = x.view(x.size()[0], -1) 78 | x = F.relu(self.linear1(x)) 79 | x = F.relu(self.linear2(x)) 80 | x = self.linear3(x) 81 | return x 82 | 83 | 84 | if __name__ == "__main__": 85 | 86 | # Hyper-parameters 87 | n_estimators = 10 88 | lr = 1e-3 89 | weight_decay = 5e-4 90 | epochs = 50 91 | 92 | # Utils 93 | batch_size = 512 94 | records = [] 95 | torch.manual_seed(0) 96 | 97 | # Load data 98 | train_loader, test_loader = load_data(batch_size) 99 | print("Finish loading data...\n") 100 | 101 | logger = set_logger("regression_YearPredictionMSD_mlp", use_tb_logger=True) 102 | 103 | # FusionRegressor 104 | model = FusionRegressor( 105 | estimator=MLP, n_estimators=n_estimators, cuda=True 106 | ) 107 | 108 | # Set the optimizer 109 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 110 | 111 | tic = time.time() 112 | model.fit(train_loader, epochs=epochs) 113 | toc = time.time() 114 | training_time = toc - tic 115 | 116 | tic = time.time() 117 | testing_mse = model.evaluate(test_loader) 118 | toc = time.time() 119 | evaluating_time = toc - tic 120 | 121 | records.append( 122 | ("FusionRegressor", training_time, evaluating_time, testing_mse) 123 | ) 124 | 125 | # VotingRegressor 126 | model = VotingRegressor( 127 | estimator=MLP, n_estimators=n_estimators, cuda=True 128 | ) 129 | 130 | # Set the optimizer 131 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 132 | 133 | tic = time.time() 134 | model.fit(train_loader, epochs=epochs) 135 | toc = time.time() 136 | training_time = toc - tic 137 | 138 | tic = time.time() 139 | testing_mse = model.evaluate(test_loader) 140 | toc = time.time() 141 | evaluating_time = toc - tic 142 | 143 | records.append( 144 | ("VotingRegressor", training_time, evaluating_time, testing_mse) 145 | ) 146 | 147 | # BaggingRegressor 148 | model = BaggingRegressor( 149 | estimator=MLP, n_estimators=n_estimators, cuda=True 150 | ) 151 | 152 | # Set the optimizer 153 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 154 | 155 | tic = time.time() 156 | model.fit(train_loader, epochs=epochs) 157 | toc = time.time() 158 | training_time = toc - tic 159 | 160 | tic = time.time() 161 | testing_mse = model.evaluate(test_loader) 162 | toc = time.time() 163 | evaluating_time = toc - tic 164 | 165 | records.append( 166 | ("BaggingRegressor", training_time, evaluating_time, testing_mse) 167 | ) 168 | 169 | # GradientBoostingRegressor 170 | model = GradientBoostingRegressor( 171 | estimator=MLP, n_estimators=n_estimators, cuda=True 172 | ) 173 | 174 | # Set the optimizer 175 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 176 | 177 | tic = time.time() 178 | model.fit(train_loader, epochs=epochs) 179 | toc = time.time() 180 | training_time = toc - tic 181 | 182 | tic = time.time() 183 | testing_mse = model.evaluate(test_loader) 184 | toc = time.time() 185 | evaluating_time = toc - tic 186 | 187 | records.append( 188 | ( 189 | "GradientBoostingRegressor", 190 | training_time, 191 | evaluating_time, 192 | testing_mse, 193 | ) 194 | ) 195 | 196 | # SnapshotEnsembleRegressor 197 | model = SnapshotEnsembleRegressor( 198 | estimator=MLP, n_estimators=n_estimators, cuda=True 199 | ) 200 | 201 | # Set the optimizer 202 | model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay) 203 | 204 | tic = time.time() 205 | model.fit(train_loader, epochs=epochs) 206 | toc = time.time() 207 | training_time = toc - tic 208 | 209 | tic = time.time() 210 | testing_acc = model.evaluate(test_loader) 211 | toc = time.time() 212 | evaluating_time = toc - tic 213 | 214 | records.append( 215 | ( 216 | "SnapshotEnsembleRegressor", 217 | training_time, 218 | evaluating_time, 219 | testing_acc, 220 | ) 221 | ) 222 | 223 | # Print results on different ensemble methods 224 | display_records(records, logger) 225 | -------------------------------------------------------------------------------- /examples/snapshot_ensemble_cifar10_resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets, transforms 6 | 7 | from torchensemble.snapshot_ensemble import SnapshotEnsembleClassifier 8 | from torchensemble.utils.logging import set_logger 9 | 10 | 11 | # The class `BasicBlock` and `ResNet` is modified from: 12 | # https://github.com/kuangliu/pytorch-cifar 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = nn.Conv2d( 19 | in_planes, 20 | planes, 21 | kernel_size=3, 22 | stride=stride, 23 | padding=1, 24 | bias=False, 25 | ) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv2 = nn.Conv2d( 28 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 29 | ) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != self.expansion * planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d( 36 | in_planes, 37 | self.expansion * planes, 38 | kernel_size=1, 39 | stride=stride, 40 | bias=False, 41 | ), 42 | nn.BatchNorm2d(self.expansion * planes), 43 | ) 44 | 45 | def forward(self, x): 46 | out = F.relu(self.bn1(self.conv1(x))) 47 | out = self.bn2(self.conv2(out)) 48 | out += self.shortcut(x) 49 | out = F.relu(out) 50 | return out 51 | 52 | 53 | class ResNet(nn.Module): 54 | def __init__(self, block, num_blocks, num_classes=10): 55 | super(ResNet, self).__init__() 56 | self.in_planes = 64 57 | 58 | self.conv1 = nn.Conv2d( 59 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 60 | ) 61 | self.bn1 = nn.BatchNorm2d(64) 62 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 63 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 64 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 65 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 66 | self.linear = nn.Linear(512 * block.expansion, num_classes) 67 | 68 | def _make_layer(self, block, planes, num_blocks, stride): 69 | strides = [stride] + [1] * (num_blocks - 1) 70 | layers = [] 71 | for stride in strides: 72 | layers.append(block(self.in_planes, planes, stride)) 73 | self.in_planes = planes * block.expansion 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | out = F.relu(self.bn1(self.conv1(x))) 78 | out = self.layer1(out) 79 | out = self.layer2(out) 80 | out = self.layer3(out) 81 | out = self.layer4(out) 82 | out = F.avg_pool2d(out, 4) 83 | out = out.view(out.size(0), -1) 84 | out = self.linear(out) 85 | return out 86 | 87 | 88 | if __name__ == "__main__": 89 | 90 | # Hyper-parameters 91 | n_estimators = 5 92 | lr = 1e-1 93 | weight_decay = 5e-4 94 | momentum = 0.9 95 | epochs = 200 # i.e., 40 epochs for each snapshot 96 | 97 | # Utils 98 | batch_size = 128 99 | data_dir = "../../Dataset/cifar" # MODIFY THIS IF YOU WANT 100 | torch.manual_seed(0) 101 | torch.cuda.set_device(0) 102 | 103 | # Load data 104 | train_transformer = transforms.Compose( 105 | [ 106 | transforms.RandomHorizontalFlip(), 107 | transforms.RandomCrop(32, 4), 108 | transforms.ToTensor(), 109 | transforms.Normalize( 110 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 111 | ), 112 | ] 113 | ) 114 | 115 | test_transformer = transforms.Compose( 116 | [ 117 | transforms.ToTensor(), 118 | transforms.Normalize( 119 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 120 | ), 121 | ] 122 | ) 123 | 124 | train_loader = DataLoader( 125 | datasets.CIFAR10( 126 | data_dir, train=True, download=True, transform=train_transformer 127 | ), 128 | batch_size=batch_size, 129 | shuffle=True, 130 | ) 131 | 132 | test_loader = DataLoader( 133 | datasets.CIFAR10(data_dir, train=False, transform=test_transformer), 134 | batch_size=batch_size, 135 | shuffle=True, 136 | ) 137 | 138 | # Set the Logger 139 | logger = set_logger( 140 | "snapshot_ensemble_cifar10_resnet18", use_tb_logger=True 141 | ) 142 | 143 | # Choose the Ensemble Method 144 | model = SnapshotEnsembleClassifier( 145 | estimator=ResNet, 146 | estimator_args={"block": BasicBlock, "num_blocks": [2, 2, 2, 2]}, 147 | n_estimators=n_estimators, 148 | cuda=True, 149 | ) 150 | 151 | # Set the Optimizer 152 | model.set_optimizer( 153 | "SGD", lr=lr, weight_decay=weight_decay, momentum=momentum 154 | ) 155 | 156 | # Train and Evaluate 157 | model.fit( 158 | train_loader, 159 | epochs=epochs, 160 | test_loader=test_loader, 161 | ) 162 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | "numpy>=1.22.4", 6 | "torch>=1.4.0", 7 | "torchvision>=0.2.2", 8 | "scikit-learn>=0.23.0" 9 | ] 10 | [tool.black] 11 | line-length = 79 12 | include = '\.pyi?$' 13 | exclude = ''' 14 | /( 15 | \.git 16 | | \.hg 17 | | \.mypy_cache 18 | | \.tox 19 | | \.venv 20 | | _build 21 | | buck-out 22 | | build 23 | | dist 24 | | docs 25 | )/ 26 | ''' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.2.2 3 | scikit-learn>=0.23.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from distutils.command.clean import clean as Clean 4 | from setuptools import setup, find_packages 5 | from codecs import open 6 | from os import path 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | # Get the long description from README.rst 11 | with open(path.join(here, "README.rst"), encoding="utf-8") as f: 12 | long_description = f.read() 13 | 14 | # get the dependencies and installs 15 | with open(path.join(here, "requirements.txt"), encoding="utf-8") as f: 16 | all_reqs = f.read().split("\n") 17 | 18 | install_requires = [x.strip() for x in all_reqs if "git+" not in x] 19 | 20 | cmdclass = {} 21 | 22 | 23 | # Custom clean command to remove build artifacts 24 | class CleanCommand(Clean): 25 | description = "Remove build artifacts from the source tree" 26 | 27 | def run(self): 28 | Clean.run(self) 29 | # Remove c files if we are not within a sdist package 30 | cwd = os.path.abspath(os.path.dirname(__file__)) 31 | remove_c_files = not os.path.exists(os.path.join(cwd, "PKG-INFO")) 32 | if remove_c_files: 33 | print("Will remove generated .c files") 34 | if os.path.exists("build"): 35 | shutil.rmtree("build") 36 | for dirpath, dirnames, filenames in os.walk("sklearn"): 37 | for filename in filenames: 38 | if any( 39 | filename.endswith(suffix) 40 | for suffix in (".so", ".pyd", ".dll", ".pyc") 41 | ): 42 | os.unlink(os.path.join(dirpath, filename)) 43 | continue 44 | extension = os.path.splitext(filename)[1] 45 | if remove_c_files and extension in [".c", ".cpp"]: 46 | pyx_file = str.replace(filename, extension, ".pyx") 47 | if os.path.exists(os.path.join(dirpath, pyx_file)): 48 | os.unlink(os.path.join(dirpath, filename)) 49 | for dirname in dirnames: 50 | if dirname == "__pycache__": 51 | shutil.rmtree(os.path.join(dirpath, dirname)) 52 | 53 | 54 | cmdclass.update({"clean": CleanCommand}) 55 | 56 | 57 | setup( 58 | name="torchensemble", 59 | maintainer="Yi-Xuan Xu", 60 | maintainer_email="xuyx@lamda.nju.edu.cn", 61 | description=( 62 | "A unified ensemble framework for PyTorch to improve the performance" 63 | " and robustness of your deep learning model" 64 | ), 65 | license="BSD 3-Clause", 66 | url="https://github.com/TorchEnsemble-Community/Ensemble-Pytorch", 67 | project_urls={ 68 | "Bug Tracker": "https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/issues", 69 | "Documentation": "https://ensemble-pytorch.readthedocs.io", 70 | "Source Code": "https://github.com/TorchEnsemble-Community/Ensemble-Pytorch", 71 | }, 72 | version="0.2.0", 73 | long_description=long_description, 74 | classifiers=[ 75 | "Intended Audience :: Science/Research", 76 | "Intended Audience :: Developers", 77 | "Topic :: Software Development", 78 | "Topic :: Scientific/Engineering", 79 | "Operating System :: Microsoft :: Windows", 80 | "Operating System :: POSIX", 81 | "Operating System :: Unix", 82 | "Operating System :: MacOS", 83 | "Programming Language :: Python :: 3.9", 84 | "Programming Language :: Python :: 3.10", 85 | ], 86 | keywords=["Deep Learning", "PyTorch", "Ensemble Learning"], 87 | packages=find_packages(), 88 | cmdclass=cmdclass, 89 | python_requires=">=3.9", 90 | install_requires=install_requires, 91 | ) 92 | -------------------------------------------------------------------------------- /torchensemble/__init__.py: -------------------------------------------------------------------------------- 1 | from .fusion import FusionClassifier 2 | from .fusion import FusionRegressor 3 | from .voting import VotingClassifier 4 | from .voting import VotingRegressor 5 | from .voting import NeuralForestClassifier 6 | from .voting import NeuralForestRegressor 7 | from .bagging import BaggingClassifier 8 | from .bagging import BaggingRegressor 9 | from .gradient_boosting import GradientBoostingClassifier 10 | from .gradient_boosting import GradientBoostingRegressor 11 | from .snapshot_ensemble import SnapshotEnsembleClassifier 12 | from .snapshot_ensemble import SnapshotEnsembleRegressor 13 | from .adversarial_training import AdversarialTrainingClassifier 14 | from .adversarial_training import AdversarialTrainingRegressor 15 | from .fast_geometric import FastGeometricClassifier 16 | from .fast_geometric import FastGeometricRegressor 17 | from .soft_gradient_boosting import SoftGradientBoostingClassifier 18 | from .soft_gradient_boosting import SoftGradientBoostingRegressor 19 | 20 | 21 | __all__ = [ 22 | "FusionClassifier", 23 | "FusionRegressor", 24 | "VotingClassifier", 25 | "VotingRegressor", 26 | "NeuralForestClassifier", 27 | "NeuralForestRegressor", 28 | "BaggingClassifier", 29 | "BaggingRegressor", 30 | "GradientBoostingClassifier", 31 | "GradientBoostingRegressor", 32 | "SnapshotEnsembleClassifier", 33 | "SnapshotEnsembleRegressor", 34 | "AdversarialTrainingClassifier", 35 | "AdversarialTrainingRegressor", 36 | "FastGeometricClassifier", 37 | "FastGeometricRegressor", 38 | "SoftGradientBoostingClassifier", 39 | "SoftGradientBoostingRegressor", 40 | ] 41 | -------------------------------------------------------------------------------- /torchensemble/_constants.py: -------------------------------------------------------------------------------- 1 | __model_doc = """ 2 | Parameters 3 | ---------- 4 | estimator : torch.nn.Module 5 | The class or object of your base estimator. 6 | 7 | - If :obj:`class`, it should inherit from :mod:`torch.nn.Module`. 8 | - If :obj:`object`, it should be instantiated from a class inherited 9 | from :mod:`torch.nn.Module`. 10 | n_estimators : int 11 | The number of base estimators in the ensemble. 12 | estimator_args : dict, default=None 13 | The dictionary of hyper-parameters used to instantiate base 14 | estimators. This parameter will have no effect if ``estimator`` is a 15 | base estimator object after instantiation. 16 | cuda : bool, default=True 17 | 18 | - If ``True``, use GPU to train and evaluate the ensemble. 19 | - If ``False``, use CPU to train and evaluate the ensemble. 20 | n_jobs : int, default=None 21 | The number of workers for training the ensemble. This input 22 | argument is used for parallel ensemble methods such as 23 | :mod:`voting` and :mod:`bagging`. Setting it to an integer larger 24 | than ``1`` enables ``n_jobs`` base estimators to be trained 25 | simultaneously. 26 | 27 | Attributes 28 | ---------- 29 | estimators_ : torch.nn.ModuleList 30 | An internal container that stores all fitted base estimators. 31 | """ 32 | 33 | 34 | __seq_model_doc = """ 35 | Parameters 36 | ---------- 37 | estimator : torch.nn.Module 38 | The class or object of your base estimator. 39 | 40 | - If :obj:`class`, it should inherit from :mod:`torch.nn.Module`. 41 | - If :obj:`object`, it should be instantiated from a class inherited 42 | from :mod:`torch.nn.Module`. 43 | n_estimators : int 44 | The number of base estimators in the ensemble. 45 | estimator_args : dict, default=None 46 | The dictionary of hyper-parameters used to instantiate base 47 | estimators. This parameter will have no effect if ``estimator`` is a 48 | base estimator object after instantiation. 49 | cuda : bool, default=True 50 | 51 | - If ``True``, use GPU to train and evaluate the ensemble. 52 | - If ``False``, use CPU to train and evaluate the ensemble. 53 | 54 | Attributes 55 | ---------- 56 | estimators_ : torch.nn.ModuleList 57 | An internal container that stores all fitted base estimators. 58 | """ 59 | 60 | 61 | __tree_ensemble_doc = """ 62 | Parameters 63 | ---------- 64 | n_estimators : int 65 | The number of neural trees in the ensemble. 66 | depth : int, default=5 67 | The depth of neural tree. A tree with depth ``d`` is with :math:`2^d` 68 | leaf nodes and :math:`2^d-1` internal nodes. 69 | lamda : float, default=1e-3 70 | The coefficient of the regularization term when training neural 71 | trees, proposed in the paper: `Distilling a neural network into a 72 | soft decision tree `_. 73 | cuda : bool, default=True 74 | 75 | - If ``True``, use GPU to train and evaluate the ensemble. 76 | - If ``False``, use CPU to train and evaluate the ensemble. 77 | n_jobs : int, default=None 78 | The number of workers for training the ensemble. This input 79 | argument is used for parallel ensemble methods such as 80 | :mod:`voting` and :mod:`bagging`. Setting it to an integer larger 81 | than ``1`` enables ``n_jobs`` base estimators to be trained 82 | simultaneously. 83 | 84 | Attributes 85 | ---------- 86 | estimators_ : torch.nn.ModuleList 87 | An internal container that stores all fitted base estimators. 88 | """ 89 | 90 | 91 | __set_optimizer_doc = """ 92 | Parameters 93 | ---------- 94 | optimizer_name : string 95 | The name of the optimizer, should be one of {``Adadelta``, ``Adagrad``, 96 | ``Adam``, ``AdamW``, ``Adamax``, ``ASGD``, ``RMSprop``, ``Rprop``, 97 | ``SGD``}. 98 | **kwargs : keyword arguments 99 | Keyword arguments on setting the optimizer, should be in the form: 100 | ``lr=1e-3, weight_decay=5e-4, ...``. These keyword arguments 101 | will be directly passed to :mod:`torch.optim.Optimizer`. 102 | """ 103 | 104 | 105 | __set_scheduler_doc = """ 106 | Parameters 107 | ---------- 108 | scheduler_name : string 109 | The name of the scheduler, should be one of {``LambdaLR``, 110 | ``MultiplicativeLR``, ``StepLR``, ``MultiStepLR``, ``ExponentialLR``, 111 | ``CosineAnnealingLR``, ``ReduceLROnPlateau``}. 112 | **kwargs : keyword arguments 113 | Keyword arguments on setting the scheduler. These keyword arguments 114 | will be directly passed to :mod:`torch.optim.lr_scheduler`. 115 | """ 116 | 117 | 118 | __set_criterion_doc = """ 119 | Parameters 120 | ---------- 121 | criterion : torch.nn.loss 122 | The customized training criterion object. 123 | """ 124 | 125 | 126 | __fit_doc = """ 127 | Parameters 128 | ---------- 129 | train_loader : torch.utils.data.DataLoader 130 | A :mod:`torch.utils.data.DataLoader` container that contains the 131 | training data. 132 | epochs : int, default=100 133 | The number of training epochs. 134 | log_interval : int, default=100 135 | The number of batches to wait before logging the training status. 136 | test_loader : torch.utils.data.DataLoader, default=None 137 | A :mod:`torch.utils.data.DataLoader` container that contains the 138 | evaluating data. 139 | 140 | - If ``None``, no validation is conducted during the training 141 | stage. 142 | - If not ``None``, the ensemble will be evaluated on this 143 | dataloader after each training epoch. 144 | save_model : bool, default=True 145 | Specify whether to save the model parameters. 146 | 147 | - If test_loader is ``None``, the ensemble fully trained will be 148 | saved. 149 | - If test_loader is not ``None``, the ensemble with the best 150 | validation performance will be saved. 151 | save_dir : string, default=None 152 | Specify where to save the model parameters. 153 | 154 | - If ``None``, the model will be saved in the current directory. 155 | - If not ``None``, the model will be saved in the specified 156 | directory: ``save_dir``. 157 | """ 158 | 159 | 160 | __predict_doc = """ 161 | Return the predictions of the ensemble given the testing data. 162 | 163 | Parameters 164 | ---------- 165 | X : {tensor, numpy array} 166 | A data batch in the form of tensor or numpy array. 167 | 168 | Returns 169 | ------- 170 | pred : tensor of shape (n_samples, n_outputs) 171 | For classifiers, ``n_outputs`` is the number of distinct classes. For 172 | regressors, ``n_output`` is the number of target variables. 173 | """ 174 | 175 | 176 | __classification_forward_doc = """ 177 | Parameters 178 | ---------- 179 | X : tensor 180 | An input batch of data, which should be a valid input data batch 181 | for base estimators in the ensemble. 182 | 183 | Returns 184 | ------- 185 | proba : tensor of shape (batch_size, n_classes) 186 | The predicted class distribution. 187 | """ 188 | 189 | 190 | __classification_evaluate_doc = """ 191 | Compute the classification accuracy of the ensemble given the testing 192 | dataloader and optionally the average cross-entropy loss. 193 | 194 | Parameters 195 | ---------- 196 | test_loader : torch.utils.data.DataLoader 197 | A data loader that contains the testing data. 198 | return_loss : bool, default=False 199 | Whether to return the average cross-entropy loss over all batches 200 | in the ``test_loader``. 201 | 202 | Returns 203 | ------- 204 | accuracy : float 205 | The classification accuracy of the fitted ensemble on ``test_loader``. 206 | loss : float 207 | The average cross-entropy loss of the fitted ensemble on 208 | ``test_loader``, only available when ``return_loss`` is True. 209 | """ 210 | 211 | 212 | __regression_forward_doc = """ 213 | Parameters 214 | ---------- 215 | X : tensor 216 | An input batch of data, which should be a valid input data batch 217 | for base estimators in the ensemble. 218 | 219 | Returns 220 | ------- 221 | pred : tensor of shape (batch_size, n_outputs) 222 | The predicted values. 223 | """ 224 | 225 | 226 | __regression_evaluate_doc = """ 227 | Compute the mean squared error (MSE) of the ensemble given the testing 228 | dataloader. 229 | 230 | Parameters 231 | ---------- 232 | test_loader : torch.utils.data.DataLoader 233 | A data loader that contains the testing data. 234 | 235 | Returns 236 | ------- 237 | mse : float 238 | The testing mean squared error of the fitted ensemble on 239 | ``test_loader``. 240 | """ 241 | -------------------------------------------------------------------------------- /torchensemble/fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | In fusion-based ensemble, predictions from all base estimators are 3 | first aggregated as an average output. After then, the training loss is 4 | computed based on this average output and the ground-truth. The training 5 | loss is then back-propagated to all base estimators simultaneously. 6 | """ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from ._base import BaseClassifier, BaseRegressor 14 | from ._base import torchensemble_model_doc 15 | from .utils import io 16 | from .utils import set_module 17 | from .utils import operator as op 18 | 19 | 20 | __all__ = ["FusionClassifier", "FusionRegressor"] 21 | 22 | 23 | @torchensemble_model_doc( 24 | """Implementation on the FusionClassifier.""", "model" 25 | ) 26 | class FusionClassifier(BaseClassifier): 27 | def _forward(self, *x): 28 | """ 29 | Implementation on the internal data forwarding in FusionClassifier. 30 | """ 31 | # Average 32 | outputs = [estimator(*x) for estimator in self.estimators_] 33 | output = op.average(outputs) 34 | 35 | return output 36 | 37 | @torchensemble_model_doc( 38 | """Implementation on the data forwarding in FusionClassifier.""", 39 | "classifier_forward", 40 | ) 41 | def forward(self, *x): 42 | output = op.unsqueeze_tensor(self._forward(*x)) 43 | proba = F.softmax(output, dim=1) 44 | 45 | return proba 46 | 47 | @torchensemble_model_doc( 48 | """Set the attributes on optimizer for FusionClassifier.""", 49 | "set_optimizer", 50 | ) 51 | def set_optimizer(self, optimizer_name, **kwargs): 52 | super().set_optimizer(optimizer_name, **kwargs) 53 | 54 | @torchensemble_model_doc( 55 | """Set the attributes on scheduler for FusionClassifier.""", 56 | "set_scheduler", 57 | ) 58 | def set_scheduler(self, scheduler_name, **kwargs): 59 | super().set_scheduler(scheduler_name, **kwargs) 60 | 61 | @torchensemble_model_doc( 62 | """Set the training criterion for FusionClassifier.""", 63 | "set_criterion", 64 | ) 65 | def set_criterion(self, criterion): 66 | super().set_criterion(criterion) 67 | 68 | @torchensemble_model_doc( 69 | """Implementation on the training stage of FusionClassifier.""", "fit" 70 | ) 71 | def fit( 72 | self, 73 | train_loader, 74 | epochs=100, 75 | log_interval=100, 76 | test_loader=None, 77 | save_model=True, 78 | save_dir=None, 79 | ): 80 | 81 | # Instantiate base estimators and set attributes 82 | for _ in range(self.n_estimators): 83 | self.estimators_.append(self._make_estimator()) 84 | self._validate_parameters(epochs, log_interval) 85 | self.n_outputs = self._decide_n_outputs(train_loader) 86 | optimizer = set_module.set_optimizer( 87 | self, self.optimizer_name, **self.optimizer_args 88 | ) 89 | 90 | # Set the scheduler if `set_scheduler` was called before 91 | if self.use_scheduler_: 92 | self.scheduler_ = set_module.set_scheduler( 93 | optimizer, self.scheduler_name, **self.scheduler_args 94 | ) 95 | 96 | # Check the training criterion 97 | if not hasattr(self, "_criterion"): 98 | self._criterion = nn.CrossEntropyLoss() 99 | 100 | # Utils 101 | best_acc = 0.0 102 | total_iters = 0 103 | 104 | # Training loop 105 | for epoch in range(epochs): 106 | self.train() 107 | for batch_idx, elem in enumerate(train_loader): 108 | 109 | data, target = io.split_data_target(elem, self.device) 110 | batch_size = data[0].size(0) 111 | 112 | optimizer.zero_grad() 113 | output = self._forward(*data) 114 | loss = self._criterion(output, target) 115 | loss.backward() 116 | optimizer.step() 117 | 118 | # Print training status 119 | if batch_idx % log_interval == 0: 120 | with torch.no_grad(): 121 | _, predicted = torch.max(output.data, 1) 122 | correct = (predicted == target).sum().item() 123 | 124 | msg = ( 125 | "Epoch: {:03d} | Batch: {:03d} | Loss:" 126 | " {:.5f} | Correct: {:d}/{:d}" 127 | ) 128 | self.logger.info( 129 | msg.format( 130 | epoch, batch_idx, loss, correct, batch_size 131 | ) 132 | ) 133 | if self.tb_logger: 134 | self.tb_logger.add_scalar( 135 | "fusion/Train_Loss", loss, total_iters 136 | ) 137 | total_iters += 1 138 | 139 | # Validation 140 | if test_loader: 141 | self.eval() 142 | with torch.no_grad(): 143 | correct = 0 144 | total = 0 145 | for _, elem in enumerate(test_loader): 146 | data, target = io.split_data_target(elem, self.device) 147 | output = self.forward(*data) 148 | _, predicted = torch.max(output.data, 1) 149 | correct += (predicted == target).sum().item() 150 | total += target.size(0) 151 | acc = 100 * correct / total 152 | 153 | if acc > best_acc: 154 | best_acc = acc 155 | if save_model: 156 | io.save(self, save_dir, self.logger) 157 | 158 | msg = ( 159 | "Epoch: {:03d} | Validation Acc: {:.3f}" 160 | " % | Historical Best: {:.3f} %" 161 | ) 162 | self.logger.info(msg.format(epoch, acc, best_acc)) 163 | if self.tb_logger: 164 | self.tb_logger.add_scalar( 165 | "fusion/Validation_Acc", acc, epoch 166 | ) 167 | 168 | # Update the scheduler 169 | if hasattr(self, "scheduler_"): 170 | if self.scheduler_name == "ReduceLROnPlateau": 171 | if test_loader: 172 | self.scheduler_.step(acc) 173 | else: 174 | self.scheduler_.step(loss) 175 | else: 176 | self.scheduler_.step() 177 | 178 | if save_model and not test_loader: 179 | io.save(self, save_dir, self.logger) 180 | 181 | @torchensemble_model_doc(item="classifier_evaluate") 182 | def evaluate(self, test_loader, return_loss=False): 183 | return super().evaluate(test_loader, return_loss) 184 | 185 | @torchensemble_model_doc(item="predict") 186 | def predict(self, *x): 187 | return super().predict(*x) 188 | 189 | 190 | @torchensemble_model_doc("""Implementation on the FusionRegressor.""", "model") 191 | class FusionRegressor(BaseRegressor): 192 | @torchensemble_model_doc( 193 | """Implementation on the data forwarding in FusionRegressor.""", 194 | "regressor_forward", 195 | ) 196 | def forward(self, *x): 197 | # Average 198 | outputs = [estimator(*x) for estimator in self.estimators_] 199 | pred = op.average(outputs) 200 | 201 | return pred 202 | 203 | @torchensemble_model_doc( 204 | """Set the attributes on optimizer for FusionRegressor.""", 205 | "set_optimizer", 206 | ) 207 | def set_optimizer(self, optimizer_name, **kwargs): 208 | super().set_optimizer(optimizer_name, **kwargs) 209 | 210 | @torchensemble_model_doc( 211 | """Set the attributes on scheduler for FusionRegressor.""", 212 | "set_scheduler", 213 | ) 214 | def set_scheduler(self, scheduler_name, **kwargs): 215 | super().set_scheduler(scheduler_name, **kwargs) 216 | 217 | @torchensemble_model_doc( 218 | """Set the training criterion for FusionRegressor.""", 219 | "set_criterion", 220 | ) 221 | def set_criterion(self, criterion): 222 | super().set_criterion(criterion) 223 | 224 | @torchensemble_model_doc( 225 | """Implementation on the training stage of FusionRegressor.""", "fit" 226 | ) 227 | def fit( 228 | self, 229 | train_loader, 230 | epochs=100, 231 | log_interval=100, 232 | test_loader=None, 233 | save_model=True, 234 | save_dir=None, 235 | ): 236 | # Instantiate base estimators and set attributes 237 | for _ in range(self.n_estimators): 238 | self.estimators_.append(self._make_estimator()) 239 | self._validate_parameters(epochs, log_interval) 240 | self.n_outputs = self._decide_n_outputs(train_loader) 241 | optimizer = set_module.set_optimizer( 242 | self, self.optimizer_name, **self.optimizer_args 243 | ) 244 | 245 | # Set the scheduler if `set_scheduler` was called before 246 | if self.use_scheduler_: 247 | self.scheduler_ = set_module.set_scheduler( 248 | optimizer, self.scheduler_name, **self.scheduler_args 249 | ) 250 | 251 | # Check the training criterion 252 | if not hasattr(self, "_criterion"): 253 | self._criterion = nn.MSELoss() 254 | 255 | # Utils 256 | best_loss = float("inf") 257 | total_iters = 0 258 | 259 | # Training loop 260 | for epoch in range(epochs): 261 | self.train() 262 | for batch_idx, elem in enumerate(train_loader): 263 | 264 | data, target = io.split_data_target(elem, self.device) 265 | 266 | optimizer.zero_grad() 267 | output = self.forward(*data) 268 | loss = self._criterion(output, target) 269 | loss.backward() 270 | optimizer.step() 271 | 272 | # Print training status 273 | if batch_idx % log_interval == 0: 274 | with torch.no_grad(): 275 | msg = "Epoch: {:03d} | Batch: {:03d} | Loss: {:.5f}" 276 | self.logger.info(msg.format(epoch, batch_idx, loss)) 277 | if self.tb_logger: 278 | self.tb_logger.add_scalar( 279 | "fusion/Train_Loss", loss, total_iters 280 | ) 281 | total_iters += 1 282 | 283 | # Validation 284 | if test_loader: 285 | self.eval() 286 | with torch.no_grad(): 287 | val_loss = 0.0 288 | for _, elem in enumerate(test_loader): 289 | data, target = io.split_data_target(elem, self.device) 290 | output = self.forward(*data) 291 | val_loss += self._criterion(output, target) 292 | val_loss /= len(test_loader) 293 | 294 | if val_loss < best_loss: 295 | best_loss = val_loss 296 | if save_model: 297 | io.save(self, save_dir, self.logger) 298 | 299 | msg = ( 300 | "Epoch: {:03d} | Validation Loss: {:.5f} |" 301 | " Historical Best: {:.5f}" 302 | ) 303 | self.logger.info(msg.format(epoch, val_loss, best_loss)) 304 | if self.tb_logger: 305 | self.tb_logger.add_scalar( 306 | "fusion/Validation_Loss", val_loss, epoch 307 | ) 308 | 309 | # Update the scheduler 310 | if hasattr(self, "scheduler_"): 311 | if self.scheduler_name == "ReduceLROnPlateau": 312 | if test_loader: 313 | self.scheduler_.step(val_loss) 314 | else: 315 | self.scheduler_.step(loss) 316 | else: 317 | self.scheduler_.step() 318 | 319 | if save_model and not test_loader: 320 | io.save(self, save_dir, self.logger) 321 | 322 | @torchensemble_model_doc(item="regressor_evaluate") 323 | def evaluate(self, test_loader): 324 | return super().evaluate(test_loader) 325 | 326 | @torchensemble_model_doc(item="predict") 327 | def predict(self, *x): 328 | return super().predict(*x) 329 | -------------------------------------------------------------------------------- /torchensemble/tests/test_adversarial_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.utils.data import TensorDataset, DataLoader 6 | 7 | import torchensemble 8 | from torchensemble.utils.logging import set_logger 9 | 10 | 11 | set_logger("pytest_adversarial_training") 12 | 13 | 14 | X_train = torch.Tensor(np.array(([1, 1], [2, 2], [3, 3], [4, 4]))) 15 | y_train_clf = torch.LongTensor(np.array(([0, 0, 1, 1]))) 16 | 17 | 18 | # Base estimator 19 | class MLP_clf(nn.Module): 20 | def __init__(self): 21 | super(MLP_clf, self).__init__() 22 | self.linear1 = nn.Linear(2, 2) 23 | self.linear2 = nn.Linear(2, 2) 24 | 25 | def forward(self, X): 26 | X = X.view(X.size()[0], -1) 27 | output = self.linear1(X) 28 | output = self.linear2(output) 29 | return output 30 | 31 | 32 | def test_adversarial_training_range(): 33 | """ 34 | This unit test checks the input range check in adversarial_training.py. 35 | """ 36 | model = torchensemble.AdversarialTrainingClassifier( 37 | estimator=MLP_clf, n_estimators=2, cuda=False 38 | ) 39 | 40 | model.set_optimizer("Adam") 41 | 42 | # Prepare data 43 | train = TensorDataset(X_train, y_train_clf) 44 | train_loader = DataLoader(train, batch_size=2) 45 | 46 | # Training 47 | with pytest.raises(ValueError) as excinfo: 48 | model.fit(train_loader) 49 | assert "input range of samples passed to adversarial" in str(excinfo.value) 50 | -------------------------------------------------------------------------------- /torchensemble/tests/test_all_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | import torch.nn as nn 5 | from numpy.testing import assert_array_equal 6 | from torch.utils.data import TensorDataset, DataLoader 7 | 8 | import torchensemble 9 | from torchensemble.utils import io 10 | from torchensemble.utils.logging import set_logger 11 | 12 | 13 | # All classifiers 14 | all_clf = [ 15 | torchensemble.FusionClassifier, 16 | torchensemble.VotingClassifier, 17 | torchensemble.BaggingClassifier, 18 | torchensemble.GradientBoostingClassifier, 19 | torchensemble.SnapshotEnsembleClassifier, 20 | torchensemble.AdversarialTrainingClassifier, 21 | torchensemble.FastGeometricClassifier, 22 | torchensemble.SoftGradientBoostingClassifier, 23 | ] 24 | 25 | 26 | # All regressors 27 | all_reg = [ 28 | torchensemble.FusionRegressor, 29 | torchensemble.VotingRegressor, 30 | torchensemble.BaggingRegressor, 31 | torchensemble.GradientBoostingRegressor, 32 | torchensemble.SnapshotEnsembleRegressor, 33 | torchensemble.AdversarialTrainingRegressor, 34 | torchensemble.FastGeometricRegressor, 35 | torchensemble.SoftGradientBoostingRegressor, 36 | ] 37 | 38 | 39 | np.random.seed(0) 40 | torch.manual_seed(0) 41 | logger = set_logger("pytest_all_models") 42 | 43 | 44 | # Base estimator 45 | class MLP_clf(nn.Module): 46 | def __init__(self): 47 | super(MLP_clf, self).__init__() 48 | self.linear1 = nn.Linear(2, 2) 49 | self.linear2 = nn.Linear(2, 2) 50 | 51 | def forward(self, X): 52 | X = X.view(X.size(0), -1) 53 | output = self.linear1(X) 54 | output = self.linear2(output) 55 | return output 56 | 57 | 58 | class MLP_reg(nn.Module): 59 | def __init__(self): 60 | super(MLP_reg, self).__init__() 61 | self.linear1 = nn.Linear(2, 2) 62 | self.linear2 = nn.Linear(2, 1) 63 | 64 | def forward(self, X): 65 | X = X.view(X.size()[0], -1) 66 | output = self.linear1(X) 67 | output = self.linear2(output) 68 | return output 69 | 70 | 71 | # Training data 72 | X_train = torch.Tensor( 73 | np.array(([0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4])) 74 | ) 75 | 76 | y_train_clf = torch.LongTensor(np.array(([0, 0, 1, 1]))) 77 | y_train_reg = torch.FloatTensor(np.array(([0.1, 0.2, 0.3, 0.4]))) 78 | y_train_reg = y_train_reg.view(-1, 1) 79 | 80 | 81 | # Testing data 82 | numpy_X_test = np.array(([0.5, 0.5], [0.6, 0.6])) 83 | X_test = torch.Tensor(numpy_X_test) 84 | 85 | y_test_clf = torch.LongTensor(np.array(([1, 0]))) 86 | y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6]))) 87 | y_test_reg = y_test_reg.view(-1, 1) 88 | 89 | 90 | @pytest.mark.parametrize("clf", all_clf) 91 | def test_clf_class(clf): 92 | """ 93 | This unit test checks the training and evaluating stage of all classifiers. 94 | """ 95 | epochs = 1 96 | n_estimators = 2 97 | 98 | model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) 99 | 100 | # Optimizer 101 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 102 | 103 | # Scheduler (Snapshot Ensemble Excluded) 104 | if not isinstance(model, torchensemble.SnapshotEnsembleClassifier): 105 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 106 | 107 | # Prepare data 108 | train = TensorDataset(X_train, y_train_clf) 109 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 110 | test = TensorDataset(X_test, y_test_clf) 111 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 112 | 113 | # Snapshot Ensemble needs more epochs 114 | if isinstance(model, torchensemble.SnapshotEnsembleClassifier): 115 | epochs = 6 116 | 117 | # Train 118 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 119 | 120 | # Evaluate & Save 121 | model.evaluate(test_loader) 122 | io.save(model, "./", logger) 123 | 124 | # Predict 125 | for _, (data, target) in enumerate(test_loader): 126 | model.predict(data) 127 | break 128 | 129 | # Reload 130 | new_model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) 131 | io.load(new_model) 132 | 133 | new_model.evaluate(test_loader) 134 | 135 | for _, (data, target) in enumerate(test_loader): 136 | new_model.predict(data) 137 | break 138 | 139 | 140 | @pytest.mark.parametrize("clf", all_clf) 141 | def test_clf_object(clf): 142 | """ 143 | This unit test checks the training and evaluating stage of all classifiers. 144 | """ 145 | epochs = 1 146 | n_estimators = 2 147 | 148 | model = clf(estimator=MLP_clf(), n_estimators=n_estimators, cuda=False) 149 | 150 | # Optimizer 151 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 152 | 153 | # Scheduler (Snapshot Ensemble Excluded) 154 | if not isinstance(model, torchensemble.SnapshotEnsembleClassifier): 155 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 156 | 157 | # Prepare data 158 | train = TensorDataset(X_train, y_train_clf) 159 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 160 | test = TensorDataset(X_test, y_test_clf) 161 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 162 | 163 | # Snapshot Ensemble needs more epochs 164 | if isinstance(model, torchensemble.SnapshotEnsembleClassifier): 165 | epochs = 6 166 | 167 | # Train 168 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 169 | 170 | # Evaluate & Save 171 | model.evaluate(test_loader) 172 | io.save(model, "./", logger) 173 | 174 | # Predict 175 | for _, (data, target) in enumerate(test_loader): 176 | model.predict(data) 177 | break 178 | 179 | # Reload 180 | new_model = clf(estimator=MLP_clf(), n_estimators=n_estimators, cuda=False) 181 | io.load(new_model) 182 | 183 | new_model.evaluate(test_loader) 184 | 185 | for _, (data, target) in enumerate(test_loader): 186 | new_model.predict(data) 187 | break 188 | 189 | 190 | @pytest.mark.parametrize("reg", all_reg) 191 | def test_reg_class(reg): 192 | """ 193 | This unit test checks the training and evaluating stage of all regressors. 194 | """ 195 | epochs = 1 196 | n_estimators = 2 197 | 198 | model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) 199 | 200 | # Optimizer 201 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 202 | 203 | # Scheduler (Snapshot Ensemble Excluded) 204 | if not isinstance(model, torchensemble.SnapshotEnsembleRegressor): 205 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 206 | 207 | # Prepare data 208 | train = TensorDataset(X_train, y_train_reg) 209 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 210 | test = TensorDataset(X_test, y_test_reg) 211 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 212 | 213 | # Snapshot Ensemble needs more epochs 214 | if isinstance(model, torchensemble.SnapshotEnsembleRegressor): 215 | epochs = 6 216 | 217 | # Train 218 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 219 | 220 | # Evaluate & Save 221 | model.evaluate(test_loader) 222 | io.save(model, "./", logger) 223 | 224 | # Predict 225 | for _, (data, target) in enumerate(test_loader): 226 | model.predict(data) 227 | break 228 | 229 | # Reload 230 | new_model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) 231 | io.load(new_model) 232 | 233 | new_model.evaluate(test_loader) 234 | 235 | for _, (data, target) in enumerate(test_loader): 236 | new_model.predict(data) 237 | break 238 | 239 | 240 | @pytest.mark.parametrize("reg", all_reg) 241 | def test_reg_object(reg): 242 | """ 243 | This unit test checks the training and evaluating stage of all regressors. 244 | """ 245 | epochs = 1 246 | n_estimators = 2 247 | 248 | model = reg(estimator=MLP_reg(), n_estimators=n_estimators, cuda=False) 249 | 250 | # Optimizer 251 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 252 | 253 | # Scheduler (Snapshot Ensemble Excluded) 254 | if not isinstance(model, torchensemble.SnapshotEnsembleRegressor): 255 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 256 | 257 | # Prepare data 258 | train = TensorDataset(X_train, y_train_reg) 259 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 260 | test = TensorDataset(X_test, y_test_reg) 261 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 262 | 263 | # Snapshot Ensemble needs more epochs 264 | if isinstance(model, torchensemble.SnapshotEnsembleRegressor): 265 | epochs = 6 266 | 267 | # Train 268 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 269 | 270 | # Evaluate & Save 271 | model.evaluate(test_loader) 272 | io.save(model, "./", logger) 273 | 274 | # Predict 275 | for _, (data, target) in enumerate(test_loader): 276 | model.predict(data) 277 | break 278 | 279 | # Reload 280 | new_model = reg(estimator=MLP_reg(), n_estimators=n_estimators, cuda=False) 281 | io.load(new_model) 282 | 283 | new_model.evaluate(test_loader) 284 | 285 | for _, (data, target) in enumerate(test_loader): 286 | new_model.predict(data) 287 | break 288 | 289 | 290 | def test_predict(): 291 | 292 | fusion = all_clf[0] # FusionClassifier 293 | model = fusion(estimator=MLP_clf, n_estimators=2, cuda=False) 294 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 295 | 296 | train = TensorDataset(X_train, y_train_clf) 297 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 298 | model.fit(train_loader, epochs=1) 299 | 300 | assert_array_equal(model.predict(X_test), model.predict(numpy_X_test)) 301 | 302 | with pytest.raises(ValueError) as excinfo: 303 | model.predict([X_test]) # list 304 | assert "The type of input X should be one of" in str(excinfo.value) 305 | -------------------------------------------------------------------------------- /torchensemble/tests/test_all_models_multi_input.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.utils.data import TensorDataset, DataLoader 6 | 7 | import torchensemble 8 | from torchensemble.utils import io 9 | from torchensemble.utils.logging import set_logger 10 | 11 | 12 | # All classifiers 13 | all_clf = [ 14 | torchensemble.FusionClassifier, 15 | torchensemble.VotingClassifier, 16 | torchensemble.BaggingClassifier, 17 | torchensemble.GradientBoostingClassifier, 18 | torchensemble.SnapshotEnsembleClassifier, 19 | torchensemble.AdversarialTrainingClassifier, 20 | torchensemble.FastGeometricClassifier, 21 | torchensemble.SoftGradientBoostingClassifier, 22 | ] 23 | 24 | 25 | # All regressors 26 | all_reg = [ 27 | torchensemble.FusionRegressor, 28 | torchensemble.VotingRegressor, 29 | torchensemble.BaggingRegressor, 30 | torchensemble.GradientBoostingRegressor, 31 | torchensemble.SnapshotEnsembleRegressor, 32 | torchensemble.AdversarialTrainingRegressor, 33 | torchensemble.FastGeometricRegressor, 34 | torchensemble.SoftGradientBoostingRegressor, 35 | ] 36 | 37 | 38 | np.random.seed(0) 39 | torch.manual_seed(0) 40 | device = torch.device("cpu") 41 | logger = set_logger("pytest_all_models_multiple_input") 42 | 43 | 44 | # Base estimator 45 | class MLP_clf(nn.Module): 46 | def __init__(self): 47 | super(MLP_clf, self).__init__() 48 | self.linear1 = nn.Linear(2, 2) 49 | self.linear2 = nn.Linear(2, 2) 50 | 51 | def forward(self, X_1, X_2): 52 | X_1 = X_1.view(X_1.size()[0], -1) 53 | X_2 = X_2.view(X_2.size()[0], -1) 54 | output_1 = self.linear1(X_1) 55 | output_1 = self.linear2(output_1) 56 | output_2 = self.linear1(X_2) 57 | output_2 = self.linear2(output_2) 58 | return 0.5 * output_1 + 0.5 * output_2 59 | 60 | 61 | class MLP_reg(nn.Module): 62 | def __init__(self): 63 | super(MLP_reg, self).__init__() 64 | self.linear1 = nn.Linear(2, 2) 65 | self.linear2 = nn.Linear(2, 1) 66 | 67 | def forward(self, X_1, X_2): 68 | X_1 = X_1.view(X_1.size()[0], -1) 69 | X_2 = X_2.view(X_2.size()[0], -1) 70 | output_1 = self.linear1(X_1) 71 | output_1 = self.linear2(output_1) 72 | output_2 = self.linear1(X_2) 73 | output_2 = self.linear2(output_2) 74 | return 0.5 * output_1 + 0.5 * output_2 75 | 76 | 77 | # Training data 78 | X_train = torch.Tensor( 79 | np.array(([0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4])) 80 | ) 81 | 82 | y_train_clf = torch.LongTensor(np.array(([0, 0, 1, 1]))) 83 | y_train_reg = torch.FloatTensor(np.array(([0.1, 0.2, 0.3, 0.4]))) 84 | y_train_reg = y_train_reg.view(-1, 1) 85 | 86 | 87 | # Testing data 88 | numpy_X_test = np.array(([0.5, 0.5], [0.6, 0.6])) 89 | X_test = torch.Tensor(numpy_X_test) 90 | 91 | y_test_clf = torch.LongTensor(np.array(([1, 0]))) 92 | y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6]))) 93 | y_test_reg = y_test_reg.view(-1, 1) 94 | 95 | 96 | @pytest.mark.parametrize("clf", all_clf) 97 | def test_clf_class(clf): 98 | """ 99 | This unit test checks the training and evaluating stage of all classifiers. 100 | """ 101 | epochs = 1 102 | n_estimators = 2 103 | 104 | model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) 105 | 106 | # Optimizer 107 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 108 | 109 | # Scheduler (Snapshot Ensemble Excluded) 110 | if not isinstance(model, torchensemble.SnapshotEnsembleClassifier): 111 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 112 | 113 | # Prepare data with multiple inputs 114 | train = TensorDataset(X_train, X_train, y_train_clf) 115 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 116 | test = TensorDataset(X_test, X_test, y_test_clf) 117 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 118 | 119 | # Snapshot Ensemble needs more epochs 120 | if isinstance(model, torchensemble.SnapshotEnsembleClassifier): 121 | epochs = 6 122 | 123 | # Train 124 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 125 | 126 | # Evaluate 127 | model.evaluate(test_loader) 128 | 129 | # Predict 130 | for _, elem in enumerate(test_loader): 131 | data, target = io.split_data_target(elem, device) 132 | model.predict(*data) 133 | break 134 | 135 | # Reload 136 | new_model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) 137 | io.load(new_model) 138 | 139 | new_model.evaluate(test_loader) 140 | 141 | for _, elem in enumerate(test_loader): 142 | data, target = io.split_data_target(elem, device) 143 | new_model.predict(*data) 144 | break 145 | 146 | 147 | @pytest.mark.parametrize("clf", all_clf) 148 | def test_clf_object(clf): 149 | """ 150 | This unit test checks the training and evaluating stage of all classifiers. 151 | """ 152 | epochs = 1 153 | n_estimators = 2 154 | 155 | model = clf(estimator=MLP_clf(), n_estimators=n_estimators, cuda=False) 156 | 157 | # Optimizer 158 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 159 | 160 | # Scheduler (Snapshot Ensemble Excluded) 161 | if not isinstance(model, torchensemble.SnapshotEnsembleClassifier): 162 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 163 | 164 | # Prepare data with multiple inputs 165 | train = TensorDataset(X_train, X_train, y_train_clf) 166 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 167 | test = TensorDataset(X_test, X_test, y_test_clf) 168 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 169 | 170 | # Snapshot Ensemble needs more epochs 171 | if isinstance(model, torchensemble.SnapshotEnsembleClassifier): 172 | epochs = 6 173 | 174 | # Train 175 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 176 | 177 | # Evaluate 178 | model.evaluate(test_loader) 179 | 180 | # Predict 181 | for _, elem in enumerate(test_loader): 182 | data, target = io.split_data_target(elem, device) 183 | model.predict(*data) 184 | break 185 | 186 | # Reload 187 | new_model = clf(estimator=MLP_clf(), n_estimators=n_estimators, cuda=False) 188 | io.load(new_model) 189 | 190 | new_model.evaluate(test_loader) 191 | 192 | for _, elem in enumerate(test_loader): 193 | data, target = io.split_data_target(elem, device) 194 | new_model.predict(*data) 195 | break 196 | 197 | 198 | @pytest.mark.parametrize("reg", all_reg) 199 | def test_reg_class(reg): 200 | """ 201 | This unit test checks the training and evaluating stage of all regressors. 202 | """ 203 | epochs = 1 204 | n_estimators = 2 205 | 206 | model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) 207 | 208 | # Optimizer 209 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 210 | 211 | # Scheduler (Snapshot Ensemble Excluded) 212 | if not isinstance(model, torchensemble.SnapshotEnsembleRegressor): 213 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 214 | 215 | # Prepare data with multiple inputs 216 | train = TensorDataset(X_train, X_train, y_train_reg) 217 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 218 | test = TensorDataset(X_test, X_test, y_test_reg) 219 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 220 | 221 | # Snapshot Ensemble needs more epochs 222 | if isinstance(model, torchensemble.SnapshotEnsembleRegressor): 223 | epochs = 6 224 | 225 | # Train 226 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 227 | 228 | # Evaluate 229 | model.evaluate(test_loader) 230 | 231 | # Predict 232 | for _, elem in enumerate(test_loader): 233 | data, target = io.split_data_target(elem, device) 234 | model.predict(*data) 235 | break 236 | 237 | # Reload 238 | new_model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) 239 | io.load(new_model) 240 | 241 | new_model.evaluate(test_loader) 242 | 243 | for _, elem in enumerate(test_loader): 244 | data, target = io.split_data_target(elem, device) 245 | new_model.predict(*data) 246 | break 247 | 248 | 249 | @pytest.mark.parametrize("reg", all_reg) 250 | def test_reg_object(reg): 251 | """ 252 | This unit test checks the training and evaluating stage of all regressors. 253 | """ 254 | epochs = 1 255 | n_estimators = 2 256 | 257 | model = reg(estimator=MLP_reg(), n_estimators=n_estimators, cuda=False) 258 | 259 | # Optimizer 260 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 261 | 262 | # Scheduler (Snapshot Ensemble Excluded) 263 | if not isinstance(model, torchensemble.SnapshotEnsembleRegressor): 264 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 265 | 266 | # Prepare data with multiple inputs 267 | train = TensorDataset(X_train, X_train, y_train_reg) 268 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 269 | test = TensorDataset(X_test, X_test, y_test_reg) 270 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 271 | 272 | # Snapshot Ensemble needs more epochs 273 | if isinstance(model, torchensemble.SnapshotEnsembleRegressor): 274 | epochs = 6 275 | 276 | # Train 277 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 278 | 279 | # Evaluate 280 | model.evaluate(test_loader) 281 | 282 | # Predict 283 | for _, elem in enumerate(test_loader): 284 | data, target = io.split_data_target(elem, device) 285 | model.predict(*data) 286 | break 287 | 288 | # Reload 289 | new_model = reg(estimator=MLP_reg(), n_estimators=n_estimators, cuda=False) 290 | io.load(new_model) 291 | 292 | new_model.evaluate(test_loader) 293 | 294 | for _, elem in enumerate(test_loader): 295 | data, target = io.split_data_target(elem, device) 296 | new_model.predict(*data) 297 | break 298 | 299 | 300 | def test_split_data_target_invalid_data_type(): 301 | with pytest.raises(ValueError) as excinfo: 302 | io.split_data_target(0.0, device, logger) 303 | assert "Invalid dataloader" in str(excinfo.value) 304 | 305 | 306 | def test_split_data_target_invalid_list_length(): 307 | with pytest.raises(ValueError) as excinfo: 308 | io.split_data_target([0.0], device, logger) 309 | assert "should at least contain two tensors" in str(excinfo.value) 310 | -------------------------------------------------------------------------------- /torchensemble/tests/test_fixed_dataloder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | from torch.utils.data import TensorDataset, DataLoader 5 | 6 | from torchensemble.utils.dataloder import FixedDataLoader 7 | 8 | 9 | # Data 10 | X = torch.Tensor(np.array(([0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]))) 11 | y = torch.LongTensor(np.array(([0, 0, 1, 1]))) 12 | 13 | data = TensorDataset(X, y) 14 | dataloder = DataLoader(data, batch_size=2, shuffle=False) 15 | 16 | 17 | def test_fixed_dataloder(): 18 | fixed_dataloader = FixedDataLoader(dataloder) 19 | for _, (fixed_elem, elem) in enumerate(zip(fixed_dataloader, dataloder)): 20 | # Check same elements 21 | for elem_1, elem_2 in zip(fixed_elem, elem): 22 | assert torch.equal(elem_1, elem_2) 23 | 24 | # Check dataloder length 25 | assert len(fixed_dataloader) == 2 26 | 27 | 28 | def test_fixed_dataloader_invalid_type(): 29 | with pytest.raises(ValueError) as excinfo: 30 | FixedDataLoader((X, y)) 31 | assert "input used to instantiate FixedDataLoader" in str(excinfo.value) 32 | -------------------------------------------------------------------------------- /torchensemble/tests/test_logging.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torchensemble.utils.logging import set_logger 4 | 5 | 6 | def _record(logger): 7 | logger.debug("Debug!") 8 | logger.info("Info!") 9 | logger.warning("Warning!") 10 | logger.error("Error!") 11 | logger.critical("Critical!") 12 | 13 | 14 | def test_logger(): 15 | logger = set_logger("Loglevel_DEBUG", "DEBUG") 16 | _record(logger) 17 | logger = set_logger("Loglevel_INFO", "INFO") 18 | _record(logger) 19 | logger = set_logger("Loglevel_WARNING", "WARNING") 20 | _record(logger) 21 | logger = set_logger("Loglevel_ERROR", "ERROR") 22 | _record(logger) 23 | logger = set_logger("Loglevel_CRITICAL", "CRITICAL") 24 | _record(logger) 25 | 26 | with pytest.raises(ValueError) as excinfo: 27 | set_logger("Loglevel_INVALID", "INVALID") 28 | assert "INVALID" in str(excinfo.value) 29 | -------------------------------------------------------------------------------- /torchensemble/tests/test_neural_tree_ensemble.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import TensorDataset, DataLoader 4 | 5 | from torchensemble.utils import io 6 | from torchensemble.utils.logging import set_logger 7 | from torchensemble import NeuralForestClassifier, NeuralForestRegressor 8 | 9 | 10 | np.random.seed(0) 11 | torch.manual_seed(0) 12 | set_logger("pytest_neural_tree_ensemble") 13 | 14 | 15 | # Training data 16 | X_train = torch.Tensor( 17 | np.array(([0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4])) 18 | ) 19 | 20 | y_train_clf = torch.LongTensor(np.array(([0, 0, 1, 1]))) 21 | y_train_reg = torch.FloatTensor(np.array(([0.1, 0.2, 0.3, 0.4]))) 22 | y_train_reg = y_train_reg.view(-1, 1) 23 | 24 | 25 | # Testing data 26 | numpy_X_test = np.array(([0.5, 0.5], [0.6, 0.6])) 27 | X_test = torch.Tensor(numpy_X_test) 28 | 29 | y_test_clf = torch.LongTensor(np.array(([1, 0]))) 30 | y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6]))) 31 | y_test_reg = y_test_reg.view(-1, 1) 32 | 33 | 34 | def test_neural_forest_classifier(): 35 | """ 36 | This unit test checks the training and evaluating stage of 37 | NeuralForestClassifier. 38 | """ 39 | epochs = 1 40 | n_estimators = 2 41 | depth = 3 42 | lamda = 1e-3 43 | 44 | model = NeuralForestClassifier( 45 | n_estimators=n_estimators, 46 | depth=depth, 47 | lamda=lamda, 48 | cuda=False, 49 | n_jobs=1, 50 | ) 51 | 52 | model.set_optimizer("Adam", lr=1e-4, weight_decay=5e-4) 53 | 54 | # Prepare data 55 | train = TensorDataset(X_train, y_train_clf) 56 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 57 | test = TensorDataset(X_test, y_test_clf) 58 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 59 | 60 | # Train 61 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 62 | 63 | # Evaluate 64 | model.evaluate(test_loader) 65 | 66 | # Predict 67 | for _, (data, target) in enumerate(test_loader): 68 | model.predict(data) 69 | break 70 | 71 | # Reload 72 | new_model = NeuralForestClassifier( 73 | n_estimators=n_estimators, 74 | depth=depth, 75 | lamda=lamda, 76 | cuda=False, 77 | n_jobs=1, 78 | ) 79 | io.load(new_model) 80 | 81 | new_model.evaluate(test_loader) 82 | 83 | for _, (data, target) in enumerate(test_loader): 84 | new_model.predict(data) 85 | break 86 | 87 | 88 | def test_neural_forest_regressor(): 89 | """ 90 | This unit test checks the training and evaluating stage of 91 | NeuralForestClassifier. 92 | """ 93 | epochs = 1 94 | n_estimators = 2 95 | depth = 3 96 | lamda = 1e-3 97 | 98 | model = NeuralForestRegressor( 99 | n_estimators=n_estimators, 100 | depth=depth, 101 | lamda=lamda, 102 | cuda=False, 103 | n_jobs=1, 104 | ) 105 | 106 | model.set_optimizer("Adam", lr=1e-4, weight_decay=5e-4) 107 | 108 | # Prepare data 109 | train = TensorDataset(X_train, y_train_reg) 110 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 111 | test = TensorDataset(X_test, y_test_reg) 112 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 113 | 114 | # Train 115 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 116 | 117 | # Evaluate 118 | model.evaluate(test_loader) 119 | 120 | # Predict 121 | for _, (data, target) in enumerate(test_loader): 122 | model.predict(data) 123 | break 124 | 125 | # Reload 126 | new_model = NeuralForestRegressor( 127 | n_estimators=n_estimators, 128 | depth=depth, 129 | lamda=lamda, 130 | cuda=False, 131 | n_jobs=1, 132 | ) 133 | io.load(new_model) 134 | 135 | new_model.evaluate(test_loader) 136 | 137 | for _, (data, target) in enumerate(test_loader): 138 | new_model.predict(data) 139 | break 140 | -------------------------------------------------------------------------------- /torchensemble/tests/test_operator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | from numpy.testing import assert_array_equal 5 | from numpy.testing import assert_array_almost_equal 6 | 7 | from torchensemble.utils import operator as op 8 | 9 | 10 | outputs = [ 11 | torch.FloatTensor(np.array(([1, 1], [1, 1]))), 12 | torch.FloatTensor(np.array(([2, 2], [2, 2]))), 13 | torch.FloatTensor(np.array(([3, 3], [3, 3]))), 14 | ] 15 | 16 | label = torch.LongTensor(np.array(([0, 1, 2, 1]))) 17 | n_classes = 3 18 | 19 | 20 | def test_average(): 21 | actual = op.average(outputs).numpy() 22 | expected = np.array(([2, 2], [2, 2])) 23 | assert_array_equal(actual, expected) 24 | 25 | 26 | def test_sum_with_multiplicative(): 27 | shrinkage_rate = 0.1 28 | actual = op.sum_with_multiplicative(outputs, shrinkage_rate).numpy() 29 | expected = np.array(([0.6, 0.6], [0.6, 0.6])) 30 | assert_array_almost_equal(actual, expected) 31 | 32 | 33 | def test_onehot_encoding(): 34 | actual = op.onehot_encoding(label, n_classes).numpy() 35 | expected = np.array(([1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 0])) 36 | assert_array_almost_equal(actual, expected) 37 | 38 | 39 | def test_residual_regression_invalid_shape(): 40 | with pytest.raises(ValueError) as excinfo: 41 | op.pseudo_residual_regression( 42 | torch.FloatTensor(np.array(([1, 1], [1, 1]))), # 2 * 2 43 | label.view(-1, 1), # 4 * 1 44 | ) 45 | assert "should be the same as output" in str(excinfo.value) 46 | 47 | 48 | def test_majority_voting(): 49 | outputs = [ 50 | torch.FloatTensor(np.array(([0.9, 0.1], [0.2, 0.8]))), 51 | torch.FloatTensor(np.array(([0.7, 0.3], [0.1, 0.9]))), 52 | torch.FloatTensor(np.array(([0.1, 0.9], [0.8, 0.2]))), 53 | ] 54 | actual = op.majority_vote(outputs).numpy() 55 | expected = np.array(([1, 0], [0, 1])) 56 | assert_array_almost_equal(actual, expected) 57 | -------------------------------------------------------------------------------- /torchensemble/tests/test_set_optimizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torchensemble 3 | import torch.nn as nn 4 | 5 | 6 | optimizer_list = [ 7 | "Adadelta", 8 | "Adagrad", 9 | "Adam", 10 | "AdamW", 11 | "Adamax", 12 | "ASGD", 13 | "RMSprop", 14 | "Rprop", 15 | "SGD", 16 | ] 17 | 18 | 19 | # Base estimator 20 | class MLP(nn.Module): 21 | def __init__(self): 22 | super(MLP, self).__init__() 23 | self.linear1 = nn.Linear(2, 2) 24 | self.linear2 = nn.Linear(2, 2) 25 | 26 | def forward(self, X): 27 | X = X.view(X.size()[0], -1) 28 | output = self.linear1(X) 29 | output = self.linear2(output) 30 | return output 31 | 32 | 33 | @pytest.mark.parametrize("optimizer_name", optimizer_list) 34 | def test_set_optimizer_normal(optimizer_name): 35 | model = MLP() 36 | torchensemble.utils.set_module.set_optimizer( 37 | model, optimizer_name, lr=1e-3 38 | ) 39 | 40 | 41 | def test_set_optimizer_Unknown(): 42 | model = MLP() 43 | 44 | with pytest.raises(NotImplementedError) as excinfo: 45 | torchensemble.utils.set_module.set_optimizer(model, "Unknown") 46 | assert "Unrecognized optimizer" in str(excinfo.value) 47 | 48 | 49 | def test_update_lr(): 50 | cur_lr = 1e-4 51 | model = MLP() 52 | optimizer = torchensemble.utils.set_module.set_optimizer( 53 | model, "Adam", lr=1e-3 54 | ) 55 | 56 | optimizer = torchensemble.utils.set_module.update_lr(optimizer, cur_lr) 57 | 58 | for group in optimizer.param_groups: 59 | assert group["lr"] == cur_lr 60 | 61 | 62 | def test_update_lr_invalid(): 63 | cur_lr = 0 64 | model = MLP() 65 | optimizer = torchensemble.utils.set_module.set_optimizer( 66 | model, "Adam", lr=1e-3 67 | ) 68 | 69 | err_msg = ( 70 | "The learning rate should be strictly positive, but got" " {} instead." 71 | ).format(cur_lr) 72 | with pytest.raises(ValueError, match=err_msg): 73 | torchensemble.utils.set_module.update_lr(optimizer, cur_lr) 74 | -------------------------------------------------------------------------------- /torchensemble/tests/test_set_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import torchensemble 4 | import numpy as np 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader, TensorDataset 7 | 8 | 9 | # All regressors 10 | all_reg = [ 11 | torchensemble.FusionRegressor, 12 | torchensemble.VotingRegressor, 13 | torchensemble.BaggingRegressor, 14 | torchensemble.GradientBoostingRegressor, 15 | torchensemble.AdversarialTrainingRegressor, 16 | torchensemble.FastGeometricRegressor, 17 | torchensemble.SoftGradientBoostingRegressor, 18 | ] 19 | 20 | # All classifiers 21 | all_clf = [ 22 | torchensemble.FusionClassifier, 23 | torchensemble.VotingClassifier, 24 | torchensemble.BaggingClassifier, 25 | torchensemble.GradientBoostingClassifier, 26 | torchensemble.AdversarialTrainingClassifier, 27 | torchensemble.FastGeometricClassifier, 28 | torchensemble.SoftGradientBoostingClassifier, 29 | ] 30 | 31 | 32 | # Base estimator 33 | class MLP(nn.Module): 34 | def __init__(self): 35 | super(MLP, self).__init__() 36 | self.linear1 = nn.Linear(2, 2) 37 | self.linear2 = nn.Linear(2, 2) 38 | 39 | def forward(self, X): 40 | X = X.view(X.size()[0], -1) 41 | output = self.linear1(X) 42 | output = self.linear2(output) 43 | return output 44 | 45 | 46 | class MLP_clf(nn.Module): 47 | def __init__(self): 48 | super(MLP_clf, self).__init__() 49 | self.linear1 = nn.Linear(2, 2) 50 | self.linear2 = nn.Linear(2, 2) 51 | 52 | def forward(self, X): 53 | X = X.view(X.size()[0], -1) 54 | output = self.linear1(X) 55 | output = self.linear2(output) 56 | return output 57 | 58 | 59 | # Training data 60 | X_train = torch.rand(2, 2) 61 | y_train_reg = torch.rand(2, 2) 62 | X_train_clf = torch.Tensor( 63 | np.array(([0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4])) 64 | ) 65 | y_train_clf = torch.LongTensor(np.array(([0, 0, 1, 1]))) 66 | 67 | # Testing data 68 | X_test = torch.rand(2, 2) 69 | y_test_reg = torch.rand(2, 2) 70 | numpy_X_test = np.array(([0.5, 0.5], [0.6, 0.6])) 71 | X_test_clf = torch.Tensor(numpy_X_test) 72 | y_test_clf = torch.LongTensor(np.array(([1, 0]))) 73 | 74 | train_data_reg = TensorDataset(X_train, y_train_reg) 75 | train_loader_reg = DataLoader(train_data_reg, batch_size=5, shuffle=True) 76 | test_data_reg = TensorDataset(X_test, y_test_reg) 77 | test_loader_reg = DataLoader(test_data_reg, batch_size=5, shuffle=True) 78 | 79 | train_data_clf = TensorDataset(X_train_clf, y_train_clf) 80 | train_loader_clf = DataLoader(train_data_clf, batch_size=5, shuffle=True) 81 | test_data_clf = TensorDataset(X_test_clf, y_test_clf) 82 | test_loader_clf = DataLoader(test_data_clf, batch_size=5, shuffle=True) 83 | 84 | 85 | def test_set_scheduler_LambdaLR(): 86 | model = MLP() 87 | optimizer = torch.optim.Adam(model.parameters()) 88 | lr_lambda = lambda x: x * 0.1 # noqa: E731 89 | torchensemble.utils.set_module.set_scheduler( 90 | optimizer, "LambdaLR", lr_lambda=lr_lambda 91 | ) 92 | 93 | 94 | def test_set_scheduler_MultiplicativeLR(): 95 | model = MLP() 96 | optimizer = torch.optim.Adam(model.parameters()) 97 | lr_lambda = lambda x: x * 0.1 # noqa: E731 98 | torchensemble.utils.set_module.set_scheduler( 99 | optimizer, "MultiplicativeLR", lr_lambda=lr_lambda 100 | ) 101 | 102 | 103 | def test_set_scheduler_StepLR(): 104 | model = MLP() 105 | optimizer = torch.optim.Adam(model.parameters()) 106 | torchensemble.utils.set_module.set_scheduler( 107 | optimizer, "StepLR", step_size=50 108 | ) 109 | 110 | 111 | def test_set_scheduler_MultiStepLR(): 112 | model = MLP() 113 | optimizer = torch.optim.Adam(model.parameters()) 114 | torchensemble.utils.set_module.set_scheduler( 115 | optimizer, "MultiStepLR", milestones=[50, 100] 116 | ) 117 | 118 | 119 | def test_set_scheduler_ExponentialLR(): 120 | model = MLP() 121 | optimizer = torch.optim.Adam(model.parameters()) 122 | torchensemble.utils.set_module.set_scheduler( 123 | optimizer, "ExponentialLR", gamma=0.1 124 | ) 125 | 126 | 127 | def test_set_scheduler_CosineAnnealingLR(): 128 | model = MLP() 129 | optimizer = torch.optim.Adam(model.parameters()) 130 | torchensemble.utils.set_module.set_scheduler( 131 | optimizer, "CosineAnnealingLR", T_max=100 132 | ) 133 | 134 | 135 | def test_set_scheduler_ReduceLROnPlateau(): 136 | model = MLP() 137 | optimizer = torch.optim.Adam(model.parameters()) 138 | torchensemble.utils.set_module.set_scheduler( 139 | optimizer, "ReduceLROnPlateau" 140 | ) 141 | 142 | 143 | @pytest.mark.parametrize( 144 | "scheduler_dict", 145 | [ 146 | {"scheduler_name": "ReduceLROnPlateau", "patience": 2, "min_lr": 1e-6}, 147 | {"scheduler_name": "LambdaLR", "lr_lambda": lambda x: x * 0.1}, 148 | {"scheduler_name": "StepLR", "step_size": 30}, 149 | {"scheduler_name": "MultiStepLR", "milestones": [20, 40]}, 150 | {"scheduler_name": "ExponentialLR", "gamma": 0.1}, 151 | {"scheduler_name": "CosineAnnealingLR", "T_max": 100}, 152 | ], 153 | ) 154 | @pytest.mark.parametrize( 155 | "test_dataloader", [test_loader_reg, None] 156 | ) # ReduceLROnPlateau scheduling should work when there is no test data 157 | @pytest.mark.parametrize( 158 | "n_estimators", [1, 5] 159 | ) # LR scheduling works for 1 as well as many estimators 160 | @pytest.mark.parametrize("ensemble_model", all_reg) 161 | def test_fit_w_all_schedulers( 162 | scheduler_dict, test_dataloader, n_estimators, ensemble_model 163 | ): 164 | """Test if LR schedulers work when `fit` is called.""" 165 | model = ensemble_model( 166 | estimator=MLP, n_estimators=n_estimators, cuda=False 167 | ) 168 | model.set_optimizer("Adam", lr=1e-1) 169 | model.set_scheduler(**scheduler_dict) 170 | model.fit( 171 | train_loader_reg, 172 | epochs=50, 173 | test_loader=test_dataloader, 174 | save_model=False, 175 | ) 176 | 177 | 178 | @pytest.mark.parametrize( 179 | "scheduler_dict", 180 | [ 181 | {"scheduler_name": "ReduceLROnPlateau", "patience": 2, "min_lr": 1e-6}, 182 | {"scheduler_name": "LambdaLR", "lr_lambda": lambda x: x * 0.1}, 183 | {"scheduler_name": "StepLR", "step_size": 30}, 184 | {"scheduler_name": "MultiStepLR", "milestones": [20, 40]}, 185 | {"scheduler_name": "ExponentialLR", "gamma": 0.1}, 186 | {"scheduler_name": "CosineAnnealingLR", "T_max": 100}, 187 | ], 188 | ) 189 | @pytest.mark.parametrize( 190 | "test_dataloader", [test_loader_clf, None] 191 | ) # ReduceLROnPlateau scheduling should work when there is no test data 192 | @pytest.mark.parametrize( 193 | "n_estimators", [1, 5] 194 | ) # LR scheduling works for 1 as well as many estimators 195 | @pytest.mark.parametrize("ensemble_model", all_clf) 196 | def test_fit_w_all_schedulers_clf( 197 | scheduler_dict, test_dataloader, n_estimators, ensemble_model 198 | ): 199 | """Test if LR schedulers work when `fit` is called.""" 200 | model = ensemble_model( 201 | estimator=MLP_clf, n_estimators=n_estimators, cuda=False 202 | ) 203 | model.set_optimizer("Adam", lr=1e-1) 204 | model.set_scheduler(**scheduler_dict) 205 | model.fit( 206 | train_loader_clf, 207 | epochs=50, 208 | test_loader=test_dataloader, 209 | save_model=False, 210 | ) 211 | 212 | 213 | def test_set_scheduler_Unknown(): 214 | model = MLP() 215 | optimizer = torch.optim.Adam(model.parameters()) 216 | 217 | with pytest.raises(NotImplementedError) as excinfo: 218 | torchensemble.utils.set_module.set_scheduler(optimizer, "Unknown") 219 | assert "Unrecognized scheduler" in str(excinfo.value) 220 | -------------------------------------------------------------------------------- /torchensemble/tests/test_tb_logging.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.utils.data import TensorDataset, DataLoader 6 | 7 | import torchensemble 8 | from torchensemble.utils import io 9 | from torchensemble.utils.logging import set_logger 10 | 11 | 12 | # All classifiers 13 | all_clf = [ 14 | torchensemble.FusionClassifier, 15 | torchensemble.VotingClassifier, 16 | torchensemble.BaggingClassifier, 17 | torchensemble.GradientBoostingClassifier, 18 | torchensemble.SnapshotEnsembleClassifier, 19 | torchensemble.AdversarialTrainingClassifier, 20 | torchensemble.FastGeometricClassifier, 21 | ] 22 | 23 | 24 | # All regressors 25 | all_reg = [ 26 | torchensemble.FusionRegressor, 27 | torchensemble.VotingRegressor, 28 | torchensemble.BaggingRegressor, 29 | torchensemble.GradientBoostingRegressor, 30 | torchensemble.SnapshotEnsembleRegressor, 31 | torchensemble.AdversarialTrainingRegressor, 32 | torchensemble.FastGeometricRegressor, 33 | ] 34 | 35 | 36 | np.random.seed(0) 37 | torch.manual_seed(0) 38 | set_logger("pytest_all_models", use_tb_logger=True) 39 | 40 | 41 | # Base estimator 42 | class MLP_clf(nn.Module): 43 | def __init__(self): 44 | super(MLP_clf, self).__init__() 45 | self.linear1 = nn.Linear(2, 2) 46 | self.linear2 = nn.Linear(2, 2) 47 | 48 | def forward(self, X): 49 | X = X.view(X.size()[0], -1) 50 | output = self.linear1(X) 51 | output = self.linear2(output) 52 | return output 53 | 54 | 55 | class MLP_reg(nn.Module): 56 | def __init__(self): 57 | super(MLP_reg, self).__init__() 58 | self.linear1 = nn.Linear(2, 2) 59 | self.linear2 = nn.Linear(2, 1) 60 | 61 | def forward(self, X): 62 | X = X.view(X.size()[0], -1) 63 | output = self.linear1(X) 64 | output = self.linear2(output) 65 | return output 66 | 67 | 68 | # Trainining data 69 | X_train = torch.Tensor( 70 | np.array(([0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4])) 71 | ) 72 | 73 | y_train_clf = torch.LongTensor(np.array(([0, 0, 1, 1]))) 74 | y_train_reg = torch.FloatTensor(np.array(([0.1, 0.2, 0.3, 0.4]))) 75 | y_train_reg = y_train_reg.view(-1, 1) 76 | 77 | 78 | # Testing data 79 | numpy_X_test = np.array(([0.5, 0.5], [0.6, 0.6])) 80 | X_test = torch.Tensor(numpy_X_test) 81 | 82 | y_test_clf = torch.LongTensor(np.array(([1, 0]))) 83 | y_test_reg = torch.FloatTensor(np.array(([0.5, 0.6]))) 84 | y_test_reg = y_test_reg.view(-1, 1) 85 | 86 | 87 | @pytest.mark.parametrize("clf", all_clf) 88 | def test_clf_class(clf): 89 | """ 90 | This unit test checks the training and evaluating stage of all classifiers. 91 | """ 92 | epochs = 1 93 | n_estimators = 2 94 | 95 | model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) 96 | 97 | # Optimizer 98 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 99 | 100 | # Scheduler (Snapshot Ensemble Excluded) 101 | if not isinstance(model, torchensemble.SnapshotEnsembleClassifier): 102 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 103 | 104 | # Prepare data 105 | train = TensorDataset(X_train, y_train_clf) 106 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 107 | test = TensorDataset(X_test, y_test_clf) 108 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 109 | 110 | # Snapshot Ensemble needs more epochs 111 | if isinstance(model, torchensemble.SnapshotEnsembleClassifier): 112 | epochs = 6 113 | 114 | # Train 115 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 116 | 117 | # Evaluate 118 | model.evaluate(test_loader) 119 | 120 | # Predict 121 | for _, (data, target) in enumerate(test_loader): 122 | model.predict(data) 123 | break 124 | 125 | # Reload 126 | new_model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) 127 | io.load(new_model) 128 | 129 | new_model.evaluate(test_loader) 130 | 131 | for _, (data, target) in enumerate(test_loader): 132 | new_model.predict(data) 133 | break 134 | 135 | 136 | @pytest.mark.parametrize("clf", all_clf) 137 | def test_clf_object(clf): 138 | """ 139 | This unit test checks the training and evaluating stage of all classifiers. 140 | """ 141 | epochs = 1 142 | n_estimators = 2 143 | 144 | model = clf(estimator=MLP_clf(), n_estimators=n_estimators, cuda=False) 145 | 146 | # Optimizer 147 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 148 | 149 | # Scheduler (Snapshot Ensemble Excluded) 150 | if not isinstance(model, torchensemble.SnapshotEnsembleClassifier): 151 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 152 | 153 | # Prepare data 154 | train = TensorDataset(X_train, y_train_clf) 155 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 156 | test = TensorDataset(X_test, y_test_clf) 157 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 158 | 159 | # Snapshot Ensemble needs more epochs 160 | if isinstance(model, torchensemble.SnapshotEnsembleClassifier): 161 | epochs = 6 162 | 163 | # Train 164 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 165 | 166 | # Evaluate 167 | model.evaluate(test_loader) 168 | 169 | # Predict 170 | for _, (data, target) in enumerate(test_loader): 171 | model.predict(data) 172 | break 173 | 174 | # Reload 175 | new_model = clf(estimator=MLP_clf(), n_estimators=n_estimators, cuda=False) 176 | io.load(new_model) 177 | 178 | new_model.evaluate(test_loader) 179 | 180 | for _, (data, target) in enumerate(test_loader): 181 | new_model.predict(data) 182 | break 183 | 184 | 185 | @pytest.mark.parametrize("reg", all_reg) 186 | def test_reg_class(reg): 187 | """ 188 | This unit test checks the training and evaluating stage of all regressors. 189 | """ 190 | epochs = 1 191 | n_estimators = 2 192 | 193 | model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) 194 | 195 | # Optimizer 196 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 197 | 198 | # Scheduler (Snapshot Ensemble Excluded) 199 | if not isinstance(model, torchensemble.SnapshotEnsembleRegressor): 200 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 201 | 202 | # Prepare data 203 | train = TensorDataset(X_train, y_train_reg) 204 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 205 | test = TensorDataset(X_test, y_test_reg) 206 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 207 | 208 | # Snapshot Ensemble needs more epochs 209 | if isinstance(model, torchensemble.SnapshotEnsembleRegressor): 210 | epochs = 6 211 | 212 | # Train 213 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 214 | 215 | # Evaluate 216 | model.evaluate(test_loader) 217 | 218 | # Predict 219 | for _, (data, target) in enumerate(test_loader): 220 | model.predict(data) 221 | break 222 | 223 | # Reload 224 | new_model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) 225 | io.load(new_model) 226 | 227 | new_model.evaluate(test_loader) 228 | 229 | for _, (data, target) in enumerate(test_loader): 230 | new_model.predict(data) 231 | break 232 | 233 | 234 | @pytest.mark.parametrize("reg", all_reg) 235 | def test_reg_object(reg): 236 | """ 237 | This unit test checks the training and evaluating stage of all regressors. 238 | """ 239 | epochs = 1 240 | n_estimators = 2 241 | 242 | model = reg(estimator=MLP_reg(), n_estimators=n_estimators, cuda=False) 243 | 244 | # Optimizer 245 | model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) 246 | 247 | # Scheduler (Snapshot Ensemble Excluded) 248 | if not isinstance(model, torchensemble.SnapshotEnsembleRegressor): 249 | model.set_scheduler("MultiStepLR", milestones=[2, 4]) 250 | 251 | # Prepare data 252 | train = TensorDataset(X_train, y_train_reg) 253 | train_loader = DataLoader(train, batch_size=2, shuffle=False) 254 | test = TensorDataset(X_test, y_test_reg) 255 | test_loader = DataLoader(test, batch_size=2, shuffle=False) 256 | 257 | # Snapshot Ensemble needs more epochs 258 | if isinstance(model, torchensemble.SnapshotEnsembleRegressor): 259 | epochs = 6 260 | 261 | # Train 262 | model.fit(train_loader, epochs=epochs, test_loader=test_loader) 263 | 264 | # Evaluate 265 | model.evaluate(test_loader) 266 | 267 | # Predict 268 | for _, (data, target) in enumerate(test_loader): 269 | model.predict(data) 270 | break 271 | 272 | # Reload 273 | new_model = reg(estimator=MLP_reg(), n_estimators=n_estimators, cuda=False) 274 | io.load(new_model) 275 | 276 | new_model.evaluate(test_loader) 277 | 278 | for _, (data, target) in enumerate(test_loader): 279 | new_model.predict(data) 280 | break 281 | -------------------------------------------------------------------------------- /torchensemble/tests/test_training_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.utils.data import TensorDataset, DataLoader 6 | 7 | import torchensemble 8 | from torchensemble.utils.logging import set_logger 9 | 10 | 11 | parallel = [ 12 | torchensemble.FusionClassifier, 13 | torchensemble.VotingClassifier, 14 | torchensemble.BaggingClassifier, 15 | torchensemble.SoftGradientBoostingClassifier, 16 | ] 17 | 18 | set_logger("pytest_training_params") 19 | 20 | 21 | # Base estimator 22 | class MLP(nn.Module): 23 | def __init__(self): 24 | super(MLP, self).__init__() 25 | self.linear1 = nn.Linear(2, 2) 26 | self.linear2 = nn.Linear(2, 2) 27 | 28 | def forward(self, X): 29 | X = X.view(X.size()[0], -1) 30 | output = self.linear1(X) 31 | output = self.linear2(output) 32 | return output 33 | 34 | 35 | # Trainining data 36 | X_train = torch.Tensor(np.array(([1, 1], [2, 2], [3, 3], [4, 4]))) 37 | 38 | y_train = torch.LongTensor(np.array(([0, 0, 1, 1]))) 39 | 40 | 41 | # Prepare data 42 | train = TensorDataset(X_train, y_train) 43 | train_loader = DataLoader(train, batch_size=2) 44 | 45 | 46 | @pytest.mark.parametrize("method", parallel) 47 | def test_parallel(method): 48 | model = method(estimator=MLP, n_estimators=2, cuda=False) 49 | 50 | # Epochs 51 | with pytest.raises(ValueError) as excinfo: 52 | model.fit(train_loader, epochs=-1) 53 | assert "number of training epochs" in str(excinfo.value) 54 | 55 | # Log interval 56 | with pytest.raises(ValueError) as excinfo: 57 | model.fit(train_loader, log_interval=-1) 58 | assert "number of batches to wait" in str(excinfo.value) 59 | 60 | 61 | def test_gradient_boosting(): 62 | model = torchensemble.GradientBoostingClassifier( 63 | estimator=MLP, n_estimators=2, cuda=False 64 | ) 65 | 66 | # Epochs 67 | with pytest.raises(ValueError) as excinfo: 68 | model.fit(train_loader, epochs=-1) 69 | assert "number of training epochs" in str(excinfo.value) 70 | 71 | # Log interval 72 | with pytest.raises(ValueError) as excinfo: 73 | model.fit(train_loader, log_interval=-1) 74 | assert "number of batches to wait" in str(excinfo.value) 75 | 76 | # Early stoppping round 77 | with pytest.raises(ValueError) as excinfo: 78 | model.fit(train_loader, early_stopping_rounds=0) 79 | assert "number of tolerant rounds" in str(excinfo.value) 80 | 81 | # Shrinkage rate 82 | model = torchensemble.GradientBoostingClassifier( 83 | estimator=MLP, n_estimators=2, shrinkage_rate=2, cuda=False 84 | ) 85 | with pytest.raises(ValueError) as excinfo: 86 | model.fit(train_loader) 87 | assert "shrinkage rate should be in the range" in str(excinfo.value) 88 | 89 | 90 | def test_snapshot_ensemble(): 91 | model = torchensemble.SnapshotEnsembleClassifier( 92 | estimator=MLP, n_estimators=2, cuda=False 93 | ) 94 | 95 | # Learning rate clip 96 | with pytest.raises(ValueError) as excinfo: 97 | model.fit(train_loader, lr_clip=1) 98 | assert "lr_clip should be a list or tuple" in str(excinfo.value) 99 | 100 | with pytest.raises(ValueError) as excinfo: 101 | model.fit(train_loader, lr_clip=[0, 1, 2]) 102 | assert "lr_clip should only have two elements" in str(excinfo.value) 103 | 104 | with pytest.raises(ValueError) as excinfo: 105 | model.fit(train_loader, lr_clip=[1, 0]) 106 | assert "should be smaller than the second" in str(excinfo.value) 107 | 108 | # Epochs 109 | with pytest.raises(ValueError) as excinfo: 110 | model.fit(train_loader, epochs=-1) 111 | assert "number of training epochs" in str(excinfo.value) 112 | 113 | # Log interval 114 | with pytest.raises(ValueError) as excinfo: 115 | model.fit(train_loader, log_interval=-1) 116 | assert "number of batches to wait" in str(excinfo.value) 117 | 118 | # Division 119 | with pytest.raises(ValueError) as excinfo: 120 | model.fit(train_loader, epochs=5) 121 | assert "should be a multiple of n_estimators" in str(excinfo.value) 122 | 123 | 124 | def test_adversarial_training(): 125 | model = torchensemble.AdversarialTrainingClassifier( 126 | estimator=MLP, n_estimators=2, cuda=False 127 | ) 128 | 129 | # Epochs 130 | with pytest.raises(ValueError) as excinfo: 131 | model.fit(train_loader, epochs=-1) 132 | assert "number of training epochs" in str(excinfo.value) 133 | 134 | # Epsilon 135 | with pytest.raises(ValueError) as excinfo: 136 | model.fit(train_loader, epsilon=2) 137 | assert "step used to generate adversarial samples" in str(excinfo.value) 138 | 139 | # Log interval 140 | with pytest.raises(ValueError) as excinfo: 141 | model.fit(train_loader, log_interval=-1) 142 | assert "number of batches to wait" in str(excinfo.value) 143 | 144 | 145 | def test_neural_forest(): 146 | model = torchensemble.NeuralForestClassifier(n_estimators=2, depth=-1) 147 | 148 | with pytest.raises(ValueError) as excinfo: 149 | model.fit(train_loader, epochs=1) 150 | assert "tree depth should be strictly positive" in str(excinfo.value) 151 | 152 | model = torchensemble.NeuralForestClassifier( 153 | n_estimators=2, 154 | depth=3, 155 | lamda=-1e-3, 156 | ) 157 | 158 | with pytest.raises(ValueError) as excinfo: 159 | model.fit(train_loader, epochs=1) 160 | assert "coefficient of the regularization term" in str(excinfo.value) 161 | -------------------------------------------------------------------------------- /torchensemble/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TorchEnsemble-Community/Ensemble-Pytorch/0e25f1f3f4ebe712288c5d37027d725df991f253/torchensemble/utils/__init__.py -------------------------------------------------------------------------------- /torchensemble/utils/dataloder.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | 4 | class FixedDataLoader(object): 5 | def __init__(self, dataloader): 6 | # Check input 7 | if not isinstance(dataloader, DataLoader): 8 | msg = ( 9 | "The input used to instantiate FixedDataLoader should be a" 10 | " DataLoader from `torch.utils.data`." 11 | ) 12 | raise ValueError(msg) 13 | 14 | self.elem_list = [] 15 | for _, elem in enumerate(dataloader): 16 | self.elem_list.append(elem) 17 | 18 | def __getitem__(self, index): 19 | return self.elem_list[index] 20 | 21 | def __len__(self): 22 | return len(self.elem_list) 23 | -------------------------------------------------------------------------------- /torchensemble/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def save(model, save_dir, logger): 6 | """Implement model serialization to the specified directory.""" 7 | if save_dir is None: 8 | save_dir = "./" 9 | 10 | if not os.path.isdir(save_dir): 11 | os.mkdir(save_dir) 12 | 13 | # Decide the base estimator name 14 | if isinstance(model.base_estimator_, type): 15 | base_estimator_name = model.base_estimator_.__name__ 16 | else: 17 | base_estimator_name = model.base_estimator_.__class__.__name__ 18 | 19 | # {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators} 20 | filename = "{}_{}_{}_ckpt.pth".format( 21 | type(model).__name__, 22 | base_estimator_name, 23 | model.n_estimators, 24 | ) 25 | 26 | # The real number of base estimators in some ensembles is not same as 27 | # `n_estimators`. 28 | state = { 29 | "n_estimators": len(model.estimators_), 30 | "model": model.state_dict(), 31 | "_criterion": model._criterion, 32 | "n_outputs": model.n_outputs, 33 | } 34 | 35 | if hasattr(model, "n_inputs"): 36 | state.update({"n_inputs": model.n_inputs}) 37 | 38 | save_dir = os.path.join(save_dir, filename) 39 | 40 | logger.info("Saving the model to `{}`".format(save_dir)) 41 | 42 | # Save 43 | torch.save(state, save_dir) 44 | 45 | return 46 | 47 | 48 | def load(model, save_dir="./", map_location=None, logger=None): 49 | """Implement model deserialization from the specified directory.""" 50 | if not os.path.exists(save_dir): 51 | raise FileExistsError("`{}` does not exist".format(save_dir)) 52 | 53 | # Decide the base estimator name 54 | if isinstance(model.base_estimator_, type): 55 | base_estimator_name = model.base_estimator_.__name__ 56 | else: 57 | base_estimator_name = model.base_estimator_.__class__.__name__ 58 | 59 | # {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators} 60 | filename = "{}_{}_{}_ckpt.pth".format( 61 | type(model).__name__, 62 | base_estimator_name, 63 | model.n_estimators, 64 | ) 65 | save_dir = os.path.join(save_dir, filename) 66 | 67 | if logger: 68 | logger.info("Loading the model from `{}`".format(save_dir)) 69 | 70 | state = torch.load(save_dir, map_location=map_location) 71 | n_estimators = state["n_estimators"] 72 | model_params = state["model"] 73 | model._criterion = state["_criterion"] 74 | model.n_outputs = state["n_outputs"] 75 | if "n_inputs" in state: 76 | model.n_inputs = state["n_inputs"] 77 | 78 | # Pre-allocate and load all base estimators 79 | for _ in range(n_estimators): 80 | model.estimators_.append(model._make_estimator()) 81 | model.load_state_dict(model_params) 82 | 83 | 84 | def split_data_target(element, device, logger=None): 85 | """Split elements in dataloader according to pre-defined rules.""" 86 | if not (isinstance(element, list) or isinstance(element, tuple)): 87 | msg = ( 88 | "Invalid dataloader, please check if the input dataloder is valid." 89 | ) 90 | if logger: 91 | logger.error(msg) 92 | raise ValueError(msg) 93 | 94 | if len(element) == 2: 95 | # Dataloader with one input and one target 96 | data, target = element[0], element[1] 97 | return [data.to(device)], target.to(device) # tensor -> list 98 | elif len(element) > 2: 99 | # Dataloader with multiple inputs and one target 100 | data, target = element[:-1], element[-1] 101 | data_device = [tensor.to(device) for tensor in data] 102 | return data_device, target.to(device) 103 | else: 104 | # Dataloader with invalid input 105 | msg = ( 106 | "The input dataloader should at least contain two tensors - data" 107 | " and target." 108 | ) 109 | if logger: 110 | logger.error(msg) 111 | raise ValueError(msg) 112 | -------------------------------------------------------------------------------- /torchensemble/utils/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | 5 | 6 | _tb_logger = None 7 | __all__ = ["set_logger", "get_tb_logger"] 8 | 9 | 10 | def set_logger( 11 | log_file=None, 12 | log_console_level="info", 13 | log_file_level=None, 14 | use_tb_logger=False, 15 | ): 16 | """Bind the default logger with console and file stream output.""" 17 | 18 | def _get_level(level): 19 | if level.lower() == "debug": 20 | return logging.DEBUG 21 | elif level.lower() == "info": 22 | return logging.INFO 23 | elif level.lower() == "warning": 24 | return logging.WARN 25 | elif level.lower() == "error": 26 | return logging.ERROR 27 | elif level.lower() == "critical": 28 | return logging.CRITICAL 29 | else: 30 | msg = ( 31 | "`log_console_level` must be one of {{DEBUG, INFO," 32 | " WARNING, ERROR, CRITICAL}}, but got {} instead." 33 | ) 34 | raise ValueError(msg.format(level.upper())) 35 | 36 | _logger = logging.getLogger() 37 | 38 | # Reset 39 | for h in _logger.handlers: 40 | _logger.removeHandler(h) 41 | 42 | rq = time.strftime("%Y_%m_%d_%H_%M", time.localtime(time.time())) 43 | log_path = os.path.join(os.getcwd(), "logs") 44 | 45 | ch_formatter = logging.Formatter( 46 | "%(asctime)s - %(levelname)s: %(message)s" 47 | ) 48 | ch = logging.StreamHandler() 49 | ch.setLevel(_get_level(log_console_level)) 50 | ch.setFormatter(ch_formatter) 51 | _logger.addHandler(ch) 52 | 53 | if log_file is not None: 54 | print("Log will be saved in '{}'.".format(log_path)) 55 | if not os.path.exists(log_path): 56 | os.mkdir(log_path) 57 | print("Create folder 'logs/'") 58 | log_name = os.path.join(log_path, log_file + "-" + rq + ".log") 59 | print("Start logging into file {}...".format(log_name)) 60 | fh = logging.FileHandler(log_name, mode="w") 61 | fh.setLevel( 62 | logging.DEBUG 63 | if log_file_level is None 64 | else _get_level(log_file_level) 65 | ) 66 | fh_formatter = logging.Formatter( 67 | "%(asctime)s - %(filename)s[line:%(lineno)d] - " 68 | "%(levelname)s: %(message)s" 69 | ) 70 | fh.setFormatter(fh_formatter) 71 | _logger.addHandler(fh) 72 | _logger.setLevel("DEBUG") 73 | 74 | if use_tb_logger: 75 | tb_log_path = os.path.join( 76 | log_path, log_file + "-" + rq + "_tb_logger" 77 | ) 78 | os.mkdir(tb_log_path) 79 | init_tb_logger(log_dir=tb_log_path) 80 | 81 | return _logger 82 | 83 | 84 | def init_tb_logger(log_dir): 85 | try: 86 | import tensorboard # noqa: F401 87 | except ModuleNotFoundError: 88 | msg = ( 89 | "Cannot load the module tensorboard. Please make sure that" 90 | " tensorboard is installed." 91 | ) 92 | raise ModuleNotFoundError(msg) 93 | 94 | from torch.utils.tensorboard import SummaryWriter 95 | 96 | global _tb_logger 97 | 98 | if not _tb_logger: 99 | _tb_logger = SummaryWriter(log_dir=log_dir) 100 | 101 | 102 | def get_tb_logger(): 103 | return _tb_logger 104 | -------------------------------------------------------------------------------- /torchensemble/utils/operator.py: -------------------------------------------------------------------------------- 1 | """This module collects operations on tensors used in Ensemble-PyTorch.""" 2 | 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import List 7 | 8 | 9 | __all__ = [ 10 | "average", 11 | "sum_with_multiplicative", 12 | "onehot_encoding", 13 | "pseudo_residual_classification", 14 | "pseudo_residual_regression", 15 | "majority_vote", 16 | "unsqueeze_tensor", 17 | ] 18 | 19 | 20 | def average(outputs): 21 | """Compute the average over a list of tensors with the same size.""" 22 | return sum(outputs) / len(outputs) 23 | 24 | 25 | def sum_with_multiplicative(outputs, factor): 26 | """ 27 | Compuate the summation on a list of tensors, and the result is multiplied 28 | by a multiplicative factor. 29 | """ 30 | return factor * sum(outputs) 31 | 32 | 33 | def onehot_encoding(label, n_classes): 34 | """Conduct one-hot encoding on a label vector.""" 35 | label = label.view(-1) 36 | onehot = torch.zeros(label.size(0), n_classes).float().to(label.device) 37 | onehot.scatter_(1, label.view(-1, 1), 1) 38 | 39 | return onehot 40 | 41 | 42 | def pseudo_residual_classification(target, output, n_classes): 43 | """ 44 | Compute the pseudo residual for classification with cross-entropyloss.""" 45 | y_onehot = onehot_encoding(target, n_classes) 46 | 47 | return y_onehot - F.softmax(output, dim=1) 48 | 49 | 50 | def pseudo_residual_regression(target, output): 51 | """Compute the pseudo residual for regression with least square error.""" 52 | if target.size() != output.size(): 53 | msg = "The shape of target {} should be the same as output {}." 54 | raise ValueError(msg.format(target.size(), output.size())) 55 | 56 | return target - output 57 | 58 | 59 | def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor: 60 | """Compute the majority vote for a list of model outputs. 61 | outputs: list of length (n_models) 62 | containing tensors with shape (n_samples, n_classes) 63 | majority_one_hots: (n_samples, n_classes) 64 | """ 65 | 66 | if len(outputs[0].shape) != 2: 67 | msg = """The shape of outputs should be a list tensors of 68 | length (n_models) with sizes (n_samples, n_classes). 69 | The first tensor had shape {} """ 70 | raise ValueError(msg.format(outputs[0].shape)) 71 | 72 | votes = torch.stack(outputs).argmax(dim=2).mode(dim=0)[0] 73 | proba = torch.zeros_like(outputs[0]) 74 | majority_one_hots = proba.scatter_(1, votes.view(-1, 1), 1) 75 | 76 | return majority_one_hots 77 | 78 | 79 | def unsqueeze_tensor(tensor: torch.Tensor, dim=1) -> torch.Tensor: 80 | """Reshape 1-D tensor to 2-D for downstream operations.""" 81 | if tensor.ndim == 1: 82 | tensor = torch.unsqueeze(tensor, dim) 83 | 84 | return tensor 85 | -------------------------------------------------------------------------------- /torchensemble/utils/set_module.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def set_optimizer(model, optimizer_name, **kwargs): 5 | """ 6 | Set the parameter optimizer for the model. 7 | 8 | Reference: https://pytorch.org/docs/stable/optim.html#algorithms 9 | """ 10 | 11 | torch_optim_optimizers = [ 12 | "Adadelta", 13 | "Adagrad", 14 | "Adam", 15 | "AdamW", 16 | "Adamax", 17 | "ASGD", 18 | "RMSprop", 19 | "Rprop", 20 | "SGD", 21 | ] 22 | if optimizer_name not in torch_optim_optimizers: 23 | msg = "Unrecognized optimizer: {}, should be one of {}." 24 | raise NotImplementedError( 25 | msg.format(optimizer_name, ",".join(torch_optim_optimizers)) 26 | ) 27 | 28 | optimizer_cls = getattr( 29 | importlib.import_module("torch.optim"), optimizer_name 30 | ) 31 | optimizer = optimizer_cls(model.parameters(), **kwargs) 32 | 33 | return optimizer 34 | 35 | 36 | def update_lr(optimizer, lr): 37 | """ 38 | Manually update the learning rate of the optimizer. This function is used 39 | when the parallelization corrupts the bindings between the optimizer and 40 | the scheduler. 41 | """ 42 | 43 | if not lr > 0: 44 | msg = ( 45 | "The learning rate should be strictly positive, but got" 46 | " {} instead." 47 | ) 48 | raise ValueError(msg.format(lr)) 49 | 50 | for group in optimizer.param_groups: 51 | group["lr"] = lr 52 | 53 | return optimizer 54 | 55 | 56 | def set_scheduler(optimizer, scheduler_name, **kwargs): 57 | """ 58 | Set the scheduler on learning rate for the optimizer. 59 | 60 | Reference: 61 | https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate 62 | """ 63 | 64 | torch_lr_schedulers = [ 65 | "LambdaLR", 66 | "MultiplicativeLR", 67 | "StepLR", 68 | "MultiStepLR", 69 | "ExponentialLR", 70 | "CosineAnnealingLR", 71 | "ReduceLROnPlateau", 72 | "CyclicLR", 73 | "OneCycleLR", 74 | "CosineAnnealingWarmRestarts", 75 | ] 76 | if scheduler_name not in torch_lr_schedulers: 77 | msg = "Unrecognized scheduler: {}, should be one of {}." 78 | raise NotImplementedError( 79 | msg.format(scheduler_name, ",".join(torch_lr_schedulers)) 80 | ) 81 | 82 | scheduler_cls = getattr( 83 | importlib.import_module("torch.optim.lr_scheduler"), scheduler_name 84 | ) 85 | scheduler = scheduler_cls(optimizer, **kwargs) 86 | 87 | return scheduler 88 | --------------------------------------------------------------------------------