├── .azure
└── gpu-tests.yml
├── .codecov.yml
├── .github
├── CODEOWNERS
├── CONTRIBUTING.md
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── documentation.md
│ ├── feature_request.md
│ └── how-to-question.md
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
├── labeler.yml
├── mergify.yml
├── stale.yml
└── workflows
│ ├── ci-checks.yml
│ ├── ci-docs.yml
│ ├── ci-tests.yml
│ ├── docs-check.yml
│ ├── labeler.yml
│ └── pypi-release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── CHANGELOG.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── docs
├── Makefile
├── make.bat
├── requirements.txt
└── source
│ ├── _images
│ ├── gans
│ │ ├── basic_gan_dloss.jpg
│ │ ├── basic_gan_gloss.jpg
│ │ ├── basic_gan_interpolate.jpg
│ │ ├── dcgan_lsun_dloss.png
│ │ ├── dcgan_lsun_gloss.png
│ │ ├── dcgan_lsun_outputs.png
│ │ ├── dcgan_mnist_dloss.png
│ │ ├── dcgan_mnist_gloss.png
│ │ ├── dcgan_mnist_outputs.png
│ │ ├── srgan-celeba-scale_factor=2.png
│ │ ├── srgan-celeba-scale_factor=4.png
│ │ ├── srgan-mnist-scale_factor=2.png
│ │ ├── srgan-mnist-scale_factor=4.png
│ │ ├── srgan-stl10-scale_factor=2.png
│ │ └── srgan-stl10-scale_factor=4.png
│ ├── logos
│ │ ├── bolts_logo.png
│ │ ├── lightning_icon.svg
│ │ ├── lightning_logo-large.svg
│ │ ├── lightning_logo-name.svg
│ │ ├── lightning_logo.png
│ │ └── lightning_logo.svg
│ ├── rl_benchmark
│ │ ├── cartpole_a2c_results.jpg
│ │ ├── dqn_ddqn_comparison.jpg
│ │ ├── pendulum_sac_results.jpg
│ │ ├── pong_double_dqn_baseline_results.jpg
│ │ ├── pong_dqn_baseline_results.jpg
│ │ ├── pong_dueling_dqn_comparison.jpg
│ │ ├── pong_dueling_dqn_results.jpg
│ │ ├── pong_noisy_dqn_comparison.jpg
│ │ ├── pong_noisy_dqn_results.jpg
│ │ ├── pong_nstep_dqn_1.jpg
│ │ ├── pong_nstep_dqn_2.jpg
│ │ ├── pong_per_dqn_baseline_results.jpg
│ │ ├── pong_per_dqn_baseline_v1_results.jpg
│ │ ├── pong_per_dqn_baseline_v1_results_comp.jpg
│ │ └── pong_per_dqn_dqn_comparison.jpg
│ └── vision
│ │ └── confused_logit.png
│ ├── _static
│ └── copybutton.js
│ ├── _templates
│ ├── layout.html
│ └── theme_variables.jinja
│ ├── callbacks
│ ├── monitor.rst
│ ├── self_supervised.rst
│ ├── sparseml.rst
│ ├── torch_ort.rst
│ ├── variational.rst
│ └── vision.rst
│ ├── conf.py
│ ├── dataloaders
│ └── async.rst
│ ├── datamodules
│ ├── sklearn.rst
│ └── vision.rst
│ ├── datasets
│ └── debug.rst
│ ├── governance.rst
│ ├── index.rst
│ ├── installation.rst
│ ├── introduction_guide.rst
│ ├── losses.rst
│ ├── models
│ ├── autoencoders.rst
│ ├── classic_ml.rst
│ ├── convolutional.rst
│ ├── gans.rst
│ ├── models_howto.rst
│ ├── object_detection.rst
│ ├── reinforce_learn.rst
│ └── self_supervised.rst
│ ├── schedulers
│ └── warmup_cosine_annealing.rst
│ ├── stability.rst
│ ├── transforms
│ ├── self_supervised.rst
│ └── semi_supervised.rst
│ └── vision_tasks.rst
├── pyproject.toml
├── requirements.txt
├── requirements
├── base.txt
├── devel.txt
├── loggers.txt
├── models.txt
├── test.txt
└── typing.txt
├── setup.py
├── src
└── pl_bolts
│ ├── __about__.py
│ ├── __init__.py
│ ├── callbacks
│ ├── __init__.py
│ ├── byol_updates.py
│ ├── data_monitor.py
│ ├── knn_online.py
│ ├── printing.py
│ ├── sparseml.py
│ ├── ssl_online.py
│ ├── torch_ort.py
│ ├── variational.py
│ ├── verification
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── batch_gradient.py
│ └── vision
│ │ ├── __init__.py
│ │ ├── confused_logit.py
│ │ ├── image_generation.py
│ │ └── sr_image_logger.py
│ ├── datamodules
│ ├── __init__.py
│ ├── async_dataloader.py
│ ├── binary_emnist_datamodule.py
│ ├── binary_mnist_datamodule.py
│ ├── cifar10_datamodule.py
│ ├── cityscapes_datamodule.py
│ ├── emnist_datamodule.py
│ ├── experience_source.py
│ ├── fashion_mnist_datamodule.py
│ ├── imagenet_datamodule.py
│ ├── kitti_datamodule.py
│ ├── mnist_datamodule.py
│ ├── sklearn_datamodule.py
│ ├── sr_datamodule.py
│ ├── ssl_imagenet_datamodule.py
│ ├── stl10_datamodule.py
│ ├── vision_datamodule.py
│ └── vocdetection_datamodule.py
│ ├── datasets
│ ├── __init__.py
│ ├── array_dataset.py
│ ├── base_dataset.py
│ ├── cifar10_dataset.py
│ ├── concat_dataset.py
│ ├── dummy_dataset.py
│ ├── emnist_dataset.py
│ ├── imagenet_dataset.py
│ ├── kitti_dataset.py
│ ├── mnist_dataset.py
│ ├── sr_celeba_dataset.py
│ ├── sr_dataset_mixin.py
│ ├── sr_mnist_dataset.py
│ ├── sr_stl10_dataset.py
│ ├── ssl_amdim_datasets.py
│ └── utils.py
│ ├── losses
│ ├── __init__.py
│ ├── rl.py
│ └── self_supervised_learning.py
│ ├── metrics
│ ├── __init__.py
│ └── aggregation.py
│ ├── models
│ ├── __init__.py
│ ├── autoencoders
│ │ ├── __init__.py
│ │ ├── basic_ae
│ │ │ ├── __init__.py
│ │ │ └── basic_ae_module.py
│ │ ├── basic_vae
│ │ │ ├── __init__.py
│ │ │ └── basic_vae_module.py
│ │ └── components.py
│ ├── detection
│ │ ├── __init__.py
│ │ ├── components
│ │ │ ├── __init__.py
│ │ │ ├── _supported_models.py
│ │ │ └── torchvision_backbones.py
│ │ ├── faster_rcnn
│ │ │ ├── __init__.py
│ │ │ ├── backbones.py
│ │ │ └── faster_rcnn_module.py
│ │ ├── retinanet
│ │ │ ├── __init__.py
│ │ │ ├── backbones.py
│ │ │ └── retinanet_module.py
│ │ └── yolo
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── darknet_network.py
│ │ │ ├── layers.py
│ │ │ ├── loss.py
│ │ │ ├── target_matching.py
│ │ │ ├── torch_networks.py
│ │ │ ├── types.py
│ │ │ ├── utils.py
│ │ │ └── yolo_module.py
│ ├── gans
│ │ ├── __init__.py
│ │ ├── basic
│ │ │ ├── __init__.py
│ │ │ ├── basic_gan_module.py
│ │ │ └── components.py
│ │ ├── dcgan
│ │ │ ├── __init__.py
│ │ │ ├── components.py
│ │ │ └── dcgan_module.py
│ │ ├── pix2pix
│ │ │ ├── __init__.py
│ │ │ ├── components.py
│ │ │ └── pix2pix_module.py
│ │ └── srgan
│ │ │ ├── __init__.py
│ │ │ ├── components.py
│ │ │ ├── srgan_module.py
│ │ │ └── srresnet_module.py
│ ├── mnist_module.py
│ ├── regression
│ │ ├── __init__.py
│ │ ├── linear_regression.py
│ │ └── logistic_regression.py
│ ├── rl
│ │ ├── __init__.py
│ │ ├── advantage_actor_critic_model.py
│ │ ├── common
│ │ │ ├── __init__.py
│ │ │ ├── agents.py
│ │ │ ├── cli.py
│ │ │ ├── distributions.py
│ │ │ ├── gym_wrappers.py
│ │ │ ├── memory.py
│ │ │ └── networks.py
│ │ ├── double_dqn_model.py
│ │ ├── dqn_model.py
│ │ ├── dueling_dqn_model.py
│ │ ├── noisy_dqn_model.py
│ │ ├── per_dqn_model.py
│ │ ├── ppo_model.py
│ │ ├── reinforce_model.py
│ │ ├── sac_model.py
│ │ └── vanilla_policy_gradient_model.py
│ ├── self_supervised
│ │ ├── __init__.py
│ │ ├── amdim
│ │ │ ├── __init__.py
│ │ │ ├── amdim_module.py
│ │ │ ├── datasets.py
│ │ │ └── networks.py
│ │ ├── byol
│ │ │ ├── __init__.py
│ │ │ ├── byol_module.py
│ │ │ └── models.py
│ │ ├── cpc
│ │ │ ├── __init__.py
│ │ │ ├── cpc_finetuner.py
│ │ │ ├── cpc_module.py
│ │ │ └── networks.py
│ │ ├── evaluator.py
│ │ ├── moco
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ ├── callbacks.py
│ │ │ ├── moco_module.py
│ │ │ └── utils.py
│ │ ├── resnets.py
│ │ ├── simclr
│ │ │ ├── __init__.py
│ │ │ ├── simclr_finetuner.py
│ │ │ └── simclr_module.py
│ │ ├── simsiam
│ │ │ ├── __init__.py
│ │ │ └── simsiam_module.py
│ │ ├── ssl_finetuner.py
│ │ └── swav
│ │ │ ├── __init__.py
│ │ │ ├── loss.py
│ │ │ ├── swav_finetuner.py
│ │ │ ├── swav_module.py
│ │ │ └── swav_resnet.py
│ └── vision
│ │ ├── __init__.py
│ │ ├── image_gpt
│ │ ├── __init__.py
│ │ ├── gpt2.py
│ │ └── igpt_module.py
│ │ ├── pixel_cnn.py
│ │ ├── segmentation.py
│ │ └── unet.py
│ ├── optimizers
│ ├── __init__.py
│ ├── lars.py
│ └── lr_scheduler.py
│ ├── transforms
│ ├── __init__.py
│ ├── dataset_normalizations.py
│ └── self_supervised
│ │ ├── __init__.py
│ │ ├── amdim_transforms.py
│ │ ├── cpc_transforms.py
│ │ ├── moco_transforms.py
│ │ ├── simclr_transforms.py
│ │ ├── ssl_transforms.py
│ │ └── swav_transforms.py
│ └── utils
│ ├── __init__.py
│ ├── _dependency.py
│ ├── arguments.py
│ ├── pretrained_weights.py
│ ├── self_supervised.py
│ ├── semi_supervised.py
│ ├── shaping.py
│ ├── stability.py
│ ├── types.py
│ └── warnings.py
└── tests
├── README.md
├── __init__.py
├── _data_configs
├── yolo.cfg
└── yolo_giou.cfg
├── callbacks
├── __init__.py
├── test_data_monitor.py
├── test_info_callbacks.py
├── test_ort.py
├── test_param_update_callbacks.py
├── test_sparseml.py
├── test_variational_callbacks.py
└── verification
│ ├── __init__.py
│ ├── test_base.py
│ └── test_batch_gradient.py
├── conftest.py
├── datamodules
├── __init__.py
├── test_dataloader.py
├── test_datamodules.py
├── test_experience_sources.py
└── test_sklearn_dataloaders.py
├── datasets
├── __init__.py
├── test_array_dataset.py
├── test_base_dataset.py
├── test_datasets.py
└── test_utils.py
├── helpers
├── __init__.py
└── boring_model.py
├── losses
├── __init__.py
└── test_rl_loss.py
├── metrics
├── __init__.py
└── test_aggregation.py
├── models
├── __init__.py
├── gans
│ ├── __init__.py
│ ├── integration
│ │ ├── __init__.py
│ │ └── test_gans.py
│ └── unit
│ │ ├── __init__.py
│ │ └── test_basic_components.py
├── regression
│ ├── __init__.py
│ └── test_logistic_regression.py
├── rl
│ ├── __init__.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── test_actor_critic_models.py
│ │ ├── test_policy_models.py
│ │ └── test_value_models.py
│ └── unit
│ │ ├── __init__.py
│ │ ├── test_a2c.py
│ │ ├── test_agents.py
│ │ ├── test_memory.py
│ │ ├── test_ppo.py
│ │ ├── test_reinforce.py
│ │ ├── test_sac.py
│ │ ├── test_vpg.py
│ │ └── test_wrappers.py
├── self_supervised
│ ├── __init__.py
│ ├── test_models.py
│ ├── test_resnets.py
│ └── unit
│ │ ├── __init__.py
│ │ └── test_transforms.py
├── test_autoencoders.py
├── test_classic_ml.py
├── test_detection.py
├── test_mnist_templates.py
├── test_scripts.py
├── test_vision.py
└── yolo
│ ├── __init__.py
│ └── unit
│ ├── __init__.py
│ ├── test_darknet_network.py
│ ├── test_target_matching.py
│ └── test_utils.py
├── optimizers
├── __init__.py
└── test_lr_scheduler.py
├── test_utilities.py
├── transforms
├── __init__.py
├── test_normalizations.py
└── test_transforms.py
└── utils
├── __init__.py
├── test_arguments.py
├── test_dependency.py
├── test_self_supervised.py
└── test_semi_supervised.py
/.codecov.yml:
--------------------------------------------------------------------------------
1 | # see https://docs.codecov.io/docs/codecov-yaml
2 | # Validation check:
3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate
4 |
5 |
6 | # https://docs.codecov.io/docs/codecovyml-reference
7 | codecov:
8 | bot: "codecov-io"
9 | strict_yaml_branch: "yaml-config"
10 | require_ci_to_pass: yes
11 | notify:
12 | after_n_builds: 17
13 | wait_for_ci: yes
14 |
15 | coverage:
16 | precision: 0 # 2 = xx.xx%, 0 = xx%
17 | round: nearest # how coverage is rounded: down/up/nearest
18 | range: 40...100 # custom range of coverage colors from red -> yellow -> green
19 | status:
20 | # https://codecov.readme.io/v1.0/docs/commit-status
21 | project:
22 | default:
23 | informational: true
24 | target: 95% # specify the target coverage for each commit status
25 | threshold: 30% # allow this little decrease on project
26 | # https://github.com/codecov/support/wiki/Filtering-Branches
27 | # branches: master
28 | if_ci_failed: error
29 | # https://github.com/codecov/support/wiki/Patch-Status
30 | patch:
31 | default:
32 | informational: true
33 | threshold: 50% # allow this much decrease on patch
34 | changes: false
35 |
36 | # https://docs.codecov.com/docs/github-checks#disabling-github-checks-patch-annotations
37 | github_checks:
38 | annotations: false
39 |
40 | parsers:
41 | gcov:
42 | branch_detection:
43 | conditional: true
44 | loop: true
45 | macro: false
46 | method: false
47 | javascript:
48 | enable_partials: false
49 |
50 | comment:
51 | layout: header, diff
52 | require_changes: false
53 | behavior: default # update if exists else create new
54 | # branches: *
55 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # This is a comment.
2 | # Each line is a file pattern followed by one or more owners.
3 |
4 | # These owners will be the default owners for everything in
5 | # the repo. Unless a later match takes precedence,
6 | # @global-owner1 and @global-owner2 will be requested for
7 | # review when someone opens a pull request.
8 | * @borda @akihironitta @ethanwharris @awaelchli
9 |
10 | # owners
11 | /.github/CODEOWNERS @borda @ethanwharris
12 | # main
13 | /README.md @borda
14 | # installation
15 | /setup.py @borda @ethanwharris
16 |
17 | # CI/CD
18 | /.github/workflows/ @borda @ethanwharris @akihironitta
19 | /.github/*.yml @borda @ethanwharris @akihironitta
20 | /.azure/ @borda @ethanwharris @akihironitta
21 | /*.yml @borda @ethanwharris @akihironitta
22 |
23 | # Docs
24 | /docs/ @edenlightning @borda
25 | /.github/*.md @edenlightning @borda
26 | /.github/ISSUE_TEMPLATE/*.md @edenlightning @borda
27 | /docs/source/conf.py @borda @akihironitta
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug, help wanted
6 | assignees: ''
7 | ---
8 |
9 | ## 🐛 Bug
10 |
11 |
12 |
13 | ### To Reproduce
14 |
15 | Steps to reproduce the behavior:
16 |
17 | 1. Go to '...'
18 | 1. Run '....'
19 | 1. Scroll down to '....'
20 | 1. See error
21 |
22 |
23 |
24 | #### Code sample
25 |
26 |
28 |
29 | ### Expected behavior
30 |
31 |
32 |
33 | ### Environment
34 |
35 | - PyTorch Version (e.g., 1.0):
36 | - OS (e.g., Linux):
37 | - How you installed PyTorch (`conda`, `pip`, source):
38 | - Build command you used (if compiling from source):
39 | - Python version:
40 | - CUDA/cuDNN version:
41 | - GPU models and configuration:
42 | - Any other relevant information:
43 |
44 | ### Additional context
45 |
46 |
47 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Typos and doc fixes
3 | about: Typos and doc fixes
4 | title: ''
5 | labels: documentation
6 | assignees: ''
7 | ---
8 |
9 | ## 📚 Documentation
10 |
11 | For typos and doc fixes, please go ahead and:
12 |
13 | 1. Create an issue.
14 | 1. Fix the typo.
15 | 1. Submit a PR.
16 |
17 | Thanks!
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement, help wanted
6 | assignees: ''
7 | ---
8 |
9 | ## 🚀 Feature
10 |
11 |
12 |
13 | ### Motivation
14 |
15 |
16 |
17 | ### Pitch
18 |
19 |
20 |
21 | ### Alternatives
22 |
23 |
24 |
25 | ### Additional context
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/how-to-question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: How to question
3 | about: Asking how-to questions
4 | title: ''
5 | labels: question
6 | assignees: ''
7 | ---
8 |
9 | ## ❓ Questions and Help
10 |
11 | ### Before asking:
12 |
13 | 1. search the issues.
14 | 1. search the docs.
15 |
16 |
17 |
18 | #### What is your question?
19 |
20 | #### Code
21 |
22 |
23 |
24 | #### What have you tried?
25 |
26 | #### What's your environment?
27 |
28 | - OS: \[e.g. iOS, Linux, Win\]
29 | - Packaging \[e.g. pip, conda\]
30 | - Version \[e.g. 0.5.2.1\]
31 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## What does this PR do?
2 |
3 |
8 |
9 | Fixes # (issue)
10 |
11 | ## Before submitting
12 |
13 | - [ ] Was this **discussed/approved** via a Github issue? (no need for typos and docs improvements)
14 | - [ ] Did you read the [contributor guideline](https://github.com/PyTorchLightning/lightning-bolts/blob/master/.github/CONTRIBUTING.md), Pull Request section?
15 | - [ ] Did you make sure your PR does **only one thing**, instead of bundling different changes together?
16 | - [ ] Did you make sure to **update the documentation** with your changes?
17 | - [ ] Did you write any **new necessary tests**? \[not needed for typos/docs\]
18 | - [ ] Did you verify new and **existing tests** pass locally with your changes?
19 | - [ ] If you made a notable change (that affects users), did you update the [CHANGELOG](https://github.com/PyTorchLightning/lightning-bolts/blob/master/CHANGELOG.md)?
20 |
21 |
22 |
23 | ## PR review
24 |
25 | - [ ] Is this pull request **ready for review**? (if not, please submit in draft mode)
26 |
27 | Anyone in the community is free to review the PR once the tests have passed.
28 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
29 |
30 | ## Did you have fun?
31 |
32 | Make sure you had fun coding 🙃
33 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
2 | version: 2
3 | updates:
4 |
5 | # Enable version updates for python
6 | - package-ecosystem: "pip"
7 | # Look for a `requirements` in the `root` directory
8 | directory: "/requirements"
9 | # Check for updates once a week
10 | schedule:
11 | interval: "weekly"
12 | # Labels on pull requests for version updates only
13 | labels: ["ci/cd"]
14 | pull-request-branch-name:
15 | # Separate sections of the branch name with a hyphen
16 | # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1`
17 | separator: "-"
18 | # Allow up to 5 open pull requests for pip dependencies
19 | open-pull-requests-limit: 5
20 | reviewers:
21 | - "Lightning-AI/core-bolts"
22 |
23 | # Enable version updates for GitHub Actions
24 | - package-ecosystem: "github-actions"
25 | directory: "/"
26 | # Check for updates once a week
27 | schedule:
28 | interval: "weekly"
29 | # Labels on pull requests for version updates only
30 | labels: ["ci/cd"]
31 | pull-request-branch-name:
32 | # Separate sections of the branch name with a hyphen
33 | # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1`
34 | separator: "-"
35 | # Allow up to 5 open pull requests for GitHub Actions
36 | open-pull-requests-limit: 5
37 | reviewers:
38 | - "Lightning-AI/core-bolts"
39 |
--------------------------------------------------------------------------------
/.github/labeler.yml:
--------------------------------------------------------------------------------
1 | documentation:
2 | - docs/**/*
3 |
4 | callback:
5 | - pl_bolts/callbacks/**/*
6 |
7 | datamodule:
8 | - pl_bolts/datamodules/**/*
9 |
10 | model:
11 | - pl_bolts/models/**/*
12 |
--------------------------------------------------------------------------------
/.github/mergify.yml:
--------------------------------------------------------------------------------
1 | # Copyright The PyTorch Lightning team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | pull_request_rules:
16 |
17 | - name: warn on conflicts
18 | conditions:
19 | - conflict
20 | - -draft # filter-out GH draft PRs
21 | - -label="has conflicts"
22 | actions:
23 | # comment:
24 | # message: This pull request is now in conflict... :(
25 | label:
26 | add: [ "has conflicts" ]
27 |
28 | - name: resolved conflicts
29 | conditions:
30 | - -conflict
31 | - label="has conflicts"
32 | - -draft # filter-out GH draft PRs
33 | - -merged # not merged yet
34 | - -closed
35 | actions:
36 | label:
37 | remove: [ "has conflicts" ]
38 |
39 | - name: Ready to Go
40 | conditions:
41 | - -conflict
42 | - -draft # filter-out GH draft PRs
43 | - -title~=(?i)wip # skip all PR that title contains “WIP” (ignoring case)
44 | - "#approved-reviews-by>=1" # number of review approvals
45 | - "#changes-requested-reviews-by=0" # no requested changes
46 | - "#check-pending<=2"
47 | - "#check-failure<5"
48 | - "#check-skipped<5"
49 | actions:
50 | label:
51 | add: [ "ready" ]
52 |
53 | - name: Not ready yet
54 | conditions:
55 | - or:
56 | - "#approved-reviews-by=0" # number of review approvals
57 | - "#changes-requested-reviews-by>=1" # no requested changes
58 | - "#check-failure>=5"
59 | - "#check-skipped>=5"
60 | actions:
61 | label:
62 | remove: [ "ready" ]
63 |
64 | - name: update PR
65 | conditions:
66 | - -conflict
67 | - -draft # filter-out GH draft PRs
68 | - base=master # apply only on master
69 | - -title~=(?i)wip # skip all PR that title contains “WIP” (ignoring case)
70 | - "#approved-reviews-by>=1" # number of review approvals
71 | - "#changes-requested-reviews-by=0" # no requested changes
72 | - "#check-failure<5"
73 | actions:
74 | update: {}
75 |
76 | - name: add core reviewer
77 | conditions:
78 | - -conflict # skip if conflict
79 | - -draft # filter-out GH draft PRs
80 | - label="ready"
81 | - "#approved-reviews-by<2" # number of review approvals
82 | - "#review-requested<2" # number of requested reviews
83 | actions:
84 | request_reviews:
85 | teams:
86 | - "@Lightning-Universe/core-bolts"
87 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # https://github.com/marketplace/stale
2 |
3 | # Number of days of inactivity before an issue becomes stale
4 | daysUntilStale: 60
5 | # Number of days of inactivity before a stale issue is closed
6 | daysUntilClose: 9
7 | # Issues with these labels will never be considered stale
8 | exemptLabels:
9 | - pinned
10 | - security
11 | # Label to use when marking an issue as stale
12 | staleLabel: won't fix
13 | # Comment to post when marking an issue as stale. Set to `false` to disable
14 | markComment: >
15 | This issue has been automatically marked as stale because it has not had
16 | recent activity. It will be closed if no further activity occurs. Thank you
17 | for your contributions.
18 | # Comment to post when closing a stale issue. Set to `false` to disable
19 | closeComment: false
20 |
21 | # Set to true to ignore issues in a project (defaults to false)
22 | exemptProjects: true
23 | # Set to true to ignore issues in a milestone (defaults to false)
24 | exemptMilestones: true
25 | # Set to true to ignore issues with an assignee (defaults to false)
26 | exemptAssignees: false
27 |
--------------------------------------------------------------------------------
/.github/workflows/ci-checks.yml:
--------------------------------------------------------------------------------
1 | name: General Checks
2 |
3 | on:
4 | push:
5 | branches: [master, "release/*"]
6 | pull_request:
7 | branches: [master, "release/*"]
8 |
9 | jobs:
10 | check-schema:
11 | uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.2
12 | with:
13 | azure-dir: ".azure"
14 |
15 | check-code:
16 | uses: Lightning-AI/utilities/.github/workflows/check-code.yml@v0.11.2
17 | with:
18 | actions-ref: v0.11.2
19 | extra-typing: typing
20 |
21 | check-package:
22 | uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.2
23 | with:
24 | actions-ref: v0.11.2
25 | artifact-name: dist-packages-${{ github.sha }}
26 | import-name: "pl_bolts"
27 | testing-matrix: |
28 | {
29 | "os": ["ubuntu-20.04", "macos-11", "windows-2022"],
30 | "python-version": ["3.8"]
31 | }
32 |
--------------------------------------------------------------------------------
/.github/workflows/ci-docs.yml:
--------------------------------------------------------------------------------
1 | name: RTFD Preview
2 | on:
3 | pull_request_target:
4 | types:
5 | - opened
6 |
7 | permissions:
8 | pull-requests: write
9 |
10 | jobs:
11 | documentation-links:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: readthedocs/actions/preview@v1
15 | with:
16 | project-slug: "lightning-bolts"
17 |
--------------------------------------------------------------------------------
/.github/workflows/docs-check.yml:
--------------------------------------------------------------------------------
1 | name: "Docs check"
2 | # https://github.com/marketplace/actions/sphinx-build
3 |
4 | on: # Trigger the workflow on push or pull request, but only for the master branch
5 | push:
6 | branches: [master]
7 | pull_request: {}
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
12 |
13 | env:
14 | FREEZE_REQUIREMENTS: 1
15 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
16 | TRANSFORMERS_CACHE: _hf_cache
17 |
18 | defaults:
19 | run:
20 | shell: bash
21 |
22 | jobs:
23 |
24 | test-docs:
25 | runs-on: ubuntu-20.04
26 | steps:
27 | - uses: actions/checkout@v4
28 | - uses: actions/setup-python@v5
29 | with:
30 | python-version: 3.9
31 |
32 | - name: Cache pip
33 | uses: actions/cache@v4
34 | with:
35 | path: ~/.cache/pip
36 | key: pip-${{ hashFiles('requirements/*.txt') }}
37 | restore-keys: pip-
38 |
39 | - name: Install package & dependencies
40 | run: |
41 | pip install "pip==22.2.1" # todo: drop after resolving extras
42 | pip install -e . -r requirements/devel.txt -r docs/requirements.txt -f $TORCH_URL
43 | pip list
44 |
45 | - name: Test Documentation
46 | working-directory: docs/
47 | env:
48 | SPHINX_MOCK_REQUIREMENTS: 0
49 | run: |
50 | make doctest
51 | make coverage
52 |
53 |
54 | make-docs:
55 | runs-on: ubuntu-20.04
56 | steps:
57 | - uses: actions/checkout@v4
58 | - uses: actions/setup-python@v5
59 | with:
60 | python-version: 3.9
61 |
62 | - name: Cache pip
63 | uses: actions/cache@v4
64 | with:
65 | path: ~/.cache/pip
66 | key: pip-${{ hashFiles('requirements/*.txt') }}
67 | restore-keys: pip-
68 |
69 | - name: Install package & dependencies
70 | run: |
71 | # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
72 | sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures
73 | pip install "pip==22.2.1" # todo: drop after resolving extras
74 | pip install -e . -r docs/requirements.txt -f $TORCH_URL
75 | pip list
76 |
77 | - name: Make Documentation
78 | working-directory: docs/
79 | run: |
80 | make clean
81 | make html --debug --jobs 2 SPHINXOPTS="-W"
82 |
83 | - name: Upload built docs
84 | uses: actions/upload-artifact@v3
85 | with:
86 | name: docs-results-${{ github.sha }}
87 | path: docs/build/html/
88 | # Use always() to always run this step to publish test results when there are test failures
89 | if: success()
90 |
--------------------------------------------------------------------------------
/.github/workflows/labeler.yml:
--------------------------------------------------------------------------------
1 | name: Labeler
2 |
3 | on:
4 | pull_request_target:
5 | types: [opened, reopened]
6 |
7 | jobs:
8 | triage:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/labeler@main
12 | with:
13 | repo-token: "${{ secrets.GITHUB_TOKEN }}"
14 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-release.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Release
2 |
3 | # https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | on: # Trigger the workflow on push or pull request, but only for the master branch
5 | push:
6 | branches: [master, "release/*"]
7 | release:
8 | types: [published]
9 |
10 | # based on https://github.com/pypa/gh-action-pypi-publish
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 |
16 | steps:
17 | - uses: actions/checkout@v4
18 | - uses: actions/setup-python@v5
19 | with:
20 | python-version: 3.8
21 |
22 | - name: Install dependencies
23 | run: python -m pip install -U setuptools wheel build twine
24 | - name: Build package
25 | run: python -m build
26 | - name: Check package
27 | run: twine check dist/*
28 |
29 | - name: Upload to release
30 | uses: AButler/upload-release-assets@v3.0
31 | with:
32 | files: 'dist/*'
33 | repo-token: ${{ secrets.GITHUB_TOKEN }}
34 |
35 | # We do this, since failures on test.pypi aren't that bad
36 | - name: Publish to Test PyPI
37 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
38 | uses: pypa/gh-action-pypi-publish@master
39 | with:
40 | user: __token__
41 | password: ${{ secrets.test_pypi_password }}
42 | repository_url: https://test.pypi.org/legacy/
43 |
44 | - name: Publish distribution 📦 to PyPI
45 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
46 | uses: pypa/gh-action-pypi-publish@master
47 | with:
48 | user: __token__
49 | password: ${{ secrets.pypi_password }}
50 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
7 | autoupdate_schedule: monthly
8 | # submodules: true
9 |
10 | repos:
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v4.4.0
13 | hooks:
14 | - id: end-of-file-fixer
15 | - id: trailing-whitespace
16 | - id: check-case-conflict
17 | - id: check-yaml
18 | - id: check-toml
19 | - id: pretty-format-json
20 | - id: check-added-large-files
21 | - id: check-docstring-first
22 | - id: detect-private-key
23 |
24 | - repo: https://github.com/asottile/pyupgrade
25 | rev: v3.8.0
26 | hooks:
27 | - id: pyupgrade
28 | args: [--py38-plus]
29 | name: Upgrade code
30 |
31 | - repo: https://github.com/PyCQA/docformatter
32 | rev: v1.7.3
33 | hooks:
34 | - id: docformatter
35 | additional_dependencies: [tomli]
36 | args: ["--in-place"]
37 |
38 | - repo: https://github.com/executablebooks/mdformat
39 | rev: 0.7.16
40 | hooks:
41 | - id: mdformat
42 | additional_dependencies:
43 | - mdformat-gfm
44 | - mdformat-black
45 | - mdformat_frontmatter
46 | exclude: CHANGELOG.md
47 |
48 | - repo: https://github.com/psf/black
49 | rev: 23.3.0
50 | hooks:
51 | - id: black
52 | name: Format code
53 |
54 | - repo: https://github.com/asottile/yesqa
55 | rev: v1.5.0
56 | hooks:
57 | - id: yesqa
58 | additional_dependencies:
59 | - flake8-pytest-style
60 | - flake8-bandit
61 | - pep8-naming
62 |
63 | - repo: https://github.com/astral-sh/ruff-pre-commit
64 | rev: v0.0.276
65 | hooks:
66 | - id: ruff
67 | args: ["--fix"]
68 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Build documentation in the docs/ directory with Sphinx
9 | sphinx:
10 | configuration: docs/source/conf.py
11 | fail_on_warning: true
12 |
13 | # Optionally build your docs in additional formats such as PDF and ePub
14 | formats:
15 | - htmlzip
16 | - pdf
17 |
18 | # Optionally set the version of Python and requirements required to build your docs
19 | python:
20 | version: 3.8
21 | install:
22 | - requirements: docs/requirements.txt
23 | #- requirements: requirements.txt
24 | - method: pip
25 | path: .
26 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html
2 | graft wheelhouse
3 |
4 | recursive-exclude __pycache__ *.py[cod] *.orig
5 |
6 | # Include the README and CHANGELOG
7 | include *.md
8 | recursive-include pl_bolts *.md
9 |
10 | # Include the license file
11 | include LICENSE
12 |
13 | exclude *.sh
14 | exclude *.toml
15 | exclude *.svg
16 | recursive-include pytorch_lightning *.py
17 |
18 | # exclude tests from package
19 | recursive-exclude tests *
20 | recursive-exclude site *
21 | exclude tests
22 |
23 | # Exclude the documentation files
24 | recursive-exclude docs *
25 | exclude docs
26 |
27 | # Include the Requirements
28 | include requirements.txt
29 | recursive-include requirements *.txt
30 |
31 | # Exclude build configs
32 | exclude *.yml
33 |
34 | # Exclude Makefile
35 | exclude Makefile
36 |
37 | prune .git
38 | prune .github
39 | prune .circleci
40 | prune notebook*
41 | prune temp*
42 | prune test*
43 | prune benchmark*
44 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test clean docs env
2 |
3 | # assume you have installed need packages
4 | export SPHINX_MOCK_REQUIREMENTS=1
5 |
6 | clean:
7 | # clean all temp runs
8 | rm -rf $(shell find . -name "lightning_log")
9 | rm -rf _ckpt_*
10 | rm -rf .mypy_cache
11 | rm -rf .pytest_cache
12 | rm -rf ./docs/build
13 | rm -rf ./docs/source/generated
14 | rm -rf ./docs/source/*/generated
15 | rm -rf ./docs/source/api
16 | rm -rf ./_datasets
17 |
18 | test: clean env
19 |
20 | # run tests with coverage
21 | python -m coverage run --source pl_bolts -m pytest pl_bolts tests -v
22 | python -m coverage report
23 |
24 | doctest: clean env
25 | pip install --quiet -r docs/requirements.txt
26 | SPHINX_MOCK_REQUIREMENTS=0 make -C docs doctest
27 |
28 | docs: clean
29 | pip install --quiet -r docs/requirements.txt
30 | python -m sphinx -b html -W docs/source docs/build
31 |
32 | env:
33 | pip install -q -r requirements/devel.txt
34 |
35 | format:
36 | isort .
37 | black .
38 | flake8 .
39 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS = -W
6 | SPHINXBUILD = $(shell which sphinx-build)
7 | SOURCEDIR = source
8 | BUILDDIR = build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
20 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx ==6.2
2 | recommonmark # fails with badges
3 | mistune < 2.0
4 | m2r # fails with multi-line text
5 | nbsphinx >0.8, <=0.9.2
6 | pandoc >2.0, <=2.3
7 | # docutils
8 | sphinxcontrib-fulltoc >1.0, <=1.2.0
9 | sphinxcontrib-mockautodoc
10 | sphinx-autodoc-typehints >1.0, <=1.17.1
11 | sphinx-paramlinks >0.4.0, <=0.5.4
12 | sphinx-togglebutton >0.2, <=0.3.2
13 | sphinx-copybutton >0.3, <=0.5.2
14 |
15 | pt-lightning-sphinx-theme @ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip
16 |
--------------------------------------------------------------------------------
/docs/source/_images/gans/basic_gan_dloss.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/basic_gan_dloss.jpg
--------------------------------------------------------------------------------
/docs/source/_images/gans/basic_gan_gloss.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/basic_gan_gloss.jpg
--------------------------------------------------------------------------------
/docs/source/_images/gans/basic_gan_interpolate.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/basic_gan_interpolate.jpg
--------------------------------------------------------------------------------
/docs/source/_images/gans/dcgan_lsun_dloss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_lsun_dloss.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/dcgan_lsun_gloss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_lsun_gloss.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/dcgan_lsun_outputs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_lsun_outputs.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/dcgan_mnist_dloss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_mnist_dloss.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/dcgan_mnist_gloss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_mnist_gloss.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/dcgan_mnist_outputs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/dcgan_mnist_outputs.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/srgan-celeba-scale_factor=2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-celeba-scale_factor=2.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/srgan-celeba-scale_factor=4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-celeba-scale_factor=4.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/srgan-mnist-scale_factor=2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-mnist-scale_factor=2.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/srgan-mnist-scale_factor=4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-mnist-scale_factor=4.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/srgan-stl10-scale_factor=2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-stl10-scale_factor=2.png
--------------------------------------------------------------------------------
/docs/source/_images/gans/srgan-stl10-scale_factor=4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/gans/srgan-stl10-scale_factor=4.png
--------------------------------------------------------------------------------
/docs/source/_images/logos/bolts_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/logos/bolts_logo.png
--------------------------------------------------------------------------------
/docs/source/_images/logos/lightning_icon.svg:
--------------------------------------------------------------------------------
1 |
9 |
--------------------------------------------------------------------------------
/docs/source/_images/logos/lightning_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/logos/lightning_logo.png
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/dqn_ddqn_comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/dqn_ddqn_comparison.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pendulum_sac_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pendulum_sac_results.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_double_dqn_baseline_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_double_dqn_baseline_results.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_dqn_baseline_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_dqn_baseline_results.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_dueling_dqn_comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_dueling_dqn_comparison.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_dueling_dqn_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_dueling_dqn_results.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_noisy_dqn_comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_noisy_dqn_comparison.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_noisy_dqn_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_noisy_dqn_results.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_nstep_dqn_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_nstep_dqn_1.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_nstep_dqn_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_nstep_dqn_2.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_results.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_v1_results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_v1_results.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_v1_results_comp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_per_dqn_baseline_v1_results_comp.jpg
--------------------------------------------------------------------------------
/docs/source/_images/rl_benchmark/pong_per_dqn_dqn_comparison.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/rl_benchmark/pong_per_dqn_dqn_comparison.jpg
--------------------------------------------------------------------------------
/docs/source/_images/vision/confused_logit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/docs/source/_images/vision/confused_logit.png
--------------------------------------------------------------------------------
/docs/source/_static/copybutton.js:
--------------------------------------------------------------------------------
1 | /* Copied from the official Python docs: https://docs.python.org/3/_static/copybutton.js */
2 | $(document).ready(function() {
3 | /* Add a [>>>] button on the top-right corner of code samples to hide
4 | * the >>> and ... prompts and the output and thus make the code
5 | * copyable. */
6 | var div = $('.highlight-python .highlight,' +
7 | '.highlight-python3 .highlight,' +
8 | '.highlight-pycon .highlight,' +
9 | '.highlight-default .highlight');
10 | var pre = div.find('pre');
11 |
12 | // get the styles from the current theme
13 | pre.parent().parent().css('position', 'relative');
14 | var hide_text = 'Hide the prompts and output';
15 | var show_text = 'Show the prompts and output';
16 | var border_width = pre.css('border-top-width');
17 | var border_style = pre.css('border-top-style');
18 | var border_color = pre.css('border-top-color');
19 | var button_styles = {
20 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0',
21 | 'border-color': border_color, 'border-style': border_style,
22 | 'border-width': border_width, 'color': border_color, 'text-size': '75%',
23 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em',
24 | 'border-radius': '0 3px 0 0'
25 | }
26 |
27 | // create and add the button to all the code blocks that contain >>>
28 | div.each(function(index) {
29 | var jthis = $(this);
30 | if (jthis.find('.gp').length > 0) {
31 | var button = $('>>>');
32 | button.css(button_styles)
33 | button.attr('title', hide_text);
34 | button.data('hidden', 'false');
35 | jthis.prepend(button);
36 | }
37 | // tracebacks (.gt) contain bare text elements that need to be
38 | // wrapped in a span to work with .nextUntil() (see later)
39 | jthis.find('pre:has(.gt)').contents().filter(function() {
40 | return ((this.nodeType == 3) && (this.data.trim().length > 0));
41 | }).wrap('');
42 | });
43 |
44 | // define the behavior of the button when it's clicked
45 | $('.copybutton').click(function(e){
46 | e.preventDefault();
47 | var button = $(this);
48 | if (button.data('hidden') === 'false') {
49 | // hide the code output
50 | button.parent().find('.go, .gp, .gt').hide();
51 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden');
52 | button.css('text-decoration', 'line-through');
53 | button.attr('title', show_text);
54 | button.data('hidden', 'true');
55 | } else {
56 | // show the code output
57 | button.parent().find('.go, .gp, .gt').show();
58 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible');
59 | button.css('text-decoration', 'none');
60 | button.attr('title', hide_text);
61 | button.data('hidden', 'false');
62 | }
63 | });
64 | });
65 |
--------------------------------------------------------------------------------
/docs/source/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 |
4 | {% block footer %}
5 | {{ super() }}
6 |
9 |
10 | {% endblock %}
11 |
--------------------------------------------------------------------------------
/docs/source/_templates/theme_variables.jinja:
--------------------------------------------------------------------------------
1 | {%- set external_urls = {
2 | 'github': 'https://github.com/PytorchLightning/lightning-bolts',
3 | 'github_issues': 'https://github.com/PytorchLightning/lightning-bolts/issues',
4 | 'contributing': 'https://github.com/PytorchLightning/lightning-bolts/blob/master/CONTRIBUTING.md',
5 | 'governance': 'https://github.com/PytorchLightning/pytorch-lightning/blob/master/governance.md',
6 | 'docs': 'https://lightning-bolts.rtfd.io/en/latest',
7 | 'twitter': 'https://twitter.com/PyTorchLightnin',
8 | 'discuss': 'https://pytorch-lightning.slack.com',
9 | 'tutorials': 'https://pytorch-lightning.readthedocs.io/en/latest/#tutorials',
10 | 'previous_pytorch_versions': 'https://pytorch.rtfd.io/en/latest/',
11 | 'home': 'https://lightning-bolts.rtfd.io/en/latest/',
12 | 'get_started': 'https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html',
13 | 'features': 'https://lightning-bolts.rtfd.io/en/latest/',
14 | 'blog': 'https://www.pytorchlightning.ai/blog',
15 | 'resources': 'https://pytorch-lightning.readthedocs.io/en/latest/#community-examples',
16 | 'support': 'https://lightning-bolts.rtfd.io/en/latest/',
17 | 'community': 'https://pytorch-lightning.slack.com',
18 | 'forums': 'https://pytorch-lightning.slack.com',
19 | 'pldocs': 'https://pytorch-lightning.readthedocs.io/en/stable/',
20 | 'tmdocs': 'https://torchmetrics.readthedocs.io/en/stable/',
21 | 'fldocs': 'https://lightning-flash.readthedocs.io/en/stable/',
22 | 'lbdocs': 'https://lightning-bolts.readthedocs.io/en/stable/',
23 | 'gdocs': 'https://docs.grid.ai/',
24 | }
25 | -%}
26 |
--------------------------------------------------------------------------------
/docs/source/callbacks/self_supervised.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | Self-supervised Callbacks
5 | =========================
6 | Useful callbacks for self-supervised learning models.
7 |
8 | .. note::
9 |
10 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
11 |
12 |
13 | ---------------
14 |
15 | BYOLMAWeightUpdate
16 | ------------------
17 | The exponential moving average weight-update rule from Bootstrap Your Own Latent (BYOL).
18 |
19 | .. autoclass:: pl_bolts.callbacks.byol_updates.BYOLMAWeightUpdate
20 | :noindex:
21 |
22 | ----------------
23 |
24 | SSLOnlineEvaluator
25 | ------------------
26 | Appends a MLP for fine-tuning to the given model. Callback has its own mini-inner loop.
27 |
28 | .. autoclass:: pl_bolts.callbacks.ssl_online.SSLOnlineEvaluator
29 | :noindex:
30 |
--------------------------------------------------------------------------------
/docs/source/callbacks/sparseml.rst:
--------------------------------------------------------------------------------
1 | =================
2 | SparseML Callback
3 | =================
4 |
5 | `SparseML `__ allows you to leverage sparsity to improve inference times substantially.
6 |
7 | SparseML requires you to fine-tune your model with the ``SparseMLCallback`` + a SparseML Recipe. By training with the ``SparseMLCallback``, you can leverage the `DeepSparse `__ engine to exploit the introduced sparsity, resulting in large performance improvements.
8 |
9 | .. warning::
10 |
11 | The SparseML callback requires the model to be ONNX exportable. This can be tricky when the model requires dynamic sequence lengths such as RNNs.
12 |
13 | To use leverage SparseML & DeepSparse follow the below steps:
14 |
15 | 1. Choose your Sparse Recipe
16 | ----------------------------
17 |
18 | To choose a recipe, have a look at `recipes `__ and `Sparse Zoo `__.
19 |
20 | It may be easier to infer a recipe via the UI dashboard using `Sparsify `__ which allows you to tweak and configure a recipe.
21 | This requires to import an ONNX model, which you can get from your ``LightningModule`` by doing ``model.to_onnx(output_path)``.
22 |
23 | 2. Train with SparseMLCallback
24 | ------------------------------
25 |
26 | .. testcode::
27 | :skipif: not _SPARSEML_TORCH_SATISFIED
28 |
29 | from pytorch_lightning import LightningModule, Trainer
30 | from pl_bolts.callbacks import SparseMLCallback
31 |
32 | class MyModel(LightningModule):
33 | ...
34 |
35 | model = MyModel()
36 |
37 | trainer = Trainer(
38 | callbacks=SparseMLCallback(recipe_path='recipe.yaml')
39 | )
40 |
41 | 3. Export to ONNX!
42 | ------------------
43 |
44 | Using the helper function, we handle any quantization/pruning internally and export the model into ONNX format.
45 | Note this assumes either you have implemented the property ``example_input_array`` in the model or you must provide a sample batch as below.
46 |
47 | .. testcode::
48 | :skipif: not _SPARSEML_TORCH_SATISFIED
49 |
50 | import torch
51 |
52 | model = MyModel()
53 | ...
54 |
55 | # export the onnx model, using the `model.example_input_array`
56 | SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/')
57 |
58 | # export the onnx model, providing a sample batch
59 | SparseMLCallback.export_to_sparse_onnx(model, 'onnx_export/', sample_batch=torch.randn(1, 128, 128, dtype=torch.float32))
60 |
61 |
62 | Once your model has been exported, you can import this into either `Sparsify `__ or `DeepSparse `__.
63 |
--------------------------------------------------------------------------------
/docs/source/callbacks/torch_ort.rst:
--------------------------------------------------------------------------------
1 | ==================
2 | Torch ORT Callback
3 | ==================
4 |
5 | `Torch ORT `__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions `here `__.
6 |
7 | This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as `self.model` within the ``LightningModule`` as shown below.
8 |
9 | .. note::
10 |
11 | Not all Transformer models are supported. See `this table `__ for supported models + branches containing fixes for certain models.
12 |
13 | .. code-block:: python
14 |
15 | from pytorch_lightning import LightningModule, Trainer
16 | from transformers import AutoModel
17 |
18 | from pl_bolts.callbacks import ORTCallback
19 |
20 |
21 | class MyTransformerModel(LightningModule):
22 |
23 | def __init__(self):
24 | super().__init__()
25 | self.model = AutoModel.from_pretrained('bert-base-cased')
26 |
27 | ...
28 |
29 |
30 | model = MyTransformerModel()
31 | trainer = Trainer(gpus=1, callbacks=ORTCallback())
32 | trainer.fit(model)
33 |
34 |
35 | For even easier setup and integration, have a look at our Lightning Flash integration for :ref:`Text Classification `, :ref:`Translation ` and :ref:`Summarization `.
36 |
--------------------------------------------------------------------------------
/docs/source/callbacks/variational.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | Variational Callbacks
5 | =====================
6 | Useful callbacks for GANs, variational-autoencoders or anything with latent spaces.
7 |
8 | .. note::
9 |
10 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
11 |
12 | ---------------
13 |
14 | Latent Dim Interpolator
15 | -----------------------
16 | Interpolates latent dims.
17 |
18 | Example output:
19 |
20 | .. image:: ../_images/gans/basic_gan_interpolate.jpg
21 | :width: 400
22 | :alt: Example latent space interpolation
23 |
24 | .. autoclass:: pl_bolts.callbacks.variational.LatentDimInterpolator
25 | :noindex:
26 |
--------------------------------------------------------------------------------
/docs/source/callbacks/vision.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | Vision Callbacks
5 | ================
6 | Useful callbacks for vision models.
7 |
8 | .. note::
9 |
10 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
11 |
12 |
13 | ---------------
14 |
15 | Confused Logit
16 | --------------
17 | Shows how the input would have to change to move the prediction from one logit to the other
18 |
19 |
20 | Example outputs:
21 |
22 | .. image:: ../_images/vision/confused_logit.png
23 | :width: 400
24 | :alt: Example of prediction confused between 5 and 8
25 |
26 | .. autoclass:: pl_bolts.callbacks.vision.confused_logit.ConfusedLogitCallback
27 | :noindex:
28 |
29 |
30 | ---------------
31 |
32 | Tensorboard Image Generator
33 | ---------------------------
34 | Generates images from a generative model and plots to tensorboard
35 |
36 | .. autoclass:: pl_bolts.callbacks.vision.image_generation.TensorboardGenerativeModelImageSampler
37 | :noindex:
38 |
--------------------------------------------------------------------------------
/docs/source/dataloaders/async.rst:
--------------------------------------------------------------------------------
1 | AsynchronousLoader
2 | ------------------
3 | This dataloader behaves identically to the standard pytorch dataloader, but will transfer
4 | data asynchronously to the GPU with training. You can also use it to wrap an existing dataloader.
5 |
6 | .. note::
7 |
8 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
9 |
10 | Example:
11 |
12 | .. code-block:: python
13 |
14 | dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device)
15 |
16 | for b in dataloader:
17 | ...
18 |
19 | .. autoclass:: pl_bolts.datamodules.async_dataloader.AsynchronousLoader
20 | :noindex:
21 |
--------------------------------------------------------------------------------
/docs/source/datamodules/sklearn.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 | Sklearn Datamodule
5 | ==================
6 | Utilities to map sklearn or numpy datasets to PyTorch Dataloaders with automatic data splits and GPU/TPU support.
7 |
8 | .. code-block:: python
9 |
10 | from sklearn.datasets import load_diabetes
11 | from pl_bolts.datamodules import SklearnDataModule
12 |
13 | X, y = load_diabetes(return_X_y=True)
14 | loaders = SklearnDataModule(X, y)
15 |
16 | train_loader = loaders.train_dataloader(batch_size=32)
17 | val_loader = loaders.val_dataloader(batch_size=32)
18 | test_loader = loaders.test_dataloader(batch_size=32)
19 |
20 | Or build your own torch datasets
21 |
22 | .. code-block:: python
23 |
24 | from sklearn.datasets import load_diabetes
25 | from pl_bolts.datamodules import SklearnDataset
26 |
27 | X, y = load_diabetes(return_X_y=True)
28 | dataset = SklearnDataset(X, y)
29 | loader = DataLoader(dataset)
30 |
31 | ----------------
32 |
33 | Sklearn Dataset Class
34 | ---------------------
35 | Transforms a sklearn or numpy dataset to a PyTorch Dataset.
36 |
37 | .. autoclass:: pl_bolts.datamodules.sklearn_datamodule.SklearnDataset
38 | :noindex:
39 |
40 | ----------------
41 |
42 | Sklearn DataModule Class
43 | ------------------------
44 | Automatically generates the train, validation and test splits for a Numpy dataset.
45 | They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits.
46 |
47 | .. autoclass:: pl_bolts.datamodules.sklearn_datamodule.SklearnDataModule
48 |
--------------------------------------------------------------------------------
/docs/source/datamodules/vision.rst:
--------------------------------------------------------------------------------
1 | Vision DataModules
2 | ==================
3 | The following are pre-built datamodules for computer-vision.
4 |
5 | -------------
6 |
7 | Supervised learning
8 | -------------------
9 | These are standard vision datasets with the train, test, val splits pre-generated in DataLoaders with
10 | the standard transforms (and Normalization) values
11 |
12 | BinaryEMNIST
13 | ^^^^^^^^^^^^
14 |
15 | .. autoclass:: pl_bolts.datamodules.binary_emnist_datamodule.BinaryEMNISTDataModule
16 |
17 | BinaryMNIST
18 | ^^^^^^^^^^^
19 |
20 | .. autoclass:: pl_bolts.datamodules.binary_mnist_datamodule.BinaryMNISTDataModule
21 |
22 | CityScapes
23 | ^^^^^^^^^^
24 |
25 | .. autoclass:: pl_bolts.datamodules.cityscapes_datamodule.CityscapesDataModule
26 |
27 | CIFAR-10
28 | ^^^^^^^^
29 |
30 | .. autoclass:: pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule
31 |
32 | EMNIST
33 | ^^^^^^
34 |
35 | .. autoclass:: pl_bolts.datamodules.emnist_datamodule.EMNISTDataModule
36 |
37 | FashionMNIST
38 | ^^^^^^^^^^^^
39 |
40 | .. autoclass:: pl_bolts.datamodules.fashion_mnist_datamodule.FashionMNISTDataModule
41 |
42 | Imagenet
43 | ^^^^^^^^
44 |
45 | .. autoclass:: pl_bolts.datamodules.imagenet_datamodule.ImagenetDataModule
46 |
47 | MNIST
48 | ^^^^^
49 |
50 | .. autoclass:: pl_bolts.datamodules.mnist_datamodule.MNISTDataModule
51 |
52 | Semi-supervised learning
53 | ------------------------
54 | The following datasets have support for unlabeled training and semi-supervised learning where only a few examples
55 | are labeled.
56 |
57 | Imagenet (ssl)
58 | ^^^^^^^^^^^^^^
59 |
60 | .. autoclass:: pl_bolts.datamodules.ssl_imagenet_datamodule.SSLImagenetDataModule
61 |
62 | STL-10
63 | ^^^^^^
64 |
65 | .. autoclass:: pl_bolts.datamodules.stl10_datamodule.STL10DataModule
66 |
--------------------------------------------------------------------------------
/docs/source/datasets/debug.rst:
--------------------------------------------------------------------------------
1 | **************
2 | Debug Datasets
3 | **************
4 |
5 | DummyDataset
6 | ============
7 |
8 | .. autoclass:: pl_bolts.datasets.dummy_dataset.DummyDataset
9 | :noindex:
10 |
11 | DummyDetectionDataset
12 | =====================
13 |
14 | .. autoclass:: pl_bolts.datasets.dummy_dataset.DummyDetectionDataset
15 | :noindex:
16 |
17 | RandomDataset
18 | =============
19 |
20 | .. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDataset
21 | :noindex:
22 |
23 | RandomDictDataset
24 | =================
25 |
26 | .. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDictDataset
27 | :noindex:
28 |
29 | RandomDictStringDataset
30 | =======================
31 |
32 | .. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDictStringDataset
33 | :noindex:
34 |
--------------------------------------------------------------------------------
/docs/source/governance.rst:
--------------------------------------------------------------------------------
1 | .. _governance:
2 |
3 | PL Bolts Governance | Persons of interest
4 | =========================================
5 |
6 | Core Maintainers
7 | ----------------
8 | - William Falcon (`williamFalcon `_) (Lightning founder)
9 | - Jirka Borovec (`Borda `_)
10 | - Akihiro Nitta (`akihironitta `_)
11 |
12 | Core Contributors
13 | -----------------
14 | - Atharva Phatak (`Atharva-Phatak `_)
15 | - Shion Matsumoto (`matsumotosan `_)
16 | - JongMok Lee (`lijm1358 `_)
17 |
18 | Alumni
19 | ------
20 | - Ananya Harsh Jha (`ananyahjha93 `_)
21 | - Annika Brundyn (`annikabrundyn `_)
22 | - Ota Jašek (`otaj `_)
23 | - Teddy Koker (`teddykoker `_)
24 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. PyTorchLightning-Bolts documentation master file, created by
2 | sphinx-quickstart on Wed Mar 25 21:34:07 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Lightning-Bolts documentation
7 | =============================
8 | .. toctree::
9 | :maxdepth: 1
10 | :name: start
11 | :caption: Start here
12 |
13 | installation
14 | introduction_guide
15 |
16 | .. toctree::
17 | :maxdepth: 2
18 | :name: callbacks
19 | :caption: Callbacks
20 |
21 | callbacks/monitor
22 | callbacks/torch_ort
23 | callbacks/sparseml
24 | callbacks/self_supervised
25 | callbacks/variational
26 | callbacks/vision
27 |
28 | .. toctree::
29 | :maxdepth: 2
30 | :name: datamodules
31 | :caption: DataModules
32 |
33 | datamodules/sklearn
34 | datamodules/vision
35 |
36 | .. toctree::
37 | :maxdepth: 2
38 | :name: datasets
39 | :caption: Datasets
40 |
41 | datasets/debug
42 |
43 | .. toctree::
44 | :maxdepth: 2
45 | :name: dataloaders
46 | :caption: DataLoaders
47 |
48 | dataloaders/async
49 |
50 | .. toctree::
51 | :maxdepth: 2
52 | :name: losses
53 | :caption: Losses
54 |
55 | losses
56 |
57 | .. toctree::
58 | :maxdepth: 2
59 | :name: models
60 | :caption: Models
61 |
62 | models/models_howto
63 | models/autoencoders
64 | models/convolutional
65 | models/gans
66 | models/object_detection
67 | models/reinforce_learn
68 | models/self_supervised
69 | models/classic_ml
70 |
71 | .. toctree::
72 | :maxdepth: 2
73 | :name: learning_rate_schedulers
74 | :caption: Learning Rate Schedulers
75 |
76 | schedulers/warmup_cosine_annealing.rst
77 |
78 | .. toctree::
79 | :maxdepth: 2
80 | :name: transforms
81 | :caption: Data Processing
82 |
83 | transforms/self_supervised
84 | transforms/semi_supervised
85 |
86 | .. toctree::
87 | :maxdepth: 2
88 | :name: ssl
89 | :caption: Tasks
90 |
91 | vision_tasks
92 |
93 | .. toctree::
94 | :maxdepth: 1
95 | :name: community
96 | :caption: Community
97 |
98 | CONTRIBUTING.md
99 | governance.md
100 | stability.md
101 | CHANGELOG.md
102 |
103 |
104 | Indices and tables
105 | ==================
106 |
107 | * :ref:`genindex`
108 | * :ref:`modindex`
109 | * :ref:`search`
110 |
--------------------------------------------------------------------------------
/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | You can install using `pip `_
5 |
6 | .. code-block:: bash
7 |
8 | pip install lightning-bolts
9 |
10 | Install bleeding-edge (no guarantees)
11 |
12 | .. code-block:: bash
13 |
14 | pip install git+https://github.com/PytorchLightning/lightning-bolts.git@master --upgrade
15 |
16 | In case you want to have full experience you can install all optional packages at once
17 |
18 | .. code-block:: bash
19 |
20 | pip install lightning-bolts["extra"]
21 |
--------------------------------------------------------------------------------
/docs/source/losses.rst:
--------------------------------------------------------------------------------
1 | Reinforcement Learning
2 | ======================
3 | These are common losses used in RL.
4 |
5 | ---------------
6 |
7 | DQN Loss
8 | --------
9 |
10 | .. autofunction:: pl_bolts.losses.rl.dqn_loss
11 | :noindex:
12 |
13 | Double DQN Loss
14 | ---------------
15 |
16 | .. autofunction:: pl_bolts.losses.rl.double_dqn_loss
17 | :noindex:
18 |
19 | Per DQN Loss
20 | ------------
21 |
22 | .. autofunction:: pl_bolts.losses.rl.per_dqn_loss
23 | :noindex:
24 |
--------------------------------------------------------------------------------
/docs/source/models/convolutional.rst:
--------------------------------------------------------------------------------
1 | Convolutional Architectures
2 | ===========================
3 | This package lists contributed convolutional architectures.
4 |
5 | .. note::
6 |
7 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
8 |
9 | --------------
10 |
11 |
12 | GPT-2
13 | -----
14 |
15 | .. autoclass:: pl_bolts.models.vision.GPT2
16 | :noindex:
17 |
18 | -------------
19 |
20 | Image GPT
21 | ---------
22 |
23 | .. autoclass:: pl_bolts.models.vision.ImageGPT
24 | :noindex:
25 |
26 | -------------
27 |
28 | Pixel CNN
29 | ---------
30 |
31 | .. autoclass:: pl_bolts.models.vision.PixelCNN
32 | :noindex:
33 |
34 | -------------
35 |
36 | UNet
37 | ----
38 |
39 | .. autoclass:: pl_bolts.models.vision.UNet
40 | :noindex:
41 |
42 | -------------
43 |
44 | Semantic Segmentation
45 | ---------------------
46 | Model template to use for semantic segmentation tasks. The model uses a UNet architecture by default. Override any part
47 | of this model to build your own variation.
48 |
49 | .. code-block:: python
50 |
51 | from pl_bolts.models.vision import SemSegment
52 | from pl_bolts.datamodules import KittiDataModule
53 | import pytorch_lightning as pl
54 |
55 | dm = KittiDataModule('path/to/kitt/dataset/', batch_size=4)
56 | model = SemSegment(datamodule=dm)
57 | trainer = pl.Trainer()
58 | trainer.fit(model)
59 |
60 | .. autoclass:: pl_bolts.models.vision.SemSegment
61 | :noindex:
62 |
--------------------------------------------------------------------------------
/docs/source/models/object_detection.rst:
--------------------------------------------------------------------------------
1 | Object Detection
2 | ================
3 | This package lists contributed object detection models.
4 |
5 | .. note::
6 |
7 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
8 |
9 | --------------
10 |
11 |
12 | Faster R-CNN
13 | ------------
14 |
15 | .. autoclass:: pl_bolts.models.detection.faster_rcnn.faster_rcnn_module.FasterRCNN
16 | :noindex:
17 |
18 | -------------
19 |
20 | RetinaNet
21 | ---------
22 |
23 | .. autoclass:: pl_bolts.models.detection.retinanet.retinanet_module.RetinaNet
24 | :noindex:
25 |
26 | -------------
27 |
28 | YOLO
29 | ----
30 |
31 | .. autoclass:: pl_bolts.models.detection.yolo.yolo_module.YOLO
32 | :noindex:
33 |
--------------------------------------------------------------------------------
/docs/source/schedulers/warmup_cosine_annealing.rst:
--------------------------------------------------------------------------------
1 |
2 | Linear Warmup Cosine Annealing
3 | ------------------------------
4 |
5 | .. note::
6 |
7 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
8 |
9 | .. autoclass:: pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR
10 | :noindex:
11 |
--------------------------------------------------------------------------------
/docs/source/stability.rst:
--------------------------------------------------------------------------------
1 | .. _stability:
2 |
3 | Bolts stability
4 | ===============
5 |
6 | Currently we are going through major revision of Bolts to ensure all of the code is stable and compatible with the rest of the Lightning ecosystem.
7 | For this reason, all of our features are either marked as stable or in need of review. Stable features are implicit, features to be reviewed are explicitly marked.
8 |
9 | At the beginning of the aforementioned revision, **ALL** of the features currently in the project have been marked as to be reviewed and will undergo rigorous review and testing before they can be marked as stable. See `this GitHub issue `_ to check progress of the revision
10 |
11 | This document is intended to help you know what to expect and to outline our commitment to stability.
12 |
13 | Stable
14 | ______
15 |
16 | For stable features, all of the following are true:
17 |
18 | - the API isn’t expected to change
19 | - if anything does change, incorrect usage will give a deprecation warning for **one minor release** before the breaking change is made
20 | - the API has been tested for compatibility with latest releases of PyTorch Lightning and Flash
21 |
22 | Under Review
23 | ____________
24 |
25 | For features to be reviewed, any or all of the following may be true:
26 |
27 | - the feature has unstable dependencies
28 | - the API may change without notice in future versions
29 | - the performance of the feature has not been verified
30 | - the docs for this feature are under active development
31 |
32 |
33 | Before a feature can be moved to Stable it needs to satisfy following conditions:
34 |
35 | - Have appropriate tests, that will check not only correctness of the feature, but also compatibility with the current versions.
36 | - Not have duplicate code accross Lightning ecosystem and more mature OSS projects.
37 | - Pass a review process.
38 |
--------------------------------------------------------------------------------
/docs/source/transforms/semi_supervised.rst:
--------------------------------------------------------------------------------
1 | Semi-supervised learning
2 | ========================
3 | Collection of utilities for semi-supervised learning where some part of the data is labeled
4 | and the other part is not.
5 |
6 | .. note::
7 |
8 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
9 |
10 | --------------
11 |
12 | Balanced classes
13 | ----------------
14 |
15 | Example::
16 |
17 | from pl_bolts.utils.semi_supervised import balance_classes
18 |
19 |
20 | .. autofunction:: pl_bolts.utils.semi_supervised.balance_classes
21 | :noindex:
22 |
23 | half labeled batches
24 | --------------------
25 |
26 | Example::
27 |
28 | from pl_bolts.utils.semi_supervised import balance_classes
29 |
30 | .. autofunction:: pl_bolts.utils.semi_supervised.generate_half_labeled_batches
31 | :noindex:
32 |
--------------------------------------------------------------------------------
/docs/source/vision_tasks.rst:
--------------------------------------------------------------------------------
1 | Self-supervised Learning
2 | ========================
3 | This section implements popular contrastive learning tasks used in self-supervised learning.
4 |
5 | .. note::
6 |
7 | We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
8 |
9 | ---------------
10 |
11 | FeatureMapContrastiveTask
12 | -------------------------
13 | This task compares sets of feature maps.
14 |
15 | In general the feature map comparison pretext task uses triplets of features.
16 | Here are the abstract steps of comparison.
17 |
18 | Generate multiple views of the same image
19 |
20 | .. code-block:: python
21 |
22 | x1_view_1 = data_augmentation(x1)
23 | x1_view_2 = data_augmentation(x1)
24 |
25 | Use a different example to generate additional views (usually within the same batch or a pool of candidates)
26 |
27 | .. code-block:: python
28 |
29 | x2_view_1 = data_augmentation(x2)
30 | x2_view_2 = data_augmentation(x2)
31 |
32 | Pick 3 views to compare, these are the anchor, positive and negative features
33 |
34 | .. code-block:: python
35 |
36 | anchor = x1_view_1
37 | positive = x1_view_2
38 | negative = x2_view_1
39 |
40 | Generate feature maps for each view
41 |
42 | .. code-block:: python
43 |
44 | (a0, a1, a2) = encoder(anchor)
45 | (p0, p1, p2) = encoder(positive)
46 |
47 | Make a comparison for a set of feature maps
48 |
49 | .. code-block:: python
50 |
51 | phi = some_score_function()
52 |
53 | # the '01' comparison
54 | score = phi(a0, p1)
55 |
56 | # and can be bidirectional
57 | score = phi(p0, a1)
58 |
59 | In practice the contrastive task creates a BxB matrix where B is the batch size. The diagonals for set 1 of feature maps
60 | are the anchors, the diagonals of set 2 of the feature maps are the positives, the non-diagonals of set 1 are the
61 | negatives.
62 |
63 | .. autoclass:: pl_bolts.losses.self_supervised_learning.FeatureMapContrastiveTask
64 | :noindex:
65 |
66 | --------------
67 |
68 | Context prediction tasks
69 | ------------------------
70 | The following tasks aim to predict a target using a context representation.
71 |
72 | CPCContrastiveTask
73 | ^^^^^^^^^^^^^^^^^^
74 | This is the predictive task from CPC (v2).
75 |
76 | .. code-block::
77 |
78 | task = CPCTask(num_input_channels=32)
79 |
80 | # (batch, channels, rows, cols)
81 | # this should be thought of as 49 feature vectors, each with 32 dims
82 | Z = torch.random.rand(3, 32, 7, 7)
83 |
84 | loss = task(Z)
85 |
86 |
87 | .. autoclass:: pl_bolts.losses.self_supervised_learning.CPCTask
88 | :noindex:
89 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/base.txt
2 |
--------------------------------------------------------------------------------
/requirements/base.txt:
--------------------------------------------------------------------------------
1 | numpy <1.26.0
2 | pytorch-lightning >1.7.0, <2.0.0 # strict
3 | torchmetrics >=0.10.0, <0.12.0
4 | lightning-utilities >0.3.1 # this is needed for PL 1.7
5 | torchvision >=0.10.0 # todo: move to topic related extras
6 | tensorboard >=2.9.1, <2.14.0 # for `TensorBoardLogger`
7 |
--------------------------------------------------------------------------------
/requirements/devel.txt:
--------------------------------------------------------------------------------
1 | # install all mandatory dependencies
2 | -r ./base.txt
3 |
4 | # install all extra dependencies for full package experience
5 | -r ./models.txt
6 | -r ./loggers.txt
7 |
8 | # extended list of dependencies to run lint and tests
9 | -r ./test.txt
10 |
--------------------------------------------------------------------------------
/requirements/loggers.txt:
--------------------------------------------------------------------------------
1 | # test_tube >=0.7.5
2 | # trains >=0.14.1
3 | matplotlib >3.0.0, <3.8.0
4 | wandb >0.13.0, <0.16.0
5 | scipy >1.5.0, <1.12.0
6 |
--------------------------------------------------------------------------------
/requirements/models.txt:
--------------------------------------------------------------------------------
1 | torchvision >=0.10.0
2 | scikit-learn >=1.0.2
3 | Pillow >9.0.0
4 | gym[atari] >=0.17.2, <0.22.0 # strict
5 | atari-py >0.2, !=0.2.6, <0.3 # strict
6 | box2d-py >2.3, <2.4 # strict
7 | opencv-python-headless >=4.5.5.62
8 |
--------------------------------------------------------------------------------
/requirements/test.txt:
--------------------------------------------------------------------------------
1 | coverage[toml] >7.0.0, <8.0.0
2 | pytest ==7.4.0
3 | pytest-cov ==4.1.0
4 | pytest-timeout ==2.1.0
5 | pytest-rerunfailures ==12.0
6 |
7 | scikit-learn >=1.0.2, <=1.3.0
8 | sparseml >1.0.0, <1.6.0
9 | ale-py >=0.7, <=0.8.1
10 | jsonargparse[signatures] >4.0.0, <=4.22.1 # for LightningCLI
11 |
--------------------------------------------------------------------------------
/requirements/typing.txt:
--------------------------------------------------------------------------------
1 | mypy ==1.7.1
2 |
--------------------------------------------------------------------------------
/src/pl_bolts/__about__.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | __version__ = "0.7.0"
4 | __author__ = "Lightning AI et al."
5 | __author_email__ = "pytorch@lightning.ai"
6 | __license__ = "Apache-2.0"
7 | __copyright__ = f"Copyright (c) 2020-{time.strftime('%Y')}, {__author__}."
8 | __homepage__ = "https://github.com/Lightning-AI/lightning-bolts"
9 | __docs__ = "Lightning Bolts is a community contribution for ML researchers."
10 | __long_doc__ = """
11 | What is it?
12 | -----------
13 | Bolts is a collection of useful models and templates to bootstrap your DL research even faster.
14 | It's designed to work with PyTorch Lightning
15 |
16 | Subclass Example
17 | ----------------
18 | Use `pl_bolts` models to remove boilerplate for common approaches and architectures.
19 | Because it uses LightningModules under the hood, you just need to overwrite
20 | the relevant parts to your research.
21 |
22 | How to add a model
23 | ------------------
24 | This repository is meant for model contributions from the community.
25 | To add a model, you can start with the MNIST template (or any other model in the repo).
26 | Please organize the functions of your lightning module.
27 | """
28 |
29 | __all__ = [
30 | "__author__",
31 | "__author_email__",
32 | "__copyright__",
33 | "__docs__",
34 | "__homepage__",
35 | "__license__",
36 | "__version__",
37 | ]
38 |
--------------------------------------------------------------------------------
/src/pl_bolts/__init__.py:
--------------------------------------------------------------------------------
1 | """Root package crossroad."""
2 |
3 | import os
4 |
5 | import numpy
6 |
7 | from pl_bolts.__about__ import * # noqa: F401, F403
8 |
9 | # adding compatibility for numpy >= 1.24
10 | for tp_name, tp_ins in [("object", object), ("bool", bool), ("int", int), ("float", float)]:
11 | if not hasattr(numpy, tp_name):
12 | setattr(numpy, tp_name, tp_ins)
13 |
14 | _PACKAGE_ROOT = os.path.dirname(__file__)
15 | _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
16 | _HTTPS_AWS_HUB = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com"
17 |
18 | from pl_bolts import ( # noqa: E402
19 | callbacks,
20 | datamodules,
21 | datasets,
22 | losses,
23 | metrics,
24 | models,
25 | optimizers,
26 | transforms,
27 | utils,
28 | )
29 |
30 | __all__ = [
31 | "callbacks",
32 | "datamodules",
33 | "datasets",
34 | "losses",
35 | "metrics",
36 | "models",
37 | "optimizers",
38 | "transforms",
39 | "utils",
40 | ]
41 |
--------------------------------------------------------------------------------
/src/pl_bolts/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | """Collection of PyTorchLightning callbacks."""
2 | from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
3 | from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor
4 | from pl_bolts.callbacks.printing import PrintTableMetricsCallback
5 | from pl_bolts.callbacks.sparseml import SparseMLCallback
6 | from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
7 | from pl_bolts.callbacks.torch_ort import ORTCallback
8 | from pl_bolts.callbacks.variational import LatentDimInterpolator
9 | from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore
10 | from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback
11 | from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler
12 | from pl_bolts.callbacks.vision.sr_image_logger import SRImageLoggerCallback
13 |
14 | __all__ = [
15 | "BatchGradientVerificationCallback",
16 | "BYOLMAWeightUpdate",
17 | "ModuleDataMonitor",
18 | "TrainingDataMonitor",
19 | "PrintTableMetricsCallback",
20 | "SSLOnlineEvaluator",
21 | "LatentDimInterpolator",
22 | "ConfusedLogitCallback",
23 | "TensorboardGenerativeModelImageSampler",
24 | "SRImageLoggerCallback",
25 | "ORTCallback",
26 | "SparseMLCallback",
27 | ]
28 |
--------------------------------------------------------------------------------
/src/pl_bolts/callbacks/byol_updates.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Sequence, Union
3 |
4 | import torch.nn as nn
5 | from pytorch_lightning import Callback, LightningModule, Trainer
6 | from torch import Tensor
7 |
8 |
9 | class BYOLMAWeightUpdate(Callback):
10 | """Weight update rule from Bootstrap Your Own Latent (BYOL).
11 |
12 | Updates the target_network params using an exponential moving average update rule weighted by tau.
13 | BYOL claims this keeps the online_network from collapsing.
14 |
15 | The PyTorch Lightning module being trained should have:
16 |
17 | - ``self.online_network``
18 | - ``self.target_network``
19 |
20 | .. note:: Automatically increases tau from ``initial_tau`` to 1.0 with every training step
21 |
22 | Args:
23 | initial_tau (float, optional): starting tau. Auto-updates with every training step
24 |
25 | Example::
26 |
27 | # model must have 2 attributes
28 | model = Model()
29 | model.online_network = ...
30 | model.target_network = ...
31 |
32 | trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
33 |
34 | """
35 |
36 | def __init__(self, initial_tau: float = 0.996) -> None:
37 | if not 0.0 <= initial_tau <= 1.0:
38 | raise ValueError(f"initial tau should be between 0 and 1 instead of {initial_tau}.")
39 |
40 | super().__init__()
41 | self.initial_tau = initial_tau
42 | self.current_tau = initial_tau
43 |
44 | def on_train_batch_end(
45 | self,
46 | trainer: Trainer,
47 | pl_module: LightningModule,
48 | outputs: Sequence,
49 | batch: Sequence,
50 | batch_idx: int,
51 | ) -> None:
52 | # get networks
53 | online_net = pl_module.online_network
54 | target_net = pl_module.target_network
55 |
56 | # update target network weights
57 | self.update_weights(online_net, target_net)
58 |
59 | # update tau after
60 | self.update_tau(pl_module, trainer)
61 |
62 | def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> None:
63 | """Update tau value for next update."""
64 | max_steps = len(trainer.train_dataloader) * trainer.max_epochs
65 | self.current_tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2
66 |
67 | def update_weights(self, online_net: Union[nn.Module, Tensor], target_net: Union[nn.Module, Tensor]) -> None:
68 | """Update target network parameters."""
69 | for online_p, target_p in zip(online_net.parameters(), target_net.parameters()):
70 | target_p.data = self.current_tau * target_p.data + (1.0 - self.current_tau) * online_p.data
71 |
--------------------------------------------------------------------------------
/src/pl_bolts/callbacks/torch_ort.py:
--------------------------------------------------------------------------------
1 | # Copyright The PyTorch Lightning team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from pytorch_lightning import Callback, LightningModule, Trainer
15 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
16 |
17 | from pl_bolts.utils import _TORCH_ORT_AVAILABLE
18 |
19 | if _TORCH_ORT_AVAILABLE:
20 | from torch_ort import ORTModule
21 |
22 | from pl_bolts.utils.stability import under_review
23 |
24 |
25 | @under_review()
26 | class ORTCallback(Callback):
27 | """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime.
28 |
29 | Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for
30 | training and inference.
31 |
32 | Usage:
33 |
34 | # via Transformer Tasks
35 | model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
36 |
37 | # or via the trainer
38 | trainer = flash.Trainer(callbacks=ORTCallback())
39 | """
40 |
41 | def __init__(self) -> None:
42 | if not _TORCH_ORT_AVAILABLE:
43 | raise MisconfigurationException(
44 | "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort"
45 | )
46 |
47 | def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None:
48 | if not hasattr(pl_module, "model"):
49 | raise MisconfigurationException(
50 | "Torch ORT requires to wrap a single model that defines a forward function "
51 | "assigned as `model` inside the `LightningModule`."
52 | )
53 | if not isinstance(pl_module.model, ORTModule):
54 | pl_module.model = ORTModule(pl_module.model)
55 |
--------------------------------------------------------------------------------
/src/pl_bolts/callbacks/verification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/callbacks/verification/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/callbacks/vision/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/callbacks/vision/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/callbacks/vision/sr_image_logger.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | import torch.nn.functional as F # noqa: N812
6 | from pytorch_lightning import Callback
7 |
8 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
9 | from pl_bolts.utils.stability import under_review
10 | from pl_bolts.utils.warnings import warn_missing_pkg
11 |
12 | if _TORCHVISION_AVAILABLE:
13 | from torchvision.utils import make_grid
14 | else: # pragma: no cover
15 | warn_missing_pkg("torchvision")
16 |
17 |
18 | @under_review()
19 | class SRImageLoggerCallback(Callback):
20 | """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement
21 | the ``forward`` function for generation.
22 |
23 | Requirements::
24 |
25 | # model forward must work generating high-res from low-res image
26 | hr_fake = pl_module(lr_image)
27 |
28 | Example::
29 |
30 | from pl_bolts.callbacks import SRImageLoggerCallback
31 |
32 | trainer = Trainer(callbacks=[SRImageLoggerCallback()])
33 | """
34 |
35 | def __init__(self, log_interval: int = 1000, scale_factor: int = 4, num_samples: int = 5) -> None:
36 | """
37 | Args:
38 | log_interval: Number of steps between logging. Default: ``1000``.
39 | scale_factor: Scale factor used for downsampling the high-res images. Default: ``4``.
40 | num_samples: Number of images of displayed in the grid. Default: ``5``.
41 | """
42 | super().__init__()
43 | self.log_interval = log_interval
44 | self.scale_factor = scale_factor
45 | self.num_samples = num_samples
46 |
47 | def on_train_batch_end(
48 | self,
49 | trainer: pl.Trainer,
50 | pl_module: pl.LightningModule,
51 | outputs: torch.Tensor,
52 | batch: Tuple[torch.Tensor, torch.Tensor],
53 | batch_idx: int,
54 | ) -> None:
55 | global_step = trainer.global_step
56 | if global_step % self.log_interval == 0:
57 | hr_image, lr_image = batch
58 | hr_image, lr_image = hr_image.to(pl_module.device), lr_image.to(pl_module.device)
59 | hr_fake = pl_module(lr_image)
60 | lr_image = F.interpolate(lr_image, scale_factor=self.scale_factor)
61 |
62 | lr_image_grid = make_grid(lr_image[: self.num_samples], nrow=1, normalize=True)
63 | hr_fake_grid = make_grid(hr_fake[: self.num_samples], nrow=1, normalize=True)
64 | hr_image_grid = make_grid(hr_image[: self.num_samples], nrow=1, normalize=True)
65 |
66 | grid = torch.cat((lr_image_grid, hr_fake_grid, hr_image_grid), -1)
67 | title = "sr_images"
68 | trainer.logger.experiment.add_image(title, grid, global_step=global_step)
69 |
--------------------------------------------------------------------------------
/src/pl_bolts/datamodules/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
2 | from pl_bolts.datamodules.binary_emnist_datamodule import BinaryEMNISTDataModule
3 | from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
4 | from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
5 | from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule
6 | from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule
7 | from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset
8 | from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
9 | from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
10 | from pl_bolts.datamodules.kitti_datamodule import KittiDataModule
11 | from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
12 | from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset
13 | from pl_bolts.datamodules.sr_datamodule import TVTDataModule
14 | from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
15 | from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
16 | from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
17 | from pl_bolts.datasets.kitti_dataset import KittiDataset
18 |
19 | __all__ = [
20 | "AsynchronousLoader",
21 | "BinaryMNISTDataModule",
22 | "CIFAR10DataModule",
23 | "TinyCIFAR10DataModule",
24 | "CityscapesDataModule",
25 | "DiscountedExperienceSource",
26 | "ExperienceSource",
27 | "ExperienceSourceDataset",
28 | "FashionMNISTDataModule",
29 | "ImagenetDataModule",
30 | "KittiDataModule",
31 | "MNISTDataModule",
32 | "SklearnDataModule",
33 | "SklearnDataset",
34 | "TVTDataModule",
35 | "SSLImagenetDataModule",
36 | "STL10DataModule",
37 | "VOCDetectionDataModule",
38 | "KittiDataset",
39 | "EMNISTDataModule",
40 | "BinaryEMNISTDataModule",
41 | ]
42 |
--------------------------------------------------------------------------------
/src/pl_bolts/datamodules/sr_datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from pytorch_lightning import LightningDataModule
4 | from torch.utils.data import DataLoader, Dataset
5 |
6 | from pl_bolts.utils.stability import under_review
7 |
8 |
9 | @under_review()
10 | class TVTDataModule(LightningDataModule):
11 | """Simple DataModule creating train, val, and test dataloaders from given train, val, and test dataset.
12 |
13 | Example::
14 | from pl_bolts.datamodules import TVTDataModule
15 | from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
16 |
17 | dataset_dev = SRMNIST(scale_factor=4, root=".", train=True)
18 | dataset_train, dataset_val = random_split(dataset_dev, lengths=[55_000, 5_000])
19 | dataset_test = SRMNIST(scale_factor=4, root=".", train=True)
20 | dm = TVTDataModule(dataset_train, dataset_val, dataset_test)
21 |
22 | """
23 |
24 | def __init__(
25 | self,
26 | dataset_train: Dataset,
27 | dataset_val: Dataset,
28 | dataset_test: Dataset,
29 | batch_size: int = 16,
30 | shuffle: bool = True,
31 | num_workers: int = 8,
32 | pin_memory: bool = True,
33 | drop_last: bool = True,
34 | *args: Any,
35 | **kwargs: Any,
36 | ) -> None:
37 | """
38 | Args:
39 | dataset_train: Train dataset
40 | dataset_val: Val dataset
41 | dataset_test: Test dataset
42 | batch_size: How many samples per batch to load
43 | num_workers: How many workers to use for loading data
44 | shuffle: If true shuffles the train data every epoch
45 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
46 | returning them
47 | drop_last: If true drops the last incomplete batch
48 | """
49 | super().__init__()
50 |
51 | self.dataset_train = dataset_train
52 | self.dataset_val = dataset_val
53 | self.dataset_test = dataset_test
54 | self.num_workers = num_workers
55 | self.batch_size = batch_size
56 | self.shuffle = shuffle
57 | self.pin_memory = pin_memory
58 | self.drop_last = drop_last
59 |
60 | def train_dataloader(self) -> DataLoader:
61 | return self._dataloader(self.dataset_train, shuffle=self.shuffle)
62 |
63 | def val_dataloader(self) -> DataLoader:
64 | return self._dataloader(self.dataset_val, shuffle=False)
65 |
66 | def test_dataloader(self) -> DataLoader:
67 | return self._dataloader(self.dataset_test, shuffle=False)
68 |
69 | def _dataloader(self, dataset: Dataset, shuffle: bool = True) -> DataLoader:
70 | return DataLoader(
71 | dataset,
72 | batch_size=self.batch_size,
73 | shuffle=shuffle,
74 | num_workers=self.num_workers,
75 | drop_last=self.drop_last,
76 | pin_memory=self.pin_memory,
77 | )
78 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import urllib
2 |
3 | from pl_bolts.datasets.array_dataset import ArrayDataset
4 | from pl_bolts.datasets.base_dataset import DataModel, LightDataset
5 | from pl_bolts.datasets.cifar10_dataset import CIFAR10, TrialCIFAR10
6 | from pl_bolts.datasets.concat_dataset import ConcatDataset
7 | from pl_bolts.datasets.dummy_dataset import (
8 | DummyDataset,
9 | DummyDetectionDataset,
10 | RandomDataset,
11 | RandomDictDataset,
12 | RandomDictStringDataset,
13 | )
14 | from pl_bolts.datasets.emnist_dataset import BinaryEMNIST
15 | from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet, extract_archive, parse_devkit_archive
16 | from pl_bolts.datasets.kitti_dataset import KittiDataset
17 | from pl_bolts.datasets.mnist_dataset import MNIST, BinaryMNIST
18 | from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin
19 |
20 | __all__ = [
21 | "ArrayDataset",
22 | "DataModel",
23 | "LightDataset",
24 | "CIFAR10",
25 | "TrialCIFAR10",
26 | "ConcatDataset",
27 | "DummyDataset",
28 | "DummyDetectionDataset",
29 | "MNIST",
30 | "RandomDataset",
31 | "RandomDictDataset",
32 | "RandomDictStringDataset",
33 | "extract_archive",
34 | "parse_devkit_archive",
35 | "UnlabeledImagenet",
36 | "KittiDataset",
37 | "BinaryMNIST",
38 | "CIFAR10Mixed",
39 | "SSLDatasetMixin",
40 | "BinaryEMNIST",
41 | ]
42 |
43 | # TorchVision hotfix https://github.com/pytorch/vision/issues/1938
44 | opener = urllib.request.build_opener()
45 | opener.addheaders = [("User-agent", "Mozilla/5.0")]
46 | urllib.request.install_opener(opener)
47 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/array_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | from pytorch_lightning.utilities import exceptions
4 | from torch.utils.data import Dataset
5 |
6 | from pl_bolts.datasets.base_dataset import DataModel, TArrays
7 |
8 |
9 | class ArrayDataset(Dataset):
10 | """Dataset wrapping tensors, lists, numpy arrays.
11 |
12 | Any number of ARRAYS can be inputted into the dataset. The ARRAYS are transformed on each `__getitem__`. When
13 | transforming, please refrain from chaning the hape of ARRAYS in the first demension.
14 |
15 | Attributes:
16 | data_models: Sequence of data models.
17 |
18 | Raises:
19 | MisconfigurationException: if there is a shape mismatch between arrays in the first dimension.
20 |
21 | Example:
22 | >>> from pl_bolts.datasets import ArrayDataset, DataModel
23 | >>> from pl_bolts.datasets.utils import to_tensor
24 |
25 | >>> features = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3]], transform=to_tensor)
26 | >>> target = DataModel(data=[1, 0, 0], transform=to_tensor)
27 |
28 | >>> ds = ArrayDataset(features, target)
29 | >>> len(ds)
30 | 3
31 |
32 | """
33 |
34 | def __init__(self, *data_models: DataModel) -> None:
35 | """Initialises class and checks if arrays are the same shape in the first dimension."""
36 | self.data_models = data_models
37 |
38 | if not self._equal_size():
39 | raise exceptions.MisconfigurationException("Shape mismatch between arrays in the first dimension")
40 |
41 | def __len__(self) -> int:
42 | return len(self.data_models[0].data)
43 |
44 | def __getitem__(self, idx: int) -> Tuple[Union[TArrays, float], ...]:
45 | return tuple(data_model.process(data_model.data[idx]) for data_model in self.data_models)
46 |
47 | def _equal_size(self) -> bool:
48 | """Checks the size of the data_models are equal in the first dimension.
49 |
50 | Returns:
51 | bool: True if size of data_models are equal in the first dimension. False, if not.
52 |
53 | """
54 | return len({len(data_model.data) for data_model in self.data_models}) == 1
55 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import urllib.parse
4 | import urllib.request
5 | from abc import ABC
6 | from dataclasses import dataclass
7 | from typing import Callable, Optional, Sequence, Tuple, Union
8 | from urllib.error import HTTPError
9 |
10 | from torch import Tensor
11 | from torch.utils.data import Dataset
12 |
13 | from pl_bolts.utils.stability import under_review
14 | from pl_bolts.utils.types import TArrays
15 |
16 |
17 | @under_review()
18 | class LightDataset(ABC, Dataset):
19 | data: Tensor
20 | targets: Tensor
21 | normalize: tuple
22 | dir_path: str
23 | cache_folder_name: str
24 | DATASET_NAME = "light"
25 |
26 | def __len__(self) -> int:
27 | return len(self.data)
28 |
29 | @property
30 | def cached_folder_path(self) -> str:
31 | return os.path.join(self.dir_path, self.DATASET_NAME, self.cache_folder_name)
32 |
33 | @staticmethod
34 | def _prepare_subset(
35 | full_data: Tensor,
36 | full_targets: Tensor,
37 | num_samples: int,
38 | labels: Sequence,
39 | ) -> Tuple[Tensor, Tensor]:
40 | """Prepare a subset of a common dataset."""
41 | classes = {d: 0 for d in labels}
42 | indexes = []
43 | for idx, target in enumerate(full_targets):
44 | label = target.item()
45 | if classes.get(label, float("inf")) >= num_samples:
46 | continue
47 | indexes.append(idx)
48 | classes[label] += 1
49 | if all(classes[k] >= num_samples for k in classes):
50 | break
51 | data = full_data[indexes]
52 | targets = full_targets[indexes]
53 | return data, targets
54 |
55 | def _download_from_url(self, base_url: str, data_folder: str, file_name: str):
56 | url = urllib.parse.urljoin(base_url, file_name)
57 | logging.info(f"Downloading {url}")
58 | fpath = os.path.join(data_folder, file_name)
59 | try:
60 | urllib.request.urlretrieve(url, fpath) # noqa: S310
61 | except HTTPError as err:
62 | raise RuntimeError(f"Failed download from {url}") from err
63 |
64 |
65 | @dataclass
66 | class DataModel:
67 | """Data model dataclass.
68 |
69 | Ties together data and callable transforms.
70 |
71 | Attributes:
72 | data: Sequence of indexables.
73 | transform: Callable to transform data. The transform is called on a subset of data.
74 |
75 | """
76 |
77 | data: TArrays
78 | transform: Optional[Callable[[TArrays], TArrays]] = None
79 |
80 | def process(self, subset: Union[TArrays, float]) -> Union[TArrays, float]:
81 | """Transforms a subset of data.
82 |
83 | Args:
84 | subset: Sequence of indexables.
85 |
86 | Returns:
87 | data: Transformed data if transform is not None.
88 |
89 | """
90 | if self.transform is not None:
91 | subset = self.transform(subset)
92 | return subset
93 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/concat_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 | from pl_bolts.utils.stability import under_review
4 |
5 |
6 | @under_review()
7 | class ConcatDataset(Dataset):
8 | def __init__(self, *datasets) -> None:
9 | self.datasets = datasets
10 |
11 | def __getitem__(self, i):
12 | result = []
13 | for dataset in self.datasets:
14 | cycled_i = i % len(dataset)
15 | result.append(dataset[cycled_i])
16 |
17 | return tuple(result)
18 |
19 | def __len__(self) -> int:
20 | return max(len(d) for d in self.datasets)
21 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/mnist_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple, Union
2 |
3 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
4 | from pl_bolts.utils.warnings import warn_missing_pkg
5 |
6 | if _TORCHVISION_AVAILABLE:
7 | from torchvision.datasets import MNIST
8 | else: # pragma: no cover
9 | warn_missing_pkg("torchvision")
10 | MNIST = object
11 |
12 | if _PIL_AVAILABLE:
13 | from PIL import Image
14 | else: # pragma: no cover
15 | warn_missing_pkg("PIL", pypi_name="Pillow")
16 |
17 |
18 | class BinaryMNIST(MNIST):
19 | """Binarized MNIST Dataset.
20 |
21 | MNIST dataset binarized using a thresholding operation. Default threshold value is 127.
22 | Note that the images are binarized prior to the application of any transforms.
23 |
24 | Args:
25 | root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
26 | and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
27 | threshold (Union[int, float], optional): Threshold value for binarizing image.
28 | Pixel value is set to 255 if value is greater than threshold, otherwise 0.
29 | train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
30 | otherwise from ``t10k-images-idx3-ubyte``.
31 | download (bool, optional): If True, downloads the dataset from the internet and
32 | puts it in root directory. If dataset is already downloaded, it is not
33 | downloaded again.
34 | transform (callable, optional): A function/transform that takes in an PIL image
35 | and returns a transformed version. E.g, ``transforms.RandomCrop``
36 | target_transform (callable, optional): A function/transform that takes in the
37 | target and transforms it.
38 |
39 | Note:
40 | Documentation is based on https://pytorch.org/vision/main/generated/torchvision.datasets.EMNIST.html
41 |
42 | """
43 |
44 | def __init__(self, root: str, threshold: Union[int, float] = 127.0, **kwargs: Any) -> None:
45 | super().__init__(root, **kwargs)
46 | self.threshold = threshold
47 |
48 | def __getitem__(self, idx: int) -> Tuple[Any, Any]:
49 | """
50 | Args:
51 | index (int): Index
52 |
53 | Returns:
54 | tuple: (image, target) where target is index of the target class.
55 | """
56 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
57 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
58 |
59 | img, target = self.data[idx], int(self.targets[idx])
60 |
61 | # Convert to PIL Image (8-bit BW)
62 | img = Image.fromarray(img.numpy(), mode="L")
63 |
64 | # Binarize image at threshold
65 | img = img.point(lambda p: 255 if p > self.threshold else 0)
66 |
67 | if self.transform is not None:
68 | img = self.transform(img)
69 |
70 | if self.target_transform is not None:
71 | target = self.target_transform(target)
72 |
73 | return img, target
74 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/sr_celeba_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Any
3 |
4 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin
5 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
6 | from pl_bolts.utils.stability import under_review
7 | from pl_bolts.utils.warnings import warn_missing_pkg
8 |
9 | if _PIL_AVAILABLE:
10 | from PIL import Image
11 | else: # pragma: no cover
12 | warn_missing_pkg("PIL", pypi_name="Pillow")
13 |
14 | if _TORCHVISION_AVAILABLE:
15 | from torchvision.datasets import CelebA
16 | else: # pragma: no cover
17 | warn_missing_pkg("torchvision")
18 | CelebA = object
19 |
20 |
21 | @under_review()
22 | class SRCelebA(SRDatasetMixin, CelebA):
23 | """CelebA dataset that can be used to train Super Resolution models.
24 |
25 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
26 |
27 | """
28 |
29 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
30 | hr_image_size = 128
31 | lr_image_size = hr_image_size // scale_factor
32 | self.image_channels = 3
33 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs)
34 |
35 | def _get_image(self, index: int):
36 | return Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
37 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/sr_dataset_mixin.py:
--------------------------------------------------------------------------------
1 | """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public."""
2 | from typing import Any, Tuple
3 |
4 | import torch
5 |
6 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
7 | from pl_bolts.utils.stability import under_review
8 | from pl_bolts.utils.warnings import warn_missing_pkg
9 |
10 | if _PIL_AVAILABLE:
11 | from PIL import Image
12 | else: # pragma: no cover
13 | warn_missing_pkg("PIL", pypi_name="Pillow")
14 |
15 | if _TORCHVISION_AVAILABLE:
16 | from torchvision import transforms as transform_lib
17 | else: # pragma: no cover
18 | warn_missing_pkg("torchvision")
19 |
20 |
21 | @under_review()
22 | class SRDatasetMixin:
23 | """Mixin for Super Resolution datasets.
24 |
25 | Scales range of high resolution images to [-1, 1] and range or low resolution images to [0, 1].
26 |
27 | """
28 |
29 | def __init__(self, hr_image_size: int, lr_image_size: int, image_channels: int, *args: Any, **kwargs: Any) -> None:
30 | super().__init__(*args, **kwargs)
31 |
32 | self.hr_transforms = transform_lib.Compose(
33 | [
34 | transform_lib.RandomCrop(hr_image_size),
35 | transform_lib.ToTensor(),
36 | transform_lib.Normalize(mean=(0.5,) * image_channels, std=(0.5,) * image_channels),
37 | ]
38 | )
39 |
40 | self.lr_transforms = transform_lib.Compose(
41 | [
42 | transform_lib.Normalize(mean=(-1.0,) * image_channels, std=(2.0,) * image_channels),
43 | transform_lib.ToPILImage(),
44 | transform_lib.Resize(lr_image_size, Image.BICUBIC),
45 | transform_lib.ToTensor(),
46 | ]
47 | )
48 |
49 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
50 | image = self._get_image(index)
51 |
52 | hr_image = self.hr_transforms(image)
53 | lr_image = self.lr_transforms(hr_image)
54 |
55 | return hr_image, lr_image
56 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/sr_mnist_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from pl_bolts.datasets.mnist_dataset import MNIST
4 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin
5 | from pl_bolts.utils import _PIL_AVAILABLE
6 | from pl_bolts.utils.stability import under_review
7 | from pl_bolts.utils.warnings import warn_missing_pkg
8 |
9 | if _PIL_AVAILABLE:
10 | from PIL import Image
11 | else: # pragma: no cover
12 | warn_missing_pkg("PIL", pypi_name="Pillow")
13 |
14 |
15 | @under_review()
16 | class SRMNIST(SRDatasetMixin, MNIST):
17 | """MNIST dataset that can be used to train Super Resolution models.
18 |
19 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
20 |
21 | """
22 |
23 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
24 | hr_image_size = 28
25 | lr_image_size = hr_image_size // scale_factor
26 | self.image_channels = 1
27 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs)
28 |
29 | def _get_image(self, index: int):
30 | return Image.fromarray(self.data[index].numpy(), mode="L")
31 |
--------------------------------------------------------------------------------
/src/pl_bolts/datasets/sr_stl10_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import numpy as np
4 |
5 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin
6 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
7 | from pl_bolts.utils.stability import under_review
8 | from pl_bolts.utils.warnings import warn_missing_pkg
9 |
10 | if _PIL_AVAILABLE:
11 | import PIL
12 | else: # pragma: no cover
13 | warn_missing_pkg("PIL", pypi_name="Pillow")
14 |
15 | if _TORCHVISION_AVAILABLE:
16 | from torchvision.datasets import STL10
17 | else: # pragma: no cover
18 | warn_missing_pkg("torchvision")
19 | STL10 = object
20 |
21 |
22 | @under_review()
23 | class SRSTL10(SRDatasetMixin, STL10):
24 | """STL10 dataset that can be used to train Super Resolution models.
25 |
26 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image.
27 |
28 | """
29 |
30 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None:
31 | hr_image_size = 96
32 | lr_image_size = hr_image_size // scale_factor
33 | self.image_channels = 3
34 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs)
35 |
36 | def _get_image(self, index: int):
37 | return PIL.Image.fromarray(np.transpose(self.data[index], (1, 2, 0)))
38 |
--------------------------------------------------------------------------------
/src/pl_bolts/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/losses/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.metrics.aggregation import accuracy, mean, precision_at_k
2 |
3 | __all__ = [
4 | "accuracy",
5 | "mean",
6 | "precision_at_k",
7 | ]
8 |
--------------------------------------------------------------------------------
/src/pl_bolts/metrics/aggregation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mean(res, key):
5 | # recursive mean for multilevel dicts
6 | return torch.stack([x[key] if isinstance(x, dict) else mean(x, key) for x in res]).mean()
7 |
8 |
9 | def accuracy(preds, labels):
10 | preds = preds.float()
11 | max_lgt = torch.max(preds, 1)[1]
12 | num_correct = (max_lgt == labels).sum().item()
13 | num_correct = torch.tensor(num_correct).float()
14 | return num_correct / len(labels)
15 |
16 |
17 | def precision_at_k(output, target, top_k=(1,)):
18 | """Computes the accuracy over the k top predictions for the specified values of k."""
19 | with torch.no_grad():
20 | maxk = max(top_k)
21 | batch_size = target.size(0)
22 |
23 | _, pred = output.topk(maxk, 1, True, True)
24 | pred = pred.t()
25 | correct = pred.eq(target.view(1, -1).expand_as(pred))
26 |
27 | res = []
28 | for k in top_k:
29 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
30 | res.append(correct_k.mul_(100.0 / batch_size))
31 | return res
32 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/__init__.py:
--------------------------------------------------------------------------------
1 | """Collection of PyTorchLightning models."""
2 | from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE
3 | from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE
4 | from pl_bolts.models.mnist_module import LitMNIST
5 | from pl_bolts.models.regression import LinearRegression, LogisticRegression
6 |
7 | __all__ = [
8 | "AE",
9 | "VAE",
10 | "LitMNIST",
11 | "LinearRegression",
12 | "LogisticRegression",
13 | ]
14 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/autoencoders/__init__.py:
--------------------------------------------------------------------------------
1 | """Here are a VAE and GAN."""
2 |
3 | from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE
4 | from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE
5 | from pl_bolts.models.autoencoders.components import (
6 | resnet18_decoder,
7 | resnet18_encoder,
8 | resnet50_decoder,
9 | resnet50_encoder,
10 | )
11 |
12 | __all__ = [
13 | "AE",
14 | "VAE",
15 | "resnet18_decoder",
16 | "resnet18_encoder",
17 | "resnet50_decoder",
18 | "resnet50_encoder",
19 | ]
20 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/autoencoders/basic_ae/__init__.py:
--------------------------------------------------------------------------------
1 | """AE Template.
2 |
3 | This is a basic template for implementing an Autoencoder in PyTorch Lightning.
4 |
5 | A default encoder and decoder have been provided but can easily be replaced by custom models.
6 |
7 | This template uses the CIFAR10 dataset but image data of any dimension can be fed in as long as the image
8 | width and image height are even values. For other types of data, such as sound, it will be necessary
9 | to change the Encoder and Decoder.
10 |
11 | The default encoder is a resnet18 backbone followed by linear layers which map representations to latent space.
12 | The default decoder mirrors the encoder architecture and is similar to an inverted resnet18.
13 |
14 | .. code-block:: python
15 |
16 | from pl_bolts.models.autoencoders import AE
17 |
18 | model = AE()
19 | trainer = pl.Trainer()
20 | trainer.fit(model)
21 |
22 | """
23 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/autoencoders/basic_vae/__init__.py:
--------------------------------------------------------------------------------
1 | """VAE Template.
2 |
3 | This is a basic template for implementing a Variational Autoencoder in PyTorch Lightning.
4 |
5 | A default encoder and decoder have been provided but can easily be replaced by custom models.
6 |
7 | This template uses the CIFAR10 dataset but image data of any dimension can be fed in as long as the image width and
8 | image height are even values. For other types of data, such as sound, it will be necessary to change the Encoder and
9 | Decoder.
10 |
11 | The default encoder is a resnet18 backbone followed by linear layers which map representations to mu and var. The
12 | default decoder mirrors the encoder architecture and is similar to an inverted resnet18. The model also assumes a
13 | Gaussian prior and a Gaussian approximate posterior distribution.
14 |
15 | """
16 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.detection import components
2 | from pl_bolts.models.detection.faster_rcnn import FasterRCNN
3 | from pl_bolts.models.detection.retinanet import RetinaNet
4 | from pl_bolts.models.detection.yolo.darknet_network import DarknetNetwork
5 | from pl_bolts.models.detection.yolo.torch_networks import (
6 | YOLOV4Backbone,
7 | YOLOV4Network,
8 | YOLOV4P6Network,
9 | YOLOV4TinyBackbone,
10 | YOLOV4TinyNetwork,
11 | YOLOV5Backbone,
12 | YOLOV5Network,
13 | YOLOV7Backbone,
14 | YOLOV7Network,
15 | YOLOXNetwork,
16 | )
17 | from pl_bolts.models.detection.yolo.yolo_module import YOLO
18 |
19 | __all__ = [
20 | "components",
21 | "FasterRCNN",
22 | "RetinaNet",
23 | "DarknetNetwork",
24 | "YOLOV4Backbone",
25 | "YOLOV4Network",
26 | "YOLOV4P6Network",
27 | "YOLOV4TinyBackbone",
28 | "YOLOV4TinyNetwork",
29 | "YOLOV5Backbone",
30 | "YOLOV5Network",
31 | "YOLOV7Backbone",
32 | "YOLOV7Network",
33 | "YOLOXNetwork",
34 | "YOLO",
35 | ]
36 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/components/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.detection.components.torchvision_backbones import create_torchvision_backbone
2 |
3 | __all__ = ["create_torchvision_backbone"]
4 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/components/_supported_models.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
2 | from pl_bolts.utils.warnings import warn_missing_pkg
3 |
4 | if _TORCHVISION_AVAILABLE:
5 | import torchvision
6 |
7 | TORCHVISION_MODEL_ZOO = {
8 | "vgg11": torchvision.models.vgg11,
9 | "vgg13": torchvision.models.vgg13,
10 | "vgg16": torchvision.models.vgg16,
11 | "vgg19": torchvision.models.vgg19,
12 | "resnet18": torchvision.models.resnet18,
13 | "resnet34": torchvision.models.resnet34,
14 | "resnet50": torchvision.models.resnet50,
15 | "resnet101": torchvision.models.resnet101,
16 | "resnet152": torchvision.models.resnet152,
17 | "resnext50_32x4d": torchvision.models.resnext50_32x4d,
18 | "resnext50_32x8d": torchvision.models.resnext101_32x8d,
19 | "mnasnet0_5": torchvision.models.mnasnet0_5,
20 | "mnasnet0_75": torchvision.models.mnasnet0_75,
21 | "mnasnet1_0": torchvision.models.mnasnet1_0,
22 | "mnasnet1_3": torchvision.models.mnasnet1_3,
23 | "mobilenet_v2": torchvision.models.mobilenet_v2,
24 | }
25 |
26 | else: # pragma: no cover
27 | warn_missing_pkg("torchvision")
28 | TORCHVISION_MODEL_ZOO = {}
29 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/faster_rcnn/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.detection.faster_rcnn.backbones import create_fasterrcnn_backbone
2 | from pl_bolts.models.detection.faster_rcnn.faster_rcnn_module import FasterRCNN
3 |
4 | __all__ = ["create_fasterrcnn_backbone", "FasterRCNN"]
5 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/faster_rcnn/backbones.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | import torch.nn as nn
4 |
5 | from pl_bolts.models.detection.components import create_torchvision_backbone
6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
7 | from pl_bolts.utils.stability import under_review
8 | from pl_bolts.utils.warnings import warn_missing_pkg
9 |
10 | if _TORCHVISION_AVAILABLE:
11 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
12 | else: # pragma: no cover
13 | warn_missing_pkg("torchvision")
14 |
15 |
16 | @under_review()
17 | def create_fasterrcnn_backbone(
18 | backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any
19 | ) -> nn.Module:
20 | """
21 | Args:
22 | backbone:
23 | Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152",
24 | "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2",
25 | as resnets with fpn backbones.
26 | Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101",
27 | "resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19",
28 | fpn: If True then constructs fpn as well.
29 | pretrained: If None creates imagenet weights backbone.
30 | trainable_backbone_layers: number of trainable resnet layers starting from final block.
31 | """
32 |
33 | if fpn:
34 | # Creates a torchvision resnet model with fpn added.
35 | backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs)
36 | else:
37 | # This does not create fpn backbone, it is supported for all models
38 | backbone, _ = create_torchvision_backbone(backbone, pretrained)
39 | return backbone
40 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/retinanet/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.detection.retinanet.backbones import create_retinanet_backbone
2 | from pl_bolts.models.detection.retinanet.retinanet_module import RetinaNet
3 |
4 | __all__ = ["create_retinanet_backbone", "RetinaNet"]
5 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/retinanet/backbones.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | import torch.nn as nn
4 |
5 | from pl_bolts.models.detection.components import create_torchvision_backbone
6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
7 | from pl_bolts.utils.stability import under_review
8 | from pl_bolts.utils.warnings import warn_missing_pkg
9 |
10 | if _TORCHVISION_AVAILABLE:
11 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
12 | else: # pragma: no cover
13 | warn_missing_pkg("torchvision")
14 |
15 |
16 | @under_review()
17 | def create_retinanet_backbone(
18 | backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any
19 | ) -> nn.Module:
20 | """
21 | Args:
22 | backbone:
23 | Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152",
24 | "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2",
25 | as resnets with fpn backbones.
26 | Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101",
27 | "resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19",
28 | fpn: If True then constructs fpn as well.
29 | pretrained: If None creates imagenet weights backbone.
30 | trainable_backbone_layers: number of trainable resnet layers starting from final block.
31 | """
32 |
33 | if fpn:
34 | # Creates a torchvision resnet model with fpn added.
35 | backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs)
36 | else:
37 | # This does not create fpn backbone, it is supported for all models
38 | backbone, _ = create_torchvision_backbone(backbone, pretrained)
39 | return backbone
40 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/yolo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/detection/yolo/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/detection/yolo/types.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Tuple, Union
2 |
3 | from torch import Tensor
4 |
5 | IMAGES = Union[Tuple[Tensor, ...], List[Tensor]]
6 | PRED = Dict[str, Any]
7 | PREDS = Union[Tuple[PRED, ...], List[PRED]]
8 | TARGET = Dict[str, Any]
9 | TARGETS = Union[Tuple[TARGET, ...], List[TARGET]]
10 | BATCH = Tuple[IMAGES, TARGETS]
11 | NETWORK_OUTPUT = Tuple[List[Tensor], List[Tensor], List[int]] # detections, losses, hits
12 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/gans/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.gans.basic.basic_gan_module import GAN
2 | from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN
3 | from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix
4 | from pl_bolts.models.gans.srgan.srgan_module import SRGAN
5 | from pl_bolts.models.gans.srgan.srresnet_module import SRResNet
6 |
7 | __all__ = [
8 | "GAN",
9 | "DCGAN",
10 | "Pix2Pix",
11 | "SRGAN",
12 | "SRResNet",
13 | ]
14 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/gans/basic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/gans/basic/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/gans/basic/components.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F # noqa: N812
5 |
6 |
7 | class Generator(nn.Module):
8 | def __init__(self, latent_dim, img_shape, hidden_dim=256) -> None:
9 | super().__init__()
10 | feats = int(np.prod(img_shape))
11 | self.img_shape = img_shape
12 | self.fc1 = nn.Linear(latent_dim, hidden_dim)
13 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
14 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
15 | self.fc4 = nn.Linear(self.fc3.out_features, feats)
16 |
17 | # forward method
18 | def forward(self, z):
19 | z = F.leaky_relu(self.fc1(z), 0.2)
20 | z = F.leaky_relu(self.fc2(z), 0.2)
21 | z = F.leaky_relu(self.fc3(z), 0.2)
22 | img = torch.tanh(self.fc4(z))
23 | return img.view(img.size(0), *self.img_shape)
24 |
25 |
26 | class Discriminator(nn.Module):
27 | def __init__(self, img_shape, hidden_dim=1024) -> None:
28 | super().__init__()
29 | in_dim = int(np.prod(img_shape))
30 | self.fc1 = nn.Linear(in_dim, hidden_dim)
31 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
32 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
33 | self.fc4 = nn.Linear(self.fc3.out_features, 1)
34 |
35 | # forward method
36 | def forward(self, img):
37 | x = img.view(img.size(0), -1)
38 | x = F.leaky_relu(self.fc1(x), 0.2)
39 | x = F.dropout(x, 0.3)
40 | x = F.leaky_relu(self.fc2(x), 0.2)
41 | x = F.dropout(x, 0.3)
42 | x = F.leaky_relu(self.fc3(x), 0.2)
43 | x = F.dropout(x, 0.3)
44 | return torch.sigmoid(self.fc4(x))
45 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/gans/dcgan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/gans/dcgan/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/gans/pix2pix/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/gans/pix2pix/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/gans/srgan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/gans/srgan/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/regression/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.regression.linear_regression import LinearRegression
2 | from pl_bolts.models.regression.logistic_regression import LogisticRegression
3 |
4 | __all__ = [
5 | "LinearRegression",
6 | "LogisticRegression",
7 | ]
8 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/rl/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic
2 | from pl_bolts.models.rl.double_dqn_model import DoubleDQN
3 | from pl_bolts.models.rl.dqn_model import DQN
4 | from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN
5 | from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
6 | from pl_bolts.models.rl.per_dqn_model import PERDQN
7 | from pl_bolts.models.rl.reinforce_model import Reinforce
8 | from pl_bolts.models.rl.sac_model import SAC
9 | from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient
10 |
11 | __all__ = [
12 | "AdvantageActorCritic",
13 | "DoubleDQN",
14 | "DQN",
15 | "DuelingDQN",
16 | "NoisyDQN",
17 | "PERDQN",
18 | "Reinforce",
19 | "SAC",
20 | "VanillaPolicyGradient",
21 | ]
22 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/rl/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/rl/common/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/rl/common/cli.py:
--------------------------------------------------------------------------------
1 | """Contains generic arguments used for all models."""
2 |
3 | import argparse
4 |
5 | from pl_bolts.utils.stability import under_review
6 |
7 |
8 | @under_review()
9 | def add_base_args(parent) -> argparse.ArgumentParser:
10 | """Adds arguments for DQN model.
11 |
12 | Note:
13 | These params are fine tuned for Pong env.
14 |
15 | Args:
16 | parent
17 |
18 | """
19 | arg_parser = argparse.ArgumentParser(parents=[parent])
20 |
21 | arg_parser.add_argument("--algo", type=str, default="dqn", help="algorithm to use for training")
22 | arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
23 | arg_parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
24 |
25 | arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag")
26 | arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
27 |
28 | arg_parser.add_argument("--episode_length", type=int, default=500, help="max length of an episode")
29 | arg_parser.add_argument("--max_episode_reward", type=int, default=18, help="max episode reward in the environment")
30 | arg_parser.add_argument(
31 | "--n_steps",
32 | type=int,
33 | default=4,
34 | help="how many steps to unroll for each update",
35 | )
36 | arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run")
37 | arg_parser.add_argument("--epoch_len", type=int, default=1000, help="how many batches per epoch")
38 | arg_parser.add_argument("--num_envs", type=int, default=1, help="number of environments to run at once")
39 | arg_parser.add_argument(
40 | "--avg_reward_len", type=int, default=100, help="how many episodes to include in avg reward"
41 | )
42 |
43 | arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run")
44 | return arg_parser
45 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/rl/common/distributions.py:
--------------------------------------------------------------------------------
1 | """Distributions used in some continuous RL algorithms."""
2 | import torch
3 |
4 | from pl_bolts.utils.stability import under_review
5 |
6 |
7 | @under_review()
8 | class TanhMultivariateNormal(torch.distributions.MultivariateNormal):
9 | """The distribution of X is an affine of tanh applied on a normal distribution::
10 |
11 | X = action_scale * tanh(Z) + action_bias
12 | Z ~ Normal(mean, variance)
13 |
14 | """
15 |
16 | def __init__(self, action_bias, action_scale, **kwargs) -> None:
17 | super().__init__(**kwargs)
18 |
19 | self.action_bias = action_bias
20 | self.action_scale = action_scale
21 |
22 | def rsample_with_z(self, sample_shape=torch.Size()):
23 | """Samples X using reparametrization trick with the intermediate variable Z.
24 |
25 | Returns:
26 | Sampled X and Z
27 |
28 | """
29 | z = super().rsample()
30 | return self.action_scale * torch.tanh(z) + self.action_bias, z
31 |
32 | def log_prob_with_z(self, value, z):
33 | """Computes the log probability of a sampled X.
34 |
35 | Refer to the original paper of SAC for more details in equation (20), (21)
36 |
37 | Args:
38 | value: the value of X
39 | z: the value of Z
40 | Returns:
41 | Log probability of the sample
42 |
43 | """
44 | value = (value - self.action_bias) / self.action_scale
45 | z_logprob = super().log_prob(z)
46 | correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1)
47 | return z_logprob - correction
48 |
49 | def rsample_and_log_prob(self, sample_shape=torch.Size()):
50 | """Samples X and computes the log probability of the sample.
51 |
52 | Returns:
53 | Sampled X and log probability
54 |
55 | """
56 | z = super().rsample()
57 | z_logprob = super().log_prob(z)
58 | value = torch.tanh(z)
59 | correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1)
60 | return self.action_scale * value + self.action_bias, z_logprob - correction
61 |
62 | def rsample(self, sample_shape=torch.Size()):
63 | fz, z = self.rsample_with_z(sample_shape)
64 | return fz
65 |
66 | def log_prob(self, value):
67 | value = (value - self.action_bias) / self.action_scale
68 | z = torch.log(1 + value) / 2 - torch.log(1 - value) / 2
69 | return self.log_prob_with_z(value, z)
70 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/rl/dueling_dqn_model.py:
--------------------------------------------------------------------------------
1 | """Dueling DQN."""
2 | import argparse
3 |
4 | from pytorch_lightning import Trainer
5 |
6 | from pl_bolts.models.rl.common.networks import DuelingCNN
7 | from pl_bolts.models.rl.dqn_model import DQN
8 | from pl_bolts.utils.stability import under_review
9 |
10 |
11 | @under_review()
12 | class DuelingDQN(DQN):
13 | """PyTorch Lightning implementation of `Dueling DQN `_
14 |
15 | Paper authors: Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas
16 |
17 | Model implemented by:
18 |
19 | - `Donal Byrne `
20 |
21 | Example:
22 |
23 | >>> from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN
24 | ...
25 | >>> model = DuelingDQN("PongNoFrameskip-v4")
26 |
27 | Train::
28 |
29 | trainer = Trainer()
30 | trainer.fit(model)
31 |
32 | .. note:: Currently only supports CPU and single GPU training with `accelerator=dp`
33 | """
34 |
35 | def build_networks(self) -> None:
36 | """Initializes the Dueling DQN train and target networks."""
37 | self.net = DuelingCNN(self.obs_shape, self.n_actions)
38 | self.target_net = DuelingCNN(self.obs_shape, self.n_actions)
39 |
40 |
41 | @under_review()
42 | def cli_main():
43 | parser = argparse.ArgumentParser(add_help=False)
44 |
45 | # trainer args
46 | parser = Trainer.add_argparse_args(parser)
47 |
48 | # model args
49 | parser = DuelingDQN.add_model_specific_args(parser)
50 | args = parser.parse_args()
51 |
52 | model = DuelingDQN(**args.__dict__)
53 |
54 | trainer = Trainer.from_argparse_args(args)
55 | trainer.fit(model)
56 |
57 |
58 | if __name__ == "__main__":
59 | cli_main()
60 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/__init__.py:
--------------------------------------------------------------------------------
1 | """These models have been pre-trained using self-supervised learning. The models can also be used without pre- training
2 | and overwritten for your own research.
3 |
4 | Here's an example for using these as pretrained models.
5 |
6 | .. code-block ::
7 |
8 | from pl_bolts.models.self_supervised import CPC_v2
9 |
10 | images = get_imagenet_batch()
11 |
12 | # extract unsupervised representations
13 | pretrained = CPC_v2(pretrained=True)
14 | representations = pretrained(images)
15 |
16 | # use these in classification or any downstream task
17 | classifications = classifier(representations)
18 |
19 | """
20 | from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM
21 | from pl_bolts.models.self_supervised.byol.byol_module import BYOL
22 | from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2
23 | from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
24 | from pl_bolts.models.self_supervised.moco.moco_module import MoCo
25 | from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
26 | from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam
27 | from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
28 | from pl_bolts.models.self_supervised.swav.swav_module import SwAV
29 |
30 | __all__ = [
31 | "AMDIM",
32 | "BYOL",
33 | "CPC_v2",
34 | "SSLEvaluator",
35 | "MoCo",
36 | "SimCLR",
37 | "SimSiam",
38 | "SSLFineTuner",
39 | "SwAV",
40 | ]
41 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/amdim/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM
2 | from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder
3 | from pl_bolts.transforms.self_supervised.amdim_transforms import (
4 | AMDIMEvalTransformsCIFAR10,
5 | AMDIMEvalTransformsImageNet128,
6 | AMDIMEvalTransformsSTL10,
7 | AMDIMTrainTransformsCIFAR10,
8 | AMDIMTrainTransformsImageNet128,
9 | AMDIMTrainTransformsSTL10,
10 | )
11 |
12 | __all__ = [
13 | "AMDIM",
14 | "AMDIMEncoder",
15 | "AMDIMEvalTransformsCIFAR10",
16 | "AMDIMEvalTransformsImageNet128",
17 | "AMDIMEvalTransformsSTL10",
18 | "AMDIMTrainTransformsCIFAR10",
19 | "AMDIMTrainTransformsImageNet128",
20 | "AMDIMTrainTransformsSTL10",
21 | ]
22 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/byol/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/self_supervised/byol/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/byol/models.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | from torch import Tensor, nn
4 |
5 | from pl_bolts.utils.self_supervised import torchvision_ssl_encoder
6 |
7 |
8 | class MLP(nn.Module):
9 | """MLP architecture used as projectors in online and target networks and predictors in the online network.
10 |
11 | Args:
12 | input_dim (int, optional): Input dimension. Defaults to 2048.
13 | hidden_dim (int, optional): Hidden layer dimension. Defaults to 4096.
14 | output_dim (int, optional): Output dimension. Defaults to 256.
15 |
16 | Note:
17 | Default values for input, hidden, and output dimensions are based on values used in BYOL.
18 |
19 | """
20 |
21 | def __init__(self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256) -> None:
22 | super().__init__()
23 |
24 | self.model = nn.Sequential(
25 | nn.Linear(input_dim, hidden_dim, bias=False),
26 | nn.BatchNorm1d(hidden_dim),
27 | nn.ReLU(inplace=True),
28 | nn.Linear(hidden_dim, output_dim, bias=True),
29 | )
30 |
31 | def forward(self, x: Tensor) -> Tensor:
32 | return self.model(x)
33 |
34 |
35 | class SiameseArm(nn.Module):
36 | """SiameseArm consolidates the encoder and projector networks of BYOL's symmetric architecture into a single class.
37 |
38 | Args:
39 | encoder (Union[str, nn.Module], optional): Online and target network encoder architecture.
40 | Defaults to "resnet50".
41 | encoder_out_dim (int, optional): Output dimension of encoder. Defaults to 2048.
42 | projector_hidden_dim (int, optional): Online and target network projector network hidden dimension.
43 | Defaults to 4096.
44 | projector_out_dim (int, optional): Online and target network projector network output dimension.
45 | Defaults to 256.
46 |
47 | """
48 |
49 | def __init__(
50 | self,
51 | encoder: Union[str, nn.Module] = "resnet50",
52 | encoder_out_dim: int = 2048,
53 | projector_hidden_dim: int = 4096,
54 | projector_out_dim: int = 256,
55 | ) -> None:
56 | super().__init__()
57 |
58 | if isinstance(encoder, str):
59 | self.encoder = torchvision_ssl_encoder(encoder)
60 | else:
61 | self.encoder = encoder
62 |
63 | self.projector = MLP(encoder_out_dim, projector_hidden_dim, projector_out_dim)
64 |
65 | def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
66 | y = self.encoder(x)[0]
67 | z = self.projector(y)
68 | return y, z
69 |
70 | def encode(self, x: Tensor) -> Tensor:
71 | """Returns the encoded representation of a view. This method does not calculate the projection as in the forward
72 | method.
73 |
74 | Args:
75 | x (Tensor): sample to be encoded
76 |
77 | """
78 | return self.encoder(x)[0]
79 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/cpc/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2
2 | from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet50, cpc_resnet101
3 | from pl_bolts.transforms.self_supervised.cpc_transforms import (
4 | CPCEvalTransformsCIFAR10,
5 | CPCEvalTransformsImageNet128,
6 | CPCEvalTransformsSTL10,
7 | CPCTrainTransformsCIFAR10,
8 | CPCTrainTransformsImageNet128,
9 | CPCTrainTransformsSTL10,
10 | )
11 |
12 | __all__ = [
13 | "CPC_v2",
14 | "cpc_resnet50",
15 | "cpc_resnet101",
16 | "CPCEvalTransformsCIFAR10",
17 | "CPCEvalTransformsImageNet128",
18 | "CPCEvalTransformsSTL10",
19 | "CPCTrainTransformsCIFAR10",
20 | "CPCTrainTransformsImageNet128",
21 | "CPCTrainTransformsSTL10",
22 | ]
23 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 |
4 | from pytorch_lightning import Trainer, seed_everything
5 |
6 | from pl_bolts.models.self_supervised import CPC_v2, SSLFineTuner
7 | from pl_bolts.transforms.self_supervised.cpc_transforms import (
8 | CPCEvalTransformsCIFAR10,
9 | CPCEvalTransformsSTL10,
10 | CPCTrainTransformsCIFAR10,
11 | CPCTrainTransformsSTL10,
12 | )
13 | from pl_bolts.utils.stability import under_review
14 |
15 |
16 | @under_review()
17 | def cli_main(): # pragma: no cover
18 | from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
19 |
20 | seed_everything(1234)
21 |
22 | parser = ArgumentParser()
23 | parser = Trainer.add_argparse_args(parser)
24 | parser.add_argument("--dataset", type=str, help="stl10, cifar10", default="cifar10")
25 | parser.add_argument("--ckpt_path", type=str, help="path to ckpt")
26 | parser.add_argument("--data_dir", type=str, help="path to ckpt", default=os.getcwd())
27 | args = parser.parse_args()
28 |
29 | # load the backbone
30 | backbone = CPC_v2.load_from_checkpoint(args.ckpt_path, strict=False)
31 |
32 | if args.dataset == "cifar10":
33 | dm = CIFAR10DataModule.from_argparse_args(args)
34 | dm.train_transforms = CPCTrainTransformsCIFAR10()
35 | dm.val_transforms = CPCEvalTransformsCIFAR10()
36 | dm.test_transforms = CPCEvalTransformsCIFAR10()
37 |
38 | elif args.dataset == "stl10":
39 | dm = STL10DataModule.from_argparse_args(args)
40 | dm.train_dataloader = dm.train_dataloader_labeled
41 | dm.val_dataloader = dm.val_dataloader_labeled
42 | dm.train_transforms = CPCTrainTransformsSTL10()
43 | dm.val_transforms = CPCEvalTransformsSTL10()
44 | dm.test_transforms = CPCEvalTransformsSTL10()
45 |
46 | # finetune
47 | tuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes)
48 | trainer = Trainer.from_argparse_args(args, early_stop_callback=True)
49 | trainer.fit(tuner, dm)
50 |
51 | trainer.test(datamodule=dm)
52 |
53 |
54 | if __name__ == "__main__":
55 | cli_main()
56 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/evaluator.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from pl_bolts.utils.stability import under_review
4 |
5 |
6 | @under_review()
7 | class SSLEvaluator(nn.Module):
8 | def __init__(self, n_input, n_classes, n_hidden=512, p=0.1) -> None:
9 | super().__init__()
10 | self.n_input = n_input
11 | self.n_classes = n_classes
12 | self.n_hidden = n_hidden
13 | if n_hidden is None:
14 | # use linear classifier
15 | self.block_forward = nn.Sequential(Flatten(), nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True))
16 | else:
17 | # use simple MLP classifier
18 | self.block_forward = nn.Sequential(
19 | Flatten(),
20 | nn.Dropout(p=p),
21 | nn.Linear(n_input, n_hidden, bias=False),
22 | nn.BatchNorm1d(n_hidden),
23 | nn.ReLU(inplace=True),
24 | nn.Dropout(p=p),
25 | nn.Linear(n_hidden, n_classes, bias=True),
26 | )
27 |
28 | def forward(self, x):
29 | return self.block_forward(x)
30 |
31 |
32 | @under_review()
33 | class Flatten(nn.Module):
34 | def __init__(self) -> None:
35 | super().__init__()
36 |
37 | def forward(self, input_tensor):
38 | return input_tensor.view(input_tensor.size(0), -1)
39 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/moco/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.transforms.self_supervised.moco_transforms import ( # noqa: F401
2 | MoCo2EvalCIFAR10Transforms,
3 | MoCo2EvalImagenetTransforms,
4 | MoCo2EvalSTL10Transforms,
5 | MoCo2TrainCIFAR10Transforms,
6 | MoCo2TrainImagenetTransforms,
7 | MoCo2TrainSTL10Transforms,
8 | )
9 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/moco/callbacks.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from pytorch_lightning import Callback
4 |
5 | from pl_bolts.utils.stability import under_review
6 |
7 |
8 | @under_review()
9 | class MoCoLRScheduler(Callback):
10 | def __init__(self, initial_lr=0.03, use_cosine_scheduler=False, schedule=(120, 160), max_epochs=200) -> None:
11 | super().__init__()
12 | self.lr = initial_lr
13 | self.use_cosine_scheduler = use_cosine_scheduler
14 | self.schedule = schedule
15 | self.max_epochs = max_epochs
16 |
17 | def on_train_epoch_start(self, trainer, pl_module):
18 | epoch = trainer.current_epoch
19 | lr = self.lr
20 |
21 | if self.use_cosine_scheduler: # cosine lr schedule
22 | lr *= 0.5 * (1.0 + math.cos(math.pi * epoch / self.max_epochs))
23 | else: # stepwise lr schedule
24 | for milestone in self.schedule:
25 | lr *= 0.1 if epoch >= milestone else 1.0
26 |
27 | optimizer = trainer.optimizers[0]
28 | for param_group in optimizer.param_groups:
29 | param_group["lr"] = lr
30 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/simclr/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.transforms.self_supervised.simclr_transforms import ( # noqa: F401
2 | SimCLREvalDataTransform,
3 | SimCLRTrainDataTransform,
4 | )
5 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/simsiam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/self_supervised/simsiam/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/self_supervised/swav/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
2 | from pl_bolts.models.self_supervised.swav.swav_module import SwAV
3 | from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
4 | from pl_bolts.transforms.self_supervised.swav_transforms import (
5 | SwAVEvalDataTransform,
6 | SwAVFinetuneTransform,
7 | SwAVTrainDataTransform,
8 | )
9 |
10 | __all__ = [
11 | "SwAV",
12 | "resnet18",
13 | "resnet50",
14 | "SwAVEvalDataTransform",
15 | "SwAVFinetuneTransform",
16 | "SwAVTrainDataTransform",
17 | "SWAVLoss",
18 | ]
19 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.models.vision.image_gpt.gpt2 import GPT2
2 | from pl_bolts.models.vision.image_gpt.igpt_module import ImageGPT
3 | from pl_bolts.models.vision.pixel_cnn import PixelCNN
4 | from pl_bolts.models.vision.segmentation import SemSegment
5 | from pl_bolts.models.vision.unet import UNet
6 |
7 | __all__ = [
8 | "GPT2",
9 | "ImageGPT",
10 | "PixelCNN",
11 | "SemSegment",
12 | "UNet",
13 | ]
14 |
--------------------------------------------------------------------------------
/src/pl_bolts/models/vision/image_gpt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/models/vision/image_gpt/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/models/vision/pixel_cnn.py:
--------------------------------------------------------------------------------
1 | """PixelCNN.
2 |
3 | Implemented by: William Falcon Reference
4 | : https: //arxiv.org/pdf/1905.09272.pdf (page 15 Accessed: May 14, 2020.
5 |
6 | """
7 | from torch import nn
8 | from torch.nn import functional as F # noqa: N812
9 |
10 | from pl_bolts.utils.stability import under_review
11 |
12 |
13 | @under_review()
14 | class PixelCNN(nn.Module):
15 | """Implementation of `Pixel CNN `_.
16 |
17 | Paper authors: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex Graves,
18 | Koray Kavukcuoglu
19 |
20 | Implemented by:
21 |
22 | - William Falcon
23 |
24 | Example::
25 |
26 | >>> from pl_bolts.models.vision import PixelCNN
27 | >>> import torch
28 | ...
29 | >>> model = PixelCNN(input_channels=3)
30 | >>> x = torch.rand(5, 3, 64, 64)
31 | >>> out = model(x)
32 | ...
33 | >>> out.shape
34 | torch.Size([5, 3, 64, 64])
35 | """
36 |
37 | def __init__(self, input_channels: int, hidden_channels: int = 256, num_blocks=5) -> None:
38 | super().__init__()
39 | self.input_channels = input_channels
40 | self.hidden_channels = hidden_channels
41 |
42 | self.blocks = nn.ModuleList([self.conv_block(input_channels) for _ in range(num_blocks)])
43 |
44 | def conv_block(self, input_channels):
45 | c1 = nn.Conv2d(in_channels=input_channels, out_channels=self.hidden_channels, kernel_size=(1, 1))
46 | act1 = nn.ReLU()
47 | c2 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=(1, 3))
48 | pad = nn.ConstantPad2d((0, 0, 1, 0, 0, 0, 0, 0), 1)
49 | c3 = nn.Conv2d(
50 | in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=(2, 1), padding=(0, 1)
51 | )
52 | act2 = nn.ReLU()
53 | c4 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=input_channels, kernel_size=(1, 1))
54 |
55 | return nn.Sequential(c1, act1, c2, pad, c3, act2, c4)
56 |
57 | def forward(self, z):
58 | c = z
59 | for conv_block in self.blocks:
60 | c = c + conv_block(c)
61 |
62 | return F.relu(c)
63 |
--------------------------------------------------------------------------------
/src/pl_bolts/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.optimizers.lars import LARS
2 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR, linear_warmup_decay
3 |
4 | __all__ = [
5 | "LARS",
6 | "LinearWarmupCosineAnnealingLR",
7 | "linear_warmup_decay",
8 | ]
9 |
--------------------------------------------------------------------------------
/src/pl_bolts/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/src/pl_bolts/transforms/__init__.py
--------------------------------------------------------------------------------
/src/pl_bolts/transforms/dataset_normalizations.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
4 | from pl_bolts.utils.warnings import warn_missing_pkg
5 |
6 | if _TORCHVISION_AVAILABLE:
7 | from torchvision import transforms
8 | else: # pragma: no cover
9 | warn_missing_pkg("torchvision")
10 |
11 |
12 | def imagenet_normalization() -> Callable:
13 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
14 | raise ModuleNotFoundError(
15 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`."
16 | )
17 |
18 | return transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
19 |
20 |
21 | def cifar10_normalization() -> Callable:
22 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
23 | raise ModuleNotFoundError(
24 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`."
25 | )
26 |
27 | return transforms.Normalize(
28 | mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
29 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
30 | )
31 |
32 |
33 | def stl10_normalization() -> Callable:
34 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
35 | raise ModuleNotFoundError(
36 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`."
37 | )
38 |
39 | return transforms.Normalize(mean=(0.43, 0.42, 0.39), std=(0.27, 0.26, 0.27))
40 |
41 |
42 | def emnist_normalization(split: str) -> Callable:
43 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
44 | raise ModuleNotFoundError(
45 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`."
46 | )
47 |
48 | # `stats` contains mean and std for each `split`.
49 | stats = {
50 | "balanced": (0.175, 0.333),
51 | "byclass": (0.174, 0.332),
52 | "bymerge": (0.174, 0.332),
53 | "digits": (0.173, 0.332),
54 | "letters": (0.172, 0.331),
55 | "mnist": (0.173, 0.332),
56 | }
57 |
58 | return transforms.Normalize(mean=stats[split][0], std=stats[split][1])
59 |
--------------------------------------------------------------------------------
/src/pl_bolts/transforms/self_supervised/__init__.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.transforms.self_supervised.ssl_transforms import Patchify, RandomTranslateWithReflect # noqa: F401
2 |
--------------------------------------------------------------------------------
/src/pl_bolts/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import operator
2 | import platform
3 |
4 | import torch
5 | from lightning_utilities.core.imports import compare_version, module_available
6 |
7 | from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore
8 |
9 | _NATIVE_AMP_AVAILABLE: bool = module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
10 | _IS_WINDOWS = platform.system() == "Windows"
11 | _TORCH_ORT_AVAILABLE = module_available("torch_ort")
12 | _TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0")
13 | _TORCHVISION_AVAILABLE: bool = module_available("torchvision")
14 | _TORCHVISION_LESS_THAN_0_9_1: bool = compare_version("torchvision", operator.lt, "0.9.1")
15 | _TORCHVISION_LESS_THAN_0_13: bool = compare_version("torchvision", operator.le, "0.13.0")
16 | _TORCHMETRICS_DETECTION_AVAILABLE: bool = module_available("torchmetrics.detection")
17 | _PL_GREATER_EQUAL_1_4 = compare_version("pytorch_lightning", operator.ge, "1.4.0")
18 | _PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5")
19 | _GYM_AVAILABLE: bool = module_available("gym")
20 | _GYM_GREATER_EQUAL_0_20 = compare_version("gym", operator.ge, "0.20.0")
21 | _SKLEARN_AVAILABLE: bool = module_available("sklearn")
22 | _PIL_AVAILABLE: bool = module_available("PIL")
23 | _SPARSEML_AVAILABLE = module_available("sparseml")
24 | _OPENCV_AVAILABLE: bool = module_available("cv2")
25 | _WANDB_AVAILABLE: bool = module_available("wandb")
26 | _MATPLOTLIB_AVAILABLE: bool = module_available("matplotlib")
27 | _JSONARGPARSE_GREATER_THAN_4_16_0 = compare_version("jsonargparse", operator.gt, "4.16.0")
28 |
29 | _SPARSEML_TORCH_SATISFIED = _SPARSEML_AVAILABLE
30 | _SPARSEML_TORCH_SATISFIED_ERROR = None
31 | if _SPARSEML_AVAILABLE:
32 | try:
33 | from sparseml.pytorch.base import _TORCH_MAX_VERSION # noqa: F401
34 |
35 | except (ImportError, ValueError) as err:
36 | _SPARSEML_TORCH_SATISFIED = False
37 | _SPARSEML_TORCH_SATISFIED_ERROR = str(err)
38 |
39 | __all__ = ["BatchGradientVerification"]
40 |
--------------------------------------------------------------------------------
/src/pl_bolts/utils/_dependency.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | from typing import Any, Callable
4 |
5 | from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache
6 |
7 |
8 | # ToDo: replace with utils wrapper after 0.10 is released
9 | def requires(*module_path_version: str) -> Callable:
10 | """Wrapper for enforcing certain requirements for a particular class or function."""
11 |
12 | def decorator(func: Callable) -> Callable:
13 | reqs = [
14 | ModuleAvailableCache(mod_ver) if "." in mod_ver else RequirementCache(mod_ver)
15 | for mod_ver in module_path_version
16 | ]
17 | available = all(map(bool, reqs))
18 | if not available:
19 |
20 | @functools.wraps(func)
21 | def wrapper(*args: Any, **kwargs: Any) -> Any:
22 | msg = os.linesep.join([repr(r) for r in reqs if not bool(r)])
23 | raise ModuleNotFoundError(f"Required dependencies not available: \n{msg}")
24 |
25 | return wrapper
26 | return func
27 |
28 | return decorator
29 |
--------------------------------------------------------------------------------
/src/pl_bolts/utils/pretrained_weights.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pytorch_lightning import LightningModule
4 |
5 | from pl_bolts.utils.stability import under_review
6 |
7 | vae_imagenet2012 = (
8 | "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/vae/imagenet_06_22_2019/checkpoints/epoch%3D63.ckpt"
9 | )
10 |
11 | cpcv2_resnet18 = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/resnet18-v6/epoch%3D85.ckpt"
12 | urls = {"vae-imagenet2012": vae_imagenet2012, "CPC_v2-resnet18": cpcv2_resnet18}
13 |
14 |
15 | @under_review()
16 | def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no cover
17 | if class_name is None:
18 | class_name = model.__class__.__name__
19 | ckpt_url = urls[class_name]
20 | weights_model = model.__class__.load_from_checkpoint(ckpt_url)
21 | model.load_state_dict(weights_model.state_dict())
22 |
--------------------------------------------------------------------------------
/src/pl_bolts/utils/self_supervised.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 |
3 | from pl_bolts.models.self_supervised import resnets
4 | from pl_bolts.utils.semi_supervised import Identity
5 | from pl_bolts.utils.stability import under_review
6 |
7 |
8 | @under_review()
9 | def torchvision_ssl_encoder(
10 | name: str,
11 | pretrained: bool = False,
12 | return_all_feature_maps: bool = False,
13 | ) -> Module:
14 | pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps)
15 | pretrained_model.fc = Identity()
16 | return pretrained_model
17 |
--------------------------------------------------------------------------------
/src/pl_bolts/utils/shaping.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import Tensor
4 |
5 | from pl_bolts.utils.stability import under_review
6 |
7 |
8 | @under_review()
9 | def tile(a: Tensor, dim: int, n_tile: int) -> Tensor:
10 | init_dim = a.size(dim)
11 | repeat_idx = [1] * a.dim()
12 | repeat_idx[dim] = n_tile
13 | a = a.repeat(*repeat_idx)
14 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
15 | return torch.index_select(a, dim, order_index)
16 |
--------------------------------------------------------------------------------
/src/pl_bolts/utils/types.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Union
2 |
3 | import numpy as np
4 | import torch
5 |
6 | TArrays = Union[torch.Tensor, np.ndarray, Sequence[float], Sequence["TArrays"]]
7 |
--------------------------------------------------------------------------------
/src/pl_bolts/utils/warnings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from typing import Callable, Dict, Optional
4 |
5 | MISSING_PACKAGE_WARNINGS: Dict[str, int] = {}
6 |
7 | WARN_MISSING_PACKAGE = int(os.environ.get("WARN_MISSING_PACKAGE", False))
8 |
9 |
10 | def warn_missing_pkg(
11 | pkg_name: str,
12 | pypi_name: Optional[str] = None,
13 | extra_text: Optional[str] = None,
14 | stdout_func: Callable = warnings.warn,
15 | ) -> int:
16 | """Template for warning on missing packages, show them just once.
17 |
18 | Args:
19 | pkg_name: Name of missing package
20 | pypi_name: In case that package name differ from PyPI name
21 | extra_text: Additional text after the base warning
22 | stdout_func: Define used function for streaming warning, use ``warnings.warn`` or ``logging.warning``
23 |
24 | Returns:
25 | Number of warning calls
26 |
27 | """
28 | if not WARN_MISSING_PACKAGE:
29 | return -1
30 |
31 | if pkg_name not in MISSING_PACKAGE_WARNINGS:
32 | extra_text = os.linesep + extra_text if extra_text else ""
33 | if not pypi_name:
34 | pypi_name = pkg_name
35 | stdout_func(
36 | f"You want to use `{pkg_name}` which is not installed yet,"
37 | f" install it with `pip install {pypi_name}`." + extra_text
38 | )
39 | MISSING_PACKAGE_WARNINGS[pkg_name] = 1
40 | else:
41 | MISSING_PACKAGE_WARNINGS[pkg_name] += 1
42 |
43 | return MISSING_PACKAGE_WARNINGS[pkg_name]
44 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | ## How to run tests
2 |
3 | The easiest way to run tests is by `make test`. Please note, that it may take quite some time, especially with unreliable internet connection since it is downloading datasets.
4 |
5 | If you want to run specific tests, first make sure your environment is up to date by running `make env && make clean`. Then you can target specific tests such as `pytest "tests/models/gans/integration/test_gans.py::test_gan"`.
6 |
7 | If you want to run doctests, run `make doctest` from root directory.
8 |
9 | ## Testing compatibility
10 |
11 | Following PR [#844](https://github.com/Lightning-AI/lightning-bolts/pull/844), all of the new tests are required to use `catch_warnings` fixture in order to filter out compatibility issues. In the future, this fixture will be marked as `autouse=True`, however until then, please add them yourself.
12 |
13 | ### Example
14 |
15 | ```python
16 | def test_this_thing(catch_warnings): ...
17 | ```
18 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from pytorch_lightning import seed_everything
5 |
6 | TEST_ROOT = os.path.realpath(os.path.dirname(__file__))
7 | PROJECT_ROOT = os.path.dirname(TEST_ROOT)
8 | DATASETS_PATH = os.environ.get("DATASETS_PATH", os.path.join(PROJECT_ROOT, "_datasets"))
9 | # generate a list of random seeds for each test
10 | ROOT_SEED = 1234
11 |
12 | _MARK_REQUIRE_GPU = {"condition": not torch.cuda.is_available(), "reason": "test requires GPU machine"}
13 |
14 |
15 | def reset_seed():
16 | seed_everything()
17 |
--------------------------------------------------------------------------------
/tests/_data_configs/yolo.cfg:
--------------------------------------------------------------------------------
1 | [net]
2 | width=256
3 | height=256
4 | channels=3
5 |
6 | [convolutional]
7 | batch_normalize=1
8 | filters=8
9 | size=3
10 | stride=1
11 | pad=1
12 | activation=leaky
13 |
14 | [route]
15 | layers=-1
16 | groups=2
17 | group_id=1
18 |
19 | [maxpool]
20 | size=2
21 | stride=2
22 |
23 | [convolutional]
24 | batch_normalize=1
25 | filters=2
26 | size=1
27 | stride=1
28 | pad=1
29 | activation=mish
30 |
31 | [convolutional]
32 | batch_normalize=1
33 | filters=4
34 | size=3
35 | stride=1
36 | pad=1
37 | activation=mish
38 |
39 | [shortcut]
40 | from=-3
41 | activation=linear
42 |
43 | [convolutional]
44 | size=1
45 | stride=1
46 | pad=1
47 | filters=14
48 | activation=linear
49 |
50 | [yolo]
51 | mask=2,3
52 | anchors=1,2, 3,4, 5,6, 9,10
53 | classes=2
54 | scale_x_y=1.05
55 | cls_normalizer=1.0
56 | iou_normalizer=0.07
57 | ignore_thresh=0.7
58 |
59 | [route]
60 | layers = -4
61 |
62 | [upsample]
63 | stride=2
64 |
65 | [convolutional]
66 | size=1
67 | stride=1
68 | pad=1
69 | filters=14
70 | activation=linear
71 |
72 | [yolo]
73 | mask=0,1
74 | anchors=1,2, 3,4, 5,6, 9,10
75 | classes=2
76 | scale_x_y=1.05
77 | cls_normalizer=1.0
78 | iou_normalizer=0.07
79 | ignore_thresh=0.7
80 |
--------------------------------------------------------------------------------
/tests/_data_configs/yolo_giou.cfg:
--------------------------------------------------------------------------------
1 | [net]
2 | width=256
3 | height=256
4 | channels=3
5 |
6 | [convolutional]
7 | batch_normalize=1
8 | filters=8
9 | size=3
10 | stride=1
11 | pad=1
12 | activation=leaky
13 |
14 | [route]
15 | layers=-1
16 | groups=2
17 | group_id=1
18 |
19 | [maxpool]
20 | size=2
21 | stride=2
22 |
23 | [convolutional]
24 | batch_normalize=1
25 | filters=2
26 | size=1
27 | stride=1
28 | pad=1
29 | activation=mish
30 |
31 | [convolutional]
32 | batch_normalize=1
33 | filters=4
34 | size=3
35 | stride=1
36 | pad=1
37 | activation=mish
38 |
39 | [shortcut]
40 | from=-3
41 | activation=linear
42 |
43 | [convolutional]
44 | size=1
45 | stride=1
46 | pad=1
47 | filters=14
48 | activation=linear
49 |
50 | [yolo]
51 | mask=2,3
52 | anchors=1,2, 3,4, 5,6, 9,10
53 | classes=2
54 | iou_loss=giou
55 | scale_x_y=1.05
56 | cls_normalizer=1.0
57 | iou_normalizer=0.07
58 | ignore_thresh=0.7
59 |
60 | [route]
61 | layers = -4
62 |
63 | [upsample]
64 | stride=2
65 |
66 | [convolutional]
67 | size=1
68 | stride=1
69 | pad=1
70 | filters=14
71 | activation=linear
72 |
73 | [yolo]
74 | mask=0,1
75 | anchors=1,2, 3,4, 5,6, 9,10
76 | classes=2
77 | iou_loss=giou
78 | scale_x_y=1.05
79 | cls_normalizer=1.0
80 | iou_normalizer=0.07
81 | ignore_thresh=0.7
82 |
--------------------------------------------------------------------------------
/tests/callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/callbacks/__init__.py
--------------------------------------------------------------------------------
/tests/callbacks/test_info_callbacks.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.callbacks import PrintTableMetricsCallback
2 |
3 |
4 | def test_printtable_metrics_callback():
5 | callback = PrintTableMetricsCallback()
6 |
7 | metrics_a = {"loss": 1.0, "epoch": 0}
8 | metrics_b = {"loss": 0.5, "epoch": 2}
9 |
10 | class FakeTrainer:
11 | def __init__(self) -> None:
12 | self.callback_metrics = {}
13 |
14 | fake_trainer = FakeTrainer()
15 | fake_trainer.callback_metrics = metrics_a
16 | callback.on_train_epoch_end(fake_trainer, None)
17 | fake_trainer.callback_metrics = metrics_b
18 | callback.on_train_epoch_end(fake_trainer, None)
19 |
20 | assert len(callback.metrics) == 2
21 | assert callback.metrics[0] == metrics_a
22 | assert callback.metrics[1] == metrics_b
23 |
--------------------------------------------------------------------------------
/tests/callbacks/test_ort.py:
--------------------------------------------------------------------------------
1 | # Copyright The PyTorch Lightning team.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import pytest
16 | from pl_bolts.callbacks import ORTCallback
17 | from pl_bolts.utils import _TORCH_ORT_AVAILABLE
18 | from pytorch_lightning import Callback, Trainer
19 | from pytorch_lightning.core.lightning import LightningModule
20 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
21 |
22 | from tests.helpers.boring_model import BoringModel
23 |
24 | if _TORCH_ORT_AVAILABLE:
25 | from torch_ort import ORTModule
26 |
27 |
28 | @pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
29 | def test_init_train_enable_ort(tmpdir):
30 | class TestCallback(Callback):
31 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
32 | assert isinstance(pl_module.model, ORTModule)
33 |
34 | class TestModel(BoringModel):
35 | def __init__(self) -> None:
36 | super().__init__()
37 | self.model = self.layer
38 |
39 | def forward(self, x):
40 | return self.model(x)
41 |
42 | model = TestModel()
43 | trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[ORTCallback(), TestCallback()])
44 | trainer.fit(model)
45 | trainer.test(model)
46 |
47 |
48 | @pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
49 | def test_ort_callback_fails_no_model(tmpdir):
50 | model = BoringModel()
51 | trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback())
52 | with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"):
53 | trainer.fit(
54 | model,
55 | )
56 |
--------------------------------------------------------------------------------
/tests/callbacks/test_param_update_callbacks.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import pytest
4 | import torch
5 | from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
6 | from torch import nn
7 |
8 |
9 | @pytest.mark.parametrize("initial_tau", [-0.1, 0.0, 0.996, 1.0, 1.1])
10 | def test_byol_ma_weight_single_update_callback(initial_tau, catch_warnings):
11 | """Check BYOL exponential moving average weight update rule for a single update."""
12 | if 0.0 <= initial_tau <= 1.0:
13 | # Create simple one layer network and their copies
14 | online_network = nn.Linear(100, 10)
15 | target_network = deepcopy(online_network)
16 | online_network_copy = deepcopy(online_network)
17 | target_network_copy = deepcopy(target_network)
18 |
19 | # Check parameters are equal
20 | assert torch.equal(next(iter(online_network.parameters()))[0], next(iter(target_network.parameters()))[0])
21 |
22 | # Simulate weight update
23 | opt = torch.optim.SGD(online_network.parameters(), lr=0.1)
24 | y = online_network(torch.randn(3, 100))
25 | loss = y.sum()
26 | loss.backward()
27 | opt.step()
28 | opt.zero_grad()
29 |
30 | # Check online network update
31 | assert not torch.equal(
32 | next(iter(online_network.parameters()))[0], next(iter(online_network_copy.parameters()))[0]
33 | )
34 |
35 | # Update target network weights via callback
36 | cb = BYOLMAWeightUpdate(initial_tau)
37 | cb.update_weights(online_network, target_network)
38 |
39 | # Check target network update according to value of tau
40 | if initial_tau == 0.0:
41 | assert torch.equal(next(iter(target_network.parameters()))[0], next(iter(online_network.parameters()))[0])
42 | elif initial_tau == 1.0:
43 | assert torch.equal(
44 | next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0]
45 | )
46 | else:
47 | for online_p, target_p in zip(online_network.parameters(), target_network_copy.parameters()):
48 | target_p.data = initial_tau * target_p.data + (1.0 - initial_tau) * online_p.data
49 |
50 | assert torch.equal(
51 | next(iter(target_network.parameters()))[0], next(iter(target_network_copy.parameters()))[0]
52 | )
53 | else:
54 | with pytest.raises(ValueError, match="initial tau should be"):
55 | cb = BYOLMAWeightUpdate(initial_tau)
56 |
--------------------------------------------------------------------------------
/tests/callbacks/test_variational_callbacks.py:
--------------------------------------------------------------------------------
1 | from pl_bolts.callbacks import LatentDimInterpolator
2 | from pl_bolts.models.gans import GAN
3 |
4 | try:
5 | from pytorch_lightning.loggers.logger import DummyLogger # PL v1.9+
6 | except ModuleNotFoundError:
7 | from pytorch_lightning.loggers.base import DummyLogger # PL v1.8
8 |
9 |
10 | def test_latent_dim_interpolator():
11 | class FakeTrainer:
12 | def __init__(self) -> None:
13 | self.current_epoch = 1
14 | self.global_step = 1
15 | self.logger = DummyLogger()
16 |
17 | model = GAN(3, 28, 28)
18 | cb = LatentDimInterpolator(interpolate_epoch_interval=2)
19 |
20 | cb.on_train_epoch_end(FakeTrainer(), model)
21 |
--------------------------------------------------------------------------------
/tests/callbacks/verification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/callbacks/verification/__init__.py
--------------------------------------------------------------------------------
/tests/datamodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/datamodules/__init__.py
--------------------------------------------------------------------------------
/tests/datamodules/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
3 | from pl_bolts.datasets.cifar10_dataset import CIFAR10
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | def test_async_dataloader(datadir):
8 | ds = CIFAR10(data_dir=datadir)
9 |
10 | if torch.cuda.device_count() > 0: # Can only run this test with a GPU
11 | device = torch.device("cuda", 0)
12 | dataloader = AsynchronousLoader(ds, device=device)
13 |
14 | for b in dataloader:
15 | pass
16 |
17 | dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device)
18 | for b in dataloader:
19 | pass
20 |
--------------------------------------------------------------------------------
/tests/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/datasets/__init__.py
--------------------------------------------------------------------------------
/tests/datasets/test_array_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from pl_bolts.datasets import ArrayDataset, DataModel
5 | from pl_bolts.datasets.utils import to_tensor
6 | from pytorch_lightning.utilities import exceptions
7 |
8 |
9 | class TestArrayDataset:
10 | @pytest.fixture()
11 | def array_dataset(self):
12 | features_1 = DataModel(data=[[1, 0, -1, 2], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]], transform=to_tensor)
13 | target_1 = DataModel(data=[1, 0, 0, 1], transform=to_tensor)
14 |
15 | features_2 = DataModel(data=np.array([[2, 1, -5, 1], [1, 0, -2, -1], [2, 5, 0, 3], [-7, 1, 2, 2]]))
16 | target_2 = DataModel(data=[1, 0, 1, 1])
17 | return ArrayDataset(features_1, target_1, features_2, target_2)
18 |
19 | def test_len(self, array_dataset):
20 | assert len(array_dataset) == 4
21 |
22 | def test_getitem_with_transforms(self, array_dataset):
23 | assert len(array_dataset[0]) == 4
24 | assert len(array_dataset[1]) == 4
25 | assert len(array_dataset[2]) == 4
26 | assert len(array_dataset[3]) == 4
27 | torch.testing.assert_close(array_dataset[0][0], torch.tensor([1, 0, -1, 2]))
28 | torch.testing.assert_close(array_dataset[0][1], torch.tensor(1))
29 | np.testing.assert_array_equal(array_dataset[0][2], np.array([2, 1, -5, 1]))
30 | assert array_dataset[0][3] == 1
31 | torch.testing.assert_close(array_dataset[1][0], torch.tensor([1, 0, -2, -1]))
32 | torch.testing.assert_close(array_dataset[1][1], torch.tensor(0))
33 | np.testing.assert_array_equal(array_dataset[1][2], np.array([1, 0, -2, -1]))
34 | assert array_dataset[1][3] == 0
35 | torch.testing.assert_close(array_dataset[2][0], torch.tensor([2, 5, 0, 3]))
36 | torch.testing.assert_close(array_dataset[2][1], torch.tensor(0))
37 | np.testing.assert_array_equal(array_dataset[2][2], np.array([2, 5, 0, 3]))
38 | assert array_dataset[2][3] == 1
39 | torch.testing.assert_close(array_dataset[3][0], torch.tensor([-7, 1, 2, 2]))
40 | torch.testing.assert_close(array_dataset[3][1], torch.tensor(1))
41 | np.testing.assert_array_equal(array_dataset[3][2], np.array([-7, 1, 2, 2]))
42 | assert array_dataset[3][3] == 1
43 |
44 | def test__equal_size_true(self, array_dataset):
45 | assert array_dataset._equal_size() is True
46 |
47 | def test__equal_size_false(self):
48 | features = DataModel(data=[[1, 0, 1]])
49 | target = DataModel([1, 0, 1])
50 | with pytest.raises(exceptions.MisconfigurationException):
51 | ArrayDataset(features, target)
52 |
--------------------------------------------------------------------------------
/tests/datasets/test_base_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from pl_bolts.datasets.base_dataset import DataModel
5 | from pl_bolts.datasets.utils import to_tensor
6 | from pl_bolts.utils import _IS_WINDOWS
7 |
8 |
9 | class TestDataModel:
10 | @pytest.fixture()
11 | def data(self):
12 | return np.array([[1, 0, 0, 1], [0, 1, 1, 0]])
13 |
14 | def test_process_transform_is_none(self, data):
15 | dm = DataModel(data=data)
16 | np.testing.assert_array_equal(dm.process(data[0]), data[0])
17 | np.testing.assert_array_equal(dm.process(data[1]), data[1])
18 |
19 | @pytest.mark.skipif( # todo
20 | _IS_WINDOWS, reason="AssertionError: The values for attribute 'dtype' do not match: torch.int32 != torch.int64"
21 | )
22 | def test_process_transform_is_not_none(self, data):
23 | dm = DataModel(data=data, transform=to_tensor)
24 | torch.testing.assert_close(dm.process(data[0]), torch.tensor([1, 0, 0, 1]))
25 | torch.testing.assert_close(dm.process(data[1]), torch.tensor([0, 1, 1, 0]))
26 |
--------------------------------------------------------------------------------
/tests/datasets/test_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.testing
3 | from pl_bolts.datasets.utils import to_tensor
4 |
5 |
6 | class TestToTensor:
7 | def test_to_tensor_list(self):
8 | _list = [1, 2, 3]
9 | torch.testing.assert_close(to_tensor(_list), torch.tensor(_list))
10 |
11 | def test_to_tensor_array(self):
12 | _array = np.array([1, 2, 3])
13 | torch.testing.assert_close(to_tensor(_array), torch.tensor(_array))
14 |
15 | def test_to_tensor_sequence_(self):
16 | _sequence = [[1.0, 2.0, 3.0]]
17 | torch.testing.assert_close(to_tensor(_sequence), torch.tensor(_sequence))
18 |
--------------------------------------------------------------------------------
/tests/helpers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/helpers/__init__.py
--------------------------------------------------------------------------------
/tests/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/losses/__init__.py
--------------------------------------------------------------------------------
/tests/losses/test_rl_loss.py:
--------------------------------------------------------------------------------
1 | """Test RL Loss Functions."""
2 |
3 | from unittest import TestCase
4 |
5 | import numpy as np
6 | import torch
7 | from pl_bolts.losses.rl import double_dqn_loss, dqn_loss, per_dqn_loss
8 | from pl_bolts.models.rl.common.gym_wrappers import make_environment
9 | from pl_bolts.models.rl.common.networks import CNN
10 | from torch import Tensor
11 |
12 |
13 | class TestRLLoss(TestCase):
14 | def setUp(self) -> None:
15 | self.state = torch.rand(32, 4, 84, 84)
16 | self.next_state = torch.rand(32, 4, 84, 84)
17 | self.action = torch.ones([32])
18 | self.reward = torch.ones([32])
19 | self.done = torch.zeros([32]).long()
20 |
21 | self.batch = (self.state, self.action, self.reward, self.done, self.next_state)
22 |
23 | self.env = make_environment("PongNoFrameskip-v4")
24 | self.obs_shape = self.env.observation_space.shape
25 | self.n_actions = self.env.action_space.n
26 | self.net = CNN(self.obs_shape, self.n_actions)
27 | self.target_net = CNN(self.obs_shape, self.n_actions)
28 |
29 | def test_dqn_loss(self):
30 | """Test the dqn loss function."""
31 |
32 | loss = dqn_loss(self.batch, self.net, self.target_net)
33 | assert isinstance(loss, Tensor)
34 |
35 | def test_double_dqn_loss(self):
36 | """Test the double dqn loss function."""
37 |
38 | loss = double_dqn_loss(self.batch, self.net, self.target_net)
39 | assert isinstance(loss, Tensor)
40 |
41 | def test_per_dqn_loss(self):
42 | """Test the double dqn loss function."""
43 | prios = torch.ones([32])
44 |
45 | loss, batch_weights = per_dqn_loss(self.batch, prios, self.net, self.target_net)
46 | assert isinstance(loss, Tensor)
47 | assert isinstance(batch_weights, np.ndarray)
48 |
--------------------------------------------------------------------------------
/tests/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/metrics/__init__.py
--------------------------------------------------------------------------------
/tests/metrics/test_aggregation.py:
--------------------------------------------------------------------------------
1 | """Test Aggregation Metric Functions."""
2 |
3 | import pytest
4 | import torch
5 | from pl_bolts.metrics.aggregation import accuracy, mean, precision_at_k
6 |
7 |
8 | @pytest.mark.parametrize(
9 | ("preds", "expected_mean"),
10 | [(torch.tensor([[100.0, 100.0, 200.0, 200.0]]), torch.tensor(150.0))],
11 | )
12 | def test_mean(preds, expected_mean):
13 | x = {"test": preds}
14 | torch.testing.assert_close(mean([x], "test"), expected_mean)
15 |
16 |
17 | @pytest.mark.parametrize(
18 | ("preds", "target", "expected_accuracy"),
19 | [
20 | (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[2]]), torch.tensor(1.0)),
21 | (torch.tensor([[100, 100, 200, 200]]), torch.tensor([[0]]), torch.tensor(0.0)),
22 | ],
23 | )
24 | def test_accuracy(preds, target, expected_accuracy):
25 | torch.testing.assert_close(accuracy(preds, target), expected_accuracy)
26 |
27 |
28 | @pytest.mark.parametrize(
29 | ("output", "target", "expected_precision_at_k"),
30 | [
31 | (torch.tensor([[100.0, 100.0, 200.0, 200.0]]), torch.tensor([[2]]), [torch.tensor([100.0])]),
32 | (torch.tensor([[100.0, 100.0, 200.0, 200.0]]), torch.tensor([[1]]), [torch.tensor([0.0])]),
33 | ],
34 | )
35 | def test_precision_at_k(output, target, expected_precision_at_k):
36 | torch.testing.assert_close(precision_at_k(output, target), expected_precision_at_k)
37 |
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/__init__.py
--------------------------------------------------------------------------------
/tests/models/gans/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/gans/__init__.py
--------------------------------------------------------------------------------
/tests/models/gans/integration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/gans/integration/__init__.py
--------------------------------------------------------------------------------
/tests/models/gans/integration/test_gans.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import pytest
4 | from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule
5 | from pl_bolts.datasets.sr_mnist_dataset import SRMNIST
6 | from pl_bolts.models.gans import DCGAN, GAN, SRGAN, SRResNet
7 | from pl_bolts.utils import _IS_WINDOWS
8 | from pytorch_lightning import Trainer, seed_everything
9 | from pytorch_lightning.utilities.warnings import PossibleUserWarning
10 | from torch.utils.data.dataloader import DataLoader
11 | from torchvision import transforms as transform_lib
12 |
13 |
14 | @pytest.mark.parametrize(
15 | "dm_cls",
16 | [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")],
17 | )
18 | def test_gan(tmpdir, datadir, catch_warnings, dm_cls):
19 | # Validation loop for GANs is not well defined!
20 | warnings.filterwarnings(
21 | "ignore",
22 | message="You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.",
23 | category=UserWarning,
24 | )
25 | warnings.filterwarnings(
26 | "ignore",
27 | message="The dataloader, train_dataloader, does not have many workers which may be a bottleneck",
28 | category=PossibleUserWarning,
29 | )
30 | seed_everything(1234)
31 | dm = dm_cls(data_dir=datadir, num_workers=0)
32 | model = GAN(*dm.dims)
33 |
34 | trainer = Trainer(
35 | default_root_dir=tmpdir,
36 | max_epochs=-1,
37 | fast_dev_run=True,
38 | log_every_n_steps=1,
39 | accelerator="auto",
40 | # TODO: We need to be able to support multiple GPUs in such a simple scenario.
41 | # But, DDP is throwing ugly errors at me at the moment
42 | devices=1,
43 | )
44 | trainer.fit(model, datamodule=dm)
45 |
46 |
47 | @pytest.mark.parametrize(
48 | "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")]
49 | )
50 | def test_dcgan(tmpdir, datadir, dm_cls):
51 | seed_everything()
52 |
53 | transforms = transform_lib.Compose([transform_lib.Resize(64), transform_lib.ToTensor()])
54 | dm = dm_cls(data_dir=datadir, train_transforms=transforms, val_transforms=transforms, test_transforms=transforms)
55 |
56 | model = DCGAN(image_channels=dm.dims[0])
57 | trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
58 | trainer.fit(model, dm)
59 |
60 |
61 | @pytest.mark.skipif(_IS_WINDOWS, reason="failing...") # todo
62 | @pytest.mark.parametrize("sr_module_cls", [SRResNet, SRGAN])
63 | @pytest.mark.parametrize("scale_factor", [2, 4])
64 | def test_sr_modules(tmpdir, datadir, sr_module_cls, scale_factor):
65 | seed_everything(42)
66 |
67 | dl = DataLoader(SRMNIST(scale_factor, root=datadir, download=True))
68 | model = sr_module_cls(image_channels=1, scale_factor=scale_factor)
69 | trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
70 | trainer.fit(model, dl)
71 |
--------------------------------------------------------------------------------
/tests/models/gans/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/gans/unit/__init__.py
--------------------------------------------------------------------------------
/tests/models/gans/unit/test_basic_components.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from pl_bolts.models.gans.basic.components import Discriminator, Generator
4 | from pytorch_lightning import seed_everything
5 |
6 |
7 | @pytest.mark.parametrize(
8 | ("latent_dim", "img_shape"),
9 | [
10 | pytest.param(100, (3, 28, 28), id="100-multichannel"),
11 | pytest.param(100, (1, 28, 28), id="100-singlechannel"),
12 | ],
13 | )
14 | def test_generator(catch_warnings, latent_dim, img_shape):
15 | batch_dim = 10
16 | seed_everything(1234)
17 | generator = Generator(latent_dim=latent_dim, img_shape=img_shape)
18 | noise = torch.randn(batch_dim, latent_dim)
19 | samples = generator(noise)
20 | assert samples.shape == (batch_dim, *img_shape)
21 |
22 |
23 | @pytest.mark.parametrize(
24 | "img_shape",
25 | [
26 | pytest.param((3, 28, 28), id="discriminator-multichannel"),
27 | pytest.param((1, 28, 28), id="discriminator-singlechannel"),
28 | ],
29 | )
30 | def test_discriminator(catch_warnings, img_shape):
31 | batch_dim = 10
32 | seed_everything(1234)
33 | discriminator = Discriminator(img_shape=img_shape)
34 | samples = torch.randn(batch_dim, *img_shape)
35 | real_or_fake = discriminator(samples)
36 | assert real_or_fake.shape == (batch_dim, 1)
37 | assert (torch.clamp(real_or_fake.clone(), 0, 1) == real_or_fake).all()
38 |
--------------------------------------------------------------------------------
/tests/models/regression/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/regression/__init__.py
--------------------------------------------------------------------------------
/tests/models/regression/test_logistic_regression.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import operator
3 |
4 | import pytorch_lightning as pl
5 | from pl_bolts import datamodules
6 | from pl_bolts.models import regression
7 |
8 |
9 | def test_logistic_regression_model(datadir):
10 | pl.seed_everything(0)
11 |
12 | dm = datamodules.MNISTDataModule(datadir)
13 |
14 | model = regression.LogisticRegression(
15 | input_dim=functools.reduce(operator.mul, dm.dims, 1), num_classes=10, learning_rate=0.001
16 | )
17 |
18 | trainer = pl.Trainer(max_epochs=3, logger=False, enable_checkpointing=False)
19 | trainer.fit(model, datamodule=dm)
20 | trainer.test(model, datamodule=dm)
21 | assert trainer.state.finished
22 | assert trainer.callback_metrics["test_acc"] > 0.9
23 | assert trainer.callback_metrics["test_loss"] < 0.3
24 |
--------------------------------------------------------------------------------
/tests/models/rl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/rl/__init__.py
--------------------------------------------------------------------------------
/tests/models/rl/integration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/rl/integration/__init__.py
--------------------------------------------------------------------------------
/tests/models/rl/integration/test_actor_critic_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import pytest
4 | import torch.cuda
5 | from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic
6 | from pl_bolts.models.rl.sac_model import SAC
7 | from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20
8 | from pytorch_lightning import Trainer
9 |
10 |
11 | def test_a2c_cli():
12 | """Smoke test that the A2C model runs."""
13 |
14 | parent_parser = argparse.ArgumentParser(add_help=False)
15 | parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser)
16 | args_list = ["--env", "CartPole-v0"]
17 | hparams = parent_parser.parse_args(args_list)
18 |
19 | trainer = Trainer(
20 | gpus=int(torch.cuda.is_available()),
21 | max_steps=100,
22 | max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early
23 | val_check_interval=1, # This just needs 'some' value, does not effect training right now
24 | fast_dev_run=True,
25 | )
26 | model = AdvantageActorCritic(hparams.env)
27 | trainer.fit(model)
28 |
29 |
30 | @pytest.mark.skipif(_GYM_GREATER_EQUAL_0_20, reason="gym.error.DeprecatedEnv: Env Pendulum-v0 not found")
31 | def test_sac_cli():
32 | """Smoke test that the SAC model runs."""
33 |
34 | parent_parser = argparse.ArgumentParser(add_help=False)
35 | parent_parser = Trainer.add_argparse_args(parent_parser)
36 | parent_parser = SAC.add_model_specific_args(parent_parser)
37 | args_list = [
38 | "--warm_start_size",
39 | "100",
40 | "--gpus",
41 | "0", # fixme: RuntimeError on GPU: Expected all tensors to be on the same device
42 | "--env",
43 | "Pendulum-v0",
44 | "--batch_size",
45 | "10",
46 | ]
47 | hparams = parent_parser.parse_args(args_list)
48 |
49 | trainer = Trainer(
50 | gpus=hparams.gpus,
51 | max_steps=100,
52 | max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early
53 | val_check_interval=1, # This just needs 'some' value, does not effect training right now
54 | fast_dev_run=True,
55 | )
56 | model = SAC(**hparams.__dict__)
57 | trainer.fit(model)
58 |
--------------------------------------------------------------------------------
/tests/models/rl/integration/test_policy_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from unittest import TestCase
3 |
4 | import torch
5 | from pl_bolts.models.rl.reinforce_model import Reinforce
6 | from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient
7 | from pytorch_lightning import Trainer
8 |
9 |
10 | class TestPolicyModels(TestCase):
11 | def setUp(self) -> None:
12 | parent_parser = argparse.ArgumentParser(add_help=False)
13 | parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser)
14 | args_list = ["--env", "CartPole-v0"]
15 | self.hparams = parent_parser.parse_args(args_list)
16 |
17 | self.trainer = Trainer(
18 | gpus=int(torch.cuda.is_available()),
19 | max_steps=100,
20 | max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early
21 | val_check_interval=1, # This just needs 'some' value, does not effect training right now
22 | fast_dev_run=True,
23 | )
24 |
25 | def test_reinforce(self):
26 | """Smoke test that the reinforce model runs."""
27 |
28 | model = Reinforce(self.hparams.env)
29 | self.trainer.fit(model)
30 |
31 | def test_policy_gradient(self):
32 | """Smoke test that the policy gradient model runs."""
33 | model = VanillaPolicyGradient(self.hparams.env)
34 | self.trainer.fit(model)
35 |
--------------------------------------------------------------------------------
/tests/models/rl/integration/test_value_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from unittest import TestCase
3 |
4 | import pytest
5 | import torch
6 | from pl_bolts.models.rl.double_dqn_model import DoubleDQN
7 | from pl_bolts.models.rl.dqn_model import DQN
8 | from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN
9 | from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
10 | from pl_bolts.models.rl.per_dqn_model import PERDQN
11 | from pl_bolts.utils import _IS_WINDOWS
12 | from pytorch_lightning import Trainer
13 |
14 |
15 | class TestValueModels(TestCase):
16 | def setUp(self) -> None:
17 | parent_parser = argparse.ArgumentParser(add_help=False)
18 | parent_parser = Trainer.add_argparse_args(parent_parser)
19 | parent_parser = DQN.add_model_specific_args(parent_parser)
20 | args_list = [
21 | "--warm_start_size",
22 | "100",
23 | "--gpus",
24 | str(int(torch.cuda.is_available())),
25 | "--env",
26 | "PongNoFrameskip-v4",
27 | ]
28 | self.hparams = parent_parser.parse_args(args_list)
29 |
30 | self.trainer = Trainer(
31 | gpus=self.hparams.gpus,
32 | max_steps=100,
33 | max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early
34 | val_check_interval=1, # This just needs 'some' value, does not effect training right now
35 | fast_dev_run=True,
36 | )
37 |
38 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut or MemoryError") # todo
39 | def test_dqn(self):
40 | """Smoke test that the DQN model runs."""
41 | model = DQN(self.hparams.env, num_envs=5)
42 | self.trainer.fit(model)
43 |
44 | def test_double_dqn(self):
45 | """Smoke test that the Double DQN model runs."""
46 | model = DoubleDQN(self.hparams.env)
47 | self.trainer.fit(model)
48 |
49 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut or MemoryError") # todo
50 | def test_dueling_dqn(self):
51 | """Smoke test that the Dueling DQN model runs."""
52 | model = DuelingDQN(self.hparams.env)
53 | self.trainer.fit(model)
54 |
55 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut or MemoryError") # todo
56 | def test_noisy_dqn(self):
57 | """Smoke test that the Noisy DQN model runs."""
58 | model = NoisyDQN(self.hparams.env)
59 | self.trainer.fit(model)
60 |
61 | @pytest.mark.skip(reason="CI is killing this test")
62 | def test_per_dqn(self):
63 | """Smoke test that the PER DQN model runs."""
64 | model = PERDQN(self.hparams.env)
65 | self.trainer.fit(model)
66 |
67 | # def test_n_step_dqn(self):
68 | # """Smoke test that the N Step DQN model runs"""
69 | # model = DQN(self.hparams.env, n_steps=self.hparams.n_steps)
70 | # result = self.trainer.fit(model)
71 |
--------------------------------------------------------------------------------
/tests/models/rl/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/rl/unit/__init__.py
--------------------------------------------------------------------------------
/tests/models/rl/unit/test_a2c.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic
5 | from torch import Tensor
6 |
7 |
8 | def test_a2c_loss():
9 | """Test the reinforce loss function."""
10 | parent_parser = argparse.ArgumentParser(add_help=False)
11 | parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser)
12 | args_list = ["--env", "CartPole-v0", "--batch_size", "32"]
13 | hparams = parent_parser.parse_args(args_list)
14 | model = AdvantageActorCritic(**vars(hparams))
15 |
16 | batch_states = torch.rand(32, 4)
17 | batch_actions = torch.rand(32).long()
18 | batch_qvals = torch.rand(32)
19 |
20 | loss = model.loss(batch_states, batch_actions, batch_qvals)
21 |
22 | assert isinstance(loss, Tensor)
23 |
24 |
25 | def test_a2c_train_batch():
26 | """Tests that a single batch generates correctly."""
27 | parent_parser = argparse.ArgumentParser(add_help=False)
28 | parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser)
29 | args_list = ["--env", "CartPole-v0", "--batch_size", "32"]
30 | hparams = parent_parser.parse_args(args_list)
31 | model = AdvantageActorCritic(**vars(hparams))
32 |
33 | model.n_steps = 4
34 | model.hparams.batch_size = 1
35 | xp_dataloader = model.train_dataloader()
36 |
37 | batch = next(iter(xp_dataloader))
38 |
39 | assert len(batch) == 3
40 | assert len(batch[0]) == model.hparams.batch_size
41 | assert isinstance(batch, list)
42 | assert isinstance(batch[0], Tensor)
43 | assert isinstance(batch[1], Tensor)
44 | assert isinstance(batch[2], Tensor)
45 |
--------------------------------------------------------------------------------
/tests/models/rl/unit/test_agents.py:
--------------------------------------------------------------------------------
1 | """Tests that the agent module works correctly."""
2 | from unittest import TestCase
3 | from unittest.mock import Mock
4 |
5 | import gym
6 | import numpy as np
7 | import torch
8 | from pl_bolts.models.rl.common.agents import ActorCriticAgent, Agent, PolicyAgent, ValueAgent
9 | from torch import Tensor
10 |
11 |
12 | class TestAgents(TestCase):
13 | def setUp(self) -> None:
14 | self.env = gym.make("CartPole-v0")
15 | self.state = self.env.reset()
16 | self.net = Mock()
17 |
18 | def test_base_agent(self):
19 | agent = Agent(self.net)
20 | action = agent(self.state, "cuda:0")
21 | assert isinstance(action, list)
22 |
23 |
24 | class TestValueAgent(TestCase):
25 | def setUp(self) -> None:
26 | self.env = gym.make("CartPole-v0")
27 | self.net = Mock(return_value=Tensor([[0.0, 100.0]]))
28 | self.state = [self.env.reset()]
29 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30 | self.value_agent = ValueAgent(self.net, self.env.action_space.n)
31 |
32 | def test_value_agent(self):
33 | action = self.value_agent(self.state, self.device)
34 | assert isinstance(action, list)
35 | assert isinstance(action[0], int)
36 |
37 | def test_value_agent_get_action(self):
38 | action = self.value_agent.get_action(self.state, self.device)
39 | assert isinstance(action, np.ndarray)
40 | assert action[0] == 1
41 |
42 | def test_value_agent_random(self):
43 | action = self.value_agent.get_random_action(self.state)
44 | assert isinstance(action[0], int)
45 |
46 |
47 | class TestPolicyAgent(TestCase):
48 | def setUp(self) -> None:
49 | self.env = gym.make("CartPole-v0")
50 | self.net = Mock(return_value=Tensor([[0.0, 100.0]]))
51 | self.states = [self.env.reset()]
52 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53 |
54 | def test_policy_agent(self):
55 | policy_agent = PolicyAgent(self.net)
56 | action = policy_agent(self.states, self.device)
57 | assert isinstance(action, list)
58 | assert action[0] == 1
59 |
60 |
61 | def test_a2c_agent():
62 | env = gym.make("CartPole-v0")
63 | logprobs = torch.nn.functional.log_softmax(Tensor([[0.0, 100.0]]))
64 | net = Mock(return_value=(logprobs, Tensor([[1]])))
65 | states = [env.reset()]
66 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67 | a2c_agent = ActorCriticAgent(net)
68 | action = a2c_agent(states, device)
69 | assert isinstance(action, list)
70 | assert action[0] == 1
71 |
--------------------------------------------------------------------------------
/tests/models/rl/unit/test_ppo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pl_bolts.models.rl.ppo_model import PPO
4 | from pytorch_lightning import Trainer
5 | from torch import Tensor
6 |
7 |
8 | def test_discount_rewards():
9 | """Test calculation of discounted rewards."""
10 | model = PPO(env="CartPole-v0", batch_size=16, gamma=0.99)
11 |
12 | rewards = np.ones(4)
13 | gt_qvals = [3.9403989999999998, 2.9701, 1.99, 1.0]
14 |
15 | qvals = model.discount_rewards(rewards, discount=0.99)
16 |
17 | assert gt_qvals == qvals
18 |
19 |
20 | def test_critic_loss():
21 | """Test the critic loss function."""
22 |
23 | model = PPO(env="CartPole-v0", batch_size=16, gamma=0.99)
24 | obs_dim = model.env.observation_space.shape[0]
25 |
26 | batch_states = torch.rand(8, obs_dim)
27 | batch_qvals = torch.rand(8)
28 |
29 | loss = model.critic_loss(batch_states, batch_qvals)
30 |
31 | assert isinstance(loss, Tensor)
32 |
33 |
34 | def test_actor_loss_categorical():
35 | """Test the actor loss function on categorical action-space environment."""
36 |
37 | model = PPO(env="CartPole-v0", batch_size=16, gamma=0.99)
38 | obs_dim = model.env.observation_space.shape[0]
39 |
40 | batch_states = torch.rand(8, obs_dim)
41 | batch_actions = torch.rand(8).long()
42 | batch_logp_old = torch.rand(8)
43 | batch_adv = torch.rand(8)
44 |
45 | loss = model.actor_loss(batch_states, batch_actions, batch_logp_old, batch_adv)
46 |
47 | assert isinstance(loss, Tensor)
48 |
49 |
50 | def test_actor_loss_continuous():
51 | """Test the actor loss function on continuous action-space environment."""
52 |
53 | model = PPO(env="MountainCarContinuous-v0", batch_size=16, gamma=0.99)
54 | obs_dim = model.env.observation_space.shape[0]
55 | action_dim = model.env.action_space.shape[0]
56 |
57 | batch_states = torch.rand(8, obs_dim)
58 | batch_actions = torch.rand(8, action_dim)
59 | batch_logp_old = torch.rand(8)
60 | batch_adv = torch.rand(8)
61 |
62 | loss = model.actor_loss(batch_states, batch_actions, batch_logp_old, batch_adv)
63 |
64 | assert isinstance(loss, Tensor)
65 |
66 |
67 | def test_generate_trajectory_samples():
68 | torch.manual_seed(123)
69 | model = PPO("CartPole-v0", batch_size=16)
70 |
71 | obs_dim = model.env.observation_space.shape[0]
72 |
73 | sample_gen = model.generate_trajectory_samples()
74 | state, action, logp_old, qval, adv = next(sample_gen)
75 |
76 | assert state.shape[0] == obs_dim
77 | assert isinstance(action, torch.Tensor)
78 | assert logp_old
79 | assert qval
80 | assert adv
81 |
82 |
83 | def test_training_categorical():
84 | model = PPO("CartPole-v0", batch_size=16)
85 | trainer = Trainer(max_epochs=1)
86 | trainer.fit(model)
87 |
88 |
89 | def test_training_continous():
90 | model = PPO("MountainCarContinuous-v0", batch_size=16)
91 | trainer = Trainer(max_epochs=1)
92 | trainer.fit(model)
93 |
--------------------------------------------------------------------------------
/tests/models/rl/unit/test_reinforce.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from unittest import TestCase
3 |
4 | import gym
5 | import numpy as np
6 | import torch
7 | from pl_bolts.datamodules.experience_source import DiscountedExperienceSource
8 | from pl_bolts.models.rl.common.agents import Agent
9 | from pl_bolts.models.rl.common.gym_wrappers import ToTensor
10 | from pl_bolts.models.rl.common.networks import MLP
11 | from pl_bolts.models.rl.reinforce_model import Reinforce
12 | from torch import Tensor
13 |
14 |
15 | class TestReinforce(TestCase):
16 | def setUp(self) -> None:
17 | self.env = ToTensor(gym.make("CartPole-v0"))
18 | self.obs_shape = self.env.observation_space.shape
19 | self.n_actions = self.env.action_space.n
20 | self.net = MLP(self.obs_shape, self.n_actions)
21 | self.agent = Agent(self.net)
22 | self.exp_source = DiscountedExperienceSource(self.env, self.agent)
23 |
24 | parent_parser = argparse.ArgumentParser(add_help=False)
25 | parent_parser = Reinforce.add_model_specific_args(parent_parser)
26 | args_list = ["--env", "CartPole-v0", "--batch_size", "32", "--gamma", "0.99"]
27 | self.hparams = parent_parser.parse_args(args_list)
28 | self.model = Reinforce(**vars(self.hparams))
29 |
30 | self.rl_dataloader = self.model.train_dataloader()
31 |
32 | def test_loss(self):
33 | """Test the reinforce loss function."""
34 |
35 | batch_states = torch.rand(16, 4)
36 | batch_actions = torch.rand(16).long()
37 | batch_qvals = torch.rand(16)
38 |
39 | loss = self.model.loss(batch_states, batch_actions, batch_qvals)
40 |
41 | assert isinstance(loss, Tensor)
42 |
43 | def test_get_qvals(self):
44 | """Test that given an batch of episodes that it will return a list of qvals for each episode."""
45 |
46 | batch_qvals = []
47 | rewards = np.ones(32)
48 | out = self.model.calc_qvals(rewards)
49 | batch_qvals.append(out)
50 |
51 | assert isinstance(batch_qvals[0][0], float)
52 | assert batch_qvals[0][0] == batch_qvals[0][1] * self.hparams.gamma + 1.0
53 |
54 | def test_calc_q_vals(self):
55 | rewards = np.ones(4)
56 | gt_qvals = [3.9403989999999998, 2.9701, 1.99, 1.0]
57 |
58 | qvals = self.model.calc_qvals(rewards)
59 |
60 | assert gt_qvals == qvals
61 |
--------------------------------------------------------------------------------
/tests/models/rl/unit/test_sac.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import pytest
4 | import torch
5 | from pl_bolts.models.rl.sac_model import SAC
6 | from pl_bolts.utils import _GYM_GREATER_EQUAL_0_20
7 | from torch import Tensor
8 |
9 |
10 | @pytest.mark.skipif(_GYM_GREATER_EQUAL_0_20, reason="gym.error.DeprecatedEnv: Env Pendulum-v0 not found")
11 | def test_sac_loss():
12 | """Test the reinforce loss function."""
13 | parent_parser = argparse.ArgumentParser(add_help=False)
14 | parent_parser = SAC.add_model_specific_args(parent_parser)
15 | args_list = ["--env", "Pendulum-v0", "--batch_size", "32"]
16 | hparams = parent_parser.parse_args(args_list)
17 | model = SAC(**vars(hparams))
18 |
19 | batch_states = torch.rand(32, 3)
20 | batch_actions = torch.rand(32, 1)
21 | batch_rewards = torch.rand(32)
22 | batch_dones = torch.ones(32)
23 | batch_next_states = torch.rand(32, 3)
24 | batch = (batch_states, batch_actions, batch_rewards, batch_dones, batch_next_states)
25 |
26 | policy_loss, q1_loss, q2_loss = model.loss(batch)
27 |
28 | assert isinstance(policy_loss, Tensor)
29 | assert isinstance(q1_loss, Tensor)
30 | assert isinstance(q2_loss, Tensor)
31 |
32 |
33 | @pytest.mark.skipif(_GYM_GREATER_EQUAL_0_20, reason="gym.error.DeprecatedEnv: Env Pendulum-v0 not found")
34 | def test_sac_train_batch():
35 | """Tests that a single batch generates correctly."""
36 | parent_parser = argparse.ArgumentParser(add_help=False)
37 | parent_parser = SAC.add_model_specific_args(parent_parser)
38 | args_list = ["--env", "Pendulum-v0", "--batch_size", "32"]
39 | hparams = parent_parser.parse_args(args_list)
40 | model = SAC(**vars(hparams))
41 |
42 | xp_dataloader = model.train_dataloader()
43 |
44 | batch = next(iter(xp_dataloader))
45 |
46 | assert len(batch) == 5
47 | assert len(batch[0]) == model.hparams.batch_size
48 | assert isinstance(batch, list)
49 | assert all(isinstance(batch[i], Tensor) for i in range(5))
50 |
--------------------------------------------------------------------------------
/tests/models/rl/unit/test_vpg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from unittest import TestCase
3 |
4 | import gym
5 | import pytest
6 | import torch
7 | from pl_bolts.models.rl.common.agents import Agent
8 | from pl_bolts.models.rl.common.gym_wrappers import ToTensor
9 | from pl_bolts.models.rl.common.networks import MLP
10 | from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient
11 | from pl_bolts.utils import _IS_WINDOWS
12 | from torch import Tensor
13 |
14 |
15 | class TestPolicyGradient(TestCase):
16 | def setUp(self) -> None:
17 | self.env = ToTensor(gym.make("CartPole-v0"))
18 | self.obs_shape = self.env.observation_space.shape
19 | self.n_actions = self.env.action_space.n
20 | self.net = MLP(self.obs_shape, self.n_actions)
21 | self.agent = Agent(self.net)
22 |
23 | parent_parser = argparse.ArgumentParser(add_help=False)
24 | parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser)
25 | args_list = ["--env", "CartPole-v0", "--batch_size", "32"]
26 | self.hparams = parent_parser.parse_args(args_list)
27 | self.model = VanillaPolicyGradient(**vars(self.hparams))
28 |
29 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange TimeOut or MemoryError") # todo
30 | def test_loss(self):
31 | """Test the reinforce loss function."""
32 |
33 | batch_states = torch.rand(32, 4)
34 | batch_actions = torch.rand(32).long()
35 | batch_qvals = torch.rand(32)
36 |
37 | loss = self.model.loss(batch_states, batch_actions, batch_qvals)
38 |
39 | assert isinstance(loss, Tensor)
40 |
41 | def test_train_batch(self):
42 | """Tests that a single batch generates correctly."""
43 |
44 | self.model.n_steps = 4
45 | self.model.batch_size = 1
46 | xp_dataloader = self.model.train_dataloader()
47 |
48 | batch = next(iter(xp_dataloader))
49 | assert len(batch) == 3
50 | assert len(batch[0]) == self.model.batch_size
51 | assert isinstance(batch, list)
52 | assert isinstance(batch[0], Tensor)
53 | assert isinstance(batch[1], list)
54 | assert isinstance(batch[1][0], Tensor)
55 | assert isinstance(batch[2], Tensor)
56 |
--------------------------------------------------------------------------------
/tests/models/rl/unit/test_wrappers.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import gym
4 | from pl_bolts.models.rl.common.gym_wrappers import ToTensor
5 | from torch import Tensor
6 |
7 |
8 | class TestToTensor(TestCase):
9 | def setUp(self) -> None:
10 | self.env = ToTensor(gym.make("CartPole-v0"))
11 |
12 | def test_wrapper(self):
13 | state = self.env.reset()
14 | assert isinstance(state, Tensor)
15 |
16 | new_state, _, _, _ = self.env.step(1)
17 | assert isinstance(new_state, Tensor)
18 |
--------------------------------------------------------------------------------
/tests/models/self_supervised/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/self_supervised/__init__.py
--------------------------------------------------------------------------------
/tests/models/self_supervised/test_resnets.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from pl_bolts.models.self_supervised.amdim import AMDIMEncoder
4 | from pl_bolts.models.self_supervised.cpc import cpc_resnet50
5 | from pl_bolts.models.self_supervised.resnets import (
6 | resnet18,
7 | resnet34,
8 | resnet50,
9 | resnet101,
10 | resnet152,
11 | resnext50_32x4d,
12 | resnext101_32x8d,
13 | wide_resnet50_2,
14 | wide_resnet101_2,
15 | )
16 | from pl_bolts.utils import _IS_WINDOWS
17 |
18 |
19 | @pytest.mark.skipif(_IS_WINDOWS, reason="strange MemoryError") # todo
20 | @torch.no_grad()
21 | def test_cpc_resnet():
22 | x = torch.rand(3, 3, 64, 64)
23 | model = cpc_resnet50(x)
24 | model(x)
25 |
26 |
27 | @pytest.mark.parametrize(
28 | "model_class",
29 | [
30 | resnet18,
31 | resnet34,
32 | resnet50,
33 | resnet101,
34 | resnet152,
35 | resnext50_32x4d,
36 | pytest.param(resnext101_32x8d, marks=pytest.mark.skipif(_IS_WINDOWS, reason="strange MemoryError")), # todo
37 | wide_resnet50_2,
38 | pytest.param(wide_resnet101_2, marks=pytest.mark.skipif(_IS_WINDOWS, reason="strange MemoryError")), # todo
39 | ],
40 | )
41 | @torch.no_grad()
42 | def test_torchvision_resnets(model_class):
43 | x = torch.rand(3, 3, 64, 64)
44 | model = model_class(pretrained=False)
45 | model(x)
46 |
47 |
48 | @pytest.mark.parametrize(
49 | "size",
50 | [
51 | 32,
52 | pytest.param(64, marks=pytest.mark.skipif(_IS_WINDOWS, reason="failing...")),
53 | pytest.param(128, marks=pytest.mark.skipif(_IS_WINDOWS, reason="failing...")),
54 | ],
55 | )
56 | @torch.no_grad()
57 | def test_amdim_encoder(size):
58 | dummy_batch = torch.zeros((2, 3, size, size))
59 | model = AMDIMEncoder(dummy_batch, encoder_size=size)
60 | model.init_weights()
61 | model(dummy_batch)
62 |
--------------------------------------------------------------------------------
/tests/models/self_supervised/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/self_supervised/unit/__init__.py
--------------------------------------------------------------------------------
/tests/models/test_classic_ml.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset
4 | from pl_bolts.models.regression import LinearRegression
5 | from pytorch_lightning import Trainer, seed_everything
6 | from torch.utils.data import DataLoader
7 |
8 |
9 | @pytest.mark.flaky(reruns=3)
10 | def test_linear_regression_model(tmpdir):
11 | seed_everything()
12 |
13 | # --------------------
14 | # numpy data
15 | # --------------------
16 | x = np.array([[1.0, 1], [1, 2], [2, 2], [2, 3], [3, 3], [3, 4], [4, 4], [4, 5]])
17 | y = np.dot(x, np.array([1.0, 2])) + 3
18 | y = y[:, np.newaxis]
19 | loader = DataLoader(SklearnDataset(x, y), batch_size=2)
20 |
21 | model = LinearRegression(input_dim=2, learning_rate=0.6)
22 | trainer = Trainer(
23 | max_epochs=400,
24 | default_root_dir=tmpdir,
25 | logger=False,
26 | enable_checkpointing=False,
27 | )
28 | trainer.fit(
29 | model,
30 | loader,
31 | loader,
32 | )
33 |
34 | coeffs = model.linear.weight.detach().numpy().flatten()
35 | np.testing.assert_allclose(coeffs, [1, 2], rtol=1e-3)
36 | trainer.test(model, loader)
37 |
--------------------------------------------------------------------------------
/tests/models/test_mnist_templates.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from pl_bolts.datamodules import MNISTDataModule
4 | from pl_bolts.models import LitMNIST
5 | from pytorch_lightning import Trainer, seed_everything
6 | from pytorch_lightning.utilities.warnings import PossibleUserWarning
7 |
8 |
9 | def test_mnist(tmpdir, datadir, catch_warnings):
10 | warnings.filterwarnings(
11 | "ignore",
12 | message=".+does not have many workers which may be a bottleneck.+",
13 | category=PossibleUserWarning,
14 | )
15 |
16 | seed_everything(1234)
17 |
18 | datamodule = MNISTDataModule(data_dir=datadir, num_workers=0)
19 | model = LitMNIST()
20 | trainer = Trainer(
21 | limit_train_batches=0.02,
22 | limit_val_batches=0.02,
23 | max_epochs=1,
24 | limit_test_batches=0.02,
25 | default_root_dir=tmpdir,
26 | log_every_n_steps=5,
27 | accelerator="auto",
28 | )
29 | trainer.fit(model, datamodule=datamodule)
30 | loss = trainer.callback_metrics["train_loss"]
31 | assert loss <= 2.3, "mnist failed"
32 |
--------------------------------------------------------------------------------
/tests/models/yolo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/yolo/__init__.py
--------------------------------------------------------------------------------
/tests/models/yolo/unit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/models/yolo/unit/__init__.py
--------------------------------------------------------------------------------
/tests/models/yolo/unit/test_target_matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pl_bolts.models.detection.yolo.target_matching import _sim_ota_match
3 |
4 |
5 | def test_sim_ota_match(catch_warnings):
6 | # For each of the two targets, k will be the sum of the IoUs. 2 and 1 predictions will be selected for the first and
7 | # the second target respectively.
8 | ious = torch.tensor([[0.1, 0.2], [0.1, 0.3], [0.9, 0.4], [0.9, 0.1]])
9 | # Costs will determine that the first and the last prediction will be selected for the first target, and the first
10 | # prediction will be selected for the second target. The first prediction was selected for two targets, but it will
11 | # be matched to the best target only (the second one).
12 | costs = torch.tensor([[0.3, 0.1], [0.5, 0.2], [0.4, 0.5], [0.3, 0.3]])
13 | matched_preds, matched_targets = _sim_ota_match(costs, ious)
14 |
15 | # The first and the last prediction were matched.
16 | assert len(matched_preds) == 4
17 | assert matched_preds[0]
18 | assert not matched_preds[1]
19 | assert not matched_preds[2]
20 | assert matched_preds[3]
21 |
22 | # The first prediction was matched to the target 1 and the last prediction was matched to target 0.
23 | assert len(matched_targets) == 2
24 | assert matched_targets[0] == 1
25 | assert matched_targets[1] == 0
26 |
--------------------------------------------------------------------------------
/tests/optimizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/optimizers/__init__.py
--------------------------------------------------------------------------------
/tests/test_utilities.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from tests import PROJECT_ROOT
4 |
5 |
6 | def test_paths():
7 | assert os.path.isdir(PROJECT_ROOT)
8 |
--------------------------------------------------------------------------------
/tests/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/transforms/__init__.py
--------------------------------------------------------------------------------
/tests/transforms/test_normalizations.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from pl_bolts.transforms.dataset_normalizations import (
4 | cifar10_normalization,
5 | emnist_normalization,
6 | imagenet_normalization,
7 | stl10_normalization,
8 | )
9 | from pytorch_lightning import seed_everything
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "normalization",
14 | [cifar10_normalization, imagenet_normalization, stl10_normalization],
15 | )
16 | def test_normalizations(normalization, catch_warnings):
17 | """Test normalizations for CIFAR10, ImageNet, STL10."""
18 | seed_everything(1234)
19 | x = torch.rand(3, 32, 32)
20 | assert normalization()(x).shape == (3, 32, 32)
21 | assert x.min() >= 0.0
22 | assert x.max() <= 1.0
23 |
24 |
25 | @pytest.mark.parametrize(
26 | "split",
27 | ["balanced", "byclass", "bymerge", "digits", "letters", "mnist"],
28 | )
29 | def test_emnist_normalizations(split, catch_warnings):
30 | """Test normalizations for each EMNIST dataset split."""
31 | x = torch.rand(1, 28, 28)
32 | assert emnist_normalization(split)(x).shape == (1, 28, 28)
33 | assert x.min() >= 0.0
34 | assert x.max() <= 1.0
35 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/test_dependency.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pl_bolts.utils._dependency import requires
3 |
4 |
5 | @requires("torch")
6 | def using_torch():
7 | return True
8 |
9 |
10 | @requires("torch.anything.wrong")
11 | def using_torch_wrong_path():
12 | return True
13 |
14 |
15 | @requires("torch>99.0")
16 | def using_torch_bad_version():
17 | return True
18 |
19 |
20 | def test_requires_pass():
21 | assert using_torch() is True
22 |
23 |
24 | def test_requires_fail():
25 | with pytest.raises(ModuleNotFoundError, match="Required dependencies not available"):
26 | assert using_torch_wrong_path()
27 | with pytest.raises(ModuleNotFoundError, match="Required dependencies not available"):
28 | assert using_torch_bad_version()
29 |
--------------------------------------------------------------------------------
/tests/utils/test_self_supervised.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightning-Universe/lightning-bolts/2c4602aa684e7b90e7ffdcea1d3f93a20f9c2ead/tests/utils/test_self_supervised.py
--------------------------------------------------------------------------------
/tests/utils/test_semi_supervised.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from warnings import warn
3 |
4 | import numpy as np
5 | import pytest
6 | import torch
7 | from pl_bolts.utils.semi_supervised import balance_classes, generate_half_labeled_batches
8 |
9 | try:
10 | from sklearn.utils import shuffle as sk_shuffle # noqa: F401
11 |
12 | _SKLEARN_AVAILABLE = True
13 | except ImportError:
14 | warn("Failing to import `sklearn` correctly")
15 | _SKLEARN_AVAILABLE = False
16 |
17 |
18 | @pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="failing to import SKLearn")
19 | def test_balance_classes():
20 | x = torch.rand(100, 3, 32, 32)
21 | c1 = torch.zeros(20, 1)
22 | c2 = torch.zeros(20, 1) + 1
23 | c3 = torch.zeros(60, 1) + 2
24 |
25 | y = torch.cat([c1, c2, c3], dim=0).int().cpu().numpy().flatten().tolist()
26 | balance_classes(x, y, batch_size=10)
27 |
28 |
29 | def test_generate_half_labeled_batches():
30 | smaller_set_x = np.random.rand(100, 3, 32, 32)
31 | smaller_set_y = np.random.randint(0, 3, (100, 1))
32 | larger_set_x = np.random.rand(100, 3, 32, 32)
33 | larger_set_y = np.zeros_like(smaller_set_y) + -1
34 |
35 | (mixed_x, mixed_y) = generate_half_labeled_batches(smaller_set_x, smaller_set_y, larger_set_x, larger_set_y, 10)
36 |
37 | # only half the batch should be unlabeled
38 | for i in range(0, len(mixed_y), 10):
39 | batch = mixed_y[i : i + 10]
40 | counts = Counter(batch.flatten().tolist())
41 | assert counts[-1] == 5
42 |
--------------------------------------------------------------------------------