├── .circleci └── config.yml ├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configen └── conf │ ├── torch │ └── configen.yaml │ └── torchvision.yaml ├── examples ├── mnist_00.md └── mnist_00.py ├── hydra-configs-projects.txt ├── hydra-configs-torch ├── hydra_configs │ └── torch │ │ ├── nn │ │ └── modules │ │ │ └── loss.py │ │ ├── optim │ │ ├── __init__.py │ │ ├── adadelta.py │ │ ├── adagrad.py │ │ ├── adam.py │ │ ├── adamax.py │ │ ├── adamw.py │ │ ├── asgd.py │ │ ├── lbfgs.py │ │ ├── lr_scheduler.py │ │ ├── rmsprop.py │ │ ├── rprop.py │ │ ├── sgd.py │ │ └── sparse_adam.py │ │ └── utils │ │ └── data │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── distributed.py │ │ └── sampler.py ├── requirements │ └── dev.txt ├── setup.py └── tests │ ├── test_instantiate_data.py │ ├── test_instantiate_losses.py │ └── test_instantiate_optimizers.py ├── hydra-configs-torchvision ├── hydra_configs │ └── torchvision │ │ ├── __init__.py │ │ ├── datasets │ │ ├── mnist.py │ │ └── vision.py │ │ └── transforms │ │ └── transforms.py ├── requirements │ └── dev.txt ├── setup.py └── tests │ ├── test_instantiate_datasets.py │ └── test_instantiate_transforms.py ├── noxfile.py ├── requirements ├── dev.txt └── requirements.txt └── setup.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | commands: 4 | linux: 5 | description: Commands run on Linux 6 | parameters: 7 | py_version: 8 | type: string 9 | steps: 10 | - checkout 11 | - run: 12 | name: Preparing environment - Conda 13 | command: | 14 | curl -o Miniconda3-py38_4.8.3-Linux-x86_64.sh https://repo.anaconda.com/miniconda/Miniconda3-py38_4.8.3-Linux-x86_64.sh 15 | bash ./Miniconda3-py38_4.8.3-Linux-x86_64.sh -b 16 | - run: 17 | name: Preparing environment - Other Dependencies 18 | command: | 19 | sudo apt-get update 20 | sudo apt-get install -y openjdk-11-jre 21 | - run: 22 | name: Preparing environment - Hydra Torch Configs 23 | command: | 24 | ~/miniconda3/bin/conda init bash 25 | ~/miniconda3/bin/conda create -n hypytorch python=<< parameters.py_version >> -yq 26 | 27 | jobs: 28 | test_linux: 29 | parameters: 30 | py_version: 31 | type: string 32 | docker: 33 | - image: cimg/base:stable-18.04 34 | steps: 35 | - linux: 36 | py_version: << parameters.py_version >> 37 | - run: 38 | name: Linting/Testing Hydra Torch Configs 39 | command: | 40 | export PATH="$HOME/miniconda3/envs/hypytorch/bin:$PATH" 41 | export NOX_PYTHON_VERSIONS=<< parameters.py_version >> 42 | pip install nox 43 | nox 44 | 45 | 46 | 47 | workflows: 48 | version: 2.1 49 | build: 50 | jobs: 51 | - test_linux: 52 | matrix: 53 | parameters: 54 | py_version: ["3.6", "3.7", "3.8"] 55 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = 3 | .git 4 | ,.nox,build 5 | ,hydra/grammar/gen 6 | ,dist 7 | ,tools/configen/example/gen 8 | ,tools/configen/tests/test_modules/expected 9 | max-line-length = 119 10 | copyright-check = True 11 | select = E,F,W,C 12 | copyright-regexp=Copyright \(c\) Facebook, Inc. and its affiliates. All Rights Reserved 13 | ignore=W503,E203 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | *.sw* 4 | *.pyc 5 | *.egg-info 6 | dist 7 | build 8 | .nox 9 | outputs 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: local 9 | hooks: 10 | - id: black 11 | name: black 12 | entry: python -m black 13 | language: system 14 | types: [python] 15 | - id: flake8 16 | name: flake8 17 | entry: python -m flake8 18 | language: system 19 | types: [python] 20 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | Change logs be here. 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to hydra-torch 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | To see a list of config projects currently maintained in this repository, please see: [hydra-configs-projects.txt](hydra-configs-projects.txt) 6 | 7 | In order to track progress or find an issue to draft a PR for, please see the [**Projects**](https://github.com/pytorch/hydra-torch/projects) tab. 8 | 9 | ## Opportunities to Contribute 10 | There are 3 main ways to contribute starting with the most straightfoward option: 11 | 12 | 1. **Filing Issues against Configs:** Noticing a bug that you believe is a problem with the generated config? Please file an issue for the offending config class stating clearly your versions of `hydra-core` and the library being configured (e.g. `torch`). If you believe it is actually a problem with hydra or torch, file the issue in their respective repositories. 13 | 14 | > **NOTE:** Please include the manual tag for the project and library version in your issue title. If, for example, there is an issue with `AdamConf` for `torch1.6` which comes from `hydra-configs-torch` v1.6.1, your issue name might look like: 15 | **[hydra-configs-torch][1.6.1] AdamConf does not instantiate**. 16 | 17 | 2. **Example Usecase / Tutorial:** The `hydra-torch` repository not only hosts config packages like `hydra-configs-torch`,`hydra-configs-torchvision`, etc., it also aggregates examples of how to structure projects utilizing hydra and torch. The bar is high for examples that get included, but we will work together as a community to hone in on what the best practices are. Ideally, example usecases will come along with an incremental tutorial that introduces a user to the methodology being followed. If you have an interesting way to use hydra/torch, write up an example and show us in a draft PR! 18 | 19 | 3. **Maintaining Configs:** After the initial (considerable) setup effort, the goal of this repository is to be self-sustaining meaning code can be autogrenerated when APIs change. In order to contribute to a particular package like `hydra-configs-torch`, please see the [**Projects**](https://github.com/pytorch/hydra-torch/projects) tab to identify outstanding issues per project and configured library version. Before contributing at this level, please familiarize with [configen](https://github.com/facebookresearch/hydra/tree/main/tools/configen). We are actively developing this tool as well. 20 | 21 | ## Linting / Formatting 22 | Please download the formatting / linting requirements: `pip install -r requirements/dev.txt`. 23 | Please install the pre-commit config for this environment: `pre-commit install`. 24 | 25 | ## Pull Requests 26 | We actively welcome your pull requests. 27 | 28 | 1. Fork the repo and create your branch from `main`. 29 | 2. If you've added code that should be tested, add tests. 30 | 3. If you've changed APIs, update the documentation. 31 | 4. Ensure the test suite passes. 32 | 5. Make sure your code lints. 33 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 34 | 35 | ## Contributor License Agreement ("CLA") 36 | In order to accept your pull request, we need you to submit a CLA. You only need 37 | to do this once to work on any of Facebook's open source projects. 38 | 39 | Complete your CLA here: 40 | 41 | ## Issues 42 | We use GitHub issues to track public bugs. Please ensure your description is 43 | clear and has sufficient instructions to be able to reproduce the issue. 44 | 45 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 46 | disclosure of security bugs. In those cases, please go through the process 47 | outlined on that page and do not file a public issue. 48 | 49 | ## License 50 | By contributing to hydra-torch, you agree that your contributions will be licensed 51 | under the LICENSE file in the root directory of this source tree. 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hydra-torch 2 | Configuration classes enabling type-safe PyTorch configuration for Hydra apps. 3 | **This repo is work in progress.** 4 | 5 | The config dataclasses are generated using [configen](https://github.com/facebookresearch/hydra/tree/main/tools/configen), check it out if you want to generate config dataclasses for your own project. 6 | 7 | ### Install: 8 | ``` 9 | # For now, please obtain through github. Soon, versioned (per-project) dists will be on PyPI. 10 | pip install git+https://github.com/pytorch/hydra-torch 11 | ``` 12 | 13 | ### Example config: 14 | Here is one of many configs available. Notice it uses the defaults defined in the torch function signatures: 15 | ```python 16 | @dataclass 17 | class TripletMarginLossConf: 18 | _target_: str = "torch.nn.modules.loss.TripletMarginLoss" 19 | margin: float = 1.0 20 | p: float = 2.0 21 | eps: float = 1e-06 22 | swap: bool = False 23 | size_average: Any = None 24 | reduce: Any = None 25 | reduction: str = "mean" 26 | ``` 27 | 28 | ### Importing Convention: 29 | ```python 30 | from hydra_configs..path.to.module import Conf 31 | ``` 32 | where `` is the package being configured and `path.to.module` is the path in the original package. 33 | 34 | Inferring where the package is located is as simple as prepending `hydra_configs.` and postpending `Conf` to the original class import: 35 | e.g. 36 | ```python 37 | #module to be configured 38 | from torch.optim.adam import Adam 39 | 40 | #config for the module 41 | from hydra_configs.torch.optim.adam import AdamConf 42 | ``` 43 | 44 | 45 | ### Getting Started: 46 | Take a look at our tutorial series: 47 | 1. [Basic Tutorial](examples/mnist_00.md) 48 | 2. Intermediate Tutorial (coming soon) 49 | 3. Advanced Tutorial (coming soon) 50 | 51 | ### Other Config Projects: 52 | A list of projects following the `hydra_configs` convention (please notify us if you have one!): 53 | 54 | [Pytorch Lightning](https://github.com/romesco/hydra-lightning) 55 | 56 | ### License 57 | hydra-torch is licensed under [MIT License](LICENSE). 58 | -------------------------------------------------------------------------------- /configen/conf/torch/configen.yaml: -------------------------------------------------------------------------------- 1 | configen: 2 | # output directory 3 | output_dir: ${hydra:runtime.cwd} 4 | 5 | header: | 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # 8 | # Generated by configen, do not edit. 9 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 10 | # fmt: off 11 | # isort:skip_file 12 | # flake8: noqa 13 | 14 | module_path_pattern: 'hydra_configs/{{module_path}}.py' 15 | 16 | # list of modules to generate configs for 17 | modules: 18 | - name: torch.optim.adadelta 19 | classes: 20 | - Adadelta 21 | 22 | - name: torch.optim.adagrad 23 | classes: 24 | - Adagrad 25 | 26 | - name: torch.optim.adam 27 | classes: 28 | - Adam 29 | 30 | - name: torch.optim.adamax 31 | classes: 32 | - Adamax 33 | 34 | - name: torch.optim.adamw 35 | classes: 36 | - AdamW 37 | 38 | - name: torch.optim.asgd 39 | classes: 40 | - ASGD 41 | 42 | - name: torch.optim.lbfgs 43 | classes: 44 | - LBFGS 45 | 46 | - name: torch.optim.rmsprop 47 | classes: 48 | - RMSprop 49 | 50 | - name: torch.optim.rprop 51 | classes: 52 | - Rprop 53 | 54 | - name: torch.optim.sgd 55 | classes: 56 | - SGD 57 | 58 | - name: torch.optim.sparse_adam 59 | classes: 60 | - SparseAdam 61 | 62 | - name: torch.optim.lr_scheduler 63 | classes: 64 | - LambdaLR 65 | - MultiplicativeLR 66 | - StepLR 67 | - MultiStepLR 68 | - ExponentialLR 69 | - CosineAnnealingLR 70 | - ReduceLROnPlateau 71 | - CyclicLR 72 | - CosineAnnealingWarmRestarts 73 | - OneCycleLR 74 | 75 | - name: torch.utils.data.dataloader 76 | classes: 77 | - DataLoader 78 | 79 | - name: torch.utils.data.dataset 80 | classes: 81 | - Dataset 82 | - ChainDataset 83 | - ConcatDataset 84 | - IterableDataset 85 | - TensorDataset 86 | - Subset 87 | 88 | - name: torch.utils.data.sampler 89 | classes: 90 | - Sampler 91 | - BatchSampler 92 | - RandomSampler 93 | - SequentialSampler 94 | - SubsetRandomSampler 95 | - WeightedRandomSampler 96 | 97 | - name: torch.utils.data.distributed 98 | classes: 99 | - DistributedSampler 100 | 101 | - name: torch.nn.modules.loss 102 | classes: 103 | - BCELoss 104 | - BCEWithLogitsLoss 105 | - CosineEmbeddingLoss 106 | - CTCLoss 107 | - L1Loss 108 | - HingeEmbeddingLoss 109 | - KLDivLoss 110 | - MarginRankingLoss 111 | - MSELoss 112 | - MultiLabelMarginLoss 113 | - MultiLabelSoftMarginLoss 114 | - MultiMarginLoss 115 | - NLLLoss 116 | - NLLLoss2d 117 | - PoissonNLLLoss 118 | - SmoothL1Loss 119 | - SoftMarginLoss 120 | - TripletMarginLoss 121 | -------------------------------------------------------------------------------- /configen/conf/torchvision.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - configen_schema 3 | 4 | configen: 5 | # output directory 6 | output_dir: ${hydra:runtime.cwd} 7 | 8 | header: | 9 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 10 | # 11 | # Generated by configen, do not edit. 12 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 13 | # fmt: off 14 | # isort:skip_file 15 | # flake8: noqa 16 | 17 | module_path_pattern: 'hydra_configs/{{module_path}}.py' 18 | 19 | # list of modules to generate configs for 20 | modules: 21 | - name: torchvision.datasets.vision 22 | classes: 23 | - VisionDataset 24 | - StandardTransform 25 | 26 | - name: torchvision.datasets.mnist 27 | # mnist datasets 28 | classes: 29 | - MNIST 30 | - FashionMNIST 31 | - KMNIST 32 | # TODO: The following need to be manually created for torchvision==0.7 33 | # - EMNIST 34 | # - QMNIST 35 | # 36 | - name: torchvision.transforms.transforms 37 | default_flags: 38 | _convert_: ALL 39 | # mnist datasets 40 | classes: 41 | - CenterCrop 42 | - ColorJitter 43 | - Compose 44 | - ConvertImageDtype 45 | - FiveCrop 46 | - Grayscale 47 | - Lambda 48 | - LinearTransformation 49 | - Normalize 50 | - Pad 51 | - PILToTensor 52 | - RandomAffine 53 | - RandomApply 54 | - RandomChoice 55 | - RandomCrop 56 | - RandomErasing 57 | - RandomGrayscale 58 | - RandomHorizontalFlip 59 | - RandomOrder 60 | - RandomPerspective 61 | - RandomResizedCrop 62 | - RandomRotation 63 | - RandomTransforms 64 | - RandomVerticalFlip 65 | - Resize 66 | - TenCrop 67 | - ToPILImage 68 | - ToTensor 69 | -------------------------------------------------------------------------------- /examples/mnist_00.md: -------------------------------------------------------------------------------- 1 | # MNIST Basic Tutorial 2 | 3 | This tutorial series is built around the [PyTorch MNIST example] and is meant to demonstrate how to modify your PyTorch code to be configured by Hydra. We will start with the simplest case which introduces one central concept while minimizing altered code. In the following tutorials ([Intermediate][Intermediate Tutorial] and [Advanced][Advanced Tutorial]), we will show how a few additional changes can yield an even more powerful end product. 4 | 5 | The source file can be found at [mnist_00.py]. 6 | 7 | ### Pre-reading 8 | Although this tutorial is aimed at being self-contained, taking a look through Hydra's terminology as well as the basic and advanced tutorials couldn't hurt. 9 | 10 | 1. [Hydra Terminology] 11 | 2. [Hydra Basic Tutorial] 12 | 3. [Hydra Structured Configs Tutorial] 13 | 14 | ### Contents 15 | 16 | 1. [The Hydra Block](#the-hydra-block) 17 | 1. [Imports](#imports) 18 | 2. [Parting with Argparse](#parting-with-argparse) 19 | 3. [Top Level Config](#top-level-config) 20 | 4. [Adding the Top Level Config to the ConfigStore](#adding-the-top-level-config-to-the-configstore) 21 | 2. [Dropping into `main()`](#dropping-into-main) 22 | 1. [Instantiating the Optimizer and Scheduler](#instantiating-the-optimizer-and-scheduler) 23 | 3. [Running with Hydra](#running-with-hydra) 24 | 1. [Commandline Overrides](#command-line-overrides) 25 | 2. [Multirun](#multirun) 26 | 4. [Summary](#summary) 27 | 28 | 29 | *** 30 | ## The 'HYDRA BLOCK' 31 | 32 | For clarity, as we modify the [PyTorch MNIST example], we will make the diffs explicit. Most of the changes we introduce will be at the top of the file within the commented `##### HYDRA BLOCK #####`, though in practice much of this block could reside in its own concise imported file. 33 | 34 | ### Imports 35 | ```python 36 | import hydra 37 | from hydra.core.config_store import ConfigStore 38 | from dataclasses import dataclass 39 | 40 | # hydra-torch structured config imports 41 | from hydra_configs.torch.optim import AdadeltaConf 42 | from hydra_configs.torch.optim.lr_scheduler import StepLRConf 43 | ``` 44 | 45 | There are two areas in our Hydra-specific imports. First, since we define configs in this file, we need access to the following: 46 | - the `ConfigStore` 47 | - the `dataclass` decorator (for structured configs) 48 | 49 | **The [ConfigStore]** is a singleton object which all config objects are registered to. This gives Hydra access to our structured config definitions *once they're registered*. 50 | 51 | **[Structured Configs][hydra structured configs tutorial]** are dataclasses that Hydra can use to compose complex config objects. We can think of them as templates or 'starting points' for our configs. Each `*Conf` file provided by `hydra-torch` is a structured config. See an example of one below: 52 | 53 | ```python 54 | # the structured config for Adadelta imported from config.torch.optim: 55 | @dataclass 56 | class AdadeltaConf: 57 | _target_: str = "torch.optim.adadelta.Adadelta" 58 | params: Any = MISSING 59 | lr: Any = 1.0 60 | rho: Any = 0.9 61 | eps: Any = 1e-06 62 | weight_decay: Any = 0 63 | ``` 64 | 65 | > **NOTE:** [`MISSING`] is a special constant used to indicate there is no default value specified. 66 | 67 | The second set of imports correspond to two components in the training pipeline of the [PyTorch MNIST example]: 68 | 69 | - `Adadelta` which resides in `torch.optim` 70 | - `StepLR` which resides in `torch.optim.lr_scheduler` 71 | 72 | Note that the naming convention for the import hierarchy mimics that of `torch`. We correspondingly import the following structured configs: 73 | - `AdadeltaConf` from `config.torch.optim` 74 | - `StepLRConf` from `config.torch.optim.lr_scheduler` 75 | 76 | Generally, we follow the naming convention of applying the suffix `-Conf` to distinguish the structured config class from the class of the object to be configured. 77 | 78 | *** 79 | ### Top Level Config 80 | After importing two pre-defined structured configs for components in our training pipeline, the optimizer and scheduler, we still need a "top level" config to merge everything. We can call this config class `MNISTConf`. You will notice that this class is nothing more than a python `dataclass` and corresponds to, you guessed it, a *structured config*. 81 | 82 | > **NOTE:** The top level config is application specific and thus is not provided by `hydra-torch`. 83 | 84 | We can start this out by including the configs we know we will need for the optimizer (`Adadelta`) and scheduler (`StepLR`): 85 | ```python 86 | # our top level config: 87 | @dataclass 88 | class MNISTConf: 89 | adadelta: AdadeltaConf = AdadeltaConf() 90 | steplr: StepLRConf = StepLRConf(step_size=1) 91 | ``` 92 | Notice that for `StepLRConf()` we need to pass `step_size=1` when we initialize because it's default value is `MISSING`. 93 | ```python 94 | # the structured config imported from hydra-torch in config.torch.optim.lr_scheduler 95 | @dataclass 96 | class StepLRConf: 97 | _target_: str = "torch.optim.lr_scheduler.StepLR" 98 | optimizer: Any = MISSING 99 | step_size: Any = MISSING 100 | gamma: Any = 0.1 last_epoch: Any = -1 101 | ``` 102 | > **NOTE:** The `hydra-torch` configs are generated from the PyTorch source and rely on whether the module uses type annotation. Once additional type annotation is added, these configs will become more strict providing greater type safety. 103 | 104 | Later, we will specify the optimizer (also default `MISSING`) as a passed through argument when the actual `StepLR` object is instantiated. 105 | 106 | ### Adding the Top Level Config to the ConfigStore 107 | Very simply, but crucially, we add the top-level config class `MNISTConf` to the `ConfigStore` in two lines: 108 | ```python 109 | cs = ConfigStore.instance() 110 | cs.store(name="mnistconf", node=MNISTConf) 111 | ``` 112 | The name `mnistconf` will be passed to the `@hydra` decorator when we get to `main()`. 113 | 114 | *** 115 | ### Parting with Argparse 116 | 117 | Now we're starting to realize our relationship with `argparse` isn't as serious as we thought it was. Although `argparse` is powerful, we can take it a step further. In the process we hope to introduce greater organization and free our primary file from as much boilerplate as possible. 118 | 119 | One feature Hydra provides us is aggregating our configuration files alongside any 'specifications' we pass via command line arguments. What this means is as long as we have the configuration file which defines possible arguments like `save_model` or `dry_run`, there is no need to also litter our code with `argparse` definitions. 120 | 121 | This whole block in `main()`: 122 | ```python 123 | def main(): 124 | # Training settings 125 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 126 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 127 | help='input batch size for training (default: 64)') 128 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 129 | help='input batch size for testing (default: 1000)') 130 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 131 | help='number of epochs to train (default: 14)') 132 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 133 | help='learning rate (default: 1.0)') 134 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 135 | help='Learning rate step gamma (default: 0.7)') 136 | parser.add_argument('--no-cuda', action='store_true', default=False, 137 | help='disables CUDA training') 138 | parser.add_argument('--dry-run', action='store_true', default=False, 139 | help='quickly check a single pass') 140 | parser.add_argument('--seed', type=int, default=1, metavar='S', 141 | help='random seed (default: 1)') 142 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 143 | help='how many batches to wait before logging training status') 144 | parser.add_argument('--save-model', action='store_true', default=False, 145 | help='For Saving the current Model') 146 | args = parser.parse_args() 147 | ``` 148 | becomes: 149 | ```python 150 | def main(cfg): 151 | # All argparse args now reside in cfg 152 | ``` 153 | Our initial strategy is to dump these arguments directly in our top-level configuration. 154 | ```python 155 | @dataclass 156 | class MNISTConf: 157 | batch_size: int = 64 158 | test_batch_size: int = 1000 159 | epochs: int = 14 160 | no_cuda: bool = False 161 | dry_run: bool = False 162 | seed: int = 1 163 | log_interval: int 164 | save_model: bool = False 165 | adadelta: AdadeltaConf = AdadeltaConf() 166 | steplr: StepLRConf = StepLRConf(step_size=1) 167 | ``` 168 | > **NOTE:** `learning_rate` and `gamma` are included in `AdadeltaConf()` and so they were omitted from the top-level args. 169 | 170 | This works, but can feel a bit flat and disorganized (much like `argparse` args can be). Don't worry, we will remedy this later in the tutorial series. Note, we also sacrifice `help` strings. This is a planned feature, but not supported in Hydra just yet. 171 | 172 | Now our `argparse` args are at the same level as our optimizer and scheduler configs. We will remove `lr` and `gamma` since they are already present within the optimizer config `AdadeltaConf`. 173 | *** 174 | ## Dropping into `main()` 175 | Now that we've defined all of our configs, we just need to let Hydra create our `cfg` object at runtime and make sure the `cfg` is plumbed to any object we want it to configure. 176 | ```python 177 | @hydra.main(config_name='mnistconf') 178 | def main(cfg): 179 | print(cfg.pretty()) 180 | ... 181 | ``` 182 | The single idea here is that `@hydra.main` looks for a config in the `ConfigStore` instance, `cs` named "`mnistconf`". It finds the `MNISTConf` (our top level conf) we registered to that name and populates `cfg` inside `main()` with the fully expanded structured config. This includes our optimizer and scheduler configs, `cfg.adadelta` and `cfg.steplr`, respectively. 183 | 184 | Instrumenting `main()` is simple. Anywhere we find `args`, replace this with `cfg` since we put all of the `argparse` arguments at the top level. For example, `args.batch_size` becomes `cfg.batch_size`: 185 | ```python 186 | # the first few lines of main 187 | ... 188 | use_cuda = not cfg.no_cuda and torch.cuda.is_available() # DIFF args.no_cuda 189 | torch.manual_seed(cfg.seed) # DIFF args.seed 190 | device = torch.device("cuda" if use_cuda else "cpu") 191 | 192 | train_kwargs = {'batch_size': cfg.batch_size} # DIFF args.batch_size 193 | test_kwargs = {'batch_size': cfg.test_batch_size} # DIFF args.test_batch_size 194 | ... 195 | ``` 196 | 197 | 198 | ### Instantiating the optimizer and scheduler 199 | Still inside `main()`, we want to draw attention to two slightly special cases before moving on. Both the `optimizer` and `scheduler` are instantiated manually by specifying each argument with its `cfg` equivalent. Note that since these are nested fields, each of these parameters is two levels down e.g. `lr=args.learning_rate` becomes `lr=cfg.adadelta.lr`. 200 | 201 | ```python 202 | optimizer = Adadelta(lr=cfg.adadelta.lr, #DIFF lr=args.learning_rate 203 | rho=cfg.adadelta.rho, 204 | eps=cfg.adadelta.eps, 205 | weight_decay=cfg.adadelta.weight_decay, 206 | params=model.parameters() 207 | ``` 208 | In this case, the `optimizer` has one argument that is not a part of our config -- `params`. If it wasn't obvious, this needs to be passed from the initialized `Net()` called model. In the structured config that initialized `cfg.adadelta`, `params` is default to `MISSING`. The same is true of the `optimizer` field in `StepLRConf`. 209 | 210 | ```python 211 | scheduler = StepLR(step_size=cfg.steplr.step_size, 212 | gamma=cfg.steplr.gamma, 213 | last_epoch=cfg.steplr.last_epoch, 214 | optimizer=optimizer 215 | ``` 216 | This method for instantiation is the least invasive to the original code, but it is also the least flexible and highly verbose. Check out the [Intermediate Tutorial] for a better approach that will allow us to hotswap optimizers and schedulers, all while writing less code. 217 | 218 | *** 219 | ## Running with Hydra 220 | 221 | ```bash 222 | $ python 00_minst.py 223 | ``` 224 | That's it. Since the `@hydra.main` decorator is above `def main(cfg)`, Hydra will manage the command line, logging, and saving outputs to a date/time stamped directory automatically. These are all configurable, but the default behavior ensures expected functionality. For example, if a model checkpoint is saved, it will appear in a new directory `./outputs/DATE/TIME/`. 225 | 226 | ### New Super Powers 🦸 227 | 228 | #### Command Line Overrides 229 | 230 | Much like passing argparse args through the CLI, we can use our default values specified in `MNISTConf` and override only the arguments/parameters we want to tweak: 231 | 232 | ```bash 233 | $ python mnist_00.py epochs=1 save_model=True checkpoint_name='experiment0.pt' 234 | ``` 235 | 236 | For more on command line overrides, see: [Hydra CLI] and [Hydra override syntax]. 237 | 238 | #### Multirun 239 | We often end up wanting to sweep our optimizer's learning rate. Here's how Hydra can help facilitate: 240 | ```bash 241 | $ python mnist_00.py -m adadelta.lr="0.001, 0.01, 0.1" 242 | ``` 243 | Notice the `-m` which indicates we want to schedule 3 jobs where the learning rate changes by an order of magnitude across each training session. 244 | 245 | It can be useful to test multirun outputs by passing `dry_run=True` and setting `epochs=1`: 246 | ```bash 247 | $ python mnist_00.py -m epochs=1 dry_run=True adadelta.lr="0.001,0.01, 0.1" 248 | ``` 249 | 250 | > **NOTE:** these jobs can be dispatched to different resources and run in parallel or scheduled to run serially (by default). More info on multirun: [Hydra Multirun]. Hydra can use different hyperparameter search tools as well. See: [Hydra Ax plugin] and [Hydra Nevergrad plugin]. 251 | 252 | *** 253 | ## Summary 254 | In this tutorial, we demonstrated the path of least resistance to configuring your existing PyTorch code with Hydra. The main benefits we get from the 'Basic' level are: 255 | - No more boilerplate `argparse` taking up precious linecount. 256 | - All training related arguments (`epochs`, `save_model`, etc.) are now configurable via Hydra. 257 | - **All** optimizer/scheduler (`Adadelta`/`StepLR`) arguments are exposed for configuration 258 | -- extending beyond only the ones the user wrote argparse code for. 259 | - We have offloaded the book-keeping of compatible `argparse` code to Hydra via `hydra-torch` which runs tests ensuring all arguments track the API for the correct version of `pytorch`. 260 | 261 | However, there are some limitations in our current strategy that the [Intermediate Tutorial] will address. Namely: 262 | - Configuring the model (*think architecture search*) 263 | - Configuring the dataset (*think transfer learning*) 264 | - Swapping in and out different Optimizers/Schedulers 265 | 266 | Once comfortable with the basics, continue on to the [Intermediate Tutorial]. 267 | 268 | [//]: # (These are reference links used in the body of this note and get stripped out when the markdown processor does its job. There is no need to format nicely because it shouldn't be seen. Thanks SO - http://stackoverflow.com/questions/4823468/store-comments-in-markdown-syntax) 269 | [pytorch mnist example]: 270 | [mnist_00.py]: mnist_00.py 271 | [config schema]: 272 | [configstore]: 273 | [hydra basic tutorial]: 274 | [hydra structured configs tutorial]: 275 | [hydra structured configs example]: 276 | [hydra terminology]: 277 | [omegaconf]: 278 | [`missing`]: 279 | [hydra cli]: 280 | [hydra override syntax]: 281 | [hydra multirun]: 282 | [hydra ax plugin]: 283 | [hydra nevergrad plugin]: 284 | [Intermediate Tutorial]: mnist_01.md 285 | [Advanced Tutorial]: mnist_02.md 286 | -------------------------------------------------------------------------------- /examples/mnist_00.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # flake8: noqa 3 | from __future__ import print_function 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import datasets, transforms 8 | from torch.optim import Adadelta 9 | from torch.optim.lr_scheduler import StepLR 10 | 11 | ###### HYDRA BLOCK ###### 12 | import hydra 13 | from hydra.core.config_store import ConfigStore 14 | from dataclasses import dataclass 15 | 16 | # hydra-torch structured config imports 17 | from hydra_configs.torch.optim import AdadeltaConf 18 | from hydra_configs.torch.optim.lr_scheduler import StepLRConf 19 | 20 | 21 | @dataclass 22 | class MNISTConf: 23 | batch_size: int = 64 24 | test_batch_size: int = 1000 25 | epochs: int = 14 26 | no_cuda: bool = False 27 | dry_run: bool = False 28 | seed: int = 1 29 | log_interval: int = 10 30 | save_model: bool = False 31 | checkpoint_name: str = "unnamed.pt" 32 | adadelta: AdadeltaConf = AdadeltaConf() 33 | steplr: StepLRConf = StepLRConf( 34 | step_size=1 35 | ) # we pass a default for step_size since it is required, but missing a default in PyTorch (and consequently in hydra-torch) 36 | 37 | 38 | cs = ConfigStore.instance() 39 | cs.store(name="mnistconf", node=MNISTConf) 40 | 41 | ###### / HYDRA BLOCK ###### 42 | 43 | 44 | class Net(nn.Module): 45 | def __init__(self): 46 | super(Net, self).__init__() 47 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 48 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 49 | self.dropout1 = nn.Dropout2d(0.25) 50 | self.dropout2 = nn.Dropout2d(0.5) 51 | self.fc1 = nn.Linear(9216, 128) 52 | self.fc2 = nn.Linear(128, 10) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x = F.relu(x) 57 | x = self.conv2(x) 58 | x = F.relu(x) 59 | x = F.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = F.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | output = F.log_softmax(x, dim=1) 67 | return output 68 | 69 | 70 | def train(args, model, device, train_loader, optimizer, epoch): 71 | model.train() 72 | for batch_idx, (data, target) in enumerate(train_loader): 73 | data, target = data.to(device), target.to(device) 74 | optimizer.zero_grad() 75 | output = model(data) 76 | loss = F.nll_loss(output, target) 77 | loss.backward() 78 | optimizer.step() 79 | if batch_idx % args.log_interval == 0: 80 | print( 81 | "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 82 | epoch, 83 | batch_idx * len(data), 84 | len(train_loader.dataset), 85 | 100.0 * batch_idx / len(train_loader), 86 | loss.item(), 87 | ) 88 | ) 89 | if args.dry_run: 90 | break 91 | 92 | 93 | def test(model, device, test_loader): 94 | model.eval() 95 | test_loss = 0 96 | correct = 0 97 | with torch.no_grad(): 98 | for data, target in test_loader: 99 | data, target = data.to(device), target.to(device) 100 | output = model(data) 101 | test_loss += F.nll_loss( 102 | output, target, reduction="sum" 103 | ).item() # sum up batch loss 104 | pred = output.argmax( 105 | dim=1, keepdim=True 106 | ) # get the index of the max log-probability 107 | correct += pred.eq(target.view_as(pred)).sum().item() 108 | 109 | test_loss /= len(test_loader.dataset) 110 | 111 | print( 112 | "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( 113 | test_loss, 114 | correct, 115 | len(test_loader.dataset), 116 | 100.0 * correct / len(test_loader.dataset), 117 | ) 118 | ) 119 | 120 | 121 | @hydra.main(config_name="mnistconf") 122 | def main(cfg): # DIFF 123 | print(cfg.pretty()) 124 | use_cuda = not cfg.no_cuda and torch.cuda.is_available() # DIFF 125 | torch.manual_seed(cfg.seed) # DIFF 126 | device = torch.device("cuda" if use_cuda else "cpu") 127 | 128 | train_kwargs = {"batch_size": cfg.batch_size} # DIFF 129 | test_kwargs = {"batch_size": cfg.test_batch_size} # DIFF 130 | if use_cuda: 131 | cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} 132 | train_kwargs.update(cuda_kwargs) 133 | test_kwargs.update(cuda_kwargs) 134 | 135 | transform = transforms.Compose( 136 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 137 | ) 138 | dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) 139 | dataset2 = datasets.MNIST("../data", train=False, transform=transform) 140 | train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) 141 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 142 | 143 | model = Net().to(device) 144 | 145 | optimizer = Adadelta( 146 | lr=cfg.adadelta.lr, 147 | rho=cfg.adadelta.rho, 148 | eps=cfg.adadelta.eps, 149 | weight_decay=cfg.adadelta.weight_decay, 150 | params=model.parameters(), 151 | ) # DIFF 152 | scheduler = StepLR( 153 | step_size=cfg.steplr.step_size, 154 | gamma=cfg.steplr.gamma, 155 | last_epoch=cfg.steplr.last_epoch, 156 | optimizer=optimizer, 157 | ) # DIFF 158 | 159 | for epoch in range(1, cfg.epochs + 1): # DIFF 160 | train(cfg, model, device, train_loader, optimizer, epoch) # DIFF 161 | test(model, device, test_loader) 162 | scheduler.step() 163 | 164 | if cfg.save_model: # DIFF 165 | torch.save(model.state_dict(), cfg.checkpoint_name) # DIFF 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /hydra-configs-projects.txt: -------------------------------------------------------------------------------- 1 | hydra-configs-torch 2 | hydra-configs-torchvision 3 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/nn/modules/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class BCELossConf: 16 | _target_: str = "torch.nn.modules.loss.BCELoss" 17 | weight: Any = MISSING # Optional[Tensor] 18 | size_average: Any = None 19 | reduce: Any = None 20 | reduction: str = "mean" 21 | 22 | 23 | @dataclass 24 | class BCEWithLogitsLossConf: 25 | _target_: str = "torch.nn.modules.loss.BCEWithLogitsLoss" 26 | weight: Any = MISSING # Optional[Tensor] 27 | size_average: Any = None 28 | reduce: Any = None 29 | reduction: str = "mean" 30 | pos_weight: Any = MISSING # Optional[Tensor] 31 | 32 | 33 | @dataclass 34 | class CosineEmbeddingLossConf: 35 | _target_: str = "torch.nn.modules.loss.CosineEmbeddingLoss" 36 | margin: float = 0.0 37 | size_average: Any = None 38 | reduce: Any = None 39 | reduction: str = "mean" 40 | 41 | 42 | @dataclass 43 | class CTCLossConf: 44 | _target_: str = "torch.nn.modules.loss.CTCLoss" 45 | blank: int = 0 46 | reduction: str = "mean" 47 | zero_infinity: bool = False 48 | 49 | 50 | @dataclass 51 | class L1LossConf: 52 | _target_: str = "torch.nn.modules.loss.L1Loss" 53 | size_average: Any = None 54 | reduce: Any = None 55 | reduction: str = "mean" 56 | 57 | 58 | @dataclass 59 | class HingeEmbeddingLossConf: 60 | _target_: str = "torch.nn.modules.loss.HingeEmbeddingLoss" 61 | margin: float = 1.0 62 | size_average: Any = None 63 | reduce: Any = None 64 | reduction: str = "mean" 65 | 66 | 67 | @dataclass 68 | class KLDivLossConf: 69 | _target_: str = "torch.nn.modules.loss.KLDivLoss" 70 | size_average: Any = None 71 | reduce: Any = None 72 | reduction: str = "mean" 73 | log_target: bool = False 74 | 75 | 76 | @dataclass 77 | class MarginRankingLossConf: 78 | _target_: str = "torch.nn.modules.loss.MarginRankingLoss" 79 | margin: float = 0.0 80 | size_average: Any = None 81 | reduce: Any = None 82 | reduction: str = "mean" 83 | 84 | 85 | @dataclass 86 | class MSELossConf: 87 | _target_: str = "torch.nn.modules.loss.MSELoss" 88 | size_average: Any = None 89 | reduce: Any = None 90 | reduction: str = "mean" 91 | 92 | 93 | @dataclass 94 | class MultiLabelMarginLossConf: 95 | _target_: str = "torch.nn.modules.loss.MultiLabelMarginLoss" 96 | size_average: Any = None 97 | reduce: Any = None 98 | reduction: str = "mean" 99 | 100 | 101 | @dataclass 102 | class MultiLabelSoftMarginLossConf: 103 | _target_: str = "torch.nn.modules.loss.MultiLabelSoftMarginLoss" 104 | weight: Any = MISSING # Optional[Tensor] 105 | size_average: Any = None 106 | reduce: Any = None 107 | reduction: str = "mean" 108 | 109 | 110 | @dataclass 111 | class MultiMarginLossConf: 112 | _target_: str = "torch.nn.modules.loss.MultiMarginLoss" 113 | p: int = 1 114 | margin: float = 1.0 115 | weight: Any = MISSING # Optional[Tensor] 116 | size_average: Any = None 117 | reduce: Any = None 118 | reduction: str = "mean" 119 | 120 | 121 | @dataclass 122 | class NLLLossConf: 123 | _target_: str = "torch.nn.modules.loss.NLLLoss" 124 | weight: Any = MISSING # Optional[Tensor] 125 | size_average: Any = None 126 | ignore_index: int = -100 127 | reduce: Any = None 128 | reduction: str = "mean" 129 | 130 | 131 | @dataclass 132 | class NLLLoss2dConf: 133 | _target_: str = "torch.nn.modules.loss.NLLLoss2d" 134 | weight: Any = MISSING # Optional[Tensor] 135 | size_average: Any = None 136 | ignore_index: int = -100 137 | reduce: Any = None 138 | reduction: str = "mean" 139 | 140 | 141 | @dataclass 142 | class PoissonNLLLossConf: 143 | _target_: str = "torch.nn.modules.loss.PoissonNLLLoss" 144 | log_input: bool = True 145 | full: bool = False 146 | size_average: Any = None 147 | eps: float = 1e-08 148 | reduce: Any = None 149 | reduction: str = "mean" 150 | 151 | 152 | @dataclass 153 | class SmoothL1LossConf: 154 | _target_: str = "torch.nn.modules.loss.SmoothL1Loss" 155 | size_average: Any = None 156 | reduce: Any = None 157 | reduction: str = "mean" 158 | 159 | 160 | @dataclass 161 | class SoftMarginLossConf: 162 | _target_: str = "torch.nn.modules.loss.SoftMarginLoss" 163 | size_average: Any = None 164 | reduce: Any = None 165 | reduction: str = "mean" 166 | 167 | 168 | @dataclass 169 | class TripletMarginLossConf: 170 | _target_: str = "torch.nn.modules.loss.TripletMarginLoss" 171 | margin: float = 1.0 172 | p: float = 2.0 173 | eps: float = 1e-06 174 | swap: bool = False 175 | size_average: Any = None 176 | reduce: Any = None 177 | reduction: str = "mean" 178 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # Mirrors torch/optim __init__ to allow for symmetric import structure 3 | from .adadelta import AdadeltaConf 4 | from .adagrad import AdagradConf 5 | from .adam import AdamConf 6 | from .adamw import AdamWConf 7 | from .sparse_adam import SparseAdamConf 8 | from .adamax import AdamaxConf 9 | from .asgd import ASGDConf 10 | from .sgd import SGDConf 11 | from .rprop import RpropConf 12 | from .rmsprop import RMSpropConf 13 | 14 | from .lbfgs import LBFGSConf 15 | from . import lr_scheduler 16 | 17 | 18 | del adadelta 19 | del adagrad 20 | del adam 21 | del adamw 22 | del sparse_adam 23 | del adamax 24 | del asgd 25 | del sgd 26 | del rprop 27 | del rmsprop 28 | del lbfgs 29 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/adadelta.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class AdadeltaConf: 16 | _target_: str = "torch.optim.adadelta.Adadelta" 17 | params: Any = MISSING 18 | lr: Any = 1.0 19 | rho: Any = 0.9 20 | eps: Any = 1e-06 21 | weight_decay: Any = 0 22 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class AdagradConf: 16 | _target_: str = "torch.optim.adagrad.Adagrad" 17 | params: Any = MISSING 18 | lr: Any = 0.01 19 | lr_decay: Any = 0 20 | weight_decay: Any = 0 21 | initial_accumulator_value: Any = 0 22 | eps: Any = 1e-10 23 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/adam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class AdamConf: 16 | _target_: str = "torch.optim.adam.Adam" 17 | params: Any = MISSING 18 | lr: Any = 0.001 19 | betas: Any = (0.9, 0.999) 20 | eps: Any = 1e-08 21 | weight_decay: Any = 0 22 | amsgrad: Any = False 23 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/adamax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class AdamaxConf: 16 | _target_: str = "torch.optim.adamax.Adamax" 17 | params: Any = MISSING 18 | lr: Any = 0.002 19 | betas: Any = (0.9, 0.999) 20 | eps: Any = 1e-08 21 | weight_decay: Any = 0 22 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/adamw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class AdamWConf: 16 | _target_: str = "torch.optim.adamw.AdamW" 17 | params: Any = MISSING 18 | lr: Any = 0.001 19 | betas: Any = (0.9, 0.999) 20 | eps: Any = 1e-08 21 | weight_decay: Any = 0.01 22 | amsgrad: Any = False 23 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/asgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class ASGDConf: 16 | _target_: str = "torch.optim.asgd.ASGD" 17 | params: Any = MISSING 18 | lr: Any = 0.01 19 | lambd: Any = 0.0001 20 | alpha: Any = 0.75 21 | t0: Any = 1000000.0 22 | weight_decay: Any = 0 23 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/lbfgs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class LBFGSConf: 16 | _target_: str = "torch.optim.lbfgs.LBFGS" 17 | params: Any = MISSING 18 | lr: Any = 1 19 | max_iter: Any = 20 20 | max_eval: Any = None 21 | tolerance_grad: Any = 1e-07 22 | tolerance_change: Any = 1e-09 23 | history_size: Any = 100 24 | line_search_fn: Any = None 25 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class LambdaLRConf: 16 | _target_: str = "torch.optim.lr_scheduler.LambdaLR" 17 | optimizer: Any = MISSING 18 | lr_lambda: Any = MISSING 19 | last_epoch: Any = -1 20 | 21 | 22 | @dataclass 23 | class MultiplicativeLRConf: 24 | _target_: str = "torch.optim.lr_scheduler.MultiplicativeLR" 25 | optimizer: Any = MISSING 26 | lr_lambda: Any = MISSING 27 | last_epoch: Any = -1 28 | 29 | 30 | @dataclass 31 | class StepLRConf: 32 | _target_: str = "torch.optim.lr_scheduler.StepLR" 33 | optimizer: Any = MISSING 34 | step_size: Any = MISSING 35 | gamma: Any = 0.1 36 | last_epoch: Any = -1 37 | 38 | 39 | @dataclass 40 | class MultiStepLRConf: 41 | _target_: str = "torch.optim.lr_scheduler.MultiStepLR" 42 | optimizer: Any = MISSING 43 | milestones: Any = MISSING 44 | gamma: Any = 0.1 45 | last_epoch: Any = -1 46 | 47 | 48 | @dataclass 49 | class ExponentialLRConf: 50 | _target_: str = "torch.optim.lr_scheduler.ExponentialLR" 51 | optimizer: Any = MISSING 52 | gamma: Any = MISSING 53 | last_epoch: Any = -1 54 | 55 | 56 | @dataclass 57 | class CosineAnnealingLRConf: 58 | _target_: str = "torch.optim.lr_scheduler.CosineAnnealingLR" 59 | optimizer: Any = MISSING 60 | T_max: Any = MISSING 61 | eta_min: Any = 0 62 | last_epoch: Any = -1 63 | 64 | 65 | @dataclass 66 | class ReduceLROnPlateauConf: 67 | _target_: str = "torch.optim.lr_scheduler.ReduceLROnPlateau" 68 | optimizer: Any = MISSING 69 | mode: Any = "min" 70 | factor: Any = 0.1 71 | patience: Any = 10 72 | verbose: Any = False 73 | threshold: Any = 0.0001 74 | threshold_mode: Any = "rel" 75 | cooldown: Any = 0 76 | min_lr: Any = 0 77 | eps: Any = 1e-08 78 | 79 | 80 | @dataclass 81 | class CyclicLRConf: 82 | _target_: str = "torch.optim.lr_scheduler.CyclicLR" 83 | optimizer: Any = MISSING 84 | base_lr: Any = MISSING 85 | max_lr: Any = MISSING 86 | step_size_up: Any = 2000 87 | step_size_down: Any = None 88 | mode: Any = "triangular" 89 | gamma: Any = 1.0 90 | scale_fn: Any = None 91 | scale_mode: Any = "cycle" 92 | cycle_momentum: Any = True 93 | base_momentum: Any = 0.8 94 | max_momentum: Any = 0.9 95 | last_epoch: Any = -1 96 | 97 | 98 | @dataclass 99 | class CosineAnnealingWarmRestartsConf: 100 | _target_: str = "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts" 101 | optimizer: Any = MISSING 102 | T_0: Any = MISSING 103 | T_mult: Any = 1 104 | eta_min: Any = 0 105 | last_epoch: Any = -1 106 | 107 | 108 | @dataclass 109 | class OneCycleLRConf: 110 | _target_: str = "torch.optim.lr_scheduler.OneCycleLR" 111 | optimizer: Any = MISSING 112 | max_lr: Any = MISSING 113 | total_steps: Any = None 114 | epochs: Any = None 115 | steps_per_epoch: Any = None 116 | pct_start: Any = 0.3 117 | anneal_strategy: Any = "cos" 118 | cycle_momentum: Any = True 119 | base_momentum: Any = 0.85 120 | max_momentum: Any = 0.95 121 | div_factor: Any = 25.0 122 | final_div_factor: Any = 10000.0 123 | last_epoch: Any = -1 124 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/rmsprop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class RMSpropConf: 16 | _target_: str = "torch.optim.rmsprop.RMSprop" 17 | params: Any = MISSING 18 | lr: Any = 0.01 19 | alpha: Any = 0.99 20 | eps: Any = 1e-08 21 | weight_decay: Any = 0 22 | momentum: Any = 0 23 | centered: Any = False 24 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/rprop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class RpropConf: 16 | _target_: str = "torch.optim.rprop.Rprop" 17 | params: Any = MISSING 18 | lr: Any = 0.01 19 | etas: Any = (0.5, 1.2) 20 | step_sizes: Any = (1e-06, 50) 21 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class SGDConf: 16 | _target_: str = "torch.optim.sgd.SGD" 17 | params: Any = MISSING 18 | lr: Any = MISSING # _RequiredParameter 19 | momentum: Any = 0 20 | dampening: Any = 0 21 | weight_decay: Any = 0 22 | nesterov: Any = False 23 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/optim/sparse_adam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class SparseAdamConf: 16 | _target_: str = "torch.optim.sparse_adam.SparseAdam" 17 | params: Any = MISSING 18 | lr: Any = 0.001 19 | betas: Any = (0.9, 0.999) 20 | eps: Any = 1e-08 21 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/utils/data/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class DataLoaderConf: 16 | _target_: str = "torch.utils.data.dataloader.DataLoader" 17 | dataset: Any = MISSING 18 | batch_size: Any = 1 19 | shuffle: Any = False 20 | sampler: Any = None 21 | batch_sampler: Any = None 22 | num_workers: Any = 0 23 | collate_fn: Any = None 24 | pin_memory: Any = False 25 | drop_last: Any = False 26 | timeout: Any = 0 27 | worker_init_fn: Any = None 28 | multiprocessing_context: Any = None 29 | generator: Any = None 30 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class DatasetConf: 16 | _target_: str = "torch.utils.data.dataset.Dataset" 17 | 18 | 19 | @dataclass 20 | class ChainDatasetConf: 21 | _target_: str = "torch.utils.data.dataset.ChainDataset" 22 | datasets: Any = MISSING 23 | 24 | 25 | @dataclass 26 | class ConcatDatasetConf: 27 | _target_: str = "torch.utils.data.dataset.ConcatDataset" 28 | datasets: Any = MISSING 29 | 30 | 31 | @dataclass 32 | class IterableDatasetConf: 33 | _target_: str = "torch.utils.data.dataset.IterableDataset" 34 | 35 | 36 | @dataclass 37 | class TensorDatasetConf: 38 | _target_: str = "torch.utils.data.dataset.TensorDataset" 39 | tensors: Any = MISSING 40 | 41 | 42 | @dataclass 43 | class SubsetConf: 44 | _target_: str = "torch.utils.data.dataset.Subset" 45 | dataset: Any = MISSING 46 | indices: Any = MISSING 47 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class DistributedSamplerConf: 16 | _target_: str = "torch.utils.data.distributed.DistributedSampler" 17 | dataset: Any = MISSING 18 | num_replicas: Any = None 19 | rank: Any = None 20 | shuffle: Any = True 21 | seed: Any = 0 22 | -------------------------------------------------------------------------------- /hydra-configs-torch/hydra_configs/torch/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class SamplerConf: 16 | _target_: str = "torch.utils.data.sampler.Sampler" 17 | data_source: Any = MISSING 18 | 19 | 20 | @dataclass 21 | class BatchSamplerConf: 22 | _target_: str = "torch.utils.data.sampler.BatchSampler" 23 | sampler: Any = MISSING 24 | batch_size: Any = MISSING 25 | drop_last: Any = MISSING 26 | 27 | 28 | @dataclass 29 | class RandomSamplerConf: 30 | _target_: str = "torch.utils.data.sampler.RandomSampler" 31 | data_source: Any = MISSING 32 | replacement: Any = False 33 | num_samples: Any = None 34 | generator: Any = None 35 | 36 | 37 | @dataclass 38 | class SequentialSamplerConf: 39 | _target_: str = "torch.utils.data.sampler.SequentialSampler" 40 | data_source: Any = MISSING 41 | 42 | 43 | @dataclass 44 | class SubsetRandomSamplerConf: 45 | _target_: str = "torch.utils.data.sampler.SubsetRandomSampler" 46 | indices: Any = MISSING 47 | generator: Any = None 48 | 49 | 50 | @dataclass 51 | class WeightedRandomSamplerConf: 52 | _target_: str = "torch.utils.data.sampler.WeightedRandomSampler" 53 | weights: Any = MISSING 54 | num_samples: Any = MISSING 55 | replacement: Any = True 56 | generator: Any = None 57 | -------------------------------------------------------------------------------- /hydra-configs-torch/requirements/dev.txt: -------------------------------------------------------------------------------- 1 | torch==1.6 2 | -------------------------------------------------------------------------------- /hydra-configs-torch/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from setuptools import find_namespace_packages, setup 3 | 4 | requirements = [ 5 | "omegaconf", 6 | ] 7 | 8 | setup( 9 | name="hydra-configs-torch", 10 | version="1.6.1", 11 | packages=find_namespace_packages(include=["hydra_configs*"]), 12 | author=["Omry Yadan", "Rosario Scalise"], 13 | author_email=["omry@fb.com", "rosario@cs.uw.edu"], 14 | url="http://github.com/pytorch/hydra-torch", 15 | include_package_data=True, 16 | install_requires=requirements, 17 | ) 18 | -------------------------------------------------------------------------------- /hydra-configs-torch/tests/test_instantiate_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import pytest 3 | from hydra.utils import get_class, instantiate 4 | from omegaconf import OmegaConf 5 | 6 | import torch.utils.data as data 7 | 8 | import torch 9 | from typing import Any 10 | 11 | dummy_tensor = torch.tensor((1, 1)) 12 | dummy_dataset = data.dataset.TensorDataset(dummy_tensor) 13 | dummy_sampler = data.Sampler(data_source=dummy_dataset) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "modulepath, classname, cfg, passthrough_args, passthrough_kwargs, expected", 18 | [ 19 | pytest.param( 20 | "utils.data.dataloader", 21 | "DataLoader", 22 | {"batch_size": 4}, 23 | [], 24 | {"dataset": dummy_dataset}, 25 | data.DataLoader(batch_size=4, dataset=dummy_dataset), 26 | id="DataLoaderConf", 27 | ), 28 | pytest.param( 29 | "utils.data.dataset", 30 | "Dataset", 31 | {}, 32 | [], 33 | {}, 34 | data.Dataset(), 35 | id="DatasetConf", 36 | ), 37 | pytest.param( 38 | "utils.data.dataset", 39 | "ChainDataset", 40 | {}, 41 | [], 42 | {"datasets": [dummy_dataset, dummy_dataset]}, 43 | data.ChainDataset(datasets=[dummy_dataset, dummy_dataset]), 44 | id="ChainDatasetConf", 45 | ), 46 | pytest.param( 47 | "utils.data.dataset", 48 | "ConcatDataset", 49 | {}, 50 | [], 51 | {"datasets": [dummy_dataset, dummy_dataset]}, 52 | data.ConcatDataset(datasets=[dummy_dataset, dummy_dataset]), 53 | id="ConcatDatasetConf", 54 | ), 55 | pytest.param( 56 | "utils.data.dataset", 57 | "IterableDataset", 58 | {}, 59 | [], 60 | {}, 61 | data.IterableDataset(), 62 | id="IterableDatasetConf", 63 | ), 64 | # TODO: investigate asterisk in signature instantiation limitation 65 | # pytest.param( 66 | # "utils.data.dataset", 67 | # "TensorDataset", 68 | # {}, 69 | # [], 70 | # {"tensors":[dummy_tensor]}, 71 | # data.TensorDataset(dummy_tensor), 72 | # id="TensorDatasetConf", 73 | # ), 74 | pytest.param( 75 | "utils.data.dataset", 76 | "Subset", 77 | {}, 78 | [], 79 | {"dataset": dummy_dataset, "indices": [0]}, 80 | data.Subset(dummy_dataset, 0), 81 | id="SubsetConf", 82 | ), 83 | pytest.param( 84 | "utils.data.sampler", 85 | "Sampler", 86 | {}, 87 | [], 88 | {"data_source": dummy_dataset}, 89 | data.Sampler(data_source=dummy_dataset), 90 | id="SamplerConf", 91 | ), 92 | pytest.param( 93 | "utils.data.sampler", 94 | "BatchSampler", 95 | {"batch_size": 4, "drop_last": False}, 96 | [], 97 | {"sampler": dummy_sampler}, 98 | data.BatchSampler(sampler=dummy_sampler, batch_size=4, drop_last=False), 99 | id="BatchSamplerConf", 100 | ), 101 | pytest.param( 102 | "utils.data.sampler", 103 | "RandomSampler", 104 | {}, 105 | [], 106 | {"data_source": dummy_dataset}, 107 | data.RandomSampler(data_source=dummy_dataset), 108 | id="RandomSamplerConf", 109 | ), 110 | pytest.param( 111 | "utils.data.sampler", 112 | "SequentialSampler", 113 | {}, 114 | [], 115 | {"data_source": dummy_dataset}, 116 | data.SequentialSampler(data_source=dummy_dataset), 117 | id="SequentialSamplerConf", 118 | ), 119 | pytest.param( 120 | "utils.data.sampler", 121 | "SubsetRandomSampler", 122 | {"indices": [1]}, 123 | [], 124 | {}, 125 | data.SubsetRandomSampler(indices=[1]), 126 | id="SubsetRandomSamplerConf", 127 | ), 128 | pytest.param( 129 | "utils.data.sampler", 130 | "WeightedRandomSampler", 131 | {"weights": [1], "num_samples": 1}, 132 | [], 133 | {}, 134 | data.WeightedRandomSampler(weights=[1], num_samples=1), 135 | id="WeightedRandomSamplerConf", 136 | ), 137 | # TODO: investigate testing distributed instantiation 138 | # pytest.param( 139 | # "utils.data.distributed", 140 | # "DistributedSampler", 141 | # {}, 142 | # [], 143 | # {"dataset": dummy_dataset}, 144 | # data.DistributedSampler(group=dummy_group,dataset=dummy_dataset), 145 | # id="DistributedSamplerConf", 146 | # ), 147 | ], 148 | ) 149 | def test_instantiate_classes( 150 | modulepath: str, 151 | classname: str, 152 | cfg: Any, 153 | passthrough_args: Any, 154 | passthrough_kwargs: Any, 155 | expected: Any, 156 | ) -> None: 157 | full_class = f"hydra_configs.torch.{modulepath}.{classname}Conf" 158 | schema = OmegaConf.structured(get_class(full_class)) 159 | cfg = OmegaConf.merge(schema, cfg) 160 | obj = instantiate(cfg, *passthrough_args, **passthrough_kwargs) 161 | 162 | assert isinstance(obj, type(expected)) 163 | -------------------------------------------------------------------------------- /hydra-configs-torch/tests/test_instantiate_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import pytest 3 | from hydra.utils import get_class, instantiate 4 | from omegaconf import OmegaConf 5 | 6 | import torch.nn.modules.loss as loss 7 | 8 | from torch.tensor import Tensor 9 | from typing import Any 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "modulepath, classname, cfg, passthrough_args, passthrough_kwargs, expected", 14 | [ 15 | pytest.param( 16 | "nn.modules.loss", 17 | "BCELoss", 18 | {}, 19 | [], 20 | {"weight": Tensor([1])}, 21 | loss.BCELoss(), 22 | id="BCELossConf", 23 | ), 24 | pytest.param( 25 | "nn.modules.loss", 26 | "BCEWithLogitsLoss", 27 | {}, 28 | [], 29 | {"weight": Tensor([1]), "pos_weight": Tensor([1])}, 30 | loss.BCEWithLogitsLoss(), 31 | id="BCEWithLogitsLossConf", 32 | ), 33 | pytest.param( 34 | "nn.modules.loss", 35 | "CosineEmbeddingLoss", 36 | {}, 37 | [], 38 | {}, 39 | loss.CosineEmbeddingLoss(), 40 | id="CosineEmbeddingLossConf", 41 | ), 42 | pytest.param( 43 | "nn.modules.loss", 44 | "CTCLoss", 45 | {}, 46 | [], 47 | {}, 48 | loss.CTCLoss(), 49 | id="CTCLossConf", 50 | ), 51 | pytest.param( 52 | "nn.modules.loss", 53 | "L1Loss", 54 | {}, 55 | [], 56 | {}, 57 | loss.L1Loss(), 58 | id="L1LossConf", 59 | ), 60 | pytest.param( 61 | "nn.modules.loss", 62 | "HingeEmbeddingLoss", 63 | {}, 64 | [], 65 | {}, 66 | loss.HingeEmbeddingLoss(), 67 | id="HingeEmbeddingLossConf", 68 | ), 69 | pytest.param( 70 | "nn.modules.loss", 71 | "KLDivLoss", 72 | {}, 73 | [], 74 | {}, 75 | loss.KLDivLoss(), 76 | id="KLDivLossConf", 77 | ), 78 | pytest.param( 79 | "nn.modules.loss", 80 | "MarginRankingLoss", 81 | {}, 82 | [], 83 | {}, 84 | loss.MarginRankingLoss(), 85 | id="MarginRankingLossConf", 86 | ), 87 | pytest.param( 88 | "nn.modules.loss", 89 | "MSELoss", 90 | {}, 91 | [], 92 | {}, 93 | loss.MSELoss(), 94 | id="MSELossConf", 95 | ), 96 | pytest.param( 97 | "nn.modules.loss", 98 | "MultiLabelMarginLoss", 99 | {}, 100 | [], 101 | {}, 102 | loss.MultiLabelMarginLoss(), 103 | id="MultiLabelMarginLossConf", 104 | ), 105 | pytest.param( 106 | "nn.modules.loss", 107 | "MultiLabelSoftMarginLoss", 108 | {}, 109 | [], 110 | {"weight": Tensor([1])}, 111 | loss.MultiLabelSoftMarginLoss(), 112 | id="MultiLabelSoftMarginLossConf", 113 | ), 114 | pytest.param( 115 | "nn.modules.loss", 116 | "MultiMarginLoss", 117 | {}, 118 | [], 119 | {"weight": Tensor([1])}, 120 | loss.MultiMarginLoss(), 121 | id="MultiMarginLossConf", 122 | ), 123 | pytest.param( 124 | "nn.modules.loss", 125 | "NLLLoss", 126 | {}, 127 | [], 128 | {"weight": Tensor([1])}, 129 | loss.NLLLoss(), 130 | id="NLLLossConf", 131 | ), 132 | pytest.param( 133 | "nn.modules.loss", 134 | "NLLLoss2d", 135 | {}, 136 | [], 137 | {"weight": Tensor([1])}, 138 | loss.NLLLoss2d(), 139 | id="NLLLoss2dConf", 140 | ), 141 | pytest.param( 142 | "nn.modules.loss", 143 | "PoissonNLLLoss", 144 | {}, 145 | [], 146 | {}, 147 | loss.PoissonNLLLoss(), 148 | id="PoissonNLLLossConf", 149 | ), 150 | pytest.param( 151 | "nn.modules.loss", 152 | "SmoothL1Loss", 153 | {}, 154 | [], 155 | {}, 156 | loss.SmoothL1Loss(), 157 | id="SmoothL1LossConf", 158 | ), 159 | pytest.param( 160 | "nn.modules.loss", 161 | "SoftMarginLoss", 162 | {}, 163 | [], 164 | {}, 165 | loss.SoftMarginLoss(), 166 | id="SoftMarginLossConf", 167 | ), 168 | pytest.param( 169 | "nn.modules.loss", 170 | "TripletMarginLoss", 171 | {}, 172 | [], 173 | {}, 174 | loss.TripletMarginLoss(), 175 | id="TripletMarginLossConf", 176 | ), 177 | ], 178 | ) 179 | def test_instantiate_classes( 180 | modulepath: str, 181 | classname: str, 182 | cfg: Any, 183 | passthrough_args: Any, 184 | passthrough_kwargs: Any, 185 | expected: Any, 186 | ) -> None: 187 | full_class = f"hydra_configs.torch.{modulepath}.{classname}Conf" 188 | schema = OmegaConf.structured(get_class(full_class)) 189 | cfg = OmegaConf.merge(schema, cfg) 190 | obj = instantiate(cfg, *passthrough_args, **passthrough_kwargs) 191 | 192 | assert isinstance(obj, type(expected)) 193 | -------------------------------------------------------------------------------- /hydra-configs-torch/tests/test_instantiate_optimizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import pytest 3 | from hydra.utils import get_class, instantiate 4 | from omegaconf import OmegaConf 5 | 6 | import torch.optim as optim 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from typing import Any 12 | 13 | model = nn.Linear(1, 1) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "modulepath, classname, cfg, passthrough_kwargs, expected", 18 | [ 19 | pytest.param( 20 | "optim.adadelta", 21 | "Adadelta", 22 | {"lr": 0.1}, 23 | {"params": model.parameters()}, 24 | optim.Adadelta(lr=0.1, params=model.parameters()), 25 | id="AdadeltaConf", 26 | ), 27 | pytest.param( 28 | "optim.adagrad", 29 | "Adagrad", 30 | {"lr": 0.1}, 31 | {"params": model.parameters()}, 32 | optim.Adagrad(lr=0.1, params=model.parameters()), 33 | id="AdagradConf", 34 | ), 35 | pytest.param( 36 | "optim.adam", 37 | "Adam", 38 | {"lr": 0.1}, 39 | {"params": model.parameters()}, 40 | optim.Adam(lr=0.1, params=model.parameters()), 41 | id="AdamConf", 42 | ), 43 | pytest.param( 44 | "optim.adamax", 45 | "Adamax", 46 | {"lr": 0.1}, 47 | {"params": model.parameters()}, 48 | optim.Adamax(lr=0.1, params=model.parameters()), 49 | id="AdamaxConf", 50 | ), 51 | pytest.param( 52 | "optim.adamw", 53 | "AdamW", 54 | {"lr": 0.1}, 55 | {"params": model.parameters()}, 56 | optim.AdamW(lr=0.1, params=model.parameters()), 57 | id="AdamWConf", 58 | ), 59 | pytest.param( 60 | "optim.asgd", 61 | "ASGD", 62 | {"lr": 0.1}, 63 | {"params": model.parameters()}, 64 | optim.ASGD(lr=0.1, params=model.parameters()), 65 | id="ASGDConf", 66 | ), 67 | pytest.param( 68 | "optim.lbfgs", 69 | "LBFGS", 70 | {"lr": 0.1}, 71 | {"params": model.parameters()}, 72 | optim.LBFGS(lr=0.1, params=model.parameters()), 73 | id="LBFGSConf", 74 | ), 75 | pytest.param( 76 | "optim.rmsprop", 77 | "RMSprop", 78 | {"lr": 0.1}, 79 | {"params": model.parameters()}, 80 | optim.RMSprop(lr=0.1, params=model.parameters()), 81 | id="RMSpropConf", 82 | ), 83 | pytest.param( 84 | "optim.rprop", 85 | "Rprop", 86 | {"lr": 0.1}, 87 | {"params": model.parameters()}, 88 | optim.Rprop(lr=0.1, params=model.parameters()), 89 | id="RpropConf", 90 | ), 91 | pytest.param( 92 | "optim.sgd", 93 | "SGD", 94 | {"lr": 0.1}, 95 | {"params": model.parameters()}, 96 | optim.SGD(lr=0.1, params=model.parameters()), 97 | id="SGDConf", 98 | ), 99 | pytest.param( 100 | "optim.sparse_adam", 101 | "SparseAdam", 102 | {"lr": 0.1}, 103 | {"params": list(model.parameters())}, 104 | optim.SparseAdam(lr=0.1, params=list(model.parameters())), 105 | id="SparseAdamConf", 106 | ), 107 | ], 108 | ) 109 | def test_instantiate_classes( 110 | modulepath: str, classname: str, cfg: Any, passthrough_kwargs: Any, expected: Any 111 | ) -> None: 112 | full_class = f"hydra_configs.torch.{modulepath}.{classname}Conf" 113 | schema = OmegaConf.structured(get_class(full_class)) 114 | cfg = OmegaConf.merge(schema, cfg) 115 | obj = instantiate(cfg, **passthrough_kwargs) 116 | 117 | def closure(): 118 | return model(Tensor([10])) 119 | 120 | assert torch.all(torch.eq(obj.step(closure), expected.step(closure))) 121 | -------------------------------------------------------------------------------- /hydra-configs-torchvision/hydra_configs/torchvision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from packaging import version 10 | from pkg_resources import get_distribution 11 | import warnings 12 | import torchvision 13 | 14 | CONFIGS_VERSION = get_distribution('hydra-configs-torchvision').version 15 | 16 | # checks if major.minor versions are matched. patch version is always different 17 | if version.parse(torchvision.__version__).release[:2] != version.parse(CONFIGS_VERSION).release[:2]: 18 | warnings.warn(f'Your config and library versions are mismatched. \n HYDRA-CONFIGS-TORCHVISION VERSION: {CONFIGS_VERSION}, \n TORCHVISION VERSION: {torchvision.__version__}. \n Please install the matching configs for reliable functionality.') 19 | -------------------------------------------------------------------------------- /hydra-configs-torchvision/hydra_configs/torchvision/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class MNISTConf: 16 | _target_: str = "torchvision.datasets.mnist.MNIST" 17 | root: Any = MISSING 18 | train: Any = True 19 | transform: Any = None 20 | target_transform: Any = None 21 | download: Any = False 22 | 23 | 24 | @dataclass 25 | class FashionMNISTConf: 26 | _target_: str = "torchvision.datasets.mnist.FashionMNIST" 27 | root: Any = MISSING 28 | train: Any = True 29 | transform: Any = None 30 | target_transform: Any = None 31 | download: Any = False 32 | 33 | 34 | @dataclass 35 | class KMNISTConf: 36 | _target_: str = "torchvision.datasets.mnist.KMNIST" 37 | root: Any = MISSING 38 | train: Any = True 39 | transform: Any = None 40 | target_transform: Any = None 41 | download: Any = False 42 | -------------------------------------------------------------------------------- /hydra-configs-torchvision/hydra_configs/torchvision/datasets/vision.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class VisionDatasetConf: 16 | _target_: str = "torchvision.datasets.vision.VisionDataset" 17 | root: Any = MISSING 18 | transforms: Any = None 19 | transform: Any = None 20 | target_transform: Any = None 21 | 22 | 23 | @dataclass 24 | class StandardTransformConf: 25 | _target_: str = "torchvision.datasets.vision.StandardTransform" 26 | transform: Any = None 27 | target_transform: Any = None 28 | -------------------------------------------------------------------------------- /hydra-configs-torchvision/hydra_configs/torchvision/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # 3 | # Generated by configen, do not edit. 4 | # See https://github.com/facebookresearch/hydra/tree/main/tools/configen 5 | # fmt: off 6 | # isort:skip_file 7 | # flake8: noqa 8 | 9 | from dataclasses import dataclass, field 10 | from omegaconf import MISSING 11 | from typing import Any 12 | 13 | 14 | @dataclass 15 | class CenterCropConf: 16 | _target_: str = "torchvision.transforms.transforms.CenterCrop" 17 | _convert_: str = "ALL" 18 | size: Any = MISSING 19 | 20 | 21 | @dataclass 22 | class ColorJitterConf: 23 | _target_: str = "torchvision.transforms.transforms.ColorJitter" 24 | _convert_: str = "ALL" 25 | brightness: Any = 0 26 | contrast: Any = 0 27 | saturation: Any = 0 28 | hue: Any = 0 29 | 30 | 31 | @dataclass 32 | class ComposeConf: 33 | _target_: str = "torchvision.transforms.transforms.Compose" 34 | _convert_: str = "ALL" 35 | transforms: Any = MISSING 36 | 37 | 38 | @dataclass 39 | class ConvertImageDtypeConf: 40 | _target_: str = "torchvision.transforms.transforms.ConvertImageDtype" 41 | _convert_: str = "ALL" 42 | dtype: Any = MISSING # dtype 43 | 44 | 45 | @dataclass 46 | class FiveCropConf: 47 | _target_: str = "torchvision.transforms.transforms.FiveCrop" 48 | _convert_: str = "ALL" 49 | size: Any = MISSING 50 | 51 | 52 | @dataclass 53 | class GrayscaleConf: 54 | _target_: str = "torchvision.transforms.transforms.Grayscale" 55 | _convert_: str = "ALL" 56 | num_output_channels: Any = 1 57 | 58 | 59 | @dataclass 60 | class LambdaConf: 61 | _target_: str = "torchvision.transforms.transforms.Lambda" 62 | _convert_: str = "ALL" 63 | lambd: Any = MISSING 64 | 65 | 66 | @dataclass 67 | class LinearTransformationConf: 68 | _target_: str = "torchvision.transforms.transforms.LinearTransformation" 69 | _convert_: str = "ALL" 70 | transformation_matrix: Any = MISSING 71 | mean_vector: Any = MISSING 72 | 73 | 74 | @dataclass 75 | class NormalizeConf: 76 | _target_: str = "torchvision.transforms.transforms.Normalize" 77 | _convert_: str = "ALL" 78 | mean: Any = MISSING 79 | std: Any = MISSING 80 | inplace: Any = False 81 | 82 | 83 | @dataclass 84 | class PadConf: 85 | _target_: str = "torchvision.transforms.transforms.Pad" 86 | _convert_: str = "ALL" 87 | padding: Any = MISSING 88 | fill: Any = 0 89 | padding_mode: Any = "constant" 90 | 91 | 92 | @dataclass 93 | class PILToTensorConf: 94 | _target_: str = "torchvision.transforms.transforms.PILToTensor" 95 | _convert_: str = "ALL" 96 | 97 | 98 | @dataclass 99 | class RandomAffineConf: 100 | _target_: str = "torchvision.transforms.transforms.RandomAffine" 101 | _convert_: str = "ALL" 102 | degrees: Any = MISSING 103 | translate: Any = None 104 | scale: Any = None 105 | shear: Any = None 106 | resample: Any = 0 107 | fillcolor: Any = 0 108 | 109 | 110 | @dataclass 111 | class RandomApplyConf: 112 | _target_: str = "torchvision.transforms.transforms.RandomApply" 113 | _convert_: str = "ALL" 114 | transforms: Any = MISSING 115 | p: Any = 0.5 116 | 117 | 118 | @dataclass 119 | class RandomChoiceConf: 120 | _target_: str = "torchvision.transforms.transforms.RandomChoice" 121 | _convert_: str = "ALL" 122 | transforms: Any = MISSING 123 | 124 | 125 | @dataclass 126 | class RandomCropConf: 127 | _target_: str = "torchvision.transforms.transforms.RandomCrop" 128 | _convert_: str = "ALL" 129 | size: Any = MISSING 130 | padding: Any = None 131 | pad_if_needed: Any = False 132 | fill: Any = 0 133 | padding_mode: Any = "constant" 134 | 135 | 136 | @dataclass 137 | class RandomErasingConf: 138 | _target_: str = "torchvision.transforms.transforms.RandomErasing" 139 | _convert_: str = "ALL" 140 | p: Any = 0.5 141 | scale: Any = (0.02, 0.33) 142 | ratio: Any = (0.3, 3.3) 143 | value: Any = 0 144 | inplace: Any = False 145 | 146 | 147 | @dataclass 148 | class RandomGrayscaleConf: 149 | _target_: str = "torchvision.transforms.transforms.RandomGrayscale" 150 | _convert_: str = "ALL" 151 | p: Any = 0.1 152 | 153 | 154 | @dataclass 155 | class RandomHorizontalFlipConf: 156 | _target_: str = "torchvision.transforms.transforms.RandomHorizontalFlip" 157 | _convert_: str = "ALL" 158 | p: Any = 0.5 159 | 160 | 161 | @dataclass 162 | class RandomOrderConf: 163 | _target_: str = "torchvision.transforms.transforms.RandomOrder" 164 | _convert_: str = "ALL" 165 | transforms: Any = MISSING 166 | 167 | 168 | @dataclass 169 | class RandomPerspectiveConf: 170 | _target_: str = "torchvision.transforms.transforms.RandomPerspective" 171 | _convert_: str = "ALL" 172 | distortion_scale: Any = 0.5 173 | p: Any = 0.5 174 | interpolation: Any = 2 175 | fill: Any = 0 176 | 177 | 178 | @dataclass 179 | class RandomResizedCropConf: 180 | _target_: str = "torchvision.transforms.transforms.RandomResizedCrop" 181 | _convert_: str = "ALL" 182 | size: Any = MISSING 183 | scale: Any = (0.08, 1.0) 184 | ratio: Any = (0.75, 1.3333333333333333) 185 | interpolation: Any = 2 186 | 187 | 188 | @dataclass 189 | class RandomRotationConf: 190 | _target_: str = "torchvision.transforms.transforms.RandomRotation" 191 | _convert_: str = "ALL" 192 | degrees: Any = MISSING 193 | resample: Any = False 194 | expand: Any = False 195 | center: Any = None 196 | fill: Any = None 197 | 198 | 199 | @dataclass 200 | class RandomTransformsConf: 201 | _target_: str = "torchvision.transforms.transforms.RandomTransforms" 202 | _convert_: str = "ALL" 203 | transforms: Any = MISSING 204 | 205 | 206 | @dataclass 207 | class RandomVerticalFlipConf: 208 | _target_: str = "torchvision.transforms.transforms.RandomVerticalFlip" 209 | _convert_: str = "ALL" 210 | p: Any = 0.5 211 | 212 | 213 | @dataclass 214 | class ResizeConf: 215 | _target_: str = "torchvision.transforms.transforms.Resize" 216 | _convert_: str = "ALL" 217 | size: Any = MISSING 218 | interpolation: Any = 2 219 | 220 | 221 | @dataclass 222 | class TenCropConf: 223 | _target_: str = "torchvision.transforms.transforms.TenCrop" 224 | _convert_: str = "ALL" 225 | size: Any = MISSING 226 | vertical_flip: Any = False 227 | 228 | 229 | @dataclass 230 | class ToPILImageConf: 231 | _target_: str = "torchvision.transforms.transforms.ToPILImage" 232 | _convert_: str = "ALL" 233 | mode: Any = None 234 | 235 | 236 | @dataclass 237 | class ToTensorConf: 238 | _target_: str = "torchvision.transforms.transforms.ToTensor" 239 | _convert_: str = "ALL" 240 | -------------------------------------------------------------------------------- /hydra-configs-torchvision/requirements/dev.txt: -------------------------------------------------------------------------------- 1 | torch==1.6 2 | torchvision==0.7 3 | -------------------------------------------------------------------------------- /hydra-configs-torchvision/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from setuptools import find_namespace_packages, setup 3 | 4 | requirements = [ 5 | "omegaconf", 6 | ] 7 | 8 | setup( 9 | name="hydra-configs-torchvision", 10 | version="0.8.2", 11 | packages=find_namespace_packages(include=["hydra_configs*"]), 12 | author=["Omry Yadan", "Rosario Scalise"], 13 | author_email=["omry@fb.com", "rosario@cs.uw.edu"], 14 | url="http://github.com/pytorch/hydra-torch", 15 | include_package_data=True, 16 | install_requires=requirements, 17 | ) 18 | -------------------------------------------------------------------------------- /hydra-configs-torchvision/tests/test_instantiate_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import os 3 | import pytest 4 | from pathlib import Path 5 | from hydra.utils import get_class, instantiate 6 | from omegaconf import OmegaConf 7 | from typing import Any 8 | 9 | import torch 10 | import torchvision.datasets as datasets 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "modulepath, classname, cfg, passthrough_args, passthrough_kwargs, expected_class", 15 | [ 16 | pytest.param( 17 | "datasets.vision", 18 | "VisionDataset", 19 | {"root": None}, 20 | [], 21 | {}, 22 | datasets.VisionDataset, 23 | id="VisionDatasetConf", 24 | ), 25 | pytest.param( 26 | "datasets.mnist", 27 | "MNIST", 28 | {"root": None}, 29 | [], 30 | {}, 31 | datasets.MNIST, 32 | id="MNISTConf", 33 | ), 34 | pytest.param( 35 | "datasets.mnist", 36 | "FashionMNIST", 37 | {"root": None}, 38 | [], 39 | {}, 40 | datasets.FashionMNIST, 41 | id="FashionMNISTConf", 42 | ), 43 | pytest.param( 44 | "datasets.mnist", 45 | "KMNIST", 46 | {"root": None}, 47 | [], 48 | {}, 49 | datasets.KMNIST, 50 | id="KMNISTConf", 51 | ), 52 | # TODO: These tests will need to be changed after blockers: 53 | # 1. EMNISTConf and QMNISTConf are manually created 54 | # 2. hydra.utils.instantiate is updated to allow *kwargs instantiation 55 | # pytest.param( 56 | # "datasets.mnist", 57 | # "EMNIST", 58 | # {"root":None, 59 | # "split":"byclass", 60 | # "kwargs":None}, 61 | # [], 62 | # {}, 63 | # datasets.EMNIST, 64 | # id="EMNISTConf", 65 | # ), 66 | # pytest.param( 67 | # "datasets.mnist", 68 | # "QMNIST", 69 | # {"root":None, 70 | # "what":'test', 71 | # "compat":None, 72 | # "kwargs":None}, 73 | # [], 74 | # {}, 75 | # datasets.QMNIST, 76 | # id="QMNISTConf", 77 | # ), 78 | ], 79 | ) 80 | def test_instantiate_classes( 81 | tmpdir: Path, 82 | modulepath: str, 83 | classname: str, 84 | cfg: Any, 85 | passthrough_args: Any, 86 | passthrough_kwargs: Any, 87 | expected_class: Any, 88 | ) -> None: 89 | 90 | # Create fake dataset and put it in tmpdir for test: 91 | tmp_data_root = tmpdir.mkdir("data") 92 | processed_dir = os.path.join(tmp_data_root, classname, "processed") 93 | os.makedirs(processed_dir) 94 | torch.save(torch.tensor([[1.0], [1.0]]), processed_dir + "/training.pt") 95 | torch.save(torch.tensor([1.0]), processed_dir + "/test.pt") 96 | 97 | # cfg is populated here since it requires tmpdir testfixture 98 | cfg["root"] = str(tmp_data_root) 99 | full_class = f"hydra_configs.torchvision.{modulepath}.{classname}Conf" 100 | schema = OmegaConf.structured(get_class(full_class)) 101 | cfg = OmegaConf.merge(schema, cfg) 102 | obj = instantiate(cfg, *passthrough_args, **passthrough_kwargs) 103 | expected_obj = expected_class(root=tmp_data_root) 104 | 105 | assert isinstance(obj, type(expected_obj)) 106 | -------------------------------------------------------------------------------- /hydra-configs-torchvision/tests/test_instantiate_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import pytest 3 | from hydra.utils import get_class, instantiate 4 | from omegaconf import OmegaConf 5 | 6 | import torch 7 | 8 | # import torchvision.datasets as datasets 9 | import torchvision.transforms as transforms 10 | from torchvision.transforms.transforms import ToTensor 11 | 12 | from typing import Any 13 | 14 | 15 | def identity(x): 16 | return x 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "modulepath, classname, cfg, passthrough_args, passthrough_kwargs, expected", 21 | [ 22 | # pytest.param( 23 | # "datasets.vision", 24 | # "StandardTransform", 25 | # {}, 26 | # [], 27 | # {}, 28 | # datasets.vision.StandardTransform(), 29 | # id="StandardTransformConf", 30 | # ), 31 | pytest.param( 32 | "transforms.transforms", 33 | "CenterCrop", 34 | {"size": (10, 10)}, 35 | [], 36 | {}, 37 | transforms.transforms.CenterCrop(size=(10, 10)), 38 | id="CenterCropConf", 39 | ), 40 | pytest.param( 41 | "transforms.transforms", 42 | "ColorJitter", 43 | {}, 44 | [], 45 | {}, 46 | transforms.transforms.ColorJitter(), 47 | id="ColorJitterConf", 48 | ), 49 | pytest.param( 50 | "transforms.transforms", 51 | "Compose", 52 | {"transforms": []}, 53 | [], 54 | {}, 55 | transforms.transforms.Compose(transforms=[]), 56 | id="ComposeConf", 57 | ), 58 | pytest.param( 59 | "transforms.transforms", 60 | "ConvertImageDtype", 61 | {}, 62 | [], 63 | {"dtype": torch.int32}, 64 | transforms.transforms.ConvertImageDtype(dtype=torch.int32), 65 | id="ConvertImageDtypeConf", 66 | ), 67 | pytest.param( 68 | "transforms.transforms", 69 | "FiveCrop", 70 | {"size": (10, 10)}, 71 | [], 72 | {}, 73 | transforms.transforms.FiveCrop(size=(10, 10)), 74 | id="FiveCropConf", 75 | ), 76 | pytest.param( 77 | "transforms.transforms", 78 | "Grayscale", 79 | {}, 80 | [], 81 | {}, 82 | transforms.transforms.Grayscale(), 83 | id="GrayscaleConf", 84 | ), 85 | pytest.param( 86 | "transforms.transforms", 87 | "Lambda", 88 | {}, 89 | [], 90 | {"lambd": identity}, 91 | transforms.transforms.Lambda(lambd=identity), 92 | id="LambdaConf", 93 | ), 94 | pytest.param( 95 | "transforms.transforms", 96 | "LinearTransformation", 97 | {}, 98 | [], 99 | { 100 | "transformation_matrix": torch.eye(2), 101 | "mean_vector": torch.Tensor([1, 1]), 102 | }, 103 | transforms.transforms.LinearTransformation( 104 | transformation_matrix=torch.eye(2), mean_vector=torch.Tensor([1, 1]) 105 | ), 106 | id="LinearTransformationConf", 107 | ), 108 | pytest.param( 109 | "transforms.transforms", 110 | "Normalize", 111 | {"mean": 0, "std": 1}, 112 | [], 113 | {}, 114 | transforms.transforms.Normalize(mean=0, std=1), 115 | id="NormalizeConf", 116 | ), 117 | pytest.param( 118 | "transforms.transforms", 119 | "Pad", 120 | {"padding": 0}, 121 | [], 122 | {}, 123 | transforms.transforms.Pad(padding=0), 124 | id="PaddingConf", 125 | ), 126 | pytest.param( 127 | "transforms.transforms", 128 | "PILToTensor", 129 | {}, 130 | [], 131 | {}, 132 | transforms.transforms.PILToTensor(), 133 | id="PILToTensorConf", 134 | ), 135 | pytest.param( 136 | "transforms.transforms", 137 | "RandomAffine", 138 | {"degrees": 0}, 139 | [], 140 | {}, 141 | transforms.transforms.RandomAffine(degrees=0), 142 | id="RandomAffineConf", 143 | ), 144 | pytest.param( 145 | "transforms.transforms", 146 | "RandomApply", 147 | {}, 148 | [], 149 | {"transforms": [ToTensor()]}, 150 | transforms.transforms.RandomApply([ToTensor()]), 151 | id="RandomApplyConf", 152 | ), 153 | pytest.param( 154 | "transforms.transforms", 155 | "RandomChoice", 156 | {}, 157 | [], 158 | {"transforms": [[ToTensor()]]}, 159 | transforms.transforms.RandomChoice([ToTensor()]), 160 | id="RandomChoiceConf", 161 | ), 162 | pytest.param( 163 | "transforms.transforms", 164 | "RandomCrop", 165 | {"size": (10, 10)}, 166 | [], 167 | {}, 168 | transforms.transforms.RandomCrop(size=(10, 10)), 169 | id="RandomCropConf", 170 | ), 171 | pytest.param( 172 | "transforms.transforms", 173 | "RandomErasing", 174 | {}, 175 | [], 176 | {}, 177 | transforms.transforms.RandomErasing(), 178 | id="RandomErasingConf", 179 | ), 180 | pytest.param( 181 | "transforms.transforms", 182 | "RandomGrayscale", 183 | {}, 184 | [], 185 | {}, 186 | transforms.transforms.RandomGrayscale(), 187 | id="RandomGrayscaleConf", 188 | ), 189 | pytest.param( 190 | "transforms.transforms", 191 | "RandomHorizontalFlip", 192 | {}, 193 | [], 194 | {}, 195 | transforms.transforms.RandomHorizontalFlip(), 196 | id="RandomHorizontalFlipConf", 197 | ), 198 | pytest.param( 199 | "transforms.transforms", 200 | "RandomOrder", 201 | {}, 202 | [], 203 | {"transforms": [ToTensor()]}, 204 | transforms.transforms.RandomOrder([ToTensor()]), 205 | id="RandomOrderConf", 206 | ), 207 | pytest.param( 208 | "transforms.transforms", 209 | "RandomPerspective", 210 | {}, 211 | [], 212 | {}, 213 | transforms.transforms.RandomPerspective(), 214 | id="RandomPerspectiveConf", 215 | ), 216 | pytest.param( 217 | "transforms.transforms", 218 | "RandomResizedCrop", 219 | {"size": (10, 10)}, 220 | [], 221 | {}, 222 | transforms.transforms.RandomResizedCrop(size=(10, 10)), 223 | id="RandomResizedCropConf", 224 | ), 225 | pytest.param( 226 | "transforms.transforms", 227 | "RandomRotation", 228 | {"degrees": 0}, 229 | [], 230 | {}, 231 | transforms.transforms.RandomRotation(degrees=0), 232 | id="RandomRotationConf", 233 | ), 234 | pytest.param( 235 | "transforms.transforms", 236 | "RandomTransforms", 237 | {"transforms": []}, 238 | [], 239 | {}, 240 | transforms.transforms.RandomTransforms([]), 241 | id="RandomTransformsConf", 242 | ), 243 | pytest.param( 244 | "transforms.transforms", 245 | "RandomVerticalFlip", 246 | {}, 247 | [], 248 | {}, 249 | transforms.transforms.RandomVerticalFlip(), 250 | id="RandomVerticalFlipConf", 251 | ), 252 | pytest.param( 253 | "transforms.transforms", 254 | "Resize", 255 | {"size": (10, 10)}, 256 | [], 257 | {}, 258 | transforms.transforms.Resize(size=(10, 10)), 259 | id="ResizeConf", 260 | ), 261 | pytest.param( 262 | "transforms.transforms", 263 | "TenCrop", 264 | {"size": (10, 10)}, 265 | [], 266 | {}, 267 | transforms.transforms.TenCrop(size=(10, 10)), 268 | id="TenCropConf", 269 | ), 270 | pytest.param( 271 | "transforms.transforms", 272 | "ToPILImage", 273 | {}, 274 | [], 275 | {}, 276 | transforms.transforms.ToPILImage(), 277 | id="ToPILImageConf", 278 | ), 279 | pytest.param( 280 | "transforms.transforms", 281 | "ToTensor", 282 | {}, 283 | [], 284 | {}, 285 | transforms.transforms.ToTensor(), 286 | id="ToTensorConf", 287 | ), 288 | ], 289 | ) 290 | def test_instantiate_classes( 291 | modulepath: str, 292 | classname: str, 293 | cfg: Any, 294 | passthrough_args: Any, 295 | passthrough_kwargs: Any, 296 | expected: Any, 297 | ) -> None: 298 | full_class = f"hydra_configs.torchvision.{modulepath}.{classname}Conf" 299 | schema = OmegaConf.structured(get_class(full_class)) 300 | cfg = OmegaConf.merge(schema, cfg) 301 | obj = instantiate(cfg, *passthrough_args, **passthrough_kwargs) 302 | 303 | assert isinstance(obj, type(expected)) 304 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import nox 3 | import os 4 | 5 | DEFAULT_PYTHON_VERSIONS = ["3.6", "3.7", "3.8"] 6 | PYTHON_VERSIONS = os.environ.get( 7 | "NOX_PYTHON_VERSIONS", ",".join(DEFAULT_PYTHON_VERSIONS) 8 | ).split(",") 9 | 10 | VERBOSE = os.environ.get("VERBOSE", "0") 11 | SILENT = VERBOSE == "0" 12 | 13 | # Linted dirs/files: 14 | lint_targets = "." 15 | # Test dirs (corresponds to each project having its own tests folder): 16 | # Note the './', this installs local packages 17 | test_targets = [ 18 | "./" + p.rstrip("\n") for p in open("hydra-configs-projects.txt", "r").readlines() 19 | ] 20 | 21 | 22 | def setup_dev_env(session): 23 | session.run( 24 | "python", 25 | "-m", 26 | "pip", 27 | "install", 28 | "--upgrade", 29 | "setuptools", 30 | "pip", 31 | silent=SILENT, 32 | ) 33 | 34 | session.run("pip", "install", "-r", "requirements/dev.txt", silent=SILENT) 35 | 36 | 37 | @nox.session(python=PYTHON_VERSIONS, reuse_venv=True) 38 | def lint(session): 39 | setup_dev_env(session) 40 | session.run("black", *lint_targets, "--check") 41 | session.run("flake8", "--config", ".flake8", *lint_targets) 42 | 43 | 44 | @nox.session(python=PYTHON_VERSIONS, reuse_venv=True) 45 | def tests(session): 46 | setup_dev_env(session) 47 | for target in test_targets: 48 | session.run( 49 | "pip", "install", "-r", target + "/requirements/dev.txt", silent=SILENT 50 | ) 51 | session.install(*test_targets) # install config packages 52 | session.run("pytest", *test_targets) 53 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | torch==1.7.1 3 | torchvision==0.8.2 4 | hydra-configen==0.9.0.dev7 5 | black==20.8b1 6 | coverage 7 | flake8==3.8.4 8 | flake8-copyright 9 | isort==5.5.2 10 | mypy 11 | nox 12 | packaging 13 | pre-commit 14 | pytest 15 | pytest-snail 16 | setuptools 17 | towncrier 18 | twine 19 | yamllint 20 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.2.0 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from setuptools import setup 3 | 4 | projects = [p.rstrip("\n") for p in open("hydra-configs-projects.txt", "r").readlines()] 5 | project_uris = [ 6 | f"{project} @ git+https://github.com/pytorch/hydra-torch/#subdirectory={project}" 7 | for project in projects 8 | ] 9 | 10 | setup( 11 | name="hydra-torch", 12 | version="0.9", 13 | author=["Omry Yadan", "Rosario Scalise"], 14 | author_email=["omry@fb.com", "rosario@cs.uw.edu"], 15 | url="http://github.com/pytorch/hydra-torch", 16 | include_package_data=True, 17 | install_requires=project_uris, 18 | ) 19 | --------------------------------------------------------------------------------