├── .azure └── gpu-tests.yml ├── .codecov.yml ├── .github ├── CODEOWNERS ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ ├── feature_request.md │ └── how-to-question.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── labeler.yml ├── mergify.yml ├── stale.yml └── workflows │ ├── ci-checks.yml │ ├── ci-docs.yml │ ├── ci-tests.yml │ ├── docs-check.yml │ ├── labeler.yml │ └── pypi-release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _images │ ├── gans │ │ ├── basic_gan_dloss.jpg │ │ ├── basic_gan_gloss.jpg │ │ ├── basic_gan_interpolate.jpg │ │ ├── dcgan_lsun_dloss.png │ │ ├── dcgan_lsun_gloss.png │ │ ├── dcgan_lsun_outputs.png │ │ ├── dcgan_mnist_dloss.png │ │ ├── dcgan_mnist_gloss.png │ │ ├── dcgan_mnist_outputs.png │ │ ├── srgan-celeba-scale_factor=2.png │ │ ├── srgan-celeba-scale_factor=4.png │ │ ├── srgan-mnist-scale_factor=2.png │ │ ├── srgan-mnist-scale_factor=4.png │ │ ├── srgan-stl10-scale_factor=2.png │ │ └── srgan-stl10-scale_factor=4.png │ ├── logos │ │ ├── bolts_logo.png │ │ ├── lightning_icon.svg │ │ ├── lightning_logo-large.svg │ │ ├── lightning_logo-name.svg │ │ ├── lightning_logo.png │ │ └── lightning_logo.svg │ ├── rl_benchmark │ │ ├── cartpole_a2c_results.jpg │ │ ├── dqn_ddqn_comparison.jpg │ │ ├── pendulum_sac_results.jpg │ │ ├── pong_double_dqn_baseline_results.jpg │ │ ├── pong_dqn_baseline_results.jpg │ │ ├── pong_dueling_dqn_comparison.jpg │ │ ├── pong_dueling_dqn_results.jpg │ │ ├── pong_noisy_dqn_comparison.jpg │ │ ├── pong_noisy_dqn_results.jpg │ │ ├── pong_nstep_dqn_1.jpg │ │ ├── pong_nstep_dqn_2.jpg │ │ ├── pong_per_dqn_baseline_results.jpg │ │ ├── pong_per_dqn_baseline_v1_results.jpg │ │ ├── pong_per_dqn_baseline_v1_results_comp.jpg │ │ └── pong_per_dqn_dqn_comparison.jpg │ └── vision │ │ └── confused_logit.png │ ├── _static │ └── copybutton.js │ ├── _templates │ ├── layout.html │ └── theme_variables.jinja │ ├── callbacks │ ├── monitor.rst │ ├── self_supervised.rst │ ├── sparseml.rst │ ├── torch_ort.rst │ ├── variational.rst │ └── vision.rst │ ├── conf.py │ ├── dataloaders │ └── async.rst │ ├── datamodules │ ├── sklearn.rst │ └── vision.rst │ ├── datasets │ └── debug.rst │ ├── governance.rst │ ├── index.rst │ ├── installation.rst │ ├── introduction_guide.rst │ ├── losses.rst │ ├── models │ ├── autoencoders.rst │ ├── classic_ml.rst │ ├── convolutional.rst │ ├── gans.rst │ ├── models_howto.rst │ ├── object_detection.rst │ ├── reinforce_learn.rst │ └── self_supervised.rst │ ├── schedulers │ └── warmup_cosine_annealing.rst │ ├── stability.rst │ ├── transforms │ ├── self_supervised.rst │ └── semi_supervised.rst │ └── vision_tasks.rst ├── pyproject.toml ├── requirements.txt ├── requirements ├── base.txt ├── devel.txt ├── loggers.txt ├── models.txt ├── test.txt └── typing.txt ├── setup.py ├── src └── pl_bolts │ ├── __about__.py │ ├── __init__.py │ ├── callbacks │ ├── __init__.py │ ├── byol_updates.py │ ├── data_monitor.py │ ├── knn_online.py │ ├── printing.py │ ├── sparseml.py │ ├── ssl_online.py │ ├── torch_ort.py │ ├── variational.py │ ├── verification │ │ ├── __init__.py │ │ ├── base.py │ │ └── batch_gradient.py │ └── vision │ │ ├── __init__.py │ │ ├── confused_logit.py │ │ ├── image_generation.py │ │ └── sr_image_logger.py │ ├── datamodules │ ├── __init__.py │ ├── async_dataloader.py │ ├── binary_emnist_datamodule.py │ ├── binary_mnist_datamodule.py │ ├── cifar10_datamodule.py │ ├── cityscapes_datamodule.py │ ├── emnist_datamodule.py │ ├── experience_source.py │ ├── fashion_mnist_datamodule.py │ ├── imagenet_datamodule.py │ ├── kitti_datamodule.py │ ├── mnist_datamodule.py │ ├── sklearn_datamodule.py │ ├── sr_datamodule.py │ ├── ssl_imagenet_datamodule.py │ ├── stl10_datamodule.py │ ├── vision_datamodule.py │ └── vocdetection_datamodule.py │ ├── datasets │ ├── __init__.py │ ├── array_dataset.py │ ├── base_dataset.py │ ├── cifar10_dataset.py │ ├── concat_dataset.py │ ├── dummy_dataset.py │ ├── emnist_dataset.py │ ├── imagenet_dataset.py │ ├── kitti_dataset.py │ ├── mnist_dataset.py │ ├── sr_celeba_dataset.py │ ├── sr_dataset_mixin.py │ ├── sr_mnist_dataset.py │ ├── sr_stl10_dataset.py │ ├── ssl_amdim_datasets.py │ └── utils.py │ ├── losses │ ├── __init__.py │ ├── rl.py │ └── self_supervised_learning.py │ ├── metrics │ ├── __init__.py │ └── aggregation.py │ ├── models │ ├── __init__.py │ ├── autoencoders │ │ ├── __init__.py │ │ ├── basic_ae │ │ │ ├── __init__.py │ │ │ └── basic_ae_module.py │ │ ├── basic_vae │ │ │ ├── __init__.py │ │ │ └── basic_vae_module.py │ │ └── components.py │ ├── detection │ │ ├── __init__.py │ │ ├── components │ │ │ ├── __init__.py │ │ │ ├── _supported_models.py │ │ │ └── torchvision_backbones.py │ │ ├── faster_rcnn │ │ │ ├── __init__.py │ │ │ ├── backbones.py │ │ │ └── faster_rcnn_module.py │ │ ├── retinanet │ │ │ ├── __init__.py │ │ │ ├── backbones.py │ │ │ └── retinanet_module.py │ │ └── yolo │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── darknet_network.py │ │ │ ├── layers.py │ │ │ ├── loss.py │ │ │ ├── target_matching.py │ │ │ ├── torch_networks.py │ │ │ ├── types.py │ │ │ ├── utils.py │ │ │ └── yolo_module.py │ ├── gans │ │ ├── __init__.py │ │ ├── basic │ │ │ ├── __init__.py │ │ │ ├── basic_gan_module.py │ │ │ └── components.py │ │ ├── dcgan │ │ │ ├── __init__.py │ │ │ ├── components.py │ │ │ └── dcgan_module.py │ │ ├── pix2pix │ │ │ ├── __init__.py │ │ │ ├── components.py │ │ │ └── pix2pix_module.py │ │ └── srgan │ │ │ ├── __init__.py │ │ │ ├── components.py │ │ │ ├── srgan_module.py │ │ │ └── srresnet_module.py │ ├── mnist_module.py │ ├── regression │ │ ├── __init__.py │ │ ├── linear_regression.py │ │ └── logistic_regression.py │ ├── rl │ │ ├── __init__.py │ │ ├── advantage_actor_critic_model.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── agents.py │ │ │ ├── cli.py │ │ │ ├── distributions.py │ │ │ ├── gym_wrappers.py │ │ │ ├── memory.py │ │ │ └── networks.py │ │ ├── double_dqn_model.py │ │ ├── dqn_model.py │ │ ├── dueling_dqn_model.py │ │ ├── noisy_dqn_model.py │ │ ├── per_dqn_model.py │ │ ├── ppo_model.py │ │ ├── reinforce_model.py │ │ ├── sac_model.py │ │ └── vanilla_policy_gradient_model.py │ ├── self_supervised │ │ ├── __init__.py │ │ ├── amdim │ │ │ ├── __init__.py │ │ │ ├── amdim_module.py │ │ │ ├── datasets.py │ │ │ └── networks.py │ │ ├── byol │ │ │ ├── __init__.py │ │ │ ├── byol_module.py │ │ │ └── models.py │ │ ├── cpc │ │ │ ├── __init__.py │ │ │ ├── cpc_finetuner.py │ │ │ ├── cpc_module.py │ │ │ └── networks.py │ │ ├── evaluator.py │ │ ├── moco │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── callbacks.py │ │ │ ├── moco_module.py │ │ │ └── utils.py │ │ ├── resnets.py │ │ ├── simclr │ │ │ ├── __init__.py │ │ │ ├── simclr_finetuner.py │ │ │ └── simclr_module.py │ │ ├── simsiam │ │ │ ├── __init__.py │ │ │ └── simsiam_module.py │ │ ├── ssl_finetuner.py │ │ └── swav │ │ │ ├── __init__.py │ │ │ ├── loss.py │ │ │ ├── swav_finetuner.py │ │ │ ├── swav_module.py │ │ │ └── swav_resnet.py │ └── vision │ │ ├── __init__.py │ │ ├── image_gpt │ │ ├── __init__.py │ │ ├── gpt2.py │ │ └── igpt_module.py │ │ ├── pixel_cnn.py │ │ ├── segmentation.py │ │ └── unet.py │ ├── optimizers │ ├── __init__.py │ ├── lars.py │ └── lr_scheduler.py │ ├── transforms │ ├── __init__.py │ ├── dataset_normalizations.py │ └── self_supervised │ │ ├── __init__.py │ │ ├── amdim_transforms.py │ │ ├── cpc_transforms.py │ │ ├── moco_transforms.py │ │ ├── simclr_transforms.py │ │ ├── ssl_transforms.py │ │ └── swav_transforms.py │ └── utils │ ├── __init__.py │ ├── _dependency.py │ ├── arguments.py │ ├── pretrained_weights.py │ ├── self_supervised.py │ ├── semi_supervised.py │ ├── shaping.py │ ├── stability.py │ ├── types.py │ └── warnings.py └── tests ├── README.md ├── __init__.py ├── _data_configs ├── yolo.cfg └── yolo_giou.cfg ├── callbacks ├── __init__.py ├── test_data_monitor.py ├── test_info_callbacks.py ├── test_ort.py ├── test_param_update_callbacks.py ├── test_sparseml.py ├── test_variational_callbacks.py └── verification │ ├── __init__.py │ ├── test_base.py │ └── test_batch_gradient.py ├── conftest.py ├── datamodules ├── __init__.py ├── test_dataloader.py ├── test_datamodules.py ├── test_experience_sources.py └── test_sklearn_dataloaders.py ├── datasets ├── __init__.py ├── test_array_dataset.py ├── test_base_dataset.py ├── test_datasets.py └── test_utils.py ├── helpers ├── __init__.py └── boring_model.py ├── losses ├── __init__.py └── test_rl_loss.py ├── metrics ├── __init__.py └── test_aggregation.py ├── models ├── __init__.py ├── gans │ ├── __init__.py │ ├── integration │ │ ├── __init__.py │ │ └── test_gans.py │ └── unit │ │ ├── __init__.py │ │ └── test_basic_components.py ├── regression │ ├── __init__.py │ └── test_logistic_regression.py ├── rl │ ├── __init__.py │ ├── integration │ │ ├── __init__.py │ │ ├── test_actor_critic_models.py │ │ ├── test_policy_models.py │ │ └── test_value_models.py │ └── unit │ │ ├── __init__.py │ │ ├── test_a2c.py │ │ ├── test_agents.py │ │ ├── test_memory.py │ │ ├── test_ppo.py │ │ ├── test_reinforce.py │ │ ├── test_sac.py │ │ ├── test_vpg.py │ │ └── test_wrappers.py ├── self_supervised │ ├── __init__.py │ ├── test_models.py │ ├── test_resnets.py │ └── unit │ │ ├── __init__.py │ │ └── test_transforms.py ├── test_autoencoders.py ├── test_classic_ml.py ├── test_detection.py ├── test_mnist_templates.py ├── test_scripts.py ├── test_vision.py └── yolo │ ├── __init__.py │ └── unit │ ├── __init__.py │ ├── test_darknet_network.py │ ├── test_target_matching.py │ └── test_utils.py ├── optimizers ├── __init__.py └── test_lr_scheduler.py ├── test_utilities.py ├── transforms ├── __init__.py ├── test_normalizations.py └── test_transforms.py └── utils ├── __init__.py ├── test_arguments.py ├── test_dependency.py ├── test_self_supervised.py └── test_semi_supervised.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.codecov.io/docs/codecov-yaml 2 | # Validation check: 3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate 4 | 5 | 6 | # https://docs.codecov.io/docs/codecovyml-reference 7 | codecov: 8 | bot: "codecov-io" 9 | strict_yaml_branch: "yaml-config" 10 | require_ci_to_pass: yes 11 | notify: 12 | after_n_builds: 17 13 | wait_for_ci: yes 14 | 15 | coverage: 16 | precision: 0 # 2 = xx.xx%, 0 = xx% 17 | round: nearest # how coverage is rounded: down/up/nearest 18 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 19 | status: 20 | # https://codecov.readme.io/v1.0/docs/commit-status 21 | project: 22 | default: 23 | informational: true 24 | target: 95% # specify the target coverage for each commit status 25 | threshold: 30% # allow this little decrease on project 26 | # https://github.com/codecov/support/wiki/Filtering-Branches 27 | # branches: master 28 | if_ci_failed: error 29 | # https://github.com/codecov/support/wiki/Patch-Status 30 | patch: 31 | default: 32 | informational: true 33 | threshold: 50% # allow this much decrease on patch 34 | changes: false 35 | 36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations 37 | github_checks: 38 | annotations: false 39 | 40 | parsers: 41 | gcov: 42 | branch_detection: 43 | conditional: true 44 | loop: true 45 | macro: false 46 | method: false 47 | javascript: 48 | enable_partials: false 49 | 50 | comment: 51 | layout: header, diff 52 | require_changes: false 53 | behavior: default # update if exists else create new 54 | # branches: * 55 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This is a comment. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # These owners will be the default owners for everything in 5 | # the repo. Unless a later match takes precedence, 6 | # @global-owner1 and @global-owner2 will be requested for 7 | # review when someone opens a pull request. 8 | * @borda @akihironitta @ethanwharris @awaelchli 9 | 10 | # owners 11 | /.github/CODEOWNERS @borda @ethanwharris 12 | # main 13 | /README.md @borda 14 | # installation 15 | /setup.py @borda @ethanwharris 16 | 17 | # CI/CD 18 | /.github/workflows/ @borda @ethanwharris @akihironitta 19 | /.github/*.yml @borda @ethanwharris @akihironitta 20 | /.azure/ @borda @ethanwharris @akihironitta 21 | /*.yml @borda @ethanwharris @akihironitta 22 | 23 | # Docs 24 | /docs/ @edenlightning @borda 25 | /.github/*.md @edenlightning @borda 26 | /.github/ISSUE_TEMPLATE/*.md @edenlightning @borda 27 | /docs/source/conf.py @borda @akihironitta 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | 13 | ### To Reproduce 14 | 15 | Steps to reproduce the behavior: 16 | 17 | 1. Go to '...' 18 | 1. Run '....' 19 | 1. Scroll down to '....' 20 | 1. See error 21 | 22 | 23 | 24 | #### Code sample 25 | 26 | 28 | 29 | ### Expected behavior 30 | 31 | 32 | 33 | ### Environment 34 | 35 | - PyTorch Version (e.g., 1.0): 36 | - OS (e.g., Linux): 37 | - How you installed PyTorch (`conda`, `pip`, source): 38 | - Build command you used (if compiling from source): 39 | - Python version: 40 | - CUDA/cuDNN version: 41 | - GPU models and configuration: 42 | - Any other relevant information: 43 | 44 | ### Additional context 45 | 46 | 47 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Typos and doc fixes 3 | about: Typos and doc fixes 4 | title: '' 5 | labels: documentation 6 | assignees: '' 7 | --- 8 | 9 | ## 📚 Documentation 10 | 11 | For typos and doc fixes, please go ahead and: 12 | 13 | 1. Create an issue. 14 | 1. Fix the typo. 15 | 1. Submit a PR. 16 | 17 | Thanks! 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | 13 | ### Motivation 14 | 15 | 16 | 17 | ### Pitch 18 | 19 | 20 | 21 | ### Alternatives 22 | 23 | 24 | 25 | ### Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/how-to-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: How to question 3 | about: Asking how-to questions 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | --- 8 | 9 | ## ❓ Questions and Help 10 | 11 | ### Before asking: 12 | 13 | 1. search the issues. 14 | 1. search the docs. 15 | 16 | 17 | 18 | #### What is your question? 19 | 20 | #### Code 21 | 22 | 23 | 24 | #### What have you tried? 25 | 26 | #### What's your environment? 27 | 28 | - OS: \[e.g. iOS, Linux, Win\] 29 | - Packaging \[e.g. pip, conda\] 30 | - Version \[e.g. 0.5.2.1\] 31 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 8 | 9 | Fixes # (issue) 10 | 11 | ## Before submitting 12 | 13 | - [ ] Was this **discussed/approved** via a Github issue? (no need for typos and docs improvements) 14 | - [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/lightning-bolts/blob/master/.github/CONTRIBUTING.md), Pull Request section? 15 | - [ ] Did you make sure your PR does **only one thing**, instead of bundling different changes together? 16 | - [ ] Did you make sure to **update the documentation** with your changes? 17 | - [ ] Did you write any **new necessary tests**? \[not needed for typos/docs\] 18 | - [ ] Did you verify new and **existing tests** pass locally with your changes? 19 | - [ ] If you made a notable change (that affects users), did you update the [CHANGELOG](https://github.com/PyTorchLightning/lightning-bolts/blob/master/CHANGELOG.md)? 20 | 21 | 22 | 23 | ## PR review 24 | 25 | - [ ] Is this pull request **ready for review**? (if not, please submit in draft mode) 26 | 27 | Anyone in the community is free to review the PR once the tests have passed. 28 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged. 29 | 30 | ## Did you have fun? 31 | 32 | Make sure you had fun coding 🙃 33 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 2 | version: 2 3 | updates: 4 | 5 | # Enable version updates for python 6 | - package-ecosystem: "pip" 7 | # Look for a `requirements` in the `root` directory 8 | directory: "/requirements" 9 | # Check for updates once a week 10 | schedule: 11 | interval: "weekly" 12 | # Labels on pull requests for version updates only 13 | labels: ["ci/cd"] 14 | pull-request-branch-name: 15 | # Separate sections of the branch name with a hyphen 16 | # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` 17 | separator: "-" 18 | # Allow up to 5 open pull requests for pip dependencies 19 | open-pull-requests-limit: 5 20 | reviewers: 21 | - "Lightning-AI/core-bolts" 22 | 23 | # Enable version updates for GitHub Actions 24 | - package-ecosystem: "github-actions" 25 | directory: "/" 26 | # Check for updates once a week 27 | schedule: 28 | interval: "weekly" 29 | # Labels on pull requests for version updates only 30 | labels: ["ci/cd"] 31 | pull-request-branch-name: 32 | # Separate sections of the branch name with a hyphen 33 | # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` 34 | separator: "-" 35 | # Allow up to 5 open pull requests for GitHub Actions 36 | open-pull-requests-limit: 5 37 | reviewers: 38 | - "Lightning-AI/core-bolts" 39 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | documentation: 2 | - docs/**/* 3 | 4 | callback: 5 | - pl_bolts/callbacks/**/* 6 | 7 | datamodule: 8 | - pl_bolts/datamodules/**/* 9 | 10 | model: 11 | - pl_bolts/models/**/* 12 | -------------------------------------------------------------------------------- /.github/mergify.yml: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | pull_request_rules: 16 | 17 | - name: warn on conflicts 18 | conditions: 19 | - conflict 20 | - -draft # filter-out GH draft PRs 21 | - -label="has conflicts" 22 | actions: 23 | # comment: 24 | # message: This pull request is now in conflict... :( 25 | label: 26 | add: [ "has conflicts" ] 27 | 28 | - name: resolved conflicts 29 | conditions: 30 | - -conflict 31 | - label="has conflicts" 32 | - -draft # filter-out GH draft PRs 33 | - -merged # not merged yet 34 | - -closed 35 | actions: 36 | label: 37 | remove: [ "has conflicts" ] 38 | 39 | - name: Ready to Go 40 | conditions: 41 | - -conflict 42 | - -draft # filter-out GH draft PRs 43 | - -title~=(?i)wip # skip all PR that title contains “WIP” (ignoring case) 44 | - "#approved-reviews-by>=1" # number of review approvals 45 | - "#changes-requested-reviews-by=0" # no requested changes 46 | - "#check-pending<=2" 47 | - "#check-failure<5" 48 | - "#check-skipped<5" 49 | actions: 50 | label: 51 | add: [ "ready" ] 52 | 53 | - name: Not ready yet 54 | conditions: 55 | - or: 56 | - "#approved-reviews-by=0" # number of review approvals 57 | - "#changes-requested-reviews-by>=1" # no requested changes 58 | - "#check-failure>=5" 59 | - "#check-skipped>=5" 60 | actions: 61 | label: 62 | remove: [ "ready" ] 63 | 64 | - name: update PR 65 | conditions: 66 | - -conflict 67 | - -draft # filter-out GH draft PRs 68 | - base=master # apply only on master 69 | - -title~=(?i)wip # skip all PR that title contains “WIP” (ignoring case) 70 | - "#approved-reviews-by>=1" # number of review approvals 71 | - "#changes-requested-reviews-by=0" # no requested changes 72 | - "#check-failure<5" 73 | actions: 74 | update: {} 75 | 76 | - name: add core reviewer 77 | conditions: 78 | - -conflict # skip if conflict 79 | - -draft # filter-out GH draft PRs 80 | - label="ready" 81 | - "#approved-reviews-by<2" # number of review approvals 82 | - "#review-requested<2" # number of requested reviews 83 | actions: 84 | request_reviews: 85 | teams: 86 | - "@Lightning-Universe/core-bolts" 87 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/stale 2 | 3 | # Number of days of inactivity before an issue becomes stale 4 | daysUntilStale: 60 5 | # Number of days of inactivity before a stale issue is closed 6 | daysUntilClose: 9 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - pinned 10 | - security 11 | # Label to use when marking an issue as stale 12 | staleLabel: won't fix 13 | # Comment to post when marking an issue as stale. Set to `false` to disable 14 | markComment: > 15 | This issue has been automatically marked as stale because it has not had 16 | recent activity. It will be closed if no further activity occurs. Thank you 17 | for your contributions. 18 | # Comment to post when closing a stale issue. Set to `false` to disable 19 | closeComment: false 20 | 21 | # Set to true to ignore issues in a project (defaults to false) 22 | exemptProjects: true 23 | # Set to true to ignore issues in a milestone (defaults to false) 24 | exemptMilestones: true 25 | # Set to true to ignore issues with an assignee (defaults to false) 26 | exemptAssignees: false 27 | -------------------------------------------------------------------------------- /.github/workflows/ci-checks.yml: -------------------------------------------------------------------------------- 1 | name: General Checks 2 | 3 | on: 4 | push: 5 | branches: [master, "release/*"] 6 | pull_request: 7 | branches: [master, "release/*"] 8 | 9 | jobs: 10 | check-schema: 11 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.2 12 | with: 13 | azure-dir: ".azure" 14 | 15 | check-code: 16 | uses: Lightning-AI/utilities/.github/workflows/check-code.yml@v0.11.2 17 | with: 18 | actions-ref: v0.11.2 19 | extra-typing: typing 20 | 21 | check-package: 22 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.2 23 | with: 24 | actions-ref: v0.11.2 25 | artifact-name: dist-packages-${{ github.sha }} 26 | import-name: "pl_bolts" 27 | testing-matrix: | 28 | { 29 | "os": ["ubuntu-20.04", "macos-11", "windows-2022"], 30 | "python-version": ["3.8"] 31 | } 32 | -------------------------------------------------------------------------------- /.github/workflows/ci-docs.yml: -------------------------------------------------------------------------------- 1 | name: RTFD Preview 2 | on: 3 | pull_request_target: 4 | types: 5 | - opened 6 | 7 | permissions: 8 | pull-requests: write 9 | 10 | jobs: 11 | documentation-links: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: readthedocs/actions/preview@v1 15 | with: 16 | project-slug: "lightning-bolts" 17 | -------------------------------------------------------------------------------- /.github/workflows/docs-check.yml: -------------------------------------------------------------------------------- 1 | name: "Docs check" 2 | # https://github.com/marketplace/actions/sphinx-build 3 | 4 | on: # Trigger the workflow on push or pull request, but only for the master branch 5 | push: 6 | branches: [master] 7 | pull_request: {} 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} 12 | 13 | env: 14 | FREEZE_REQUIREMENTS: 1 15 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 16 | TRANSFORMERS_CACHE: _hf_cache 17 | 18 | defaults: 19 | run: 20 | shell: bash 21 | 22 | jobs: 23 | 24 | test-docs: 25 | runs-on: ubuntu-20.04 26 | steps: 27 | - uses: actions/checkout@v4 28 | - uses: actions/setup-python@v5 29 | with: 30 | python-version: 3.9 31 | 32 | - name: Cache pip 33 | uses: actions/cache@v4 34 | with: 35 | path: ~/.cache/pip 36 | key: pip-${{ hashFiles('requirements/*.txt') }} 37 | restore-keys: pip- 38 | 39 | - name: Install package & dependencies 40 | run: | 41 | pip install "pip==22.2.1" # todo: drop after resolving extras 42 | pip install -e . -r requirements/devel.txt -r docs/requirements.txt -f $TORCH_URL 43 | pip list 44 | 45 | - name: Test Documentation 46 | working-directory: docs/ 47 | env: 48 | SPHINX_MOCK_REQUIREMENTS: 0 49 | run: | 50 | make doctest 51 | make coverage 52 | 53 | 54 | make-docs: 55 | runs-on: ubuntu-20.04 56 | steps: 57 | - uses: actions/checkout@v4 58 | - uses: actions/setup-python@v5 59 | with: 60 | python-version: 3.9 61 | 62 | - name: Cache pip 63 | uses: actions/cache@v4 64 | with: 65 | path: ~/.cache/pip 66 | key: pip-${{ hashFiles('requirements/*.txt') }} 67 | restore-keys: pip- 68 | 69 | - name: Install package & dependencies 70 | run: | 71 | # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux 72 | sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures 73 | pip install "pip==22.2.1" # todo: drop after resolving extras 74 | pip install -e . -r docs/requirements.txt -f $TORCH_URL 75 | pip list 76 | 77 | - name: Make Documentation 78 | working-directory: docs/ 79 | run: | 80 | make clean 81 | make html --debug --jobs 2 SPHINXOPTS="-W" 82 | 83 | - name: Upload built docs 84 | uses: actions/upload-artifact@v3 85 | with: 86 | name: docs-results-${{ github.sha }} 87 | path: docs/build/html/ 88 | # Use always() to always run this step to publish test results when there are test failures 89 | if: success() 90 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: Labeler 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened, reopened] 6 | 7 | jobs: 8 | triage: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/labeler@main 12 | with: 13 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 14 | -------------------------------------------------------------------------------- /.github/workflows/pypi-release.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the master branch 5 | push: 6 | branches: [master, "release/*"] 7 | release: 8 | types: [published] 9 | 10 | # based on https://github.com/pypa/gh-action-pypi-publish 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: actions/setup-python@v5 19 | with: 20 | python-version: 3.8 21 | 22 | - name: Install dependencies 23 | run: python -m pip install -U setuptools wheel build twine 24 | - name: Build package 25 | run: python -m build 26 | - name: Check package 27 | run: twine check dist/* 28 | 29 | - name: Upload to release 30 | uses: AButler/upload-release-assets@v3.0 31 | with: 32 | files: 'dist/*' 33 | repo-token: ${{ secrets.GITHUB_TOKEN }} 34 | 35 | # We do this, since failures on test.pypi aren't that bad 36 | - name: Publish to Test PyPI 37 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 38 | uses: pypa/gh-action-pypi-publish@master 39 | with: 40 | user: __token__ 41 | password: ${{ secrets.test_pypi_password }} 42 | repository_url: https://test.pypi.org/legacy/ 43 | 44 | - name: Publish distribution 📦 to PyPI 45 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 46 | uses: pypa/gh-action-pypi-publish@master 47 | with: 48 | user: __token__ 49 | password: ${{ secrets.pypi_password }} 50 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | autoupdate_schedule: monthly 8 | # submodules: true 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v4.4.0 13 | hooks: 14 | - id: end-of-file-fixer 15 | - id: trailing-whitespace 16 | - id: check-case-conflict 17 | - id: check-yaml 18 | - id: check-toml 19 | - id: pretty-format-json 20 | - id: check-added-large-files 21 | - id: check-docstring-first 22 | - id: detect-private-key 23 | 24 | - repo: https://github.com/asottile/pyupgrade 25 | rev: v3.8.0 26 | hooks: 27 | - id: pyupgrade 28 | args: [--py38-plus] 29 | name: Upgrade code 30 | 31 | - repo: https://github.com/PyCQA/docformatter 32 | rev: v1.7.3 33 | hooks: 34 | - id: docformatter 35 | additional_dependencies: [tomli] 36 | args: ["--in-place"] 37 | 38 | - repo: https://github.com/executablebooks/mdformat 39 | rev: 0.7.16 40 | hooks: 41 | - id: mdformat 42 | additional_dependencies: 43 | - mdformat-gfm 44 | - mdformat-black 45 | - mdformat_frontmatter 46 | exclude: CHANGELOG.md 47 | 48 | - repo: https://github.com/psf/black 49 | rev: 23.3.0 50 | hooks: 51 | - id: black 52 | name: Format code 53 | 54 | - repo: https://github.com/asottile/yesqa 55 | rev: v1.5.0 56 | hooks: 57 | - id: yesqa 58 | additional_dependencies: 59 | - flake8-pytest-style 60 | - flake8-bandit 61 | - pep8-naming 62 | 63 | - repo: https://github.com/astral-sh/ruff-pre-commit 64 | rev: v0.0.276 65 | hooks: 66 | - id: ruff 67 | args: ["--fix"] 68 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | fail_on_warning: true 12 | 13 | # Optionally build your docs in additional formats such as PDF and ePub 14 | formats: 15 | - htmlzip 16 | - pdf 17 | 18 | # Optionally set the version of Python and requirements required to build your docs 19 | python: 20 | version: 3.8 21 | install: 22 | - requirements: docs/requirements.txt 23 | #- requirements: requirements.txt 24 | - method: pip 25 | path: . 26 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html 2 | graft wheelhouse 3 | 4 | recursive-exclude __pycache__ *.py[cod] *.orig 5 | 6 | # Include the README and CHANGELOG 7 | include *.md 8 | recursive-include pl_bolts *.md 9 | 10 | # Include the license file 11 | include LICENSE 12 | 13 | exclude *.sh 14 | exclude *.toml 15 | exclude *.svg 16 | recursive-include pytorch_lightning *.py 17 | 18 | # exclude tests from package 19 | recursive-exclude tests * 20 | recursive-exclude site * 21 | exclude tests 22 | 23 | # Exclude the documentation files 24 | recursive-exclude docs * 25 | exclude docs 26 | 27 | # Include the Requirements 28 | include requirements.txt 29 | recursive-include requirements *.txt 30 | 31 | # Exclude build configs 32 | exclude *.yml 33 | 34 | # Exclude Makefile 35 | exclude Makefile 36 | 37 | prune .git 38 | prune .github 39 | prune .circleci 40 | prune notebook* 41 | prune temp* 42 | prune test* 43 | prune benchmark* 44 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test clean docs env 2 | 3 | # assume you have installed need packages 4 | export SPHINX_MOCK_REQUIREMENTS=1 5 | 6 | clean: 7 | # clean all temp runs 8 | rm -rf $(shell find . -name "lightning_log") 9 | rm -rf _ckpt_* 10 | rm -rf .mypy_cache 11 | rm -rf .pytest_cache 12 | rm -rf ./docs/build 13 | rm -rf ./docs/source/generated 14 | rm -rf ./docs/source/*/generated 15 | rm -rf ./docs/source/api 16 | rm -rf ./_datasets 17 | 18 | test: clean env 19 | 20 | # run tests with coverage 21 | python -m coverage run --source pl_bolts -m pytest pl_bolts tests -v 22 | python -m coverage report 23 | 24 | doctest: clean env 25 | pip install --quiet -r docs/requirements.txt 26 | SPHINX_MOCK_REQUIREMENTS=0 make -C docs doctest 27 | 28 | docs: clean 29 | pip install --quiet -r docs/requirements.txt 30 | python -m sphinx -b html -W docs/source docs/build 31 | 32 | env: 33 | pip install -q -r requirements/devel.txt 34 | 35 | format: 36 | isort . 37 | black . 38 | flake8 . 39 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W 6 | SPHINXBUILD = $(shell which sphinx-build) 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx ==6.2 2 | recommonmark # fails with badges 3 | mistune < 2.0 4 | m2r # fails with multi-line text 5 | nbsphinx >0.8, <=0.9.2 6 | pandoc >2.0, <=2.3 7 | # docutils 8 | sphinxcontrib-fulltoc >1.0, <=1.2.0 9 | sphinxcontrib-mockautodoc 10 | sphinx-autodoc-typehints >1.0, <=1.17.1 11 | sphinx-paramlinks >0.4.0, <=0.5.4 12 | sphinx-togglebutton >0.2, <=0.3.2 13 | sphinx-copybutton >0.3, <=0.5.2 14 | 15 | pt-lightning-sphinx-theme @ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip 16 | -------------------------------------------------------------------------------- /docs/source/_images/gans/basic_gan_dloss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/basic_gan_dloss.jpg -------------------------------------------------------------------------------- /docs/source/_images/gans/basic_gan_gloss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/basic_gan_gloss.jpg -------------------------------------------------------------------------------- /docs/source/_images/gans/basic_gan_interpolate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/basic_gan_interpolate.jpg -------------------------------------------------------------------------------- /docs/source/_images/gans/dcgan_lsun_dloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_lsun_dloss.png -------------------------------------------------------------------------------- /docs/source/_images/gans/dcgan_lsun_gloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_lsun_gloss.png -------------------------------------------------------------------------------- /docs/source/_images/gans/dcgan_lsun_outputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_lsun_outputs.png -------------------------------------------------------------------------------- /docs/source/_images/gans/dcgan_mnist_dloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_mnist_dloss.png -------------------------------------------------------------------------------- /docs/source/_images/gans/dcgan_mnist_gloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_mnist_gloss.png -------------------------------------------------------------------------------- /docs/source/_images/gans/dcgan_mnist_outputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_mnist_outputs.png -------------------------------------------------------------------------------- /docs/source/_images/gans/srgan-celeba-scale_factor=2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-celeba-scale_factor=2.png -------------------------------------------------------------------------------- /docs/source/_images/gans/srgan-celeba-scale_factor=4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-celeba-scale_factor=4.png -------------------------------------------------------------------------------- /docs/source/_images/gans/srgan-mnist-scale_factor=2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-mnist-scale_factor=2.png -------------------------------------------------------------------------------- /docs/source/_images/gans/srgan-mnist-scale_factor=4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-mnist-scale_factor=4.png -------------------------------------------------------------------------------- /docs/source/_images/gans/srgan-stl10-scale_factor=2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-stl10-scale_factor=2.png -------------------------------------------------------------------------------- /docs/source/_images/gans/srgan-stl10-scale_factor=4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-stl10-scale_factor=4.png -------------------------------------------------------------------------------- /docs/source/_images/logos/bolts_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/logos/bolts_logo.png -------------------------------------------------------------------------------- /docs/source/_images/logos/lightning_icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/source/_images/logos/lightning_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/logos/lightning_logo.png -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/dqn_ddqn_comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/dqn_ddqn_comparison.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pendulum_sac_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pendulum_sac_results.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_double_dqn_baseline_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_double_dqn_baseline_results.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_dqn_baseline_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_dqn_baseline_results.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_dueling_dqn_comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_dueling_dqn_comparison.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_dueling_dqn_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_dueling_dqn_results.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_noisy_dqn_comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_noisy_dqn_comparison.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_noisy_dqn_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_noisy_dqn_results.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_nstep_dqn_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_nstep_dqn_1.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_nstep_dqn_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_nstep_dqn_2.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_per_dqn_baseline_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_results.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_per_dqn_baseline_v1_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_v1_results.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_per_dqn_baseline_v1_results_comp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_v1_results_comp.jpg -------------------------------------------------------------------------------- /docs/source/_images/rl_benchmark/pong_per_dqn_dqn_comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_per_dqn_dqn_comparison.jpg -------------------------------------------------------------------------------- /docs/source/_images/vision/confused_logit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/vision/confused_logit.png -------------------------------------------------------------------------------- /docs/source/_static/copybutton.js: -------------------------------------------------------------------------------- 1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */ 2 | $(document).ready(function() { 3 | /* Add a [>>>] button on the top-right corner of code samples to hide 4 | * the >>> and ... prompts and the output and thus make the code 5 | * copyable. */ 6 | var div = $('.highlight-python .highlight,' + 7 | '.highlight-python3 .highlight,' + 8 | '.highlight-pycon .highlight,' + 9 | '.highlight-default .highlight'); 10 | var pre = div.find('pre'); 11 | 12 | // get the styles from the current theme 13 | pre.parent().parent().css('position', 'relative'); 14 | var hide_text = 'Hide the prompts and output'; 15 | var show_text = 'Show the prompts and output'; 16 | var border_width = pre.css('border-top-width'); 17 | var border_style = pre.css('border-top-style'); 18 | var border_color = pre.css('border-top-color'); 19 | var button_styles = { 20 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 21 | 'border-color': border_color, 'border-style': border_style, 22 | 'border-width': border_width, 'color': border_color, 'text-size': '75%', 23 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 24 | 'border-radius': '0 3px 0 0' 25 | } 26 | 27 | // create and add the button to all the code blocks that contain >>> 28 | div.each(function(index) { 29 | var jthis = $(this); 30 | if (jthis.find('.gp').length > 0) { 31 | var button = $('>>>'); 32 | button.css(button_styles) 33 | button.attr('title', hide_text); 34 | button.data('hidden', 'false'); 35 | jthis.prepend(button); 36 | } 37 | // tracebacks (.gt) contain bare text elements that need to be 38 | // wrapped in a span to work with .nextUntil() (see later) 39 | jthis.find('pre:has(.gt)').contents().filter(function() { 40 | return ((this.nodeType == 3) && (this.data.trim().length > 0)); 41 | }).wrap(''); 42 | }); 43 | 44 | // define the behavior of the button when it's clicked 45 | $('.copybutton').click(function(e){ 46 | e.preventDefault(); 47 | var button = $(this); 48 | if (button.data('hidden') === 'false') { 49 | // hide the code output 50 | button.parent().find('.go, .gp, .gt').hide(); 51 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); 52 | button.css('text-decoration', 'line-through'); 53 | button.attr('title', show_text); 54 | button.data('hidden', 'true'); 55 | } else { 56 | // show the code output 57 | button.parent().find('.go, .gp, .gt').show(); 58 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); 59 | button.css('text-decoration', 'none'); 60 | button.attr('title', hide_text); 61 | button.data('hidden', 'false'); 62 | } 63 | }); 64 | }); 65 | -------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | 4 | {% block footer %} 5 | {{ super() }} 6 | 9 | 10 | {% endblock %} 11 | -------------------------------------------------------------------------------- /docs/source/_templates/theme_variables.jinja: -------------------------------------------------------------------------------- 1 | {%- set external_urls = { 2 | 'github': 'https://github.com/PytorchLightning/lightning-bolts', 3 | 'github_issues': 'https://github.com/PytorchLightning/lightning-bolts/issues', 4 | 'contributing': 'https://github.com/PytorchLightning/lightning-bolts/blob/master/CONTRIBUTING.md', 5 | 'governance': 'https://github.com/PytorchLightning/pytorch-lightning/blob/master/governance.md', 6 | 'docs': 'https://lightning-bolts.rtfd.io/en/latest', 7 | 'twitter': 'https://twitter.com/PyTorchLightnin', 8 | 'discuss': 'https://pytorch-lightning.slack.com', 9 | 'tutorials': 'https://pytorch-lightning.readthedocs.io/en/latest/#tutorials', 10 | 'previous_pytorch_versions': 'https://pytorch.rtfd.io/en/latest/', 11 | 'home': 'https://lightning-bolts.rtfd.io/en/latest/', 12 | 'get_started': 'https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html', 13 | 'features': 'https://lightning-bolts.rtfd.io/en/latest/', 14 | 'blog': 'https://www.pytorchlightning.ai/blog', 15 | 'resources': 'https://pytorch-lightning.readthedocs.io/en/latest/#community-examples', 16 | 'support': 'https://lightning-bolts.rtfd.io/en/latest/', 17 | 'community': 'https://pytorch-lightning.slack.com', 18 | 'forums': 'https://pytorch-lightning.slack.com', 19 | 'pldocs': 'https://pytorch-lightning.readthedocs.io/en/stable/', 20 | 'tmdocs': 'https://torchmetrics.readthedocs.io/en/stable/', 21 | 'fldocs': 'https://lightning-flash.readthedocs.io/en/stable/', 22 | 'lbdocs': 'https://lightning-bolts.readthedocs.io/en/stable/', 23 | 'gdocs': 'https://docs.grid.ai/', 24 | } 25 | -%} 26 | -------------------------------------------------------------------------------- /docs/source/callbacks/self_supervised.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Self-supervised Callbacks 5 | ========================= 6 | Useful callbacks for self-supervised learning models. 7 | 8 | .. note:: 9 | 10 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 11 | 12 | 13 | --------------- 14 | 15 | BYOLMAWeightUpdate 16 | ------------------ 17 | The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL). 18 | 19 | .. autoclass:: pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate 20 | :noindex: 21 | 22 | ---------------- 23 | 24 | SSLOnlineEvaluator 25 | ------------------ 26 | Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop. 27 | 28 | .. autoclass:: pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator 29 | :noindex: 30 | -------------------------------------------------------------------------------- /docs/source/callbacks/sparseml.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | SparseML Callback 3 | ================= 4 | 5 | `SparseML `__ allows you to leverage sparsity to improve inference times substantially. 6 | 7 | SparseML requires you to fine-tune your model with the ``SparseMLCallback`` + a SparseML Recipe. By training with the ``SparseMLCallback``, you can leverage the `DeepSparse `__ engine to exploit the introduced sparsity, resulting in large performance improvements. 8 | 9 | .. warning:: 10 | 11 | The SparseML callback requires the model to be ONNX exportable. This can be tricky when the model requires dynamic sequence lengths such as RNNs. 12 | 13 | To use leverage SparseML & DeepSparse follow the below steps: 14 | 15 | 1. Choose your Sparse Recipe 16 | ---------------------------- 17 | 18 | To choose a recipe, have a look at `recipes `__ and `Sparse Zoo `__. 19 | 20 | It may be easier to infer a recipe via the UI dashboard using `Sparsify `__ which allows you to tweak and configure a recipe. 21 | This requires to import an ONNX model, which you can get from your ``LightningModule`` by doing ``model.to_onnx(output_path)``. 22 | 23 | 2. Train with SparseMLCallback 24 | ------------------------------ 25 | 26 | .. testcode:: 27 | :skipif: not _SPARSEML_TORCH_SATISFIED 28 | 29 | from pytorch_lightning import LightningModule, Trainer 30 | from pl_bolts.callbacks import SparseMLCallback 31 | 32 | class MyModel(LightningModule): 33 | ... 34 | 35 | model = MyModel() 36 | 37 | trainer = Trainer( 38 | callbacks=SparseMLCallback(recipe_path='recipe.yaml') 39 | ) 40 | 41 | 3. Export to ONNX! 42 | ------------------ 43 | 44 | Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format. 45 | Note this assumes either you have implemented the property ``example_input_array`` in the model or you must provide a sample batch as below. 46 | 47 | .. testcode:: 48 | :skipif: not _SPARSEML_TORCH_SATISFIED 49 | 50 | import torch 51 | 52 | model = MyModel() 53 | ... 54 | 55 | # export the onnx model, using the `model.example_input_array` 56 | SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/') 57 | 58 | # export the onnx model, providing a sample batch 59 | SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/', sample_batch=torch.randn(1, 128, 128, dtype=torch.float32)) 60 | 61 | 62 | Once your model has been exported, you can import this into either `Sparsify `__ or `DeepSparse `__. 63 | -------------------------------------------------------------------------------- /docs/source/callbacks/torch_ort.rst: -------------------------------------------------------------------------------- 1 | ================== 2 | Torch ORT Callback 3 | ================== 4 | 5 | `Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions `here `__. 6 | 7 | This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as `self.model` within the ``LightningModule`` as shown below. 8 | 9 | .. note:: 10 | 11 | Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models. 12 | 13 | .. code-block:: python 14 | 15 | from pytorch_lightning import LightningModule, Trainer 16 | from transformers import AutoModel 17 | 18 | from pl_bolts.callbacks import ORTCallback 19 | 20 | 21 | class MyTransformerModel(LightningModule): 22 | 23 | def __init__(self): 24 | super().__init__() 25 | self.model = AutoModel.from_pretrained('bert-base-cased') 26 | 27 | ... 28 | 29 | 30 | model = MyTransformerModel() 31 | trainer = Trainer(gpus=1, callbacks=ORTCallback()) 32 | trainer.fit(model) 33 | 34 | 35 | For even easier setup and integration, have a look at our Lightning Flash integration for :ref:`Text Classification `, :ref:`Translation ` and :ref:`Summarization `. 36 | -------------------------------------------------------------------------------- /docs/source/callbacks/variational.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Variational Callbacks 5 | ===================== 6 | Useful callbacks for GANs, variational-autoencoders or anything with latent spaces. 7 | 8 | .. note:: 9 | 10 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 11 | 12 | --------------- 13 | 14 | Latent Dim Interpolator 15 | ----------------------- 16 | Interpolates latent dims. 17 | 18 | Example output: 19 | 20 | .. image:: ../_images/gans/basic_gan_interpolate.jpg 21 | :width: 400 22 | :alt: Example latent space interpolation 23 | 24 | .. autoclass:: pl_bolts.callbacks.variational.LatentDimInterpolator 25 | :noindex: 26 | -------------------------------------------------------------------------------- /docs/source/callbacks/vision.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Vision Callbacks 5 | ================ 6 | Useful callbacks for vision models. 7 | 8 | .. note:: 9 | 10 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 11 | 12 | 13 | --------------- 14 | 15 | Confused Logit 16 | -------------- 17 | Shows how the input would have to change to move the prediction from one logit to the other 18 | 19 | 20 | Example outputs: 21 | 22 | .. image:: ../_images/vision/confused_logit.png 23 | :width: 400 24 | :alt: Example of prediction confused between 5 and 8 25 | 26 | .. autoclass:: pl_bolts.callbacks.vision.confused_logit.ConfusedLogitCallback 27 | :noindex: 28 | 29 | 30 | --------------- 31 | 32 | Tensorboard Image Generator 33 | --------------------------- 34 | Generates images from a generative model and plots to tensorboard 35 | 36 | .. autoclass:: pl_bolts.callbacks.vision.image_generation.TensorboardGenerativeModelImageSampler 37 | :noindex: 38 | -------------------------------------------------------------------------------- /docs/source/dataloaders/async.rst: -------------------------------------------------------------------------------- 1 | AsynchronousLoader 2 | ------------------ 3 | This dataloader behaves identically to the standard pytorch dataloader, but will transfer 4 | data asynchronously to the GPU with training. You can also use it to wrap an existing dataloader. 5 | 6 | .. note:: 7 | 8 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 9 | 10 | Example: 11 | 12 | .. code-block:: python 13 | 14 | dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) 15 | 16 | for b in dataloader: 17 | ... 18 | 19 | .. autoclass:: pl_bolts.datamodules.async_dataloader.AsynchronousLoader 20 | :noindex: 21 | -------------------------------------------------------------------------------- /docs/source/datamodules/sklearn.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Sklearn Datamodule 5 | ================== 6 | Utilities to map sklearn or numpy datasets to PyTorch Dataloaders with automatic data splits and GPU/TPU support. 7 | 8 | .. code-block:: python 9 | 10 | from sklearn.datasets import load_diabetes 11 | from pl_bolts.datamodules import SklearnDataModule 12 | 13 | X, y = load_diabetes(return_X_y=True) 14 | loaders = SklearnDataModule(X, y) 15 | 16 | train_loader = loaders.train_dataloader(batch_size=32) 17 | val_loader = loaders.val_dataloader(batch_size=32) 18 | test_loader = loaders.test_dataloader(batch_size=32) 19 | 20 | Or build your own torch datasets 21 | 22 | .. code-block:: python 23 | 24 | from sklearn.datasets import load_diabetes 25 | from pl_bolts.datamodules import SklearnDataset 26 | 27 | X, y = load_diabetes(return_X_y=True) 28 | dataset = SklearnDataset(X, y) 29 | loader = DataLoader(dataset) 30 | 31 | ---------------- 32 | 33 | Sklearn Dataset Class 34 | --------------------- 35 | Transforms a sklearn or numpy dataset to a PyTorch Dataset. 36 | 37 | .. autoclass:: pl_bolts.datamodules.sklearn_datamodule.SklearnDataset 38 | :noindex: 39 | 40 | ---------------- 41 | 42 | Sklearn DataModule Class 43 | ------------------------ 44 | Automatically generates the train, validation and test splits for a Numpy dataset. 45 | They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits. 46 | 47 | .. autoclass:: pl_bolts.datamodules.sklearn_datamodule.SklearnDataModule 48 | -------------------------------------------------------------------------------- /docs/source/datamodules/vision.rst: -------------------------------------------------------------------------------- 1 | Vision DataModules 2 | ================== 3 | The following are pre-built datamodules for computer-vision. 4 | 5 | ------------- 6 | 7 | Supervised learning 8 | ------------------- 9 | These are standard vision datasets with the train, test, val splits pre-generated in DataLoaders with 10 | the standard transforms (and Normalization) values 11 | 12 | BinaryEMNIST 13 | ^^^^^^^^^^^^ 14 | 15 | .. autoclass:: pl_bolts.datamodules.binary_emnist_datamodule.BinaryEMNISTDataModule 16 | 17 | BinaryMNIST 18 | ^^^^^^^^^^^ 19 | 20 | .. autoclass:: pl_bolts.datamodules.binary_mnist_datamodule.BinaryMNISTDataModule 21 | 22 | CityScapes 23 | ^^^^^^^^^^ 24 | 25 | .. autoclass:: pl_bolts.datamodules.cityscapes_datamodule.CityscapesDataModule 26 | 27 | CIFAR-10 28 | ^^^^^^^^ 29 | 30 | .. autoclass:: pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule 31 | 32 | EMNIST 33 | ^^^^^^ 34 | 35 | .. autoclass:: pl_bolts.datamodules.emnist_datamodule.EMNISTDataModule 36 | 37 | FashionMNIST 38 | ^^^^^^^^^^^^ 39 | 40 | .. autoclass:: pl_bolts.datamodules.fashion_mnist_datamodule.FashionMNISTDataModule 41 | 42 | Imagenet 43 | ^^^^^^^^ 44 | 45 | .. autoclass:: pl_bolts.datamodules.imagenet_datamodule.ImagenetDataModule 46 | 47 | MNIST 48 | ^^^^^ 49 | 50 | .. autoclass:: pl_bolts.datamodules.mnist_datamodule.MNISTDataModule 51 | 52 | Semi-supervised learning 53 | ------------------------ 54 | The following datasets have support for unlabeled training and semi-supervised learning where only a few examples 55 | are labeled. 56 | 57 | Imagenet (ssl) 58 | ^^^^^^^^^^^^^^ 59 | 60 | .. autoclass:: pl_bolts.datamodules.ssl_imagenet_datamodule.SSLImagenetDataModule 61 | 62 | STL-10 63 | ^^^^^^ 64 | 65 | .. autoclass:: pl_bolts.datamodules.stl10_datamodule.STL10DataModule 66 | -------------------------------------------------------------------------------- /docs/source/datasets/debug.rst: -------------------------------------------------------------------------------- 1 | ************** 2 | Debug Datasets 3 | ************** 4 | 5 | DummyDataset 6 | ============ 7 | 8 | .. autoclass:: pl_bolts.datasets.dummy_dataset.DummyDataset 9 | :noindex: 10 | 11 | DummyDetectionDataset 12 | ===================== 13 | 14 | .. autoclass:: pl_bolts.datasets.dummy_dataset.DummyDetectionDataset 15 | :noindex: 16 | 17 | RandomDataset 18 | ============= 19 | 20 | .. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDataset 21 | :noindex: 22 | 23 | RandomDictDataset 24 | ================= 25 | 26 | .. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDictDataset 27 | :noindex: 28 | 29 | RandomDictStringDataset 30 | ======================= 31 | 32 | .. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDictStringDataset 33 | :noindex: 34 | -------------------------------------------------------------------------------- /docs/source/governance.rst: -------------------------------------------------------------------------------- 1 | .. _governance: 2 | 3 | PL Bolts Governance | Persons of interest 4 | ========================================= 5 | 6 | Core Maintainers 7 | ---------------- 8 | - William Falcon (`williamFalcon `_) (Lightning founder) 9 | - Jirka Borovec (`Borda `_) 10 | - Akihiro Nitta (`akihironitta `_) 11 | 12 | Core Contributors 13 | ----------------- 14 | - Atharva Phatak (`Atharva-Phatak `_) 15 | - Shion Matsumoto (`matsumotosan `_) 16 | - JongMok Lee (`lijm1358 `_) 17 | 18 | Alumni 19 | ------ 20 | - Ananya Harsh Jha (`ananyahjha93 `_) 21 | - Annika Brundyn (`annikabrundyn `_) 22 | - Ota Jašek (`otaj `_) 23 | - Teddy Koker (`teddykoker `_) 24 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. PyTorchLightning-Bolts documentation master file, created by 2 | sphinx-quickstart on Wed Mar 25 21:34:07 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Lightning-Bolts documentation 7 | ============================= 8 | .. toctree:: 9 | :maxdepth: 1 10 | :name: start 11 | :caption: Start here 12 | 13 | installation 14 | introduction_guide 15 | 16 | .. toctree:: 17 | :maxdepth: 2 18 | :name: callbacks 19 | :caption: Callbacks 20 | 21 | callbacks/monitor 22 | callbacks/torch_ort 23 | callbacks/sparseml 24 | callbacks/self_supervised 25 | callbacks/variational 26 | callbacks/vision 27 | 28 | .. toctree:: 29 | :maxdepth: 2 30 | :name: datamodules 31 | :caption: DataModules 32 | 33 | datamodules/sklearn 34 | datamodules/vision 35 | 36 | .. toctree:: 37 | :maxdepth: 2 38 | :name: datasets 39 | :caption: Datasets 40 | 41 | datasets/debug 42 | 43 | .. toctree:: 44 | :maxdepth: 2 45 | :name: dataloaders 46 | :caption: DataLoaders 47 | 48 | dataloaders/async 49 | 50 | .. toctree:: 51 | :maxdepth: 2 52 | :name: losses 53 | :caption: Losses 54 | 55 | losses 56 | 57 | .. toctree:: 58 | :maxdepth: 2 59 | :name: models 60 | :caption: Models 61 | 62 | models/models_howto 63 | models/autoencoders 64 | models/convolutional 65 | models/gans 66 | models/object_detection 67 | models/reinforce_learn 68 | models/self_supervised 69 | models/classic_ml 70 | 71 | .. toctree:: 72 | :maxdepth: 2 73 | :name: learning_rate_schedulers 74 | :caption: Learning Rate Schedulers 75 | 76 | schedulers/warmup_cosine_annealing.rst 77 | 78 | .. toctree:: 79 | :maxdepth: 2 80 | :name: transforms 81 | :caption: Data Processing 82 | 83 | transforms/self_supervised 84 | transforms/semi_supervised 85 | 86 | .. toctree:: 87 | :maxdepth: 2 88 | :name: ssl 89 | :caption: Tasks 90 | 91 | vision_tasks 92 | 93 | .. toctree:: 94 | :maxdepth: 1 95 | :name: community 96 | :caption: Community 97 | 98 | CONTRIBUTING.md 99 | governance.md 100 | stability.md 101 | CHANGELOG.md 102 | 103 | 104 | Indices and tables 105 | ================== 106 | 107 | * :ref:`genindex` 108 | * :ref:`modindex` 109 | * :ref:`search` 110 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | You can install using `pip `_ 5 | 6 | .. code-block:: bash 7 | 8 | pip install lightning-bolts 9 | 10 | Install bleeding-edge (no guarantees) 11 | 12 | .. code-block:: bash 13 | 14 | pip install git+https://github.com/PytorchLightning/lightning-bolts.git@master --upgrade 15 | 16 | In case you want to have full experience you can install all optional packages at once 17 | 18 | .. code-block:: bash 19 | 20 | pip install lightning-bolts["extra"] 21 | -------------------------------------------------------------------------------- /docs/source/losses.rst: -------------------------------------------------------------------------------- 1 | Reinforcement Learning 2 | ====================== 3 | These are common losses used in RL. 4 | 5 | --------------- 6 | 7 | DQN Loss 8 | -------- 9 | 10 | .. autofunction:: pl_bolts.losses.rl.dqn_loss 11 | :noindex: 12 | 13 | Double DQN Loss 14 | --------------- 15 | 16 | .. autofunction:: pl_bolts.losses.rl.double_dqn_loss 17 | :noindex: 18 | 19 | Per DQN Loss 20 | ------------ 21 | 22 | .. autofunction:: pl_bolts.losses.rl.per_dqn_loss 23 | :noindex: 24 | -------------------------------------------------------------------------------- /docs/source/models/convolutional.rst: -------------------------------------------------------------------------------- 1 | Convolutional Architectures 2 | =========================== 3 | This package lists contributed convolutional architectures. 4 | 5 | .. note:: 6 | 7 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 8 | 9 | -------------- 10 | 11 | 12 | GPT-2 13 | ----- 14 | 15 | .. autoclass:: pl_bolts.models.vision.GPT2 16 | :noindex: 17 | 18 | ------------- 19 | 20 | Image GPT 21 | --------- 22 | 23 | .. autoclass:: pl_bolts.models.vision.ImageGPT 24 | :noindex: 25 | 26 | ------------- 27 | 28 | Pixel CNN 29 | --------- 30 | 31 | .. autoclass:: pl_bolts.models.vision.PixelCNN 32 | :noindex: 33 | 34 | ------------- 35 | 36 | UNet 37 | ---- 38 | 39 | .. autoclass:: pl_bolts.models.vision.UNet 40 | :noindex: 41 | 42 | ------------- 43 | 44 | Semantic Segmentation 45 | --------------------- 46 | Model template to use for semantic segmentation tasks. The model uses a UNet architecture by default. Override any part 47 | of this model to build your own variation. 48 | 49 | .. code-block:: python 50 | 51 | from pl_bolts.models.vision import SemSegment 52 | from pl_bolts.datamodules import KittiDataModule 53 | import pytorch_lightning as pl 54 | 55 | dm = KittiDataModule('path/to/kitt/dataset/', batch_size=4) 56 | model = SemSegment(datamodule=dm) 57 | trainer = pl.Trainer() 58 | trainer.fit(model) 59 | 60 | .. autoclass:: pl_bolts.models.vision.SemSegment 61 | :noindex: 62 | -------------------------------------------------------------------------------- /docs/source/models/object_detection.rst: -------------------------------------------------------------------------------- 1 | Object Detection 2 | ================ 3 | This package lists contributed object detection models. 4 | 5 | .. note:: 6 | 7 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 8 | 9 | -------------- 10 | 11 | 12 | Faster R-CNN 13 | ------------ 14 | 15 | .. autoclass:: pl_bolts.models.detection.faster_rcnn.faster_rcnn_module.FasterRCNN 16 | :noindex: 17 | 18 | ------------- 19 | 20 | RetinaNet 21 | --------- 22 | 23 | .. autoclass:: pl_bolts.models.detection.retinanet.retinanet_module.RetinaNet 24 | :noindex: 25 | 26 | ------------- 27 | 28 | YOLO 29 | ---- 30 | 31 | .. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO 32 | :noindex: 33 | -------------------------------------------------------------------------------- /docs/source/schedulers/warmup_cosine_annealing.rst: -------------------------------------------------------------------------------- 1 | 2 | Linear Warmup Cosine Annealing 3 | ------------------------------ 4 | 5 | .. note:: 6 | 7 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 8 | 9 | .. autoclass:: pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR 10 | :noindex: 11 | -------------------------------------------------------------------------------- /docs/source/stability.rst: -------------------------------------------------------------------------------- 1 | .. _stability: 2 | 3 | Bolts stability 4 | =============== 5 | 6 | Currently we are going through major revision of Bolts to ensure all of the code is stable and compatible with the rest of the Lightning ecosystem. 7 | For this reason, all of our features are either marked as stable or in need of review. Stable features are implicit, features to be reviewed are explicitly marked. 8 | 9 | At the beginning of the aforementioned revision, **ALL** of the features currently in the project have been marked as to be reviewed and will undergo rigorous review and testing before they can be marked as stable. See `this GitHub issue `_ to check progress of the revision 10 | 11 | This document is intended to help you know what to expect and to outline our commitment to stability. 12 | 13 | Stable 14 | ______ 15 | 16 | For stable features, all of the following are true: 17 | 18 | - the API isn’t expected to change 19 | - if anything does change, incorrect usage will give a deprecation warning for **one minor release** before the breaking change is made 20 | - the API has been tested for compatibility with latest releases of PyTorch Lightning and Flash 21 | 22 | Under Review 23 | ____________ 24 | 25 | For features to be reviewed, any or all of the following may be true: 26 | 27 | - the feature has unstable dependencies 28 | - the API may change without notice in future versions 29 | - the performance of the feature has not been verified 30 | - the docs for this feature are under active development 31 | 32 | 33 | Before a feature can be moved to Stable it needs to satisfy following conditions: 34 | 35 | - Have appropriate tests, that will check not only correctness of the feature, but also compatibility with the current versions. 36 | - Not have duplicate code accross Lightning ecosystem and more mature OSS projects. 37 | - Pass a review process. 38 | -------------------------------------------------------------------------------- /docs/source/transforms/semi_supervised.rst: -------------------------------------------------------------------------------- 1 | Semi-supervised learning 2 | ======================== 3 | Collection of utilities for semi-supervised learning where some part of the data is labeled 4 | and the other part is not. 5 | 6 | .. note:: 7 | 8 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 9 | 10 | -------------- 11 | 12 | Balanced classes 13 | ---------------- 14 | 15 | Example:: 16 | 17 | from pl_bolts.utils.semi_supervised import balance_classes 18 | 19 | 20 | .. autofunction:: pl_bolts.utils.semi_supervised.balance_classes 21 | :noindex: 22 | 23 | half labeled batches 24 | -------------------- 25 | 26 | Example:: 27 | 28 | from pl_bolts.utils.semi_supervised import balance_classes 29 | 30 | .. autofunction:: pl_bolts.utils.semi_supervised.generate_half_labeled_batches 31 | :noindex: 32 | -------------------------------------------------------------------------------- /docs/source/vision_tasks.rst: -------------------------------------------------------------------------------- 1 | Self-supervised Learning 2 | ======================== 3 | This section implements popular contrastive learning tasks used in self-supervised learning. 4 | 5 | .. note:: 6 | 7 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix! 8 | 9 | --------------- 10 | 11 | FeatureMapContrastiveTask 12 | ------------------------- 13 | This task compares sets of feature maps. 14 | 15 | In general the feature map comparison pretext task uses triplets of features. 16 | Here are the abstract steps of comparison. 17 | 18 | Generate multiple views of the same image 19 | 20 | .. code-block:: python 21 | 22 | x1_view_1 = data_augmentation(x1) 23 | x1_view_2 = data_augmentation(x1) 24 | 25 | Use a different example to generate additional views (usually within the same batch or a pool of candidates) 26 | 27 | .. code-block:: python 28 | 29 | x2_view_1 = data_augmentation(x2) 30 | x2_view_2 = data_augmentation(x2) 31 | 32 | Pick 3 views to compare, these are the anchor, positive and negative features 33 | 34 | .. code-block:: python 35 | 36 | anchor = x1_view_1 37 | positive = x1_view_2 38 | negative = x2_view_1 39 | 40 | Generate feature maps for each view 41 | 42 | .. code-block:: python 43 | 44 | (a0, a1, a2) = encoder(anchor) 45 | (p0, p1, p2) = encoder(positive) 46 | 47 | Make a comparison for a set of feature maps 48 | 49 | .. code-block:: python 50 | 51 | phi = some_score_function() 52 | 53 | # the '01' comparison 54 | score = phi(a0, p1) 55 | 56 | # and can be bidirectional 57 | score = phi(p0, a1) 58 | 59 | In practice the contrastive task creates a BxB matrix where B is the batch size. The diagonals for set 1 of feature maps 60 | are the anchors, the diagonals of set 2 of the feature maps are the positives, the non-diagonals of set 1 are the 61 | negatives. 62 | 63 | .. autoclass:: pl_bolts.losses.self_supervised_learning.FeatureMapContrastiveTask 64 | :noindex: 65 | 66 | -------------- 67 | 68 | Context prediction tasks 69 | ------------------------ 70 | The following tasks aim to predict a target using a context representation. 71 | 72 | CPCContrastiveTask 73 | ^^^^^^^^^^^^^^^^^^ 74 | This is the predictive task from CPC (v2). 75 | 76 | .. code-block:: 77 | 78 | task = CPCTask(num_input_channels=32) 79 | 80 | # (batch, channels, rows, cols) 81 | # this should be thought of as 49 feature vectors, each with 32 dims 82 | Z = torch.random.rand(3, 32, 7, 7) 83 | 84 | loss = task(Z) 85 | 86 | 87 | .. autoclass:: pl_bolts.losses.self_supervised_learning.CPCTask 88 | :noindex: 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/base.txt 2 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | numpy <1.26.0 2 | pytorch-lightning >1.7.0, <2.0.0 # strict 3 | torchmetrics >=0.10.0, <0.12.0 4 | lightning-utilities >0.3.1 # this is needed for PL 1.7 5 | torchvision >=0.10.0 # todo: move to topic related extras 6 | tensorboard >=2.9.1, <2.14.0 # for `TensorBoardLogger` 7 | -------------------------------------------------------------------------------- /requirements/devel.txt: -------------------------------------------------------------------------------- 1 | # install all mandatory dependencies 2 | -r ./base.txt 3 | 4 | # install all extra dependencies for full package experience 5 | -r ./models.txt 6 | -r ./loggers.txt 7 | 8 | # extended list of dependencies to run lint and tests 9 | -r ./test.txt 10 | -------------------------------------------------------------------------------- /requirements/loggers.txt: -------------------------------------------------------------------------------- 1 | # test_tube >=0.7.5 2 | # trains >=0.14.1 3 | matplotlib >3.0.0, <3.8.0 4 | wandb >0.13.0, <0.16.0 5 | scipy >1.5.0, <1.12.0 6 | -------------------------------------------------------------------------------- /requirements/models.txt: -------------------------------------------------------------------------------- 1 | torchvision >=0.10.0 2 | scikit-learn >=1.0.2 3 | Pillow >9.0.0 4 | gym[atari] >=0.17.2, <0.22.0 # strict 5 | atari-py >0.2, !=0.2.6, <0.3 # strict 6 | box2d-py >2.3, <2.4 # strict 7 | opencv-python-headless >=4.5.5.62 8 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | coverage[toml] >7.0.0, <8.0.0 2 | pytest ==7.4.0 3 | pytest-cov ==4.1.0 4 | pytest-timeout ==2.1.0 5 | pytest-rerunfailures ==12.0 6 | 7 | scikit-learn >=1.0.2, <=1.3.0 8 | sparseml >1.0.0, <1.6.0 9 | ale-py >=0.7, <=0.8.1 10 | jsonargparse[signatures] >4.0.0, <=4.22.1 # for LightningCLI 11 | -------------------------------------------------------------------------------- /requirements/typing.txt: -------------------------------------------------------------------------------- 1 | mypy ==1.7.1 2 | -------------------------------------------------------------------------------- /src/pl_bolts/__about__.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | __version__ = "0.7.0" 4 | __author__ = "Lightning AI et al." 5 | __author_email__ = "pytorch@lightning.ai" 6 | __license__ = "Apache-2.0" 7 | __copyright__ = f"Copyright (c) 2020-{time.strftime('%Y')}, {__author__}." 8 | __homepage__ = "https://github.com/Lightning-AI/lightning-bolts" 9 | __docs__ = "Lightning Bolts is a community contribution for ML researchers." 10 | __long_doc__ = """ 11 | What is it? 12 | ----------- 13 | Bolts is a collection of useful models and templates to bootstrap your DL research even faster. 14 | It's designed to work with PyTorch Lightning 15 | 16 | Subclass Example 17 | ---------------- 18 | Use `pl_bolts` models to remove boilerplate for common approaches and architectures. 19 | Because it uses LightningModules under the hood, you just need to overwrite 20 | the relevant parts to your research. 21 | 22 | How to add a model 23 | ------------------ 24 | This repository is meant for model contributions from the community. 25 | To add a model, you can start with the MNIST template (or any other model in the repo). 26 | Please organize the functions of your lightning module. 27 | """ 28 | 29 | __all__ = [ 30 | "__author__", 31 | "__author_email__", 32 | "__copyright__", 33 | "__docs__", 34 | "__homepage__", 35 | "__license__", 36 | "__version__", 37 | ] 38 | -------------------------------------------------------------------------------- /src/pl_bolts/__init__.py: -------------------------------------------------------------------------------- 1 | """Root package crossroad.""" 2 | 3 | import os 4 | 5 | import numpy 6 | 7 | from pl_bolts.__about__ import * # noqa: F401, F403 8 | 9 | # adding compatibility for numpy >= 1.24 10 | for tp_name, tp_ins in [("object", object), ("bool", bool), ("int", int), ("float", float)]: 11 | if not hasattr(numpy, tp_name): 12 | setattr(numpy, tp_name, tp_ins) 13 | 14 | _PACKAGE_ROOT = os.path.dirname(__file__) 15 | _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) 16 | _HTTPS_AWS_HUB = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" 17 | 18 | from pl_bolts import ( # noqa: E402 19 | callbacks, 20 | datamodules, 21 | datasets, 22 | losses, 23 | metrics, 24 | models, 25 | optimizers, 26 | transforms, 27 | utils, 28 | ) 29 | 30 | __all__ = [ 31 | "callbacks", 32 | "datamodules", 33 | "datasets", 34 | "losses", 35 | "metrics", 36 | "models", 37 | "optimizers", 38 | "transforms", 39 | "utils", 40 | ] 41 | -------------------------------------------------------------------------------- /src/pl_bolts/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | """Collection of PyTorchLightning callbacks.""" 2 | from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate 3 | from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor 4 | from pl_bolts.callbacks.printing import PrintTableMetricsCallback 5 | from pl_bolts.callbacks.sparseml import SparseMLCallback 6 | from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator 7 | from pl_bolts.callbacks.torch_ort import ORTCallback 8 | from pl_bolts.callbacks.variational import LatentDimInterpolator 9 | from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore 10 | from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback 11 | from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler 12 | from pl_bolts.callbacks.vision.sr_image_logger import SRImageLoggerCallback 13 | 14 | __all__ = [ 15 | "BatchGradientVerificationCallback", 16 | "BYOLMAWeightUpdate", 17 | "ModuleDataMonitor", 18 | "TrainingDataMonitor", 19 | "PrintTableMetricsCallback", 20 | "SSLOnlineEvaluator", 21 | "LatentDimInterpolator", 22 | "ConfusedLogitCallback", 23 | "TensorboardGenerativeModelImageSampler", 24 | "SRImageLoggerCallback", 25 | "ORTCallback", 26 | "SparseMLCallback", 27 | ] 28 | -------------------------------------------------------------------------------- /src/pl_bolts/callbacks/byol_updates.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Sequence, Union 3 | 4 | import torch.nn as nn 5 | from pytorch_lightning import Callback, LightningModule, Trainer 6 | from torch import Tensor 7 | 8 | 9 | class BYOLMAWeightUpdate(Callback): 10 | """Weight update rule from Bootstrap Your Own Latent (BYOL). 11 | 12 | Updates the target_network params using an exponential moving average update rule weighted by tau. 13 | BYOL claims this keeps the online_network from collapsing. 14 | 15 | The PyTorch Lightning module being trained should have: 16 | 17 | - ``self.online_network`` 18 | - ``self.target_network`` 19 | 20 | .. note:: Automatically increases tau from ``initial_tau`` to 1.0 with every training step 21 | 22 | Args: 23 | initial_tau (float, optional): starting tau. Auto-updates with every training step 24 | 25 | Example:: 26 | 27 | # model must have 2 attributes 28 | model = Model() 29 | model.online_network = ... 30 | model.target_network = ... 31 | 32 | trainer = Trainer(callbacks=[BYOLMAWeightUpdate()]) 33 | 34 | """ 35 | 36 | def __init__(self, initial_tau: float = 0.996) -> None: 37 | if not 0.0 <= initial_tau <= 1.0: 38 | raise ValueError(f"initial tau should be between 0 and 1 instead of {initial_tau}.") 39 | 40 | super().__init__() 41 | self.initial_tau = initial_tau 42 | self.current_tau = initial_tau 43 | 44 | def on_train_batch_end( 45 | self, 46 | trainer: Trainer, 47 | pl_module: LightningModule, 48 | outputs: Sequence, 49 | batch: Sequence, 50 | batch_idx: int, 51 | ) -> None: 52 | # get networks 53 | online_net = pl_module.online_network 54 | target_net = pl_module.target_network 55 | 56 | # update target network weights 57 | self.update_weights(online_net, target_net) 58 | 59 | # update tau after 60 | self.update_tau(pl_module, trainer) 61 | 62 | def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> None: 63 | """Update tau value for next update.""" 64 | max_steps = len(trainer.train_dataloader) * trainer.max_epochs 65 | self.current_tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2 66 | 67 | def update_weights(self, online_net: Union[nn.Module, Tensor], target_net: Union[nn.Module, Tensor]) -> None: 68 | """Update target network parameters.""" 69 | for online_p, target_p in zip(online_net.parameters(), target_net.parameters()): 70 | target_p.data = self.current_tau * target_p.data + (1.0 - self.current_tau) * online_p.data 71 | -------------------------------------------------------------------------------- /src/pl_bolts/callbacks/torch_ort.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from pytorch_lightning import Callback, LightningModule, Trainer 15 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 16 | 17 | from pl_bolts.utils import _TORCH_ORT_AVAILABLE 18 | 19 | if _TORCH_ORT_AVAILABLE: 20 | from torch_ort import ORTModule 21 | 22 | from pl_bolts.utils.stability import under_review 23 | 24 | 25 | @under_review() 26 | class ORTCallback(Callback): 27 | """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. 28 | 29 | Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for 30 | training and inference. 31 | 32 | Usage: 33 | 34 | # via Transformer Tasks 35 | model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) 36 | 37 | # or via the trainer 38 | trainer = flash.Trainer(callbacks=ORTCallback()) 39 | """ 40 | 41 | def __init__(self) -> None: 42 | if not _TORCH_ORT_AVAILABLE: 43 | raise MisconfigurationException( 44 | "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort" 45 | ) 46 | 47 | def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None: 48 | if not hasattr(pl_module, "model"): 49 | raise MisconfigurationException( 50 | "Torch ORT requires to wrap a single model that defines a forward function " 51 | "assigned as `model` inside the `LightningModule`." 52 | ) 53 | if not isinstance(pl_module.model, ORTModule): 54 | pl_module.model = ORTModule(pl_module.model) 55 | -------------------------------------------------------------------------------- /src/pl_bolts/callbacks/verification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/callbacks/verification/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/callbacks/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/callbacks/vision/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/callbacks/vision/sr_image_logger.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn.functional as F # noqa: N812 6 | from pytorch_lightning import Callback 7 | 8 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 9 | from pl_bolts.utils.stability import under_review 10 | from pl_bolts.utils.warnings import warn_missing_pkg 11 | 12 | if _TORCHVISION_AVAILABLE: 13 | from torchvision.utils import make_grid 14 | else: # pragma: no cover 15 | warn_missing_pkg("torchvision") 16 | 17 | 18 | @under_review() 19 | class SRImageLoggerCallback(Callback): 20 | """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement 21 | the ``forward`` function for generation. 22 | 23 | Requirements:: 24 | 25 | # model forward must work generating high-res from low-res image 26 | hr_fake = pl_module(lr_image) 27 | 28 | Example:: 29 | 30 | from pl_bolts.callbacks import SRImageLoggerCallback 31 | 32 | trainer = Trainer(callbacks=[SRImageLoggerCallback()]) 33 | """ 34 | 35 | def __init__(self, log_interval: int = 1000, scale_factor: int = 4, num_samples: int = 5) -> None: 36 | """ 37 | Args: 38 | log_interval: Number of steps between logging. Default: ``1000``. 39 | scale_factor: Scale factor used for downsampling the high-res images. Default: ``4``. 40 | num_samples: Number of images of displayed in the grid. Default: ``5``. 41 | """ 42 | super().__init__() 43 | self.log_interval = log_interval 44 | self.scale_factor = scale_factor 45 | self.num_samples = num_samples 46 | 47 | def on_train_batch_end( 48 | self, 49 | trainer: pl.Trainer, 50 | pl_module: pl.LightningModule, 51 | outputs: torch.Tensor, 52 | batch: Tuple[torch.Tensor, torch.Tensor], 53 | batch_idx: int, 54 | ) -> None: 55 | global_step = trainer.global_step 56 | if global_step % self.log_interval == 0: 57 | hr_image, lr_image = batch 58 | hr_image, lr_image = hr_image.to(pl_module.device), lr_image.to(pl_module.device) 59 | hr_fake = pl_module(lr_image) 60 | lr_image = F.interpolate(lr_image, scale_factor=self.scale_factor) 61 | 62 | lr_image_grid = make_grid(lr_image[: self.num_samples], nrow=1, normalize=True) 63 | hr_fake_grid = make_grid(hr_fake[: self.num_samples], nrow=1, normalize=True) 64 | hr_image_grid = make_grid(hr_image[: self.num_samples], nrow=1, normalize=True) 65 | 66 | grid = torch.cat((lr_image_grid, hr_fake_grid, hr_image_grid), -1) 67 | title = "sr_images" 68 | trainer.logger.experiment.add_image(title, grid, global_step=global_step) 69 | -------------------------------------------------------------------------------- /src/pl_bolts/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.datamodules.async_dataloader import AsynchronousLoader 2 | from pl_bolts.datamodules.binary_emnist_datamodule import BinaryEMNISTDataModule 3 | from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule 4 | from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule 5 | from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule 6 | from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule 7 | from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset 8 | from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule 9 | from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule 10 | from pl_bolts.datamodules.kitti_datamodule import KittiDataModule 11 | from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule 12 | from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset 13 | from pl_bolts.datamodules.sr_datamodule import TVTDataModule 14 | from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule 15 | from pl_bolts.datamodules.stl10_datamodule import STL10DataModule 16 | from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule 17 | from pl_bolts.datasets.kitti_dataset import KittiDataset 18 | 19 | __all__ = [ 20 | "AsynchronousLoader", 21 | "BinaryMNISTDataModule", 22 | "CIFAR10DataModule", 23 | "TinyCIFAR10DataModule", 24 | "CityscapesDataModule", 25 | "DiscountedExperienceSource", 26 | "ExperienceSource", 27 | "ExperienceSourceDataset", 28 | "FashionMNISTDataModule", 29 | "ImagenetDataModule", 30 | "KittiDataModule", 31 | "MNISTDataModule", 32 | "SklearnDataModule", 33 | "SklearnDataset", 34 | "TVTDataModule", 35 | "SSLImagenetDataModule", 36 | "STL10DataModule", 37 | "VOCDetectionDataModule", 38 | "KittiDataset", 39 | "EMNISTDataModule", 40 | "BinaryEMNISTDataModule", 41 | ] 42 | -------------------------------------------------------------------------------- /src/pl_bolts/datamodules/sr_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader, Dataset 5 | 6 | from pl_bolts.utils.stability import under_review 7 | 8 | 9 | @under_review() 10 | class TVTDataModule(LightningDataModule): 11 | """Simple DataModule creating train, val, and test dataloaders from given train, val, and test dataset. 12 | 13 | Example:: 14 | from pl_bolts.datamodules import TVTDataModule 15 | from pl_bolts.datasets.sr_mnist_dataset import SRMNIST 16 | 17 | dataset_dev = SRMNIST(scale_factor=4, root=".", train=True) 18 | dataset_train, dataset_val = random_split(dataset_dev, lengths=[55_000, 5_000]) 19 | dataset_test = SRMNIST(scale_factor=4, root=".", train=True) 20 | dm = TVTDataModule(dataset_train, dataset_val, dataset_test) 21 | 22 | """ 23 | 24 | def __init__( 25 | self, 26 | dataset_train: Dataset, 27 | dataset_val: Dataset, 28 | dataset_test: Dataset, 29 | batch_size: int = 16, 30 | shuffle: bool = True, 31 | num_workers: int = 8, 32 | pin_memory: bool = True, 33 | drop_last: bool = True, 34 | *args: Any, 35 | **kwargs: Any, 36 | ) -> None: 37 | """ 38 | Args: 39 | dataset_train: Train dataset 40 | dataset_val: Val dataset 41 | dataset_test: Test dataset 42 | batch_size: How many samples per batch to load 43 | num_workers: How many workers to use for loading data 44 | shuffle: If true shuffles the train data every epoch 45 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before 46 | returning them 47 | drop_last: If true drops the last incomplete batch 48 | """ 49 | super().__init__() 50 | 51 | self.dataset_train = dataset_train 52 | self.dataset_val = dataset_val 53 | self.dataset_test = dataset_test 54 | self.num_workers = num_workers 55 | self.batch_size = batch_size 56 | self.shuffle = shuffle 57 | self.pin_memory = pin_memory 58 | self.drop_last = drop_last 59 | 60 | def train_dataloader(self) -> DataLoader: 61 | return self._dataloader(self.dataset_train, shuffle=self.shuffle) 62 | 63 | def val_dataloader(self) -> DataLoader: 64 | return self._dataloader(self.dataset_val, shuffle=False) 65 | 66 | def test_dataloader(self) -> DataLoader: 67 | return self._dataloader(self.dataset_test, shuffle=False) 68 | 69 | def _dataloader(self, dataset: Dataset, shuffle: bool = True) -> DataLoader: 70 | return DataLoader( 71 | dataset, 72 | batch_size=self.batch_size, 73 | shuffle=shuffle, 74 | num_workers=self.num_workers, 75 | drop_last=self.drop_last, 76 | pin_memory=self.pin_memory, 77 | ) 78 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | 3 | from pl_bolts.datasets.array_dataset import ArrayDataset 4 | from pl_bolts.datasets.base_dataset import DataModel, LightDataset 5 | from pl_bolts.datasets.cifar10_dataset import CIFAR10, TrialCIFAR10 6 | from pl_bolts.datasets.concat_dataset import ConcatDataset 7 | from pl_bolts.datasets.dummy_dataset import ( 8 | DummyDataset, 9 | DummyDetectionDataset, 10 | RandomDataset, 11 | RandomDictDataset, 12 | RandomDictStringDataset, 13 | ) 14 | from pl_bolts.datasets.emnist_dataset import BinaryEMNIST 15 | from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet, extract_archive, parse_devkit_archive 16 | from pl_bolts.datasets.kitti_dataset import KittiDataset 17 | from pl_bolts.datasets.mnist_dataset import MNIST, BinaryMNIST 18 | from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin 19 | 20 | __all__ = [ 21 | "ArrayDataset", 22 | "DataModel", 23 | "LightDataset", 24 | "CIFAR10", 25 | "TrialCIFAR10", 26 | "ConcatDataset", 27 | "DummyDataset", 28 | "DummyDetectionDataset", 29 | "MNIST", 30 | "RandomDataset", 31 | "RandomDictDataset", 32 | "RandomDictStringDataset", 33 | "extract_archive", 34 | "parse_devkit_archive", 35 | "UnlabeledImagenet", 36 | "KittiDataset", 37 | "BinaryMNIST", 38 | "CIFAR10Mixed", 39 | "SSLDatasetMixin", 40 | "BinaryEMNIST", 41 | ] 42 | 43 | # TorchVision hotfix https://github.com/pytorch/vision/issues/1938 44 | opener = urllib.request.build_opener() 45 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 46 | urllib.request.install_opener(opener) 47 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/array_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | from pytorch_lightning.utilities import exceptions 4 | from torch.utils.data import Dataset 5 | 6 | from pl_bolts.datasets.base_dataset import DataModel, TArrays 7 | 8 | 9 | class ArrayDataset(Dataset): 10 | """Dataset wrapping tensors, lists, numpy arrays. 11 | 12 | Any number of ARRAYS can be inputted into the dataset. The ARRAYS are transformed on each `__getitem__`. When 13 | transforming, please refrain from chaning the hape of ARRAYS in the first demension. 14 | 15 | Attributes: 16 | data_models: Sequence of data models. 17 | 18 | Raises: 19 | MisconfigurationException: if there is a shape mismatch between arrays in the first dimension. 20 | 21 | Example: 22 | >>> from pl_bolts.datasets import ArrayDataset, DataModel 23 | >>> from pl_bolts.datasets.utils import to_tensor 24 | 25 | >>> features = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3]], transform=to_tensor) 26 | >>> target = DataModel(data=[1, 0, 0], transform=to_tensor) 27 | 28 | >>> ds = ArrayDataset(features, target) 29 | >>> len(ds) 30 | 3 31 | 32 | """ 33 | 34 | def __init__(self, *data_models: DataModel) -> None: 35 | """Initialises class and checks if arrays are the same shape in the first dimension.""" 36 | self.data_models = data_models 37 | 38 | if not self._equal_size(): 39 | raise exceptions.MisconfigurationException("Shape mismatch between arrays in the first dimension") 40 | 41 | def __len__(self) -> int: 42 | return len(self.data_models[0].data) 43 | 44 | def __getitem__(self, idx: int) -> Tuple[Union[TArrays, float], ...]: 45 | return tuple(data_model.process(data_model.data[idx]) for data_model in self.data_models) 46 | 47 | def _equal_size(self) -> bool: 48 | """Checks the size of the data_models are equal in the first dimension. 49 | 50 | Returns: 51 | bool: True if size of data_models are equal in the first dimension. False, if not. 52 | 53 | """ 54 | return len({len(data_model.data) for data_model in self.data_models}) == 1 55 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib.parse 4 | import urllib.request 5 | from abc import ABC 6 | from dataclasses import dataclass 7 | from typing import Callable, Optional, Sequence, Tuple, Union 8 | from urllib.error import HTTPError 9 | 10 | from torch import Tensor 11 | from torch.utils.data import Dataset 12 | 13 | from pl_bolts.utils.stability import under_review 14 | from pl_bolts.utils.types import TArrays 15 | 16 | 17 | @under_review() 18 | class LightDataset(ABC, Dataset): 19 | data: Tensor 20 | targets: Tensor 21 | normalize: tuple 22 | dir_path: str 23 | cache_folder_name: str 24 | DATASET_NAME = "light" 25 | 26 | def __len__(self) -> int: 27 | return len(self.data) 28 | 29 | @property 30 | def cached_folder_path(self) -> str: 31 | return os.path.join(self.dir_path, self.DATASET_NAME, self.cache_folder_name) 32 | 33 | @staticmethod 34 | def _prepare_subset( 35 | full_data: Tensor, 36 | full_targets: Tensor, 37 | num_samples: int, 38 | labels: Sequence, 39 | ) -> Tuple[Tensor, Tensor]: 40 | """Prepare a subset of a common dataset.""" 41 | classes = {d: 0 for d in labels} 42 | indexes = [] 43 | for idx, target in enumerate(full_targets): 44 | label = target.item() 45 | if classes.get(label, float("inf")) >= num_samples: 46 | continue 47 | indexes.append(idx) 48 | classes[label] += 1 49 | if all(classes[k] >= num_samples for k in classes): 50 | break 51 | data = full_data[indexes] 52 | targets = full_targets[indexes] 53 | return data, targets 54 | 55 | def _download_from_url(self, base_url: str, data_folder: str, file_name: str): 56 | url = urllib.parse.urljoin(base_url, file_name) 57 | logging.info(f"Downloading {url}") 58 | fpath = os.path.join(data_folder, file_name) 59 | try: 60 | urllib.request.urlretrieve(url, fpath) # noqa: S310 61 | except HTTPError as err: 62 | raise RuntimeError(f"Failed download from {url}") from err 63 | 64 | 65 | @dataclass 66 | class DataModel: 67 | """Data model dataclass. 68 | 69 | Ties together data and callable transforms. 70 | 71 | Attributes: 72 | data: Sequence of indexables. 73 | transform: Callable to transform data. The transform is called on a subset of data. 74 | 75 | """ 76 | 77 | data: TArrays 78 | transform: Optional[Callable[[TArrays], TArrays]] = None 79 | 80 | def process(self, subset: Union[TArrays, float]) -> Union[TArrays, float]: 81 | """Transforms a subset of data. 82 | 83 | Args: 84 | subset: Sequence of indexables. 85 | 86 | Returns: 87 | data: Transformed data if transform is not None. 88 | 89 | """ 90 | if self.transform is not None: 91 | subset = self.transform(subset) 92 | return subset 93 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | from pl_bolts.utils.stability import under_review 4 | 5 | 6 | @under_review() 7 | class ConcatDataset(Dataset): 8 | def __init__(self, *datasets) -> None: 9 | self.datasets = datasets 10 | 11 | def __getitem__(self, i): 12 | result = [] 13 | for dataset in self.datasets: 14 | cycled_i = i % len(dataset) 15 | result.append(dataset[cycled_i]) 16 | 17 | return tuple(result) 18 | 19 | def __len__(self) -> int: 20 | return max(len(d) for d in self.datasets) 21 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Union 2 | 3 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE 4 | from pl_bolts.utils.warnings import warn_missing_pkg 5 | 6 | if _TORCHVISION_AVAILABLE: 7 | from torchvision.datasets import MNIST 8 | else: # pragma: no cover 9 | warn_missing_pkg("torchvision") 10 | MNIST = object 11 | 12 | if _PIL_AVAILABLE: 13 | from PIL import Image 14 | else: # pragma: no cover 15 | warn_missing_pkg("PIL", pypi_name="Pillow") 16 | 17 | 18 | class BinaryMNIST(MNIST): 19 | """Binarized MNIST Dataset. 20 | 21 | MNIST dataset binarized using a thresholding operation. Default threshold value is 127. 22 | Note that the images are binarized prior to the application of any transforms. 23 | 24 | Args: 25 | root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` 26 | and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. 27 | threshold (Union[int, float], optional): Threshold value for binarizing image. 28 | Pixel value is set to 255 if value is greater than threshold, otherwise 0. 29 | train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, 30 | otherwise from ``t10k-images-idx3-ubyte``. 31 | download (bool, optional): If True, downloads the dataset from the internet and 32 | puts it in root directory. If dataset is already downloaded, it is not 33 | downloaded again. 34 | transform (callable, optional): A function/transform that takes in an PIL image 35 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 36 | target_transform (callable, optional): A function/transform that takes in the 37 | target and transforms it. 38 | 39 | Note: 40 | Documentation is based on https://pytorch.org/vision/main/generated/torchvision.datasets.EMNIST.html 41 | 42 | """ 43 | 44 | def __init__(self, root: str, threshold: Union[int, float] = 127.0, **kwargs: Any) -> None: 45 | super().__init__(root, **kwargs) 46 | self.threshold = threshold 47 | 48 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 49 | """ 50 | Args: 51 | index (int): Index 52 | 53 | Returns: 54 | tuple: (image, target) where target is index of the target class. 55 | """ 56 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 57 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") 58 | 59 | img, target = self.data[idx], int(self.targets[idx]) 60 | 61 | # Convert to PIL Image (8-bit BW) 62 | img = Image.fromarray(img.numpy(), mode="L") 63 | 64 | # Binarize image at threshold 65 | img = img.point(lambda p: 255 if p > self.threshold else 0) 66 | 67 | if self.transform is not None: 68 | img = self.transform(img) 69 | 70 | if self.target_transform is not None: 71 | target = self.target_transform(target) 72 | 73 | return img, target 74 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/sr_celeba_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin 5 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE 6 | from pl_bolts.utils.stability import under_review 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _PIL_AVAILABLE: 10 | from PIL import Image 11 | else: # pragma: no cover 12 | warn_missing_pkg("PIL", pypi_name="Pillow") 13 | 14 | if _TORCHVISION_AVAILABLE: 15 | from torchvision.datasets import CelebA 16 | else: # pragma: no cover 17 | warn_missing_pkg("torchvision") 18 | CelebA = object 19 | 20 | 21 | @under_review() 22 | class SRCelebA(SRDatasetMixin, CelebA): 23 | """CelebA dataset that can be used to train Super Resolution models. 24 | 25 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. 26 | 27 | """ 28 | 29 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: 30 | hr_image_size = 128 31 | lr_image_size = hr_image_size // scale_factor 32 | self.image_channels = 3 33 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) 34 | 35 | def _get_image(self, index: int): 36 | return Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 37 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/sr_dataset_mixin.py: -------------------------------------------------------------------------------- 1 | """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | 6 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.stability import under_review 8 | from pl_bolts.utils.warnings import warn_missing_pkg 9 | 10 | if _PIL_AVAILABLE: 11 | from PIL import Image 12 | else: # pragma: no cover 13 | warn_missing_pkg("PIL", pypi_name="Pillow") 14 | 15 | if _TORCHVISION_AVAILABLE: 16 | from torchvision import transforms as transform_lib 17 | else: # pragma: no cover 18 | warn_missing_pkg("torchvision") 19 | 20 | 21 | @under_review() 22 | class SRDatasetMixin: 23 | """Mixin for Super Resolution datasets. 24 | 25 | Scales range of high resolution images to [-1, 1] and range or low resolution images to [0, 1]. 26 | 27 | """ 28 | 29 | def __init__(self, hr_image_size: int, lr_image_size: int, image_channels: int, *args: Any, **kwargs: Any) -> None: 30 | super().__init__(*args, **kwargs) 31 | 32 | self.hr_transforms = transform_lib.Compose( 33 | [ 34 | transform_lib.RandomCrop(hr_image_size), 35 | transform_lib.ToTensor(), 36 | transform_lib.Normalize(mean=(0.5,) * image_channels, std=(0.5,) * image_channels), 37 | ] 38 | ) 39 | 40 | self.lr_transforms = transform_lib.Compose( 41 | [ 42 | transform_lib.Normalize(mean=(-1.0,) * image_channels, std=(2.0,) * image_channels), 43 | transform_lib.ToPILImage(), 44 | transform_lib.Resize(lr_image_size, Image.BICUBIC), 45 | transform_lib.ToTensor(), 46 | ] 47 | ) 48 | 49 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: 50 | image = self._get_image(index) 51 | 52 | hr_image = self.hr_transforms(image) 53 | lr_image = self.lr_transforms(hr_image) 54 | 55 | return hr_image, lr_image 56 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/sr_mnist_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pl_bolts.datasets.mnist_dataset import MNIST 4 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin 5 | from pl_bolts.utils import _PIL_AVAILABLE 6 | from pl_bolts.utils.stability import under_review 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _PIL_AVAILABLE: 10 | from PIL import Image 11 | else: # pragma: no cover 12 | warn_missing_pkg("PIL", pypi_name="Pillow") 13 | 14 | 15 | @under_review() 16 | class SRMNIST(SRDatasetMixin, MNIST): 17 | """MNIST dataset that can be used to train Super Resolution models. 18 | 19 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. 20 | 21 | """ 22 | 23 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: 24 | hr_image_size = 28 25 | lr_image_size = hr_image_size // scale_factor 26 | self.image_channels = 1 27 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) 28 | 29 | def _get_image(self, index: int): 30 | return Image.fromarray(self.data[index].numpy(), mode="L") 31 | -------------------------------------------------------------------------------- /src/pl_bolts/datasets/sr_stl10_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | 5 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin 6 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.stability import under_review 8 | from pl_bolts.utils.warnings import warn_missing_pkg 9 | 10 | if _PIL_AVAILABLE: 11 | import PIL 12 | else: # pragma: no cover 13 | warn_missing_pkg("PIL", pypi_name="Pillow") 14 | 15 | if _TORCHVISION_AVAILABLE: 16 | from torchvision.datasets import STL10 17 | else: # pragma: no cover 18 | warn_missing_pkg("torchvision") 19 | STL10 = object 20 | 21 | 22 | @under_review() 23 | class SRSTL10(SRDatasetMixin, STL10): 24 | """STL10 dataset that can be used to train Super Resolution models. 25 | 26 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. 27 | 28 | """ 29 | 30 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: 31 | hr_image_size = 96 32 | lr_image_size = hr_image_size // scale_factor 33 | self.image_channels = 3 34 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) 35 | 36 | def _get_image(self, index: int): 37 | return PIL.Image.fromarray(np.transpose(self.data[index], (1, 2, 0))) 38 | -------------------------------------------------------------------------------- /src/pl_bolts/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/losses/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.metrics.aggregation import accuracy, mean, precision_at_k 2 | 3 | __all__ = [ 4 | "accuracy", 5 | "mean", 6 | "precision_at_k", 7 | ] 8 | -------------------------------------------------------------------------------- /src/pl_bolts/metrics/aggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean(res, key): 5 | # recursive mean for multilevel dicts 6 | return torch.stack([x[key] if isinstance(x, dict) else mean(x, key) for x in res]).mean() 7 | 8 | 9 | def accuracy(preds, labels): 10 | preds = preds.float() 11 | max_lgt = torch.max(preds, 1)[1] 12 | num_correct = (max_lgt == labels).sum().item() 13 | num_correct = torch.tensor(num_correct).float() 14 | return num_correct / len(labels) 15 | 16 | 17 | def precision_at_k(output, target, top_k=(1,)): 18 | """Computes the accuracy over the k top predictions for the specified values of k.""" 19 | with torch.no_grad(): 20 | maxk = max(top_k) 21 | batch_size = target.size(0) 22 | 23 | _, pred = output.topk(maxk, 1, True, True) 24 | pred = pred.t() 25 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 26 | 27 | res = [] 28 | for k in top_k: 29 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 30 | res.append(correct_k.mul_(100.0 / batch_size)) 31 | return res 32 | -------------------------------------------------------------------------------- /src/pl_bolts/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Collection of PyTorchLightning models.""" 2 | from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE 3 | from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE 4 | from pl_bolts.models.mnist_module import LitMNIST 5 | from pl_bolts.models.regression import LinearRegression, LogisticRegression 6 | 7 | __all__ = [ 8 | "AE", 9 | "VAE", 10 | "LitMNIST", 11 | "LinearRegression", 12 | "LogisticRegression", 13 | ] 14 | -------------------------------------------------------------------------------- /src/pl_bolts/models/autoencoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Here are a VAE and GAN.""" 2 | 3 | from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE 4 | from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE 5 | from pl_bolts.models.autoencoders.components import ( 6 | resnet18_decoder, 7 | resnet18_encoder, 8 | resnet50_decoder, 9 | resnet50_encoder, 10 | ) 11 | 12 | __all__ = [ 13 | "AE", 14 | "VAE", 15 | "resnet18_decoder", 16 | "resnet18_encoder", 17 | "resnet50_decoder", 18 | "resnet50_encoder", 19 | ] 20 | -------------------------------------------------------------------------------- /src/pl_bolts/models/autoencoders/basic_ae/__init__.py: -------------------------------------------------------------------------------- 1 | """AE Template. 2 | 3 | This is a basic template for implementing an Autoencoder in PyTorch Lightning. 4 | 5 | A default encoder and decoder have been provided but can easily be replaced by custom models. 6 | 7 | This template uses the CIFAR10 dataset but image data of any dimension can be fed in as long as the image 8 | width and image height are even values. For other types of data, such as sound, it will be necessary 9 | to change the Encoder and Decoder. 10 | 11 | The default encoder is a resnet18 backbone followed by linear layers which map representations to latent space. 12 | The default decoder mirrors the encoder architecture and is similar to an inverted resnet18. 13 | 14 | .. code-block:: python 15 | 16 | from pl_bolts.models.autoencoders import AE 17 | 18 | model = AE() 19 | trainer = pl.Trainer() 20 | trainer.fit(model) 21 | 22 | """ 23 | -------------------------------------------------------------------------------- /src/pl_bolts/models/autoencoders/basic_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """VAE Template. 2 | 3 | This is a basic template for implementing a Variational Autoencoder in PyTorch Lightning. 4 | 5 | A default encoder and decoder have been provided but can easily be replaced by custom models. 6 | 7 | This template uses the CIFAR10 dataset but image data of any dimension can be fed in as long as the image width and 8 | image height are even values. For other types of data, such as sound, it will be necessary to change the Encoder and 9 | Decoder. 10 | 11 | The default encoder is a resnet18 backbone followed by linear layers which map representations to mu and var. The 12 | default decoder mirrors the encoder architecture and is similar to an inverted resnet18. The model also assumes a 13 | Gaussian prior and a Gaussian approximate posterior distribution. 14 | 15 | """ 16 | -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.detection import components 2 | from pl_bolts.models.detection.faster_rcnn import FasterRCNN 3 | from pl_bolts.models.detection.retinanet import RetinaNet 4 | from pl_bolts.models.detection.yolo.darknet_network import DarknetNetwork 5 | from pl_bolts.models.detection.yolo.torch_networks import ( 6 | YOLOV4Backbone, 7 | YOLOV4Network, 8 | YOLOV4P6Network, 9 | YOLOV4TinyBackbone, 10 | YOLOV4TinyNetwork, 11 | YOLOV5Backbone, 12 | YOLOV5Network, 13 | YOLOV7Backbone, 14 | YOLOV7Network, 15 | YOLOXNetwork, 16 | ) 17 | from pl_bolts.models.detection.yolo.yolo_module import YOLO 18 | 19 | __all__ = [ 20 | "components", 21 | "FasterRCNN", 22 | "RetinaNet", 23 | "DarknetNetwork", 24 | "YOLOV4Backbone", 25 | "YOLOV4Network", 26 | "YOLOV4P6Network", 27 | "YOLOV4TinyBackbone", 28 | "YOLOV4TinyNetwork", 29 | "YOLOV5Backbone", 30 | "YOLOV5Network", 31 | "YOLOV7Backbone", 32 | "YOLOV7Network", 33 | "YOLOXNetwork", 34 | "YOLO", 35 | ] 36 | -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/components/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.detection.components.torchvision_backbones import create_torchvision_backbone 2 | 3 | __all__ = ["create_torchvision_backbone"] 4 | -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/components/_supported_models.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 2 | from pl_bolts.utils.warnings import warn_missing_pkg 3 | 4 | if _TORCHVISION_AVAILABLE: 5 | import torchvision 6 | 7 | TORCHVISION_MODEL_ZOO = { 8 | "vgg11": torchvision.models.vgg11, 9 | "vgg13": torchvision.models.vgg13, 10 | "vgg16": torchvision.models.vgg16, 11 | "vgg19": torchvision.models.vgg19, 12 | "resnet18": torchvision.models.resnet18, 13 | "resnet34": torchvision.models.resnet34, 14 | "resnet50": torchvision.models.resnet50, 15 | "resnet101": torchvision.models.resnet101, 16 | "resnet152": torchvision.models.resnet152, 17 | "resnext50_32x4d": torchvision.models.resnext50_32x4d, 18 | "resnext50_32x8d": torchvision.models.resnext101_32x8d, 19 | "mnasnet0_5": torchvision.models.mnasnet0_5, 20 | "mnasnet0_75": torchvision.models.mnasnet0_75, 21 | "mnasnet1_0": torchvision.models.mnasnet1_0, 22 | "mnasnet1_3": torchvision.models.mnasnet1_3, 23 | "mobilenet_v2": torchvision.models.mobilenet_v2, 24 | } 25 | 26 | else: # pragma: no cover 27 | warn_missing_pkg("torchvision") 28 | TORCHVISION_MODEL_ZOO = {} 29 | -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/faster_rcnn/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.detection.faster_rcnn.backbones import create_fasterrcnn_backbone 2 | from pl_bolts.models.detection.faster_rcnn.faster_rcnn_module import FasterRCNN 3 | 4 | __all__ = ["create_fasterrcnn_backbone", "FasterRCNN"] 5 | -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/faster_rcnn/backbones.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch.nn as nn 4 | 5 | from pl_bolts.models.detection.components import create_torchvision_backbone 6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.stability import under_review 8 | from pl_bolts.utils.warnings import warn_missing_pkg 9 | 10 | if _TORCHVISION_AVAILABLE: 11 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone 12 | else: # pragma: no cover 13 | warn_missing_pkg("torchvision") 14 | 15 | 16 | @under_review() 17 | def create_fasterrcnn_backbone( 18 | backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any 19 | ) -> nn.Module: 20 | """ 21 | Args: 22 | backbone: 23 | Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", 24 | "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", 25 | as resnets with fpn backbones. 26 | Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", 27 | "resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19", 28 | fpn: If True then constructs fpn as well. 29 | pretrained: If None creates imagenet weights backbone. 30 | trainable_backbone_layers: number of trainable resnet layers starting from final block. 31 | """ 32 | 33 | if fpn: 34 | # Creates a torchvision resnet model with fpn added. 35 | backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs) 36 | else: 37 | # This does not create fpn backbone, it is supported for all models 38 | backbone, _ = create_torchvision_backbone(backbone, pretrained) 39 | return backbone 40 | -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/retinanet/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.detection.retinanet.backbones import create_retinanet_backbone 2 | from pl_bolts.models.detection.retinanet.retinanet_module import RetinaNet 3 | 4 | __all__ = ["create_retinanet_backbone", "RetinaNet"] 5 | -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/retinanet/backbones.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch.nn as nn 4 | 5 | from pl_bolts.models.detection.components import create_torchvision_backbone 6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.stability import under_review 8 | from pl_bolts.utils.warnings import warn_missing_pkg 9 | 10 | if _TORCHVISION_AVAILABLE: 11 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone 12 | else: # pragma: no cover 13 | warn_missing_pkg("torchvision") 14 | 15 | 16 | @under_review() 17 | def create_retinanet_backbone( 18 | backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any 19 | ) -> nn.Module: 20 | """ 21 | Args: 22 | backbone: 23 | Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", 24 | "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", 25 | as resnets with fpn backbones. 26 | Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", 27 | "resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19", 28 | fpn: If True then constructs fpn as well. 29 | pretrained: If None creates imagenet weights backbone. 30 | trainable_backbone_layers: number of trainable resnet layers starting from final block. 31 | """ 32 | 33 | if fpn: 34 | # Creates a torchvision resnet model with fpn added. 35 | backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs) 36 | else: 37 | # This does not create fpn backbone, it is supported for all models 38 | backbone, _ = create_torchvision_backbone(backbone, pretrained) 39 | return backbone 40 | -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/yolo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/detection/yolo/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/detection/yolo/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple, Union 2 | 3 | from torch import Tensor 4 | 5 | IMAGES = Union[Tuple[Tensor, ...], List[Tensor]] 6 | PRED = Dict[str, Any] 7 | PREDS = Union[Tuple[PRED, ...], List[PRED]] 8 | TARGET = Dict[str, Any] 9 | TARGETS = Union[Tuple[TARGET, ...], List[TARGET]] 10 | BATCH = Tuple[IMAGES, TARGETS] 11 | NETWORK_OUTPUT = Tuple[List[Tensor], List[Tensor], List[int]] # detections, losses, hits 12 | -------------------------------------------------------------------------------- /src/pl_bolts/models/gans/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.gans.basic.basic_gan_module import GAN 2 | from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN 3 | from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix 4 | from pl_bolts.models.gans.srgan.srgan_module import SRGAN 5 | from pl_bolts.models.gans.srgan.srresnet_module import SRResNet 6 | 7 | __all__ = [ 8 | "GAN", 9 | "DCGAN", 10 | "Pix2Pix", 11 | "SRGAN", 12 | "SRResNet", 13 | ] 14 | -------------------------------------------------------------------------------- /src/pl_bolts/models/gans/basic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/gans/basic/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/gans/basic/components.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F # noqa: N812 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, latent_dim, img_shape, hidden_dim=256) -> None: 9 | super().__init__() 10 | feats = int(np.prod(img_shape)) 11 | self.img_shape = img_shape 12 | self.fc1 = nn.Linear(latent_dim, hidden_dim) 13 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features * 2) 14 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features * 2) 15 | self.fc4 = nn.Linear(self.fc3.out_features, feats) 16 | 17 | # forward method 18 | def forward(self, z): 19 | z = F.leaky_relu(self.fc1(z), 0.2) 20 | z = F.leaky_relu(self.fc2(z), 0.2) 21 | z = F.leaky_relu(self.fc3(z), 0.2) 22 | img = torch.tanh(self.fc4(z)) 23 | return img.view(img.size(0), *self.img_shape) 24 | 25 | 26 | class Discriminator(nn.Module): 27 | def __init__(self, img_shape, hidden_dim=1024) -> None: 28 | super().__init__() 29 | in_dim = int(np.prod(img_shape)) 30 | self.fc1 = nn.Linear(in_dim, hidden_dim) 31 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2) 32 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2) 33 | self.fc4 = nn.Linear(self.fc3.out_features, 1) 34 | 35 | # forward method 36 | def forward(self, img): 37 | x = img.view(img.size(0), -1) 38 | x = F.leaky_relu(self.fc1(x), 0.2) 39 | x = F.dropout(x, 0.3) 40 | x = F.leaky_relu(self.fc2(x), 0.2) 41 | x = F.dropout(x, 0.3) 42 | x = F.leaky_relu(self.fc3(x), 0.2) 43 | x = F.dropout(x, 0.3) 44 | return torch.sigmoid(self.fc4(x)) 45 | -------------------------------------------------------------------------------- /src/pl_bolts/models/gans/dcgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/gans/dcgan/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/gans/pix2pix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/gans/pix2pix/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/gans/srgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/gans/srgan/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/regression/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.regression.linear_regression import LinearRegression 2 | from pl_bolts.models.regression.logistic_regression import LogisticRegression 3 | 4 | __all__ = [ 5 | "LinearRegression", 6 | "LogisticRegression", 7 | ] 8 | -------------------------------------------------------------------------------- /src/pl_bolts/models/rl/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic 2 | from pl_bolts.models.rl.double_dqn_model import DoubleDQN 3 | from pl_bolts.models.rl.dqn_model import DQN 4 | from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN 5 | from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN 6 | from pl_bolts.models.rl.per_dqn_model import PERDQN 7 | from pl_bolts.models.rl.reinforce_model import Reinforce 8 | from pl_bolts.models.rl.sac_model import SAC 9 | from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient 10 | 11 | __all__ = [ 12 | "AdvantageActorCritic", 13 | "DoubleDQN", 14 | "DQN", 15 | "DuelingDQN", 16 | "NoisyDQN", 17 | "PERDQN", 18 | "Reinforce", 19 | "SAC", 20 | "VanillaPolicyGradient", 21 | ] 22 | -------------------------------------------------------------------------------- /src/pl_bolts/models/rl/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/rl/common/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/rl/common/cli.py: -------------------------------------------------------------------------------- 1 | """Contains generic arguments used for all models.""" 2 | 3 | import argparse 4 | 5 | from pl_bolts.utils.stability import under_review 6 | 7 | 8 | @under_review() 9 | def add_base_args(parent) -> argparse.ArgumentParser: 10 | """Adds arguments for DQN model. 11 | 12 | Note: 13 | These params are fine tuned for Pong env. 14 | 15 | Args: 16 | parent 17 | 18 | """ 19 | arg_parser = argparse.ArgumentParser(parents=[parent]) 20 | 21 | arg_parser.add_argument("--algo", type=str, default="dqn", help="algorithm to use for training") 22 | arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") 23 | arg_parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 24 | 25 | arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") 26 | arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") 27 | 28 | arg_parser.add_argument("--episode_length", type=int, default=500, help="max length of an episode") 29 | arg_parser.add_argument("--max_episode_reward", type=int, default=18, help="max episode reward in the environment") 30 | arg_parser.add_argument( 31 | "--n_steps", 32 | type=int, 33 | default=4, 34 | help="how many steps to unroll for each update", 35 | ) 36 | arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") 37 | arg_parser.add_argument("--epoch_len", type=int, default=1000, help="how many batches per epoch") 38 | arg_parser.add_argument("--num_envs", type=int, default=1, help="number of environments to run at once") 39 | arg_parser.add_argument( 40 | "--avg_reward_len", type=int, default=100, help="how many episodes to include in avg reward" 41 | ) 42 | 43 | arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") 44 | return arg_parser 45 | -------------------------------------------------------------------------------- /src/pl_bolts/models/rl/common/distributions.py: -------------------------------------------------------------------------------- 1 | """Distributions used in some continuous RL algorithms.""" 2 | import torch 3 | 4 | from pl_bolts.utils.stability import under_review 5 | 6 | 7 | @under_review() 8 | class TanhMultivariateNormal(torch.distributions.MultivariateNormal): 9 | """The distribution of X is an affine of tanh applied on a normal distribution:: 10 | 11 | X = action_scale * tanh(Z) + action_bias 12 | Z ~ Normal(mean, variance) 13 | 14 | """ 15 | 16 | def __init__(self, action_bias, action_scale, **kwargs) -> None: 17 | super().__init__(**kwargs) 18 | 19 | self.action_bias = action_bias 20 | self.action_scale = action_scale 21 | 22 | def rsample_with_z(self, sample_shape=torch.Size()): 23 | """Samples X using reparametrization trick with the intermediate variable Z. 24 | 25 | Returns: 26 | Sampled X and Z 27 | 28 | """ 29 | z = super().rsample() 30 | return self.action_scale * torch.tanh(z) + self.action_bias, z 31 | 32 | def log_prob_with_z(self, value, z): 33 | """Computes the log probability of a sampled X. 34 | 35 | Refer to the original paper of SAC for more details in equation (20), (21) 36 | 37 | Args: 38 | value: the value of X 39 | z: the value of Z 40 | Returns: 41 | Log probability of the sample 42 | 43 | """ 44 | value = (value - self.action_bias) / self.action_scale 45 | z_logprob = super().log_prob(z) 46 | correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) 47 | return z_logprob - correction 48 | 49 | def rsample_and_log_prob(self, sample_shape=torch.Size()): 50 | """Samples X and computes the log probability of the sample. 51 | 52 | Returns: 53 | Sampled X and log probability 54 | 55 | """ 56 | z = super().rsample() 57 | z_logprob = super().log_prob(z) 58 | value = torch.tanh(z) 59 | correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) 60 | return self.action_scale * value + self.action_bias, z_logprob - correction 61 | 62 | def rsample(self, sample_shape=torch.Size()): 63 | fz, z = self.rsample_with_z(sample_shape) 64 | return fz 65 | 66 | def log_prob(self, value): 67 | value = (value - self.action_bias) / self.action_scale 68 | z = torch.log(1 + value) / 2 - torch.log(1 - value) / 2 69 | return self.log_prob_with_z(value, z) 70 | -------------------------------------------------------------------------------- /src/pl_bolts/models/rl/dueling_dqn_model.py: -------------------------------------------------------------------------------- 1 | """Dueling DQN.""" 2 | import argparse 3 | 4 | from pytorch_lightning import Trainer 5 | 6 | from pl_bolts.models.rl.common.networks import DuelingCNN 7 | from pl_bolts.models.rl.dqn_model import DQN 8 | from pl_bolts.utils.stability import under_review 9 | 10 | 11 | @under_review() 12 | class DuelingDQN(DQN): 13 | """PyTorch Lightning implementation of `Dueling DQN `_ 14 | 15 | Paper authors: Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas 16 | 17 | Model implemented by: 18 | 19 | - `Donal Byrne ` 20 | 21 | Example: 22 | 23 | >>> from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN 24 | ... 25 | >>> model = DuelingDQN("PongNoFrameskip-v4") 26 | 27 | Train:: 28 | 29 | trainer = Trainer() 30 | trainer.fit(model) 31 | 32 | .. note:: Currently only supports CPU and single GPU training with `accelerator=dp` 33 | """ 34 | 35 | def build_networks(self) -> None: 36 | """Initializes the Dueling DQN train and target networks.""" 37 | self.net = DuelingCNN(self.obs_shape, self.n_actions) 38 | self.target_net = DuelingCNN(self.obs_shape, self.n_actions) 39 | 40 | 41 | @under_review() 42 | def cli_main(): 43 | parser = argparse.ArgumentParser(add_help=False) 44 | 45 | # trainer args 46 | parser = Trainer.add_argparse_args(parser) 47 | 48 | # model args 49 | parser = DuelingDQN.add_model_specific_args(parser) 50 | args = parser.parse_args() 51 | 52 | model = DuelingDQN(**args.__dict__) 53 | 54 | trainer = Trainer.from_argparse_args(args) 55 | trainer.fit(model) 56 | 57 | 58 | if __name__ == "__main__": 59 | cli_main() 60 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/__init__.py: -------------------------------------------------------------------------------- 1 | """These models have been pre-trained using self-supervised learning. The models can also be used without pre- training 2 | and overwritten for your own research. 3 | 4 | Here's an example for using these as pretrained models. 5 | 6 | .. code-block :: 7 | 8 | from pl_bolts.models.self_supervised import CPC_v2 9 | 10 | images = get_imagenet_batch() 11 | 12 | # extract unsupervised representations 13 | pretrained = CPC_v2(pretrained=True) 14 | representations = pretrained(images) 15 | 16 | # use these in classification or any downstream task 17 | classifications = classifier(representations) 18 | 19 | """ 20 | from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM 21 | from pl_bolts.models.self_supervised.byol.byol_module import BYOL 22 | from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2 23 | from pl_bolts.models.self_supervised.evaluator import SSLEvaluator 24 | from pl_bolts.models.self_supervised.moco.moco_module import MoCo 25 | from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR 26 | from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam 27 | from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner 28 | from pl_bolts.models.self_supervised.swav.swav_module import SwAV 29 | 30 | __all__ = [ 31 | "AMDIM", 32 | "BYOL", 33 | "CPC_v2", 34 | "SSLEvaluator", 35 | "MoCo", 36 | "SimCLR", 37 | "SimSiam", 38 | "SSLFineTuner", 39 | "SwAV", 40 | ] 41 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/amdim/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM 2 | from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder 3 | from pl_bolts.transforms.self_supervised.amdim_transforms import ( 4 | AMDIMEvalTransformsCIFAR10, 5 | AMDIMEvalTransformsImageNet128, 6 | AMDIMEvalTransformsSTL10, 7 | AMDIMTrainTransformsCIFAR10, 8 | AMDIMTrainTransformsImageNet128, 9 | AMDIMTrainTransformsSTL10, 10 | ) 11 | 12 | __all__ = [ 13 | "AMDIM", 14 | "AMDIMEncoder", 15 | "AMDIMEvalTransformsCIFAR10", 16 | "AMDIMEvalTransformsImageNet128", 17 | "AMDIMEvalTransformsSTL10", 18 | "AMDIMTrainTransformsCIFAR10", 19 | "AMDIMTrainTransformsImageNet128", 20 | "AMDIMTrainTransformsSTL10", 21 | ] 22 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/byol/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/self_supervised/byol/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/byol/models.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | from torch import Tensor, nn 4 | 5 | from pl_bolts.utils.self_supervised import torchvision_ssl_encoder 6 | 7 | 8 | class MLP(nn.Module): 9 | """MLP architecture used as projectors in online and target networks and predictors in the online network. 10 | 11 | Args: 12 | input_dim (int, optional): Input dimension. Defaults to 2048. 13 | hidden_dim (int, optional): Hidden layer dimension. Defaults to 4096. 14 | output_dim (int, optional): Output dimension. Defaults to 256. 15 | 16 | Note: 17 | Default values for input, hidden, and output dimensions are based on values used in BYOL. 18 | 19 | """ 20 | 21 | def __init__(self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256) -> None: 22 | super().__init__() 23 | 24 | self.model = nn.Sequential( 25 | nn.Linear(input_dim, hidden_dim, bias=False), 26 | nn.BatchNorm1d(hidden_dim), 27 | nn.ReLU(inplace=True), 28 | nn.Linear(hidden_dim, output_dim, bias=True), 29 | ) 30 | 31 | def forward(self, x: Tensor) -> Tensor: 32 | return self.model(x) 33 | 34 | 35 | class SiameseArm(nn.Module): 36 | """SiameseArm consolidates the encoder and projector networks of BYOL's symmetric architecture into a single class. 37 | 38 | Args: 39 | encoder (Union[str, nn.Module], optional): Online and target network encoder architecture. 40 | Defaults to "resnet50". 41 | encoder_out_dim (int, optional): Output dimension of encoder. Defaults to 2048. 42 | projector_hidden_dim (int, optional): Online and target network projector network hidden dimension. 43 | Defaults to 4096. 44 | projector_out_dim (int, optional): Online and target network projector network output dimension. 45 | Defaults to 256. 46 | 47 | """ 48 | 49 | def __init__( 50 | self, 51 | encoder: Union[str, nn.Module] = "resnet50", 52 | encoder_out_dim: int = 2048, 53 | projector_hidden_dim: int = 4096, 54 | projector_out_dim: int = 256, 55 | ) -> None: 56 | super().__init__() 57 | 58 | if isinstance(encoder, str): 59 | self.encoder = torchvision_ssl_encoder(encoder) 60 | else: 61 | self.encoder = encoder 62 | 63 | self.projector = MLP(encoder_out_dim, projector_hidden_dim, projector_out_dim) 64 | 65 | def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: 66 | y = self.encoder(x)[0] 67 | z = self.projector(y) 68 | return y, z 69 | 70 | def encode(self, x: Tensor) -> Tensor: 71 | """Returns the encoded representation of a view. This method does not calculate the projection as in the forward 72 | method. 73 | 74 | Args: 75 | x (Tensor): sample to be encoded 76 | 77 | """ 78 | return self.encoder(x)[0] 79 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/cpc/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2 2 | from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet50, cpc_resnet101 3 | from pl_bolts.transforms.self_supervised.cpc_transforms import ( 4 | CPCEvalTransformsCIFAR10, 5 | CPCEvalTransformsImageNet128, 6 | CPCEvalTransformsSTL10, 7 | CPCTrainTransformsCIFAR10, 8 | CPCTrainTransformsImageNet128, 9 | CPCTrainTransformsSTL10, 10 | ) 11 | 12 | __all__ = [ 13 | "CPC_v2", 14 | "cpc_resnet50", 15 | "cpc_resnet101", 16 | "CPCEvalTransformsCIFAR10", 17 | "CPCEvalTransformsImageNet128", 18 | "CPCEvalTransformsSTL10", 19 | "CPCTrainTransformsCIFAR10", 20 | "CPCTrainTransformsImageNet128", 21 | "CPCTrainTransformsSTL10", 22 | ] 23 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | from pytorch_lightning import Trainer, seed_everything 5 | 6 | from pl_bolts.models.self_supervised import CPC_v2, SSLFineTuner 7 | from pl_bolts.transforms.self_supervised.cpc_transforms import ( 8 | CPCEvalTransformsCIFAR10, 9 | CPCEvalTransformsSTL10, 10 | CPCTrainTransformsCIFAR10, 11 | CPCTrainTransformsSTL10, 12 | ) 13 | from pl_bolts.utils.stability import under_review 14 | 15 | 16 | @under_review() 17 | def cli_main(): # pragma: no cover 18 | from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule 19 | 20 | seed_everything(1234) 21 | 22 | parser = ArgumentParser() 23 | parser = Trainer.add_argparse_args(parser) 24 | parser.add_argument("--dataset", type=str, help="stl10, cifar10", default="cifar10") 25 | parser.add_argument("--ckpt_path", type=str, help="path to ckpt") 26 | parser.add_argument("--data_dir", type=str, help="path to ckpt", default=os.getcwd()) 27 | args = parser.parse_args() 28 | 29 | # load the backbone 30 | backbone = CPC_v2.load_from_checkpoint(args.ckpt_path, strict=False) 31 | 32 | if args.dataset == "cifar10": 33 | dm = CIFAR10DataModule.from_argparse_args(args) 34 | dm.train_transforms = CPCTrainTransformsCIFAR10() 35 | dm.val_transforms = CPCEvalTransformsCIFAR10() 36 | dm.test_transforms = CPCEvalTransformsCIFAR10() 37 | 38 | elif args.dataset == "stl10": 39 | dm = STL10DataModule.from_argparse_args(args) 40 | dm.train_dataloader = dm.train_dataloader_labeled 41 | dm.val_dataloader = dm.val_dataloader_labeled 42 | dm.train_transforms = CPCTrainTransformsSTL10() 43 | dm.val_transforms = CPCEvalTransformsSTL10() 44 | dm.test_transforms = CPCEvalTransformsSTL10() 45 | 46 | # finetune 47 | tuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes) 48 | trainer = Trainer.from_argparse_args(args, early_stop_callback=True) 49 | trainer.fit(tuner, dm) 50 | 51 | trainer.test(datamodule=dm) 52 | 53 | 54 | if __name__ == "__main__": 55 | cli_main() 56 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/evaluator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from pl_bolts.utils.stability import under_review 4 | 5 | 6 | @under_review() 7 | class SSLEvaluator(nn.Module): 8 | def __init__(self, n_input, n_classes, n_hidden=512, p=0.1) -> None: 9 | super().__init__() 10 | self.n_input = n_input 11 | self.n_classes = n_classes 12 | self.n_hidden = n_hidden 13 | if n_hidden is None: 14 | # use linear classifier 15 | self.block_forward = nn.Sequential(Flatten(), nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True)) 16 | else: 17 | # use simple MLP classifier 18 | self.block_forward = nn.Sequential( 19 | Flatten(), 20 | nn.Dropout(p=p), 21 | nn.Linear(n_input, n_hidden, bias=False), 22 | nn.BatchNorm1d(n_hidden), 23 | nn.ReLU(inplace=True), 24 | nn.Dropout(p=p), 25 | nn.Linear(n_hidden, n_classes, bias=True), 26 | ) 27 | 28 | def forward(self, x): 29 | return self.block_forward(x) 30 | 31 | 32 | @under_review() 33 | class Flatten(nn.Module): 34 | def __init__(self) -> None: 35 | super().__init__() 36 | 37 | def forward(self, input_tensor): 38 | return input_tensor.view(input_tensor.size(0), -1) 39 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/moco/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.transforms.self_supervised.moco_transforms import ( # noqa: F401 2 | MoCo2EvalCIFAR10Transforms, 3 | MoCo2EvalImagenetTransforms, 4 | MoCo2EvalSTL10Transforms, 5 | MoCo2TrainCIFAR10Transforms, 6 | MoCo2TrainImagenetTransforms, 7 | MoCo2TrainSTL10Transforms, 8 | ) 9 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/moco/callbacks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from pytorch_lightning import Callback 4 | 5 | from pl_bolts.utils.stability import under_review 6 | 7 | 8 | @under_review() 9 | class MoCoLRScheduler(Callback): 10 | def __init__(self, initial_lr=0.03, use_cosine_scheduler=False, schedule=(120, 160), max_epochs=200) -> None: 11 | super().__init__() 12 | self.lr = initial_lr 13 | self.use_cosine_scheduler = use_cosine_scheduler 14 | self.schedule = schedule 15 | self.max_epochs = max_epochs 16 | 17 | def on_train_epoch_start(self, trainer, pl_module): 18 | epoch = trainer.current_epoch 19 | lr = self.lr 20 | 21 | if self.use_cosine_scheduler: # cosine lr schedule 22 | lr *= 0.5 * (1.0 + math.cos(math.pi * epoch / self.max_epochs)) 23 | else: # stepwise lr schedule 24 | for milestone in self.schedule: 25 | lr *= 0.1 if epoch >= milestone else 1.0 26 | 27 | optimizer = trainer.optimizers[0] 28 | for param_group in optimizer.param_groups: 29 | param_group["lr"] = lr 30 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/simclr/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.transforms.self_supervised.simclr_transforms import ( # noqa: F401 2 | SimCLREvalDataTransform, 3 | SimCLRTrainDataTransform, 4 | ) 5 | -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/simsiam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/self_supervised/simsiam/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/self_supervised/swav/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.self_supervised.swav.loss import SWAVLoss 2 | from pl_bolts.models.self_supervised.swav.swav_module import SwAV 3 | from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 4 | from pl_bolts.transforms.self_supervised.swav_transforms import ( 5 | SwAVEvalDataTransform, 6 | SwAVFinetuneTransform, 7 | SwAVTrainDataTransform, 8 | ) 9 | 10 | __all__ = [ 11 | "SwAV", 12 | "resnet18", 13 | "resnet50", 14 | "SwAVEvalDataTransform", 15 | "SwAVFinetuneTransform", 16 | "SwAVTrainDataTransform", 17 | "SWAVLoss", 18 | ] 19 | -------------------------------------------------------------------------------- /src/pl_bolts/models/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.vision.image_gpt.gpt2 import GPT2 2 | from pl_bolts.models.vision.image_gpt.igpt_module import ImageGPT 3 | from pl_bolts.models.vision.pixel_cnn import PixelCNN 4 | from pl_bolts.models.vision.segmentation import SemSegment 5 | from pl_bolts.models.vision.unet import UNet 6 | 7 | __all__ = [ 8 | "GPT2", 9 | "ImageGPT", 10 | "PixelCNN", 11 | "SemSegment", 12 | "UNet", 13 | ] 14 | -------------------------------------------------------------------------------- /src/pl_bolts/models/vision/image_gpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/vision/image_gpt/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/models/vision/pixel_cnn.py: -------------------------------------------------------------------------------- 1 | """PixelCNN. 2 | 3 | Implemented by: William Falcon Reference 4 | : https: //arxiv.org/pdf/1905.09272.pdf (page 15 Accessed: May 14, 2020. 5 | 6 | """ 7 | from torch import nn 8 | from torch.nn import functional as F # noqa: N812 9 | 10 | from pl_bolts.utils.stability import under_review 11 | 12 | 13 | @under_review() 14 | class PixelCNN(nn.Module): 15 | """Implementation of `Pixel CNN `_. 16 | 17 | Paper authors: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex Graves, 18 | Koray Kavukcuoglu 19 | 20 | Implemented by: 21 | 22 | - William Falcon 23 | 24 | Example:: 25 | 26 | >>> from pl_bolts.models.vision import PixelCNN 27 | >>> import torch 28 | ... 29 | >>> model = PixelCNN(input_channels=3) 30 | >>> x = torch.rand(5, 3, 64, 64) 31 | >>> out = model(x) 32 | ... 33 | >>> out.shape 34 | torch.Size([5, 3, 64, 64]) 35 | """ 36 | 37 | def __init__(self, input_channels: int, hidden_channels: int = 256, num_blocks=5) -> None: 38 | super().__init__() 39 | self.input_channels = input_channels 40 | self.hidden_channels = hidden_channels 41 | 42 | self.blocks = nn.ModuleList([self.conv_block(input_channels) for _ in range(num_blocks)]) 43 | 44 | def conv_block(self, input_channels): 45 | c1 = nn.Conv2d(in_channels=input_channels, out_channels=self.hidden_channels, kernel_size=(1, 1)) 46 | act1 = nn.ReLU() 47 | c2 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=(1, 3)) 48 | pad = nn.ConstantPad2d((0, 0, 1, 0, 0, 0, 0, 0), 1) 49 | c3 = nn.Conv2d( 50 | in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=(2, 1), padding=(0, 1) 51 | ) 52 | act2 = nn.ReLU() 53 | c4 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=input_channels, kernel_size=(1, 1)) 54 | 55 | return nn.Sequential(c1, act1, c2, pad, c3, act2, c4) 56 | 57 | def forward(self, z): 58 | c = z 59 | for conv_block in self.blocks: 60 | c = c + conv_block(c) 61 | 62 | return F.relu(c) 63 | -------------------------------------------------------------------------------- /src/pl_bolts/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.optimizers.lars import LARS 2 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR, linear_warmup_decay 3 | 4 | __all__ = [ 5 | "LARS", 6 | "LinearWarmupCosineAnnealingLR", 7 | "linear_warmup_decay", 8 | ] 9 | -------------------------------------------------------------------------------- /src/pl_bolts/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/transforms/__init__.py -------------------------------------------------------------------------------- /src/pl_bolts/transforms/dataset_normalizations.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 4 | from pl_bolts.utils.warnings import warn_missing_pkg 5 | 6 | if _TORCHVISION_AVAILABLE: 7 | from torchvision import transforms 8 | else: # pragma: no cover 9 | warn_missing_pkg("torchvision") 10 | 11 | 12 | def imagenet_normalization() -> Callable: 13 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 14 | raise ModuleNotFoundError( 15 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." 16 | ) 17 | 18 | return transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 19 | 20 | 21 | def cifar10_normalization() -> Callable: 22 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 23 | raise ModuleNotFoundError( 24 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." 25 | ) 26 | 27 | return transforms.Normalize( 28 | mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 29 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]], 30 | ) 31 | 32 | 33 | def stl10_normalization() -> Callable: 34 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 35 | raise ModuleNotFoundError( 36 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." 37 | ) 38 | 39 | return transforms.Normalize(mean=(0.43, 0.42, 0.39), std=(0.27, 0.26, 0.27)) 40 | 41 | 42 | def emnist_normalization(split: str) -> Callable: 43 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 44 | raise ModuleNotFoundError( 45 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." 46 | ) 47 | 48 | # `stats` contains mean and std for each `split`. 49 | stats = { 50 | "balanced": (0.175, 0.333), 51 | "byclass": (0.174, 0.332), 52 | "bymerge": (0.174, 0.332), 53 | "digits": (0.173, 0.332), 54 | "letters": (0.172, 0.331), 55 | "mnist": (0.173, 0.332), 56 | } 57 | 58 | return transforms.Normalize(mean=stats[split][0], std=stats[split][1]) 59 | -------------------------------------------------------------------------------- /src/pl_bolts/transforms/self_supervised/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.transforms.self_supervised.ssl_transforms import Patchify, RandomTranslateWithReflect # noqa: F401 2 | -------------------------------------------------------------------------------- /src/pl_bolts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import platform 3 | 4 | import torch 5 | from lightning_utilities.core.imports import compare_version, module_available 6 | 7 | from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore 8 | 9 | _NATIVE_AMP_AVAILABLE: bool = module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") 10 | _IS_WINDOWS = platform.system() == "Windows" 11 | _TORCH_ORT_AVAILABLE = module_available("torch_ort") 12 | _TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0") 13 | _TORCHVISION_AVAILABLE: bool = module_available("torchvision") 14 | _TORCHVISION_LESS_THAN_0_9_1: bool = compare_version("torchvision", operator.lt, "0.9.1") 15 | _TORCHVISION_LESS_THAN_0_13: bool = compare_version("torchvision", operator.le, "0.13.0") 16 | _TORCHMETRICS_DETECTION_AVAILABLE: bool = module_available("torchmetrics.detection") 17 | _PL_GREATER_EQUAL_1_4 = compare_version("pytorch_lightning", operator.ge, "1.4.0") 18 | _PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5") 19 | _GYM_AVAILABLE: bool = module_available("gym") 20 | _GYM_GREATER_EQUAL_0_20 = compare_version("gym", operator.ge, "0.20.0") 21 | _SKLEARN_AVAILABLE: bool = module_available("sklearn") 22 | _PIL_AVAILABLE: bool = module_available("PIL") 23 | _SPARSEML_AVAILABLE = module_available("sparseml") 24 | _OPENCV_AVAILABLE: bool = module_available("cv2") 25 | _WANDB_AVAILABLE: bool = module_available("wandb") 26 | _MATPLOTLIB_AVAILABLE: bool = module_available("matplotlib") 27 | _JSONARGPARSE_GREATER_THAN_4_16_0 = compare_version("jsonargparse", operator.gt, "4.16.0") 28 | 29 | _SPARSEML_TORCH_SATISFIED = _SPARSEML_AVAILABLE 30 | _SPARSEML_TORCH_SATISFIED_ERROR = None 31 | if _SPARSEML_AVAILABLE: 32 | try: 33 | from sparseml.pytorch.base import _TORCH_MAX_VERSION # noqa: F401 34 | 35 | except (ImportError, ValueError) as err: 36 | _SPARSEML_TORCH_SATISFIED = False 37 | _SPARSEML_TORCH_SATISFIED_ERROR = str(err) 38 | 39 | __all__ = ["BatchGradientVerification"] 40 | -------------------------------------------------------------------------------- /src/pl_bolts/utils/_dependency.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from typing import Any, Callable 4 | 5 | from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache 6 | 7 | 8 | # ToDo: replace with utils wrapper after 0.10 is released 9 | def requires(*module_path_version: str) -> Callable: 10 | """Wrapper for enforcing certain requirements for a particular class or function.""" 11 | 12 | def decorator(func: Callable) -> Callable: 13 | reqs = [ 14 | ModuleAvailableCache(mod_ver) if "." in mod_ver else RequirementCache(mod_ver) 15 | for mod_ver in module_path_version 16 | ] 17 | available = all(map(bool, reqs)) 18 | if not available: 19 | 20 | @functools.wraps(func) 21 | def wrapper(*args: Any, **kwargs: Any) -> Any: 22 | msg = os.linesep.join([repr(r) for r in reqs if not bool(r)]) 23 | raise ModuleNotFoundError(f"Required dependencies not available: \n{msg}") 24 | 25 | return wrapper 26 | return func 27 | 28 | return decorator 29 | -------------------------------------------------------------------------------- /src/pl_bolts/utils/pretrained_weights.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pytorch_lightning import LightningModule 4 | 5 | from pl_bolts.utils.stability import under_review 6 | 7 | vae_imagenet2012 = ( 8 | "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/imagenet_06_22_2019/checkpoints/epoch%3D63.ckpt" 9 | ) 10 | 11 | cpcv2_resnet18 = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/resnet18-v6/epoch%3D85.ckpt" 12 | urls = {"vae-imagenet2012": vae_imagenet2012, "CPC_v2-resnet18": cpcv2_resnet18} 13 | 14 | 15 | @under_review() 16 | def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no cover 17 | if class_name is None: 18 | class_name = model.__class__.__name__ 19 | ckpt_url = urls[class_name] 20 | weights_model = model.__class__.load_from_checkpoint(ckpt_url) 21 | model.load_state_dict(weights_model.state_dict()) 22 | -------------------------------------------------------------------------------- /src/pl_bolts/utils/self_supervised.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | 3 | from pl_bolts.models.self_supervised import resnets 4 | from pl_bolts.utils.semi_supervised import Identity 5 | from pl_bolts.utils.stability import under_review 6 | 7 | 8 | @under_review() 9 | def torchvision_ssl_encoder( 10 | name: str, 11 | pretrained: bool = False, 12 | return_all_feature_maps: bool = False, 13 | ) -> Module: 14 | pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps) 15 | pretrained_model.fc = Identity() 16 | return pretrained_model 17 | -------------------------------------------------------------------------------- /src/pl_bolts/utils/shaping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | 5 | from pl_bolts.utils.stability import under_review 6 | 7 | 8 | @under_review() 9 | def tile(a: Tensor, dim: int, n_tile: int) -> Tensor: 10 | init_dim = a.size(dim) 11 | repeat_idx = [1] * a.dim() 12 | repeat_idx[dim] = n_tile 13 | a = a.repeat(*repeat_idx) 14 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 15 | return torch.index_select(a, dim, order_index) 16 | -------------------------------------------------------------------------------- /src/pl_bolts/utils/types.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | TArrays = Union[torch.Tensor, np.ndarray, Sequence[float], Sequence["TArrays"]] 7 | -------------------------------------------------------------------------------- /src/pl_bolts/utils/warnings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Callable, Dict, Optional 4 | 5 | MISSING_PACKAGE_WARNINGS: Dict[str, int] = {} 6 | 7 | WARN_MISSING_PACKAGE = int(os.environ.get("WARN_MISSING_PACKAGE", False)) 8 | 9 | 10 | def warn_missing_pkg( 11 | pkg_name: str, 12 | pypi_name: Optional[str] = None, 13 | extra_text: Optional[str] = None, 14 | stdout_func: Callable = warnings.warn, 15 | ) -> int: 16 | """Template for warning on missing packages, show them just once. 17 | 18 | Args: 19 | pkg_name: Name of missing package 20 | pypi_name: In case that package name differ from PyPI name 21 | extra_text: Additional text after the base warning 22 | stdout_func: Define used function for streaming warning, use ``warnings.warn`` or ``logging.warning`` 23 | 24 | Returns: 25 | Number of warning calls 26 | 27 | """ 28 | if not WARN_MISSING_PACKAGE: 29 | return -1 30 | 31 | if pkg_name not in MISSING_PACKAGE_WARNINGS: 32 | extra_text = os.linesep + extra_text if extra_text else "" 33 | if not pypi_name: 34 | pypi_name = pkg_name 35 | stdout_func( 36 | f"You want to use `{pkg_name}` which is not installed yet," 37 | f" install it with `pip install {pypi_name}`." + extra_text 38 | ) 39 | MISSING_PACKAGE_WARNINGS[pkg_name] = 1 40 | else: 41 | MISSING_PACKAGE_WARNINGS[pkg_name] += 1 42 | 43 | return MISSING_PACKAGE_WARNINGS[pkg_name] 44 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | ## How to run tests 2 | 3 | The easiest way to run tests is by `make test`. Please note, that it may take quite some time, especially with unreliable internet connection since it is downloading datasets. 4 | 5 | If you want to run specific tests, first make sure your environment is up to date by running `make env && make clean`. Then you can target specific tests such as `pytest "tests/models/gans/integration/test_gans.py::test_gan"`. 6 | 7 | If you want to run doctests, run `make doctest` from root directory. 8 | 9 | ## Testing compatibility 10 | 11 | Following PR [#844](https://github.com/Lightning-AI/lightning-bolts/pull/844), all of the new tests are required to use `catch_warnings` fixture in order to filter out compatibility issues. In the future, this fixture will be marked as `autouse=True`, however until then, please add them yourself. 12 | 13 | ### Example 14 | 15 | ```python 16 | def test_this_thing(catch_warnings): ... 17 | ``` 18 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from pytorch_lightning import seed_everything 5 | 6 | TEST_ROOT = os.path.realpath(os.path.dirname(__file__)) 7 | PROJECT_ROOT = os.path.dirname(TEST_ROOT) 8 | DATASETS_PATH = os.environ.get("DATASETS_PATH", os.path.join(PROJECT_ROOT, "_datasets")) 9 | # generate a list of random seeds for each test 10 | ROOT_SEED = 1234 11 | 12 | _MARK_REQUIRE_GPU = {"condition": not torch.cuda.is_available(), "reason": "test requires GPU machine"} 13 | 14 | 15 | def reset_seed(): 16 | seed_everything() 17 | -------------------------------------------------------------------------------- /tests/_data_configs/yolo.cfg: -------------------------------------------------------------------------------- 1 | [net] 2 | width=256 3 | height=256 4 | channels=3 5 | 6 | [convolutional] 7 | batch_normalize=1 8 | filters=8 9 | size=3 10 | stride=1 11 | pad=1 12 | activation=leaky 13 | 14 | [route] 15 | layers=-1 16 | groups=2 17 | group_id=1 18 | 19 | [maxpool] 20 | size=2 21 | stride=2 22 | 23 | [convolutional] 24 | batch_normalize=1 25 | filters=2 26 | size=1 27 | stride=1 28 | pad=1 29 | activation=mish 30 | 31 | [convolutional] 32 | batch_normalize=1 33 | filters=4 34 | size=3 35 | stride=1 36 | pad=1 37 | activation=mish 38 | 39 | [shortcut] 40 | from=-3 41 | activation=linear 42 | 43 | [convolutional] 44 | size=1 45 | stride=1 46 | pad=1 47 | filters=14 48 | activation=linear 49 | 50 | [yolo] 51 | mask=2,3 52 | anchors=1,2, 3,4, 5,6, 9,10 53 | classes=2 54 | scale_x_y=1.05 55 | cls_normalizer=1.0 56 | iou_normalizer=0.07 57 | ignore_thresh=0.7 58 | 59 | [route] 60 | layers = -4 61 | 62 | [upsample] 63 | stride=2 64 | 65 | [convolutional] 66 | size=1 67 | stride=1 68 | pad=1 69 | filters=14 70 | activation=linear 71 | 72 | [yolo] 73 | mask=0,1 74 | anchors=1,2, 3,4, 5,6, 9,10 75 | classes=2 76 | scale_x_y=1.05 77 | cls_normalizer=1.0 78 | iou_normalizer=0.07 79 | ignore_thresh=0.7 80 | -------------------------------------------------------------------------------- /tests/_data_configs/yolo_giou.cfg: -------------------------------------------------------------------------------- 1 | [net] 2 | width=256 3 | height=256 4 | channels=3 5 | 6 | [convolutional] 7 | batch_normalize=1 8 | filters=8 9 | size=3 10 | stride=1 11 | pad=1 12 | activation=leaky 13 | 14 | [route] 15 | layers=-1 16 | groups=2 17 | group_id=1 18 | 19 | [maxpool] 20 | size=2 21 | stride=2 22 | 23 | [convolutional] 24 | batch_normalize=1 25 | filters=2 26 | size=1 27 | stride=1 28 | pad=1 29 | activation=mish 30 | 31 | [convolutional] 32 | batch_normalize=1 33 | filters=4 34 | size=3 35 | stride=1 36 | pad=1 37 | activation=mish 38 | 39 | [shortcut] 40 | from=-3 41 | activation=linear 42 | 43 | [convolutional] 44 | size=1 45 | stride=1 46 | pad=1 47 | filters=14 48 | activation=linear 49 | 50 | [yolo] 51 | mask=2,3 52 | anchors=1,2, 3,4, 5,6, 9,10 53 | classes=2 54 | iou_loss=giou 55 | scale_x_y=1.05 56 | cls_normalizer=1.0 57 | iou_normalizer=0.07 58 | ignore_thresh=0.7 59 | 60 | [route] 61 | layers = -4 62 | 63 | [upsample] 64 | stride=2 65 | 66 | [convolutional] 67 | size=1 68 | stride=1 69 | pad=1 70 | filters=14 71 | activation=linear 72 | 73 | [yolo] 74 | mask=0,1 75 | anchors=1,2, 3,4, 5,6, 9,10 76 | classes=2 77 | iou_loss=giou 78 | scale_x_y=1.05 79 | cls_normalizer=1.0 80 | iou_normalizer=0.07 81 | ignore_thresh=0.7 82 | -------------------------------------------------------------------------------- /tests/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/callbacks/__init__.py -------------------------------------------------------------------------------- /tests/callbacks/test_info_callbacks.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.callbacks import PrintTableMetricsCallback 2 | 3 | 4 | def test_printtable_metrics_callback(): 5 | callback = PrintTableMetricsCallback() 6 | 7 | metrics_a = {"loss": 1.0, "epoch": 0} 8 | metrics_b = {"loss": 0.5, "epoch": 2} 9 | 10 | class FakeTrainer: 11 | def __init__(self) -> None: 12 | self.callback_metrics = {} 13 | 14 | fake_trainer = FakeTrainer() 15 | fake_trainer.callback_metrics = metrics_a 16 | callback.on_train_epoch_end(fake_trainer, None) 17 | fake_trainer.callback_metrics = metrics_b 18 | callback.on_train_epoch_end(fake_trainer, None) 19 | 20 | assert len(callback.metrics) == 2 21 | assert callback.metrics[0] == metrics_a 22 | assert callback.metrics[1] == metrics_b 23 | -------------------------------------------------------------------------------- /tests/callbacks/test_ort.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | from pl_bolts.callbacks import ORTCallback 17 | from pl_bolts.utils import _TORCH_ORT_AVAILABLE 18 | from pytorch_lightning import Callback, Trainer 19 | from pytorch_lightning.core.lightning import LightningModule 20 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 21 | 22 | from tests.helpers.boring_model import BoringModel 23 | 24 | if _TORCH_ORT_AVAILABLE: 25 | from torch_ort import ORTModule 26 | 27 | 28 | @pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") 29 | def test_init_train_enable_ort(tmpdir): 30 | class TestCallback(Callback): 31 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 32 | assert isinstance(pl_module.model, ORTModule) 33 | 34 | class TestModel(BoringModel): 35 | def __init__(self) -> None: 36 | super().__init__() 37 | self.model = self.layer 38 | 39 | def forward(self, x): 40 | return self.model(x) 41 | 42 | model = TestModel() 43 | trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[ORTCallback(), TestCallback()]) 44 | trainer.fit(model) 45 | trainer.test(model) 46 | 47 | 48 | @pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") 49 | def test_ort_callback_fails_no_model(tmpdir): 50 | model = BoringModel() 51 | trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback()) 52 | with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"): 53 | trainer.fit( 54 | model, 55 | ) 56 | -------------------------------------------------------------------------------- /tests/callbacks/test_param_update_callbacks.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pytest 4 | import torch 5 | from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate 6 | from torch import nn 7 | 8 | 9 | @pytest.mark.parametrize("initial_tau", [-0.1, 0.0, 0.996, 1.0, 1.1]) 10 | def test_byol_ma_weight_single_update_callback(initial_tau, catch_warnings): 11 | """Check BYOL exponential moving average weight update rule for a single update.""" 12 | if 0.0 <= initial_tau <= 1.0: 13 | # Create simple one layer network and their copies 14 | online_network = nn.Linear(100, 10) 15 | target_network = deepcopy(online_network) 16 | online_network_copy = deepcopy(online_network) 17 | target_network_copy = deepcopy(target_network) 18 | 19 | # Check parameters are equal 20 | assert torch.equal(next(iter(online_network.parameters()))[0], next(iter(target_network.parameters()))[0]) 21 | 22 | # Simulate weight update 23 | opt = torch.optim.SGD(online_network.parameters(), lr=0.1) 24 | y = online_network(torch.randn(3, 100)) 25 | loss = y.sum() 26 | loss.backward() 27 | opt.step() 28 | opt.zero_grad() 29 | 30 | # Check online network update 31 | assert not torch.equal( 32 | next(iter(online_network.parameters()))[0], next(iter(online_network_copy.parameters()))[0] 33 | ) 34 | 35 | # Update target network weights via callback 36 | cb = BYOLMAWeightUpdate(initial_tau) 37 | cb.update_weights(online_network, target_network) 38 | 39 | # Check target network update according to value of tau 40 | if initial_tau == 0.0: 41 | assert torch.equal(next(iter(target_network.parameters()))[0], next(iter(online_network.parameters()))[0]) 42 | elif initial_tau == 1.0: 43 | assert torch.equal( 44 | next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0] 45 | ) 46 | else: 47 | for online_p, target_p in zip(online_network.parameters(), target_network_copy.parameters()): 48 | target_p.data = initial_tau * target_p.data + (1.0 - initial_tau) * online_p.data 49 | 50 | assert torch.equal( 51 | next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0] 52 | ) 53 | else: 54 | with pytest.raises(ValueError, match="initial tau should be"): 55 | cb = BYOLMAWeightUpdate(initial_tau) 56 | -------------------------------------------------------------------------------- /tests/callbacks/test_variational_callbacks.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.callbacks import LatentDimInterpolator 2 | from pl_bolts.models.gans import GAN 3 | 4 | try: 5 | from pytorch_lightning.loggers.logger import DummyLogger # PL v1.9+ 6 | except ModuleNotFoundError: 7 | from pytorch_lightning.loggers.base import DummyLogger # PL v1.8 8 | 9 | 10 | def test_latent_dim_interpolator(): 11 | class FakeTrainer: 12 | def __init__(self) -> None: 13 | self.current_epoch = 1 14 | self.global_step = 1 15 | self.logger = DummyLogger() 16 | 17 | model = GAN(3, 28, 28) 18 | cb = LatentDimInterpolator(interpolate_epoch_interval=2) 19 | 20 | cb.on_train_epoch_end(FakeTrainer(), model) 21 | -------------------------------------------------------------------------------- /tests/callbacks/verification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/callbacks/verification/__init__.py -------------------------------------------------------------------------------- /tests/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/datamodules/__init__.py -------------------------------------------------------------------------------- /tests/datamodules/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pl_bolts.datamodules.async_dataloader import AsynchronousLoader 3 | from pl_bolts.datasets.cifar10_dataset import CIFAR10 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | def test_async_dataloader(datadir): 8 | ds = CIFAR10(data_dir=datadir) 9 | 10 | if torch.cuda.device_count() > 0: # Can only run this test with a GPU 11 | device = torch.device("cuda", 0) 12 | dataloader = AsynchronousLoader(ds, device=device) 13 | 14 | for b in dataloader: 15 | pass 16 | 17 | dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) 18 | for b in dataloader: 19 | pass 20 | -------------------------------------------------------------------------------- /tests/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/datasets/__init__.py -------------------------------------------------------------------------------- /tests/datasets/test_array_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from pl_bolts.datasets import ArrayDataset, DataModel 5 | from pl_bolts.datasets.utils import to_tensor 6 | from pytorch_lightning.utilities import exceptions 7 | 8 | 9 | class TestArrayDataset: 10 | @pytest.fixture() 11 | def array_dataset(self): 12 | features_1 = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]], transform=to_tensor) 13 | target_1 = DataModel(data=[1, 0, 0, 1], transform=to_tensor) 14 | 15 | features_2 = DataModel(data=np.array([[2, 1, -5, 1], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]])) 16 | target_2 = DataModel(data=[1, 0, 1, 1]) 17 | return ArrayDataset(features_1, target_1, features_2, target_2) 18 | 19 | def test_len(self, array_dataset): 20 | assert len(array_dataset) == 4 21 | 22 | def test_getitem_with_transforms(self, array_dataset): 23 | assert len(array_dataset[0]) == 4 24 | assert len(array_dataset[1]) == 4 25 | assert len(array_dataset[2]) == 4 26 | assert len(array_dataset[3]) == 4 27 | torch.testing.assert_close(array_dataset[0][0], torch.tensor([1, 0, -1, 2])) 28 | torch.testing.assert_close(array_dataset[0][1], torch.tensor(1)) 29 | np.testing.assert_array_equal(array_dataset[0][2], np.array([2, 1, -5, 1])) 30 | assert array_dataset[0][3] == 1 31 | torch.testing.assert_close(array_dataset[1][0], torch.tensor([1, 0, -2, -1])) 32 | torch.testing.assert_close(array_dataset[1][1], torch.tensor(0)) 33 | np.testing.assert_array_equal(array_dataset[1][2], np.array([1, 0, -2, -1])) 34 | assert array_dataset[1][3] == 0 35 | torch.testing.assert_close(array_dataset[2][0], torch.tensor([2, 5, 0, 3])) 36 | torch.testing.assert_close(array_dataset[2][1], torch.tensor(0)) 37 | np.testing.assert_array_equal(array_dataset[2][2], np.array([2, 5, 0, 3])) 38 | assert array_dataset[2][3] == 1 39 | torch.testing.assert_close(array_dataset[3][0], torch.tensor([-7, 1, 2, 2])) 40 | torch.testing.assert_close(array_dataset[3][1], torch.tensor(1)) 41 | np.testing.assert_array_equal(array_dataset[3][2], np.array([-7, 1, 2, 2])) 42 | assert array_dataset[3][3] == 1 43 | 44 | def test__equal_size_true(self, array_dataset): 45 | assert array_dataset._equal_size() is True 46 | 47 | def test__equal_size_false(self): 48 | features = DataModel(data=[[1, 0, 1]]) 49 | target = DataModel([1, 0, 1]) 50 | with pytest.raises(exceptions.MisconfigurationException): 51 | ArrayDataset(features, target) 52 | -------------------------------------------------------------------------------- /tests/datasets/test_base_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from pl_bolts.datasets.base_dataset import DataModel 5 | from pl_bolts.datasets.utils import to_tensor 6 | from pl_bolts.utils import _IS_WINDOWS 7 | 8 | 9 | class TestDataModel: 10 | @pytest.fixture() 11 | def data(self): 12 | return np.array([[1, 0, 0, 1], [0, 1, 1, 0]]) 13 | 14 | def test_process_transform_is_none(self, data): 15 | dm = DataModel(data=data) 16 | np.testing.assert_array_equal(dm.process(data[0]), data[0]) 17 | np.testing.assert_array_equal(dm.process(data[1]), data[1]) 18 | 19 | @pytest.mark.skipif( # todo 20 | _IS_WINDOWS, reason="AssertionError: The values for attribute 'dtype' do not match: torch.int32 != torch.int64" 21 | ) 22 | def test_process_transform_is_not_none(self, data): 23 | dm = DataModel(data=data, transform=to_tensor) 24 | torch.testing.assert_close(dm.process(data[0]), torch.tensor([1, 0, 0, 1])) 25 | torch.testing.assert_close(dm.process(data[1]), torch.tensor([0, 1, 1, 0])) 26 | -------------------------------------------------------------------------------- /tests/datasets/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.testing 3 | from pl_bolts.datasets.utils import to_tensor 4 | 5 | 6 | class TestToTensor: 7 | def test_to_tensor_list(self): 8 | _list = [1, 2, 3] 9 | torch.testing.assert_close(to_tensor(_list), torch.tensor(_list)) 10 | 11 | def test_to_tensor_array(self): 12 | _array = np.array([1, 2, 3]) 13 | torch.testing.assert_close(to_tensor(_array), torch.tensor(_array)) 14 | 15 | def test_to_tensor_sequence_(self): 16 | _sequence = [[1.0, 2.0, 3.0]] 17 | torch.testing.assert_close(to_tensor(_sequence), torch.tensor(_sequence)) 18 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/losses/__init__.py -------------------------------------------------------------------------------- /tests/losses/test_rl_loss.py: -------------------------------------------------------------------------------- 1 | """Test RL Loss Functions.""" 2 | 3 | from unittest import TestCase 4 | 5 | import numpy as np 6 | import torch 7 | from pl_bolts.losses.rl import double_dqn_loss, dqn_loss, per_dqn_loss 8 | from pl_bolts.models.rl.common.gym_wrappers import make_environment 9 | from pl_bolts.models.rl.common.networks import CNN 10 | from torch import Tensor 11 | 12 | 13 | class TestRLLoss(TestCase): 14 | def setUp(self) -> None: 15 | self.state = torch.rand(32, 4, 84, 84) 16 | self.next_state = torch.rand(32, 4, 84, 84) 17 | self.action = torch.ones([32]) 18 | self.reward = torch.ones([32]) 19 | self.done = torch.zeros([32]).long() 20 | 21 | self.batch = (self.state, self.action, self.reward, self.done, self.next_state) 22 | 23 | self.env = make_environment("PongNoFrameskip-v4") 24 | self.obs_shape = self.env.observation_space.shape 25 | self.n_actions = self.env.action_space.n 26 | self.net = CNN(self.obs_shape, self.n_actions) 27 | self.target_net = CNN(self.obs_shape, self.n_actions) 28 | 29 | def test_dqn_loss(self): 30 | """Test the dqn loss function.""" 31 | 32 | loss = dqn_loss(self.batch, self.net, self.target_net) 33 | assert isinstance(loss, Tensor) 34 | 35 | def test_double_dqn_loss(self): 36 | """Test the double dqn loss function.""" 37 | 38 | loss = double_dqn_loss(self.batch, self.net, self.target_net) 39 | assert isinstance(loss, Tensor) 40 | 41 | def test_per_dqn_loss(self): 42 | """Test the double dqn loss function.""" 43 | prios = torch.ones([32]) 44 | 45 | loss, batch_weights = per_dqn_loss(self.batch, prios, self.net, self.target_net) 46 | assert isinstance(loss, Tensor) 47 | assert isinstance(batch_weights, np.ndarray) 48 | -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/metrics/__init__.py -------------------------------------------------------------------------------- /tests/metrics/test_aggregation.py: -------------------------------------------------------------------------------- 1 | """Test Aggregation Metric Functions.""" 2 | 3 | import pytest 4 | import torch 5 | from pl_bolts.metrics.aggregation import accuracy, mean, precision_at_k 6 | 7 | 8 | @pytest.mark.parametrize( 9 | ("preds", "expected_mean"), 10 | [(torch.tensor([[100.0, 100.0, 200.0, 200.0]]), torch.tensor(150.0))], 11 | ) 12 | def test_mean(preds, expected_mean): 13 | x = {"test": preds} 14 | torch.testing.assert_close(mean([x], "test"), expected_mean) 15 | 16 | 17 | @pytest.mark.parametrize( 18 | ("preds", "target", "expected_accuracy"), 19 | [ 20 | (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[2]]), torch.tensor(1.0)), 21 | (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[0]]), torch.tensor(0.0)), 22 | ], 23 | ) 24 | def test_accuracy(preds, target, expected_accuracy): 25 | torch.testing.assert_close(accuracy(preds, target), expected_accuracy) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | ("output", "target", "expected_precision_at_k"), 30 | [ 31 | (torch.tensor([[100.0, 100.0, 200.0, 200.0]]), torch.tensor([[2]]), [torch.tensor([100.0])]), 32 | (torch.tensor([[100.0, 100.0, 200.0, 200.0]]), torch.tensor([[1]]), [torch.tensor([0.0])]), 33 | ], 34 | ) 35 | def test_precision_at_k(output, target, expected_precision_at_k): 36 | torch.testing.assert_close(precision_at_k(output, target), expected_precision_at_k) 37 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/gans/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/gans/__init__.py -------------------------------------------------------------------------------- /tests/models/gans/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/gans/integration/__init__.py -------------------------------------------------------------------------------- /tests/models/gans/integration/test_gans.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule 5 | from pl_bolts.datasets.sr_mnist_dataset import SRMNIST 6 | from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet 7 | from pl_bolts.utils import _IS_WINDOWS 8 | from pytorch_lightning import Trainer, seed_everything 9 | from pytorch_lightning.utilities.warnings import PossibleUserWarning 10 | from torch.utils.data.dataloader import DataLoader 11 | from torchvision import transforms as transform_lib 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "dm_cls", 16 | [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")], 17 | ) 18 | def test_gan(tmpdir, datadir, catch_warnings, dm_cls): 19 | # Validation loop for GANs is not well defined! 20 | warnings.filterwarnings( 21 | "ignore", 22 | message="You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.", 23 | category=UserWarning, 24 | ) 25 | warnings.filterwarnings( 26 | "ignore", 27 | message="The dataloader, train_dataloader, does not have many workers which may be a bottleneck", 28 | category=PossibleUserWarning, 29 | ) 30 | seed_everything(1234) 31 | dm = dm_cls(data_dir=datadir, num_workers=0) 32 | model = GAN(*dm.dims) 33 | 34 | trainer = Trainer( 35 | default_root_dir=tmpdir, 36 | max_epochs=-1, 37 | fast_dev_run=True, 38 | log_every_n_steps=1, 39 | accelerator="auto", 40 | # TODO: We need to be able to support multiple GPUs in such a simple scenario. 41 | # But, DDP is throwing ugly errors at me at the moment 42 | devices=1, 43 | ) 44 | trainer.fit(model, datamodule=dm) 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")] 49 | ) 50 | def test_dcgan(tmpdir, datadir, dm_cls): 51 | seed_everything() 52 | 53 | transforms = transform_lib.Compose([transform_lib.Resize(64), transform_lib.ToTensor()]) 54 | dm = dm_cls(data_dir=datadir, train_transforms=transforms, val_transforms=transforms, test_transforms=transforms) 55 | 56 | model = DCGAN(image_channels=dm.dims[0]) 57 | trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) 58 | trainer.fit(model, dm) 59 | 60 | 61 | @pytest.mark.skipif(_IS_WINDOWS, reason="failing...") # todo 62 | @pytest.mark.parametrize("sr_module_cls", [SRResNet, SRGAN]) 63 | @pytest.mark.parametrize("scale_factor", [2, 4]) 64 | def test_sr_modules(tmpdir, datadir, sr_module_cls, scale_factor): 65 | seed_everything(42) 66 | 67 | dl = DataLoader(SRMNIST(scale_factor, root=datadir, download=True)) 68 | model = sr_module_cls(image_channels=1, scale_factor=scale_factor) 69 | trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) 70 | trainer.fit(model, dl) 71 | -------------------------------------------------------------------------------- /tests/models/gans/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/gans/unit/__init__.py -------------------------------------------------------------------------------- /tests/models/gans/unit/test_basic_components.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pl_bolts.models.gans.basic.components import Discriminator, Generator 4 | from pytorch_lightning import seed_everything 5 | 6 | 7 | @pytest.mark.parametrize( 8 | ("latent_dim", "img_shape"), 9 | [ 10 | pytest.param(100, (3, 28, 28), id="100-multichannel"), 11 | pytest.param(100, (1, 28, 28), id="100-singlechannel"), 12 | ], 13 | ) 14 | def test_generator(catch_warnings, latent_dim, img_shape): 15 | batch_dim = 10 16 | seed_everything(1234) 17 | generator = Generator(latent_dim=latent_dim, img_shape=img_shape) 18 | noise = torch.randn(batch_dim, latent_dim) 19 | samples = generator(noise) 20 | assert samples.shape == (batch_dim, *img_shape) 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "img_shape", 25 | [ 26 | pytest.param((3, 28, 28), id="discriminator-multichannel"), 27 | pytest.param((1, 28, 28), id="discriminator-singlechannel"), 28 | ], 29 | ) 30 | def test_discriminator(catch_warnings, img_shape): 31 | batch_dim = 10 32 | seed_everything(1234) 33 | discriminator = Discriminator(img_shape=img_shape) 34 | samples = torch.randn(batch_dim, *img_shape) 35 | real_or_fake = discriminator(samples) 36 | assert real_or_fake.shape == (batch_dim, 1) 37 | assert (torch.clamp(real_or_fake.clone(), 0, 1) == real_or_fake).all() 38 | -------------------------------------------------------------------------------- /tests/models/regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/regression/__init__.py -------------------------------------------------------------------------------- /tests/models/regression/test_logistic_regression.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | 4 | import pytorch_lightning as pl 5 | from pl_bolts import datamodules 6 | from pl_bolts.models import regression 7 | 8 | 9 | def test_logistic_regression_model(datadir): 10 | pl.seed_everything(0) 11 | 12 | dm = datamodules.MNISTDataModule(datadir) 13 | 14 | model = regression.LogisticRegression( 15 | input_dim=functools.reduce(operator.mul, dm.dims, 1), num_classes=10, learning_rate=0.001 16 | ) 17 | 18 | trainer = pl.Trainer(max_epochs=3, logger=False, enable_checkpointing=False) 19 | trainer.fit(model, datamodule=dm) 20 | trainer.test(model, datamodule=dm) 21 | assert trainer.state.finished 22 | assert trainer.callback_metrics["test_acc"] > 0.9 23 | assert trainer.callback_metrics["test_loss"] < 0.3 24 | -------------------------------------------------------------------------------- /tests/models/rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/rl/__init__.py -------------------------------------------------------------------------------- /tests/models/rl/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/rl/integration/__init__.py -------------------------------------------------------------------------------- /tests/models/rl/integration/test_actor_critic_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pytest 4 | import torch.cuda 5 | from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic 6 | from pl_bolts.models.rl.sac_model import SAC 7 | from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20 8 | from pytorch_lightning import Trainer 9 | 10 | 11 | def test_a2c_cli(): 12 | """Smoke test that the A2C model runs.""" 13 | 14 | parent_parser = argparse.ArgumentParser(add_help=False) 15 | parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) 16 | args_list = ["--env", "CartPole-v0"] 17 | hparams = parent_parser.parse_args(args_list) 18 | 19 | trainer = Trainer( 20 | gpus=int(torch.cuda.is_available()), 21 | max_steps=100, 22 | max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early 23 | val_check_interval=1, # This just needs 'some' value, does not effect training right now 24 | fast_dev_run=True, 25 | ) 26 | model = AdvantageActorCritic(hparams.env) 27 | trainer.fit(model) 28 | 29 | 30 | @pytest.mark.skipif(_GYM_GREATER_EQUAL_0_20, reason="gym.error.DeprecatedEnv: Env Pendulum-v0 not found") 31 | def test_sac_cli(): 32 | """Smoke test that the SAC model runs.""" 33 | 34 | parent_parser = argparse.ArgumentParser(add_help=False) 35 | parent_parser = Trainer.add_argparse_args(parent_parser) 36 | parent_parser = SAC.add_model_specific_args(parent_parser) 37 | args_list = [ 38 | "--warm_start_size", 39 | "100", 40 | "--gpus", 41 | "0", # fixme: RuntimeError on GPU: Expected all tensors to be on the same device 42 | "--env", 43 | "Pendulum-v0", 44 | "--batch_size", 45 | "10", 46 | ] 47 | hparams = parent_parser.parse_args(args_list) 48 | 49 | trainer = Trainer( 50 | gpus=hparams.gpus, 51 | max_steps=100, 52 | max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early 53 | val_check_interval=1, # This just needs 'some' value, does not effect training right now 54 | fast_dev_run=True, 55 | ) 56 | model = SAC(**hparams.__dict__) 57 | trainer.fit(model) 58 | -------------------------------------------------------------------------------- /tests/models/rl/integration/test_policy_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from unittest import TestCase 3 | 4 | import torch 5 | from pl_bolts.models.rl.reinforce_model import Reinforce 6 | from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient 7 | from pytorch_lightning import Trainer 8 | 9 | 10 | class TestPolicyModels(TestCase): 11 | def setUp(self) -> None: 12 | parent_parser = argparse.ArgumentParser(add_help=False) 13 | parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser) 14 | args_list = ["--env", "CartPole-v0"] 15 | self.hparams = parent_parser.parse_args(args_list) 16 | 17 | self.trainer = Trainer( 18 | gpus=int(torch.cuda.is_available()), 19 | max_steps=100, 20 | max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early 21 | val_check_interval=1, # This just needs 'some' value, does not effect training right now 22 | fast_dev_run=True, 23 | ) 24 | 25 | def test_reinforce(self): 26 | """Smoke test that the reinforce model runs.""" 27 | 28 | model = Reinforce(self.hparams.env) 29 | self.trainer.fit(model) 30 | 31 | def test_policy_gradient(self): 32 | """Smoke test that the policy gradient model runs.""" 33 | model = VanillaPolicyGradient(self.hparams.env) 34 | self.trainer.fit(model) 35 | -------------------------------------------------------------------------------- /tests/models/rl/integration/test_value_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from unittest import TestCase 3 | 4 | import pytest 5 | import torch 6 | from pl_bolts.models.rl.double_dqn_model import DoubleDQN 7 | from pl_bolts.models.rl.dqn_model import DQN 8 | from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN 9 | from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN 10 | from pl_bolts.models.rl.per_dqn_model import PERDQN 11 | from pl_bolts.utils import _IS_WINDOWS 12 | from pytorch_lightning import Trainer 13 | 14 | 15 | class TestValueModels(TestCase): 16 | def setUp(self) -> None: 17 | parent_parser = argparse.ArgumentParser(add_help=False) 18 | parent_parser = Trainer.add_argparse_args(parent_parser) 19 | parent_parser = DQN.add_model_specific_args(parent_parser) 20 | args_list = [ 21 | "--warm_start_size", 22 | "100", 23 | "--gpus", 24 | str(int(torch.cuda.is_available())), 25 | "--env", 26 | "PongNoFrameskip-v4", 27 | ] 28 | self.hparams = parent_parser.parse_args(args_list) 29 | 30 | self.trainer = Trainer( 31 | gpus=self.hparams.gpus, 32 | max_steps=100, 33 | max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early 34 | val_check_interval=1, # This just needs 'some' value, does not effect training right now 35 | fast_dev_run=True, 36 | ) 37 | 38 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut or MemoryError") # todo 39 | def test_dqn(self): 40 | """Smoke test that the DQN model runs.""" 41 | model = DQN(self.hparams.env, num_envs=5) 42 | self.trainer.fit(model) 43 | 44 | def test_double_dqn(self): 45 | """Smoke test that the Double DQN model runs.""" 46 | model = DoubleDQN(self.hparams.env) 47 | self.trainer.fit(model) 48 | 49 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut or MemoryError") # todo 50 | def test_dueling_dqn(self): 51 | """Smoke test that the Dueling DQN model runs.""" 52 | model = DuelingDQN(self.hparams.env) 53 | self.trainer.fit(model) 54 | 55 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut or MemoryError") # todo 56 | def test_noisy_dqn(self): 57 | """Smoke test that the Noisy DQN model runs.""" 58 | model = NoisyDQN(self.hparams.env) 59 | self.trainer.fit(model) 60 | 61 | @pytest.mark.skip(reason="CI is killing this test") 62 | def test_per_dqn(self): 63 | """Smoke test that the PER DQN model runs.""" 64 | model = PERDQN(self.hparams.env) 65 | self.trainer.fit(model) 66 | 67 | # def test_n_step_dqn(self): 68 | # """Smoke test that the N Step DQN model runs""" 69 | # model = DQN(self.hparams.env, n_steps=self.hparams.n_steps) 70 | # result = self.trainer.fit(model) 71 | -------------------------------------------------------------------------------- /tests/models/rl/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/rl/unit/__init__.py -------------------------------------------------------------------------------- /tests/models/rl/unit/test_a2c.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic 5 | from torch import Tensor 6 | 7 | 8 | def test_a2c_loss(): 9 | """Test the reinforce loss function.""" 10 | parent_parser = argparse.ArgumentParser(add_help=False) 11 | parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) 12 | args_list = ["--env", "CartPole-v0", "--batch_size", "32"] 13 | hparams = parent_parser.parse_args(args_list) 14 | model = AdvantageActorCritic(**vars(hparams)) 15 | 16 | batch_states = torch.rand(32, 4) 17 | batch_actions = torch.rand(32).long() 18 | batch_qvals = torch.rand(32) 19 | 20 | loss = model.loss(batch_states, batch_actions, batch_qvals) 21 | 22 | assert isinstance(loss, Tensor) 23 | 24 | 25 | def test_a2c_train_batch(): 26 | """Tests that a single batch generates correctly.""" 27 | parent_parser = argparse.ArgumentParser(add_help=False) 28 | parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) 29 | args_list = ["--env", "CartPole-v0", "--batch_size", "32"] 30 | hparams = parent_parser.parse_args(args_list) 31 | model = AdvantageActorCritic(**vars(hparams)) 32 | 33 | model.n_steps = 4 34 | model.hparams.batch_size = 1 35 | xp_dataloader = model.train_dataloader() 36 | 37 | batch = next(iter(xp_dataloader)) 38 | 39 | assert len(batch) == 3 40 | assert len(batch[0]) == model.hparams.batch_size 41 | assert isinstance(batch, list) 42 | assert isinstance(batch[0], Tensor) 43 | assert isinstance(batch[1], Tensor) 44 | assert isinstance(batch[2], Tensor) 45 | -------------------------------------------------------------------------------- /tests/models/rl/unit/test_agents.py: -------------------------------------------------------------------------------- 1 | """Tests that the agent module works correctly.""" 2 | from unittest import TestCase 3 | from unittest.mock import Mock 4 | 5 | import gym 6 | import numpy as np 7 | import torch 8 | from pl_bolts.models.rl.common.agents import ActorCriticAgent, Agent, PolicyAgent, ValueAgent 9 | from torch import Tensor 10 | 11 | 12 | class TestAgents(TestCase): 13 | def setUp(self) -> None: 14 | self.env = gym.make("CartPole-v0") 15 | self.state = self.env.reset() 16 | self.net = Mock() 17 | 18 | def test_base_agent(self): 19 | agent = Agent(self.net) 20 | action = agent(self.state, "cuda:0") 21 | assert isinstance(action, list) 22 | 23 | 24 | class TestValueAgent(TestCase): 25 | def setUp(self) -> None: 26 | self.env = gym.make("CartPole-v0") 27 | self.net = Mock(return_value=Tensor([[0.0, 100.0]])) 28 | self.state = [self.env.reset()] 29 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | self.value_agent = ValueAgent(self.net, self.env.action_space.n) 31 | 32 | def test_value_agent(self): 33 | action = self.value_agent(self.state, self.device) 34 | assert isinstance(action, list) 35 | assert isinstance(action[0], int) 36 | 37 | def test_value_agent_get_action(self): 38 | action = self.value_agent.get_action(self.state, self.device) 39 | assert isinstance(action, np.ndarray) 40 | assert action[0] == 1 41 | 42 | def test_value_agent_random(self): 43 | action = self.value_agent.get_random_action(self.state) 44 | assert isinstance(action[0], int) 45 | 46 | 47 | class TestPolicyAgent(TestCase): 48 | def setUp(self) -> None: 49 | self.env = gym.make("CartPole-v0") 50 | self.net = Mock(return_value=Tensor([[0.0, 100.0]])) 51 | self.states = [self.env.reset()] 52 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 53 | 54 | def test_policy_agent(self): 55 | policy_agent = PolicyAgent(self.net) 56 | action = policy_agent(self.states, self.device) 57 | assert isinstance(action, list) 58 | assert action[0] == 1 59 | 60 | 61 | def test_a2c_agent(): 62 | env = gym.make("CartPole-v0") 63 | logprobs = torch.nn.functional.log_softmax(Tensor([[0.0, 100.0]])) 64 | net = Mock(return_value=(logprobs, Tensor([[1]]))) 65 | states = [env.reset()] 66 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 67 | a2c_agent = ActorCriticAgent(net) 68 | action = a2c_agent(states, device) 69 | assert isinstance(action, list) 70 | assert action[0] == 1 71 | -------------------------------------------------------------------------------- /tests/models/rl/unit/test_ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pl_bolts.models.rl.ppo_model import PPO 4 | from pytorch_lightning import Trainer 5 | from torch import Tensor 6 | 7 | 8 | def test_discount_rewards(): 9 | """Test calculation of discounted rewards.""" 10 | model = PPO(env="CartPole-v0", batch_size=16, gamma=0.99) 11 | 12 | rewards = np.ones(4) 13 | gt_qvals = [3.9403989999999998, 2.9701, 1.99, 1.0] 14 | 15 | qvals = model.discount_rewards(rewards, discount=0.99) 16 | 17 | assert gt_qvals == qvals 18 | 19 | 20 | def test_critic_loss(): 21 | """Test the critic loss function.""" 22 | 23 | model = PPO(env="CartPole-v0", batch_size=16, gamma=0.99) 24 | obs_dim = model.env.observation_space.shape[0] 25 | 26 | batch_states = torch.rand(8, obs_dim) 27 | batch_qvals = torch.rand(8) 28 | 29 | loss = model.critic_loss(batch_states, batch_qvals) 30 | 31 | assert isinstance(loss, Tensor) 32 | 33 | 34 | def test_actor_loss_categorical(): 35 | """Test the actor loss function on categorical action-space environment.""" 36 | 37 | model = PPO(env="CartPole-v0", batch_size=16, gamma=0.99) 38 | obs_dim = model.env.observation_space.shape[0] 39 | 40 | batch_states = torch.rand(8, obs_dim) 41 | batch_actions = torch.rand(8).long() 42 | batch_logp_old = torch.rand(8) 43 | batch_adv = torch.rand(8) 44 | 45 | loss = model.actor_loss(batch_states, batch_actions, batch_logp_old, batch_adv) 46 | 47 | assert isinstance(loss, Tensor) 48 | 49 | 50 | def test_actor_loss_continuous(): 51 | """Test the actor loss function on continuous action-space environment.""" 52 | 53 | model = PPO(env="MountainCarContinuous-v0", batch_size=16, gamma=0.99) 54 | obs_dim = model.env.observation_space.shape[0] 55 | action_dim = model.env.action_space.shape[0] 56 | 57 | batch_states = torch.rand(8, obs_dim) 58 | batch_actions = torch.rand(8, action_dim) 59 | batch_logp_old = torch.rand(8) 60 | batch_adv = torch.rand(8) 61 | 62 | loss = model.actor_loss(batch_states, batch_actions, batch_logp_old, batch_adv) 63 | 64 | assert isinstance(loss, Tensor) 65 | 66 | 67 | def test_generate_trajectory_samples(): 68 | torch.manual_seed(123) 69 | model = PPO("CartPole-v0", batch_size=16) 70 | 71 | obs_dim = model.env.observation_space.shape[0] 72 | 73 | sample_gen = model.generate_trajectory_samples() 74 | state, action, logp_old, qval, adv = next(sample_gen) 75 | 76 | assert state.shape[0] == obs_dim 77 | assert isinstance(action, torch.Tensor) 78 | assert logp_old 79 | assert qval 80 | assert adv 81 | 82 | 83 | def test_training_categorical(): 84 | model = PPO("CartPole-v0", batch_size=16) 85 | trainer = Trainer(max_epochs=1) 86 | trainer.fit(model) 87 | 88 | 89 | def test_training_continous(): 90 | model = PPO("MountainCarContinuous-v0", batch_size=16) 91 | trainer = Trainer(max_epochs=1) 92 | trainer.fit(model) 93 | -------------------------------------------------------------------------------- /tests/models/rl/unit/test_reinforce.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from unittest import TestCase 3 | 4 | import gym 5 | import numpy as np 6 | import torch 7 | from pl_bolts.datamodules.experience_source import DiscountedExperienceSource 8 | from pl_bolts.models.rl.common.agents import Agent 9 | from pl_bolts.models.rl.common.gym_wrappers import ToTensor 10 | from pl_bolts.models.rl.common.networks import MLP 11 | from pl_bolts.models.rl.reinforce_model import Reinforce 12 | from torch import Tensor 13 | 14 | 15 | class TestReinforce(TestCase): 16 | def setUp(self) -> None: 17 | self.env = ToTensor(gym.make("CartPole-v0")) 18 | self.obs_shape = self.env.observation_space.shape 19 | self.n_actions = self.env.action_space.n 20 | self.net = MLP(self.obs_shape, self.n_actions) 21 | self.agent = Agent(self.net) 22 | self.exp_source = DiscountedExperienceSource(self.env, self.agent) 23 | 24 | parent_parser = argparse.ArgumentParser(add_help=False) 25 | parent_parser = Reinforce.add_model_specific_args(parent_parser) 26 | args_list = ["--env", "CartPole-v0", "--batch_size", "32", "--gamma", "0.99"] 27 | self.hparams = parent_parser.parse_args(args_list) 28 | self.model = Reinforce(**vars(self.hparams)) 29 | 30 | self.rl_dataloader = self.model.train_dataloader() 31 | 32 | def test_loss(self): 33 | """Test the reinforce loss function.""" 34 | 35 | batch_states = torch.rand(16, 4) 36 | batch_actions = torch.rand(16).long() 37 | batch_qvals = torch.rand(16) 38 | 39 | loss = self.model.loss(batch_states, batch_actions, batch_qvals) 40 | 41 | assert isinstance(loss, Tensor) 42 | 43 | def test_get_qvals(self): 44 | """Test that given an batch of episodes that it will return a list of qvals for each episode.""" 45 | 46 | batch_qvals = [] 47 | rewards = np.ones(32) 48 | out = self.model.calc_qvals(rewards) 49 | batch_qvals.append(out) 50 | 51 | assert isinstance(batch_qvals[0][0], float) 52 | assert batch_qvals[0][0] == batch_qvals[0][1] * self.hparams.gamma + 1.0 53 | 54 | def test_calc_q_vals(self): 55 | rewards = np.ones(4) 56 | gt_qvals = [3.9403989999999998, 2.9701, 1.99, 1.0] 57 | 58 | qvals = self.model.calc_qvals(rewards) 59 | 60 | assert gt_qvals == qvals 61 | -------------------------------------------------------------------------------- /tests/models/rl/unit/test_sac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pytest 4 | import torch 5 | from pl_bolts.models.rl.sac_model import SAC 6 | from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20 7 | from torch import Tensor 8 | 9 | 10 | @pytest.mark.skipif(_GYM_GREATER_EQUAL_0_20, reason="gym.error.DeprecatedEnv: Env Pendulum-v0 not found") 11 | def test_sac_loss(): 12 | """Test the reinforce loss function.""" 13 | parent_parser = argparse.ArgumentParser(add_help=False) 14 | parent_parser = SAC.add_model_specific_args(parent_parser) 15 | args_list = ["--env", "Pendulum-v0", "--batch_size", "32"] 16 | hparams = parent_parser.parse_args(args_list) 17 | model = SAC(**vars(hparams)) 18 | 19 | batch_states = torch.rand(32, 3) 20 | batch_actions = torch.rand(32, 1) 21 | batch_rewards = torch.rand(32) 22 | batch_dones = torch.ones(32) 23 | batch_next_states = torch.rand(32, 3) 24 | batch = (batch_states, batch_actions, batch_rewards, batch_dones, batch_next_states) 25 | 26 | policy_loss, q1_loss, q2_loss = model.loss(batch) 27 | 28 | assert isinstance(policy_loss, Tensor) 29 | assert isinstance(q1_loss, Tensor) 30 | assert isinstance(q2_loss, Tensor) 31 | 32 | 33 | @pytest.mark.skipif(_GYM_GREATER_EQUAL_0_20, reason="gym.error.DeprecatedEnv: Env Pendulum-v0 not found") 34 | def test_sac_train_batch(): 35 | """Tests that a single batch generates correctly.""" 36 | parent_parser = argparse.ArgumentParser(add_help=False) 37 | parent_parser = SAC.add_model_specific_args(parent_parser) 38 | args_list = ["--env", "Pendulum-v0", "--batch_size", "32"] 39 | hparams = parent_parser.parse_args(args_list) 40 | model = SAC(**vars(hparams)) 41 | 42 | xp_dataloader = model.train_dataloader() 43 | 44 | batch = next(iter(xp_dataloader)) 45 | 46 | assert len(batch) == 5 47 | assert len(batch[0]) == model.hparams.batch_size 48 | assert isinstance(batch, list) 49 | assert all(isinstance(batch[i], Tensor) for i in range(5)) 50 | -------------------------------------------------------------------------------- /tests/models/rl/unit/test_vpg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from unittest import TestCase 3 | 4 | import gym 5 | import pytest 6 | import torch 7 | from pl_bolts.models.rl.common.agents import Agent 8 | from pl_bolts.models.rl.common.gym_wrappers import ToTensor 9 | from pl_bolts.models.rl.common.networks import MLP 10 | from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient 11 | from pl_bolts.utils import _IS_WINDOWS 12 | from torch import Tensor 13 | 14 | 15 | class TestPolicyGradient(TestCase): 16 | def setUp(self) -> None: 17 | self.env = ToTensor(gym.make("CartPole-v0")) 18 | self.obs_shape = self.env.observation_space.shape 19 | self.n_actions = self.env.action_space.n 20 | self.net = MLP(self.obs_shape, self.n_actions) 21 | self.agent = Agent(self.net) 22 | 23 | parent_parser = argparse.ArgumentParser(add_help=False) 24 | parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser) 25 | args_list = ["--env", "CartPole-v0", "--batch_size", "32"] 26 | self.hparams = parent_parser.parse_args(args_list) 27 | self.model = VanillaPolicyGradient(**vars(self.hparams)) 28 | 29 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut or MemoryError") # todo 30 | def test_loss(self): 31 | """Test the reinforce loss function.""" 32 | 33 | batch_states = torch.rand(32, 4) 34 | batch_actions = torch.rand(32).long() 35 | batch_qvals = torch.rand(32) 36 | 37 | loss = self.model.loss(batch_states, batch_actions, batch_qvals) 38 | 39 | assert isinstance(loss, Tensor) 40 | 41 | def test_train_batch(self): 42 | """Tests that a single batch generates correctly.""" 43 | 44 | self.model.n_steps = 4 45 | self.model.batch_size = 1 46 | xp_dataloader = self.model.train_dataloader() 47 | 48 | batch = next(iter(xp_dataloader)) 49 | assert len(batch) == 3 50 | assert len(batch[0]) == self.model.batch_size 51 | assert isinstance(batch, list) 52 | assert isinstance(batch[0], Tensor) 53 | assert isinstance(batch[1], list) 54 | assert isinstance(batch[1][0], Tensor) 55 | assert isinstance(batch[2], Tensor) 56 | -------------------------------------------------------------------------------- /tests/models/rl/unit/test_wrappers.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import gym 4 | from pl_bolts.models.rl.common.gym_wrappers import ToTensor 5 | from torch import Tensor 6 | 7 | 8 | class TestToTensor(TestCase): 9 | def setUp(self) -> None: 10 | self.env = ToTensor(gym.make("CartPole-v0")) 11 | 12 | def test_wrapper(self): 13 | state = self.env.reset() 14 | assert isinstance(state, Tensor) 15 | 16 | new_state, _, _, _ = self.env.step(1) 17 | assert isinstance(new_state, Tensor) 18 | -------------------------------------------------------------------------------- /tests/models/self_supervised/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/self_supervised/__init__.py -------------------------------------------------------------------------------- /tests/models/self_supervised/test_resnets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pl_bolts.models.self_supervised.amdim import AMDIMEncoder 4 | from pl_bolts.models.self_supervised.cpc import cpc_resnet50 5 | from pl_bolts.models.self_supervised.resnets import ( 6 | resnet18, 7 | resnet34, 8 | resnet50, 9 | resnet101, 10 | resnet152, 11 | resnext50_32x4d, 12 | resnext101_32x8d, 13 | wide_resnet50_2, 14 | wide_resnet101_2, 15 | ) 16 | from pl_bolts.utils import _IS_WINDOWS 17 | 18 | 19 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange MemoryError") # todo 20 | @torch.no_grad() 21 | def test_cpc_resnet(): 22 | x = torch.rand(3, 3, 64, 64) 23 | model = cpc_resnet50(x) 24 | model(x) 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "model_class", 29 | [ 30 | resnet18, 31 | resnet34, 32 | resnet50, 33 | resnet101, 34 | resnet152, 35 | resnext50_32x4d, 36 | pytest.param(resnext101_32x8d, marks=pytest.mark.skipif(_IS_WINDOWS, reason="strange MemoryError")), # todo 37 | wide_resnet50_2, 38 | pytest.param(wide_resnet101_2, marks=pytest.mark.skipif(_IS_WINDOWS, reason="strange MemoryError")), # todo 39 | ], 40 | ) 41 | @torch.no_grad() 42 | def test_torchvision_resnets(model_class): 43 | x = torch.rand(3, 3, 64, 64) 44 | model = model_class(pretrained=False) 45 | model(x) 46 | 47 | 48 | @pytest.mark.parametrize( 49 | "size", 50 | [ 51 | 32, 52 | pytest.param(64, marks=pytest.mark.skipif(_IS_WINDOWS, reason="failing...")), 53 | pytest.param(128, marks=pytest.mark.skipif(_IS_WINDOWS, reason="failing...")), 54 | ], 55 | ) 56 | @torch.no_grad() 57 | def test_amdim_encoder(size): 58 | dummy_batch = torch.zeros((2, 3, size, size)) 59 | model = AMDIMEncoder(dummy_batch, encoder_size=size) 60 | model.init_weights() 61 | model(dummy_batch) 62 | -------------------------------------------------------------------------------- /tests/models/self_supervised/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/self_supervised/unit/__init__.py -------------------------------------------------------------------------------- /tests/models/test_classic_ml.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset 4 | from pl_bolts.models.regression import LinearRegression 5 | from pytorch_lightning import Trainer, seed_everything 6 | from torch.utils.data import DataLoader 7 | 8 | 9 | @pytest.mark.flaky(reruns=3) 10 | def test_linear_regression_model(tmpdir): 11 | seed_everything() 12 | 13 | # -------------------- 14 | # numpy data 15 | # -------------------- 16 | x = np.array([[1.0, 1], [1, 2], [2, 2], [2, 3], [3, 3], [3, 4], [4, 4], [4, 5]]) 17 | y = np.dot(x, np.array([1.0, 2])) + 3 18 | y = y[:, np.newaxis] 19 | loader = DataLoader(SklearnDataset(x, y), batch_size=2) 20 | 21 | model = LinearRegression(input_dim=2, learning_rate=0.6) 22 | trainer = Trainer( 23 | max_epochs=400, 24 | default_root_dir=tmpdir, 25 | logger=False, 26 | enable_checkpointing=False, 27 | ) 28 | trainer.fit( 29 | model, 30 | loader, 31 | loader, 32 | ) 33 | 34 | coeffs = model.linear.weight.detach().numpy().flatten() 35 | np.testing.assert_allclose(coeffs, [1, 2], rtol=1e-3) 36 | trainer.test(model, loader) 37 | -------------------------------------------------------------------------------- /tests/models/test_mnist_templates.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from pl_bolts.datamodules import MNISTDataModule 4 | from pl_bolts.models import LitMNIST 5 | from pytorch_lightning import Trainer, seed_everything 6 | from pytorch_lightning.utilities.warnings import PossibleUserWarning 7 | 8 | 9 | def test_mnist(tmpdir, datadir, catch_warnings): 10 | warnings.filterwarnings( 11 | "ignore", 12 | message=".+does not have many workers which may be a bottleneck.+", 13 | category=PossibleUserWarning, 14 | ) 15 | 16 | seed_everything(1234) 17 | 18 | datamodule = MNISTDataModule(data_dir=datadir, num_workers=0) 19 | model = LitMNIST() 20 | trainer = Trainer( 21 | limit_train_batches=0.02, 22 | limit_val_batches=0.02, 23 | max_epochs=1, 24 | limit_test_batches=0.02, 25 | default_root_dir=tmpdir, 26 | log_every_n_steps=5, 27 | accelerator="auto", 28 | ) 29 | trainer.fit(model, datamodule=datamodule) 30 | loss = trainer.callback_metrics["train_loss"] 31 | assert loss <= 2.3, "mnist failed" 32 | -------------------------------------------------------------------------------- /tests/models/yolo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/yolo/__init__.py -------------------------------------------------------------------------------- /tests/models/yolo/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/yolo/unit/__init__.py -------------------------------------------------------------------------------- /tests/models/yolo/unit/test_target_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pl_bolts.models.detection.yolo.target_matching import _sim_ota_match 3 | 4 | 5 | def test_sim_ota_match(catch_warnings): 6 | # For each of the two targets, k will be the sum of the IoUs. 2 and 1 predictions will be selected for the first and 7 | # the second target respectively. 8 | ious = torch.tensor([[0.1, 0.2], [0.1, 0.3], [0.9, 0.4], [0.9, 0.1]]) 9 | # Costs will determine that the first and the last prediction will be selected for the first target, and the first 10 | # prediction will be selected for the second target. The first prediction was selected for two targets, but it will 11 | # be matched to the best target only (the second one). 12 | costs = torch.tensor([[0.3, 0.1], [0.5, 0.2], [0.4, 0.5], [0.3, 0.3]]) 13 | matched_preds, matched_targets = _sim_ota_match(costs, ious) 14 | 15 | # The first and the last prediction were matched. 16 | assert len(matched_preds) == 4 17 | assert matched_preds[0] 18 | assert not matched_preds[1] 19 | assert not matched_preds[2] 20 | assert matched_preds[3] 21 | 22 | # The first prediction was matched to the target 1 and the last prediction was matched to target 0. 23 | assert len(matched_targets) == 2 24 | assert matched_targets[0] == 1 25 | assert matched_targets[1] == 0 26 | -------------------------------------------------------------------------------- /tests/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/optimizers/__init__.py -------------------------------------------------------------------------------- /tests/test_utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tests import PROJECT_ROOT 4 | 5 | 6 | def test_paths(): 7 | assert os.path.isdir(PROJECT_ROOT) 8 | -------------------------------------------------------------------------------- /tests/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/transforms/__init__.py -------------------------------------------------------------------------------- /tests/transforms/test_normalizations.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pl_bolts.transforms.dataset_normalizations import ( 4 | cifar10_normalization, 5 | emnist_normalization, 6 | imagenet_normalization, 7 | stl10_normalization, 8 | ) 9 | from pytorch_lightning import seed_everything 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "normalization", 14 | [cifar10_normalization, imagenet_normalization, stl10_normalization], 15 | ) 16 | def test_normalizations(normalization, catch_warnings): 17 | """Test normalizations for CIFAR10, ImageNet, STL10.""" 18 | seed_everything(1234) 19 | x = torch.rand(3, 32, 32) 20 | assert normalization()(x).shape == (3, 32, 32) 21 | assert x.min() >= 0.0 22 | assert x.max() <= 1.0 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "split", 27 | ["balanced", "byclass", "bymerge", "digits", "letters", "mnist"], 28 | ) 29 | def test_emnist_normalizations(split, catch_warnings): 30 | """Test normalizations for each EMNIST dataset split.""" 31 | x = torch.rand(1, 28, 28) 32 | assert emnist_normalization(split)(x).shape == (1, 28, 28) 33 | assert x.min() >= 0.0 34 | assert x.max() <= 1.0 35 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_dependency.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pl_bolts.utils._dependency import requires 3 | 4 | 5 | @requires("torch") 6 | def using_torch(): 7 | return True 8 | 9 | 10 | @requires("torch.anything.wrong") 11 | def using_torch_wrong_path(): 12 | return True 13 | 14 | 15 | @requires("torch>99.0") 16 | def using_torch_bad_version(): 17 | return True 18 | 19 | 20 | def test_requires_pass(): 21 | assert using_torch() is True 22 | 23 | 24 | def test_requires_fail(): 25 | with pytest.raises(ModuleNotFoundError, match="Required dependencies not available"): 26 | assert using_torch_wrong_path() 27 | with pytest.raises(ModuleNotFoundError, match="Required dependencies not available"): 28 | assert using_torch_bad_version() 29 | -------------------------------------------------------------------------------- /tests/utils/test_self_supervised.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/utils/test_self_supervised.py -------------------------------------------------------------------------------- /tests/utils/test_semi_supervised.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from warnings import warn 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | from pl_bolts.utils.semi_supervised import balance_classes, generate_half_labeled_batches 8 | 9 | try: 10 | from sklearn.utils import shuffle as sk_shuffle # noqa: F401 11 | 12 | _SKLEARN_AVAILABLE = True 13 | except ImportError: 14 | warn("Failing to import `sklearn` correctly") 15 | _SKLEARN_AVAILABLE = False 16 | 17 | 18 | @pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="failing to import SKLearn") 19 | def test_balance_classes(): 20 | x = torch.rand(100, 3, 32, 32) 21 | c1 = torch.zeros(20, 1) 22 | c2 = torch.zeros(20, 1) + 1 23 | c3 = torch.zeros(60, 1) + 2 24 | 25 | y = torch.cat([c1, c2, c3], dim=0).int().cpu().numpy().flatten().tolist() 26 | balance_classes(x, y, batch_size=10) 27 | 28 | 29 | def test_generate_half_labeled_batches(): 30 | smaller_set_x = np.random.rand(100, 3, 32, 32) 31 | smaller_set_y = np.random.randint(0, 3, (100, 1)) 32 | larger_set_x = np.random.rand(100, 3, 32, 32) 33 | larger_set_y = np.zeros_like(smaller_set_y) + -1 34 | 35 | (mixed_x, mixed_y) = generate_half_labeled_batches(smaller_set_x, smaller_set_y, larger_set_x, larger_set_y, 10) 36 | 37 | # only half the batch should be unlabeled 38 | for i in range(0, len(mixed_y), 10): 39 | batch = mixed_y[i : i + 10] 40 | counts = Counter(batch.flatten().tolist()) 41 | assert counts[-1] == 5 42 | --------------------------------------------------------------------------------