├── .gitignore ├── .pyre_configuration ├── .watchmanconfig ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── dev-requirements.txt ├── pyproject.toml ├── requirements.txt ├── setup.py └── torchrecipes ├── __init__.py ├── _internal_patches.py ├── audio ├── __init__.py └── source_separation │ ├── conf │ ├── __init__.py │ └── default_config.yaml │ ├── datamodule │ ├── __init__.py │ ├── librimix.py │ └── utils.py │ ├── loss │ ├── __init__.py │ ├── si_sdr.py │ └── utils.py │ ├── main.py │ ├── metrics │ ├── __init__.py │ └── sdr.py │ ├── module │ └── conv_tasnet.py │ └── tests │ └── test_main.py ├── core ├── __init__.py ├── base_train_app.py ├── conf │ ├── __init__.py │ ├── lr_scheduler │ │ ├── exponential.yaml │ │ ├── multi_step.yaml │ │ └── step.yaml │ ├── module │ │ ├── loss │ │ │ └── cross_entropy.yaml │ │ ├── lr_scheduler │ │ │ └── step_lr.yaml │ │ └── optim │ │ │ ├── adam.yaml │ │ │ ├── adamw.yaml │ │ │ └── sgd.yaml │ ├── optim │ │ ├── adamw.yaml │ │ └── sgd.yaml │ └── trainer │ │ ├── cpu.yaml │ │ ├── multi_gpu.yaml │ │ └── single_gpu.yaml ├── logger.py ├── test_utils │ ├── conf_utils.py │ └── test_base.py └── tests │ └── test_base_train_app.py ├── launcher ├── __init__.py ├── run.py └── tests │ └── test_env.py ├── paved_path ├── .gitignore ├── .lintrunner.toml ├── README.md ├── airflow │ ├── README.md │ ├── charnn_dag.py │ └── setup.sh ├── charnn │ ├── char_dataset.py │ ├── char_transform.py │ ├── combined_module.py │ ├── data │ │ └── input.txt │ ├── export.py │ ├── main.py │ ├── model.py │ ├── serve │ │ ├── deploy.sh │ │ └── handler.py │ ├── tests │ │ └── test_cli.py │ ├── trainer.py │ ├── trainer_config.yaml │ └── utils.py ├── docker │ ├── Dockerfile │ ├── README.md │ └── build.sh ├── requirements.txt └── tools │ └── linter │ ├── pip_init.py │ └── ufmt_linter.py ├── rec ├── README.MD ├── __init__.py ├── accelerators │ ├── __init__.py │ └── torchrec.py ├── datamodules │ ├── __init__.py │ ├── commons.py │ ├── criteo_datamodule.py │ ├── random_rec_datamodule.py │ ├── samplers │ │ ├── tests │ │ │ └── test_undersampler.py │ │ └── undersampler.py │ └── tests │ │ ├── __init__.py │ │ ├── test_criteo_datamodule.py │ │ ├── test_random_rec_datamodule.py │ │ └── utils.py ├── dlrm_main.py ├── modules │ ├── __init__.py │ ├── lightning_dlrm.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_lightning_dlrm.py │ │ └── test_unsharded_lightning_dlrm.py │ └── unsharded_lightning_dlrm.py └── tests │ ├── __init__.py │ └── test_dlrm_main.py ├── tests └── test_version.py ├── text └── doc_classification │ ├── .torchxconfig │ ├── README.md │ ├── __init__.py │ ├── conf │ ├── __init__.py │ ├── datamodule │ │ ├── dataset │ │ │ └── sst2_dataset.yaml │ │ └── doc_classification_datamodule.yaml │ ├── default_config.yaml │ ├── module │ │ ├── model │ │ │ ├── xlmrbase_classifier.yaml │ │ │ └── xlmrbase_classifier_tiny.yaml │ │ └── optim │ │ │ └── adamw.yaml │ ├── tiny_model_full_config.yaml │ ├── tiny_model_mixed_config.yaml │ ├── trainer │ │ ├── cpu.yaml │ │ ├── multi_gpu.yaml │ │ └── single_gpu.yaml │ ├── transform │ │ ├── doc_classification_transform.yaml │ │ └── doc_classification_transform_tiny.yaml │ └── xlmrbase_sst2_config.yaml │ ├── datamodule │ └── doc_classification.py │ ├── main.py │ ├── module │ └── doc_classification.py │ ├── tests │ ├── common │ │ ├── __init__.py │ │ └── assets.py │ ├── data │ │ ├── SST2 │ │ │ └── SST-2.zip │ │ ├── spm_example.model │ │ └── vocab_example.pt │ ├── test_doc_classification_config.py │ ├── test_doc_classification_datamodule.py │ ├── test_doc_classification_main.py │ ├── test_doc_classification_module.py │ └── test_doc_classification_transform.py │ └── transform │ └── doc_classification_text_transform.py ├── utils ├── __init__.py ├── checkpoint.py ├── config_utils.py ├── distributed_utils.py ├── mixup_utils.py ├── task_test_base.py ├── test.py ├── tests │ ├── test_config_utils.py │ └── test_trainer_plugins.py └── trainer_plugins.py ├── version.py └── vision ├── core ├── __init__.py ├── conf │ └── datamodule │ │ ├── mnist_datamodule.yaml │ │ └── torchvision_datamodule.yaml ├── datamodule │ ├── __init__.py │ ├── mnist_data_module.py │ ├── tests │ │ ├── test_mnist_data_module.py │ │ └── test_torchvision_data_module.py │ ├── torchvision_data_module.py │ └── transforms │ │ ├── __init__.py │ │ └── builder.py ├── ops │ ├── __init__.py │ └── fine_tuning_wrapper.py ├── optim │ └── lr_scheduler.py ├── tests │ ├── test_fine_tuning_wrapper.py │ └── test_lr_scheduler.py └── utils │ ├── lr_scheduler.py │ ├── model_weights.py │ ├── model_weights_exporter.py │ └── test_module.py ├── image_classification ├── README.md ├── __init__.py ├── callbacks │ ├── __init__.py │ └── mixup_transform.py ├── conf │ ├── __init__.py │ ├── datamodule │ │ ├── datasets │ │ │ ├── cifar10.yaml │ │ │ └── fake_data.yaml │ │ └── torchvision_datamodule.yaml │ ├── default_config.yaml │ ├── module │ │ ├── default_module.yaml │ │ ├── loss │ │ │ └── cross_entropy.yaml │ │ ├── lr_scheduler │ │ │ └── step_lr.yaml │ │ ├── metrics │ │ │ ├── accuracy.yaml │ │ │ ├── accuracy_top1_top5.yaml │ │ │ ├── average_precision.yaml │ │ │ └── multilabel_accuracy.yaml │ │ ├── model │ │ │ ├── resnet18.yaml │ │ │ ├── resnet50.yaml │ │ │ └── resnext101_32x4d.yaml │ │ └── optim │ │ │ └── sgd.yaml │ └── trainer │ │ ├── cpu.yaml │ │ └── gpu.yaml ├── losses │ ├── soft_target_cross_entropy_loss.py │ └── tests │ │ └── test_soft_target_cross_entropy_loss.py ├── main.py ├── metrics │ ├── multilabel_accuracy.py │ └── tests │ │ └── test_multilabel_accuracy.py ├── module │ └── image_classification.py └── tests │ ├── test_image_classification_module.py │ └── test_main.py └── image_generation ├── __init__.py ├── callbacks ├── __init__.py └── image_generation.py ├── conf ├── __init__.py ├── datamodule │ ├── datasets │ │ └── fake_data.yaml │ ├── mnist.yaml │ ├── torchvision_datamodule.yaml │ └── transforms │ │ └── resize.yaml ├── gan_train_app.yaml ├── infogan_train_app.yaml └── module │ ├── criterion │ └── bce_loss.yaml │ ├── discriminator │ ├── dcgan.yaml │ ├── gan.yaml │ └── infogan.yaml │ ├── generator │ ├── dcgan.yaml │ ├── gan.yaml │ └── infogan.yaml │ ├── optim │ ├── default.yaml │ └── infogan_default.yaml │ └── tests │ └── test_gan_module_conf.py ├── models ├── dcgan.py ├── gan.py └── infogan.py ├── module ├── gan.py └── infogan.py ├── tests ├── test_image_generation.py └── test_train_app.py └── train_app.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /.pyre_configuration: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | ".*/build/.*", 4 | ".*/docs/.*" 5 | ], 6 | "site_package_search_strategy": "all", 7 | "source_directories": [ 8 | "." 9 | ], 10 | "strict": true, 11 | "version": "0.0.101681211430" 12 | } 13 | -------------------------------------------------------------------------------- /.watchmanconfig: -------------------------------------------------------------------------------- 1 | { 2 | "root_files": [ 3 | "pytorch", 4 | ".pyre_configuration", 5 | ".watchmanconfig" 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | 3 | Summary of per-version changes. 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to torchrecipes 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 16 | 17 | ## Contributor License Agreement ("CLA") 18 | In order to accept your pull request, we need you to submit a CLA. You only need 19 | to do this once to work on any of Facebook's open source projects. 20 | 21 | Complete your CLA here: 22 | 23 | ## Issues 24 | We use GitHub issues to track public bugs. Please ensure your description is 25 | clear and has sufficient instructions to be able to reproduce the issue. 26 | 27 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 28 | disclosure of security bugs. In those cases, please go through the process 29 | outlined on that page and do not file a public issue. 30 | 31 | ## License 32 | By contributing to torchrecipes, you agree that your contributions will be licensed 33 | under the LICENSE file in the root directory of this source tree. 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019-present, Facebook, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt dev-requirements.txt 2 | global-include *.yaml 3 | recursive-exclude torchrecipes/fb * 4 | recursive-exclude torchrecipes/github * 5 | global-exclude fb_*.py *_fb.py 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](LICENSE) 2 | 3 | # TorchRecipes 4 | 5 |

Train machine learning models with a couple of lines of code with torchrecipes.

6 | 7 | > This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an GitHub issue or reach out. We'd love to hear about how you're using `torchrecipes`! 8 | 9 | A recipe is a ready-to-run application that trains a deep learning model by combining a model architecture, trainer, config, etc that you can easily run, modify, or extend. Recipes run on everything from local development environments on your laptop, to large scale clusters. They enable quick experimentation through configuration, and a good starting place to extend the code by forking when more extensive changes are needed. 10 | 11 | We provide a number of out-of-the-box recipes across popular domains (vision, NLP, etc) and tasks (image, text classification, etc) that you can use immediately or as a starting point for further work. 12 | 13 | ## Why `torchrecipes`? 14 | 15 | Getting started with training machine learning models is a lot easier if you can start with something that already runs, instead of having to write all the glue code yourself. 16 | 17 | Machine learning, whether for research or production training, requires working with a number of components like training loops or frameworks, configuration/hyper-parameter parsing, model architectures, data loading, etc. Recipes provide production-ready examples for common tasks that can be easily modified. A recipe at a high-level integrates these modular components so that you can modify the ones that matter for your problem! 18 | 19 | We focus our recipes on providing consistent, high-quality baselines that accurately reproduce research papers. 20 | 21 | ## Get Started 22 | 23 | ### Installation 24 | 25 | We recommend Anaconda as Python package management system. Please refer to pytorch.org for the detail of PyTorch (torch) installation. 26 | 27 | ```bash 28 | pip install torchrecipes 29 | ``` 30 | 31 | To install `torchrecipes` from source, please run the following commands: 32 | 33 | ```bash 34 | git clone https://github.com/facebookresearch/recipes.git && cd recipes 35 | pip install -e . 36 | ``` 37 | 38 | ### Vision 39 | 40 | - [Image Classification Recipe](torchrecipes/vision/image_classification) 41 | 42 | ### Text 43 | - [Text Classification Recipe](torchrecipes/text/doc_classification) 44 | 45 | 46 | ## Anatomy of a Recipe 47 | 48 | A recipe is a Python application that you can run directly or customize: 49 | 50 | * `main.py`: the entrypoint to start training. The script name doesn't matter and might be different in various recipes. 51 | * `conf/`: Hydra configuration for the training job (including defaults) 52 | * `module/`: the model implementation for pytorch-lightning based recipes 53 | * `data_module/`: the data loading/processing implementation for pytorch-lightning based recipes 54 | * `tests/`: tests for recipe 55 | 56 | By default each recipe supports changing common things via configuration files like hyper-parameters as well as changing the model that is loaded itself (e.g. change from `ResNet18` to `ResNet52`). For research and experimentation, you can also override any of the functionality directly by modifying the model, training loop, etc. 57 | 58 | 59 | ## Contributing 60 | 61 | We welcome PRs! See the [CONTRIBUTING](CONTRIBUTING.md) file. 62 | 63 | ## License 64 | 65 | TorchRecipes is BSD licensed, as found in the [LICENSE](LICENSE) file. 66 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | black>=21.5b1 2 | importlib-metadata 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | first_party_detection = false 3 | 4 | [build-system] 5 | requires = [ 6 | "setuptools>=42", 7 | "wheel" 8 | ] 9 | build-backend = "setuptools.build_meta" 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core>=1.1.2 2 | pyre-extensions>=0.0.23 3 | pytorch-lightning @ git+https://github.com/PyTorchLightning/pytorch-lightning.git@9b011606f 4 | protobuf<=3.20.1 # strict. TODO: Remove after tensorboard gets compatible https://github.com/tensorflow/tensorboard/issues/5708 5 | TestSlide>=2.7.0 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import glob 8 | import os 9 | import re 10 | import shutil 11 | import subprocess 12 | from typing import List 13 | 14 | from setuptools import find_packages, setup 15 | import distutils.command.clean # isort:skip 16 | 17 | _PACKAGE_NAME: str = "torchrecipes" 18 | _VERSION_FILE: str = "version.py" 19 | _README: str = "README.md" 20 | _REQUIREMENTS: str = "requirements.txt" 21 | _DEV_REQUIREMENTS: str = "dev-requirements.txt" 22 | _GITIGNORE: str = ".gitignore" 23 | 24 | 25 | def get_version() -> str: 26 | """Retrieves the version of the library.""" 27 | version = os.getenv("BUILD_VERSION") 28 | if version: 29 | return version 30 | cwd = os.path.dirname(os.path.abspath(__file__)) 31 | version_file_path = os.path.join(_PACKAGE_NAME, _VERSION_FILE) 32 | version_regex = r"__version__: str = ['\"]([^'\"]*)['\"]" 33 | with open(version_file_path, "r") as f: 34 | search = re.search(version_regex, f.read(), re.M) 35 | assert search 36 | version = search.group(1) 37 | 38 | try: 39 | sha = ( 40 | subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd) 41 | .decode("ascii") 42 | .strip() 43 | ) 44 | version += "+" + sha[:7] 45 | except Exception: 46 | pass 47 | return version 48 | 49 | 50 | def get_long_description() -> str: 51 | """Fetch project description as Markdown.""" 52 | with open(_README, mode="r") as f: 53 | return f.read() 54 | 55 | 56 | def get_requirements() -> List[str]: 57 | """Fetch requirements.""" 58 | with open(_REQUIREMENTS, mode="r") as f: 59 | return f.readlines() 60 | 61 | 62 | def get_dev_requirements() -> List[str]: 63 | """Fetch requirements for library development.""" 64 | with open(_DEV_REQUIREMENTS, mode="r") as f: 65 | return f.readlines() 66 | 67 | 68 | class clean(distutils.command.clean.clean): 69 | def run(self) -> None: 70 | with open(_GITIGNORE, "r") as f: 71 | ignores = f.readlines() 72 | for wildcard in filter(None, ignores): 73 | for filename in glob.glob(wildcard): 74 | try: 75 | os.remove(filename) 76 | except OSError: 77 | shutil.rmtree(filename, ignore_errors=True) 78 | 79 | # It's an old-style class in Python 2.7... 80 | distutils.command.clean.clean.run(self) 81 | 82 | 83 | def main() -> None: 84 | global version 85 | version = get_version() 86 | print("Building wheel {}-{}".format(_PACKAGE_NAME, version)) 87 | 88 | setup( 89 | # Metadata 90 | name=_PACKAGE_NAME, 91 | version=version, 92 | author="PyTorch Ecosystem Foundations Team", 93 | author_email="luispe@fb.com", 94 | description="Prototype of training recipes for PyTorch", 95 | long_description=get_long_description(), 96 | long_description_content_type="text/markdown", 97 | url="https://github.com/facebookresearch/recipes", 98 | license="BSD-3", 99 | keywords=["pytorch", "machine learning"], 100 | python_requires=">=3.7", 101 | install_requires=get_requirements(), 102 | include_package_data=True, 103 | # Package info 104 | packages=find_packages(), 105 | # pyre-fixme[6]: For 15th argument expected `Mapping[str, Type[Command]]` 106 | # but got `Mapping[str, Type[clean]]`. 107 | cmdclass={ 108 | "clean": clean, 109 | }, 110 | extras_require={"dev": get_dev_requirements()}, 111 | # PyPI package information. 112 | classifiers=[ 113 | "Development Status :: 4 - Beta", 114 | "Intended Audience :: Developers", 115 | "Intended Audience :: Science/Research", 116 | "License :: OSI Approved :: BSD License", 117 | "Programming Language :: Python :: 3", 118 | "Programming Language :: Python :: 3.8", 119 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 120 | ], 121 | ) 122 | 123 | 124 | version: str 125 | 126 | 127 | if __name__ == "__main__": 128 | main() # pragma: no cover 129 | -------------------------------------------------------------------------------- /torchrecipes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchrecipes.version import __version__ as __version__ # noqa F401 7 | -------------------------------------------------------------------------------- /torchrecipes/_internal_patches.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | """ 8 | If you change the symbols exported in this file, you very likely want to change 9 | the symbols expoerted in ./fb/_internal_patches.py as well. 10 | 11 | For any buck-based build internally (within fbcode), this file is silently 12 | replaced with the file located in ./fb/_internal_patches.py. This is to enable 13 | use to silenetly swap out symbols (such as Checkpoint, Logger, etc.) between 14 | internal-only implementations and external versions w/o requiring user involvement. 15 | """ 16 | 17 | from functools import wraps 18 | from typing import Any 19 | 20 | from pytorch_lightning.callbacks import ModelCheckpoint as OSSModelCheckpoint 21 | from pytorch_lightning.loggers import TensorBoardLogger as OSSTensorboardLogger 22 | 23 | # Use tuples since these are not mutable. 24 | _FB_ONLY_ARGS = ( 25 | "has_user_data", 26 | "ttl_days", 27 | "manifold_bucket", 28 | "manifold_path", 29 | "num_retries", 30 | "save_torchscript", 31 | "save_quantized", 32 | "api_key", 33 | ) 34 | 35 | 36 | @wraps(OSSModelCheckpoint, updated=()) 37 | def ModelCheckpoint(**kwargs: Any) -> OSSModelCheckpoint: 38 | for arg_name in _FB_ONLY_ARGS: 39 | kwargs.pop(arg_name, None) 40 | return OSSModelCheckpoint(**kwargs) 41 | 42 | 43 | @wraps(OSSTensorboardLogger, updated=()) 44 | def TensorBoardLogger(**kwargs: Any) -> OSSTensorboardLogger: 45 | for arg_name in _FB_ONLY_ARGS: 46 | kwargs.pop(arg_name, None) 47 | return OSSTensorboardLogger(**kwargs) 48 | 49 | 50 | def log_run(**kwargs: Any) -> None: 51 | """Log Run.""" 52 | pass 53 | -------------------------------------------------------------------------------- /torchrecipes/audio/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/conf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/conf/default_config.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: torchrecipes.audio.source_separation.datamodule.librimix.LibriMixDataModule 3 | root_dir: ??? 4 | batch_size: 6 5 | tr_split: train-360 6 | num_speakers: 2 7 | sample_rate: 8000 8 | task: sep_clean 9 | num_workers: 4 10 | 11 | module: 12 | _target_: torchrecipes.audio.source_separation.module.conv_tasnet.ConvTasNetModule 13 | model: 14 | _target_: torchaudio.models.ConvTasNet 15 | num_sources: 2 16 | enc_kernel_size: 16 17 | enc_num_feats: 512 18 | msk_kernel_size: 3 19 | msk_num_feats: 128 20 | msk_num_hidden_feats: 512 21 | msk_num_layers: 8 22 | msk_num_stacks: 3 23 | msk_activate: relu 24 | loss: 25 | _target_: torchrecipes.audio.source_separation.loss.si_sdr_loss 26 | _partial_: true 27 | optim_fn: 28 | _target_: torch.optim.Adam 29 | _partial_: true 30 | lr: 0.001 31 | metrics: 32 | sdri: 33 | _target_: torchrecipes.audio.source_separation.metrics.sdri_metric 34 | _partial_: true 35 | sisdri: 36 | _target_: torchrecipes.audio.source_separation.metrics.sisdri_metric 37 | _partial_: true 38 | lr_scheduler: 39 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 40 | _partial_: true 41 | mode: min 42 | factor: 0.5 43 | patience: 5 44 | 45 | trainer: 46 | _target_: pytorch_lightning.trainer.Trainer 47 | strategy: ddp 48 | accelerator: gpu 49 | devices: 2 50 | default_root_dir: null 51 | max_epochs: 200 52 | limit_train_batches: 1.0 53 | gradient_clip_val: 5.0 54 | logger: 55 | _target_: pytorch_lightning.loggers.TensorBoardLogger 56 | save_dir: /tmp/logs 57 | callbacks: 58 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 59 | dirpath: /tmp/checkpoints 60 | monitor: losses/val_loss 61 | mode: min 62 | save_top_k: 5 63 | save_weights_only: true 64 | - _target_: pytorch_lightning.callbacks.EarlyStopping 65 | monitor: losses/val_loss 66 | mode: min 67 | patience: 30 68 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | from torchrecipes.audio.source_separation.datamodule.librimix import LibriMixDataModule 11 | 12 | __all__ = [ 13 | "LibriMixDataModule", 14 | ] 15 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/datamodule/librimix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | from typing import Dict, Optional 11 | 12 | from pytorch_lightning import LightningDataModule 13 | from torch.utils.data import DataLoader 14 | from torchaudio.datasets import LibriMix 15 | 16 | from .utils import CollateFn 17 | 18 | 19 | class LibriMixDataModule(LightningDataModule): 20 | def __init__( 21 | self, 22 | root_dir: str, 23 | batch_size: int = 6, 24 | tr_split: str = "train-360", 25 | num_speakers: int = 2, 26 | sample_rate: int = 8000, 27 | task: str = "sep_clean", 28 | num_workers: int = 4, 29 | ) -> None: 30 | """The LightningDataModule for LibriMix Dataset. 31 | Args: 32 | root_dir (str): the root directory of the dataset. 33 | batch_size (int, optional): the batch size of the dataset. (Default: 6) 34 | tr_split (str, optional): the training split in LibriMix dataset. 35 | Options: [``train-360`, ``train-100``] (Default: ``train-360``) 36 | num_speakers (int, optional): The number of speakers, which determines the directories 37 | to traverse. The datamodule will traverse ``s1`` to ``sN`` directories to collect 38 | N source audios. (Default: 2) 39 | sample_rate (int, optional): the sample rate of the audio. (Default: 8000) 40 | task (str, optional): the task of LibriMix. 41 | Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``] 42 | (Default: ``sep_clean``) 43 | num_workers (int, optional): the number of workers for each dataloader. (Default: 4) 44 | testing (bool, optional): To test the training recipe. If set to ``True``, the dataset will 45 | output random Tensors without need of the real dataset. (Default: ``False``) 46 | """ 47 | super().__init__() 48 | self.root_dir = root_dir 49 | self.batch_size = batch_size 50 | self.tr_split = tr_split 51 | self.num_speakers = num_speakers 52 | self.sample_rate = sample_rate 53 | self.task = task 54 | self.num_workers = num_workers 55 | self.datasets: Dict[str, LibriMix] = {} 56 | 57 | def _get_dataset(self, subset: str) -> LibriMix: 58 | return LibriMix( 59 | root=self.root_dir, 60 | subset=subset, 61 | num_speakers=self.num_speakers, 62 | sample_rate=self.sample_rate, 63 | task=self.task, 64 | ) 65 | 66 | def setup(self, stage: Optional[str] = None) -> None: 67 | if stage == "fit" or stage is None: 68 | self.datasets["train"] = self._get_dataset(subset=self.tr_split) 69 | self.datasets["val"] = self._get_dataset(subset="dev") 70 | if stage == "test" or stage is None: 71 | self.datasets["test"] = self._get_dataset(subset="test") 72 | 73 | def train_dataloader(self) -> DataLoader: 74 | return DataLoader( 75 | self.datasets["train"], 76 | batch_size=self.batch_size, 77 | collate_fn=CollateFn(sample_rate=self.sample_rate, duration=3), 78 | num_workers=self.num_workers, 79 | drop_last=True, 80 | ) 81 | 82 | def val_dataloader(self) -> DataLoader: 83 | return DataLoader( 84 | self.datasets["val"], 85 | batch_size=self.batch_size, 86 | collate_fn=CollateFn(sample_rate=self.sample_rate, duration=-1), 87 | num_workers=self.num_workers, 88 | drop_last=True, 89 | ) 90 | 91 | def test_dataloader(self) -> DataLoader: 92 | return DataLoader( 93 | self.datasets["test"], 94 | batch_size=self.batch_size, 95 | collate_fn=CollateFn(sample_rate=self.sample_rate, duration=-1), 96 | num_workers=self.num_workers, 97 | ) 98 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/datamodule/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | from typing import List, Tuple 11 | 12 | import torch 13 | from torch import Tensor 14 | 15 | 16 | class CollateFn: 17 | """Collate the waveforms to have the same size. 18 | Args: 19 | sample_rate (int): The sample rate of audio. 20 | duration (int): The duration of the waveform in the mini-batch (in seconds). 21 | """ 22 | 23 | def __init__(self, sample_rate: int, duration: int) -> None: 24 | self.sample_rate = sample_rate 25 | self.duration = duration 26 | 27 | def __call__(self, samples: List[torch.Tensor]) -> Tuple[Tensor, Tensor, Tensor]: 28 | """ 29 | Args: 30 | samples (List[torch.Tensor]): a list of samples 31 | 32 | Returns: 33 | (Tuple(Tensor, Tensor, Tensor)): 34 | The Tensor of mixture speech wavecforms of dimension `[batch, time]`. 35 | The Tensor of clean speech wavecforms of dimension `[batch, num_speaker, time]`. 36 | The Tensor of padding mask of dimension `[batch, time]`. 37 | """ 38 | if self.duration == -1: 39 | target_num_frames = max(s[1].shape[-1] for s in samples) 40 | else: 41 | target_num_frames = int(self.duration * self.sample_rate) 42 | 43 | mixes, srcs, masks = [], [], [] 44 | for sample in samples: 45 | mix, src, mask = self._fix_num_frames( 46 | sample, target_num_frames, self.sample_rate, random_start=True 47 | ) 48 | 49 | mixes.append(mix) 50 | srcs.append(src) 51 | masks.append(mask) 52 | 53 | return (torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0)) 54 | 55 | def _fix_num_frames( 56 | self, 57 | sample: torch.Tensor, 58 | target_num_frames: int, 59 | sample_rate: int, 60 | random_start: bool = False, 61 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 62 | """Ensure waveform has exact number of frames by slicing or padding""" 63 | mix = sample[1] # [1, time] 64 | src = torch.cat(sample[2], 0) # [num_sources, time] 65 | 66 | num_channels, num_frames = src.shape 67 | num_seconds = int(num_frames / sample_rate) 68 | target_seconds = int(target_num_frames / sample_rate) 69 | if num_frames >= target_num_frames: 70 | if random_start and num_frames > target_num_frames: 71 | start_frame = ( 72 | torch.randint(num_seconds - target_seconds + 1, (1,)) * sample_rate 73 | ) 74 | mix = mix[:, start_frame:] 75 | src = src[:, start_frame:] 76 | mix = mix[:, :target_num_frames] 77 | src = src[:, :target_num_frames] 78 | mask = torch.ones_like(mix) 79 | else: 80 | num_padding = target_num_frames - num_frames 81 | pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device) 82 | mix = torch.cat([mix, pad], 1) 83 | src = torch.cat([src, pad.expand(num_channels, -1)], 1) 84 | mask = torch.ones_like(mix) 85 | mask[..., num_frames:] = 0 86 | return mix, src, mask 87 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | from torchrecipes.audio.source_separation.loss.si_sdr import si_sdr_loss 11 | 12 | 13 | __all__ = [ 14 | "si_sdr_loss", 15 | ] 16 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/loss/si_sdr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | import torch 11 | from torchrecipes.audio.source_separation.loss import utils 12 | 13 | 14 | def si_sdr_loss( 15 | estimate: torch.Tensor, reference: torch.Tensor, mask: torch.Tensor 16 | ) -> torch.Tensor: 17 | """Compute the Si-SDR loss. 18 | Args: 19 | estimate (torch.Tensor): Estimated source signals. 20 | Tensor of dimension (batch, speakers, time) 21 | reference (torch.Tensor): Reference (original) source signals. 22 | Tensor of dimension (batch, speakers, time) 23 | mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1). 24 | Tensor of dimension (batch, 1, time) 25 | Returns: 26 | torch.Tensor: Si-SDR loss. Tensor of dimension (batch, ) 27 | """ 28 | estimate = estimate - estimate.mean(dim=2, keepdim=True) 29 | reference = reference - reference.mean(dim=2, keepdim=True) 30 | 31 | si_sdri = utils.sdr_pit(estimate, reference, mask=mask) 32 | return -si_sdri.mean() 33 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | import logging 11 | 12 | import hydra 13 | import torchrecipes.audio.source_separation.conf # noqa 14 | from omegaconf import DictConfig, OmegaConf 15 | from pytorch_lightning import seed_everything 16 | 17 | log: logging.Logger = logging.getLogger(__name__) 18 | 19 | 20 | @hydra.main(config_path="conf", config_name="default_config") 21 | def main(config: DictConfig) -> None: 22 | seed = config.get("seed", 0) 23 | seed_everything(seed, workers=True) 24 | log.info(f"Config:\n{OmegaConf.to_yaml(config)}") 25 | 26 | datamodule = hydra.utils.instantiate(config.datamodule) 27 | trainer = hydra.utils.instantiate(config.trainer) 28 | module = hydra.utils.instantiate(config.module) 29 | 30 | if getattr(config, "pretrained_checkpoint_path", None): 31 | module = module.load_from_checkpoint( 32 | checkpoint_path=config.pretrained_checkpoint_path 33 | ) 34 | trainer.fit(module, datamodule=datamodule) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | from torchrecipes.audio.source_separation.metrics.sdr import sdri_metric, sisdri_metric 11 | 12 | 13 | __all__ = [ 14 | "sdri_metric", 15 | "sisdri_metric", 16 | ] 17 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/metrics/sdr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | import torch 11 | from torchrecipes.audio.source_separation.loss import utils 12 | 13 | 14 | def sisdri_metric( 15 | estimate: torch.Tensor, 16 | reference: torch.Tensor, 17 | mix: torch.Tensor, 18 | mask: torch.Tensor, 19 | ) -> torch.Tensor: 20 | """Compute the improvement of scale-invariant SDR. (SI-SDRi). 21 | Args: 22 | estimate (torch.Tensor): Estimated source signals. 23 | Tensor of dimension (batch, speakers, time) 24 | reference (torch.Tensor): Reference (original) source signals. 25 | Tensor of dimension (batch, speakers, time) 26 | mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated. 27 | Tensor of dimension (batch, speakers == 1, time) 28 | mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1). 29 | Tensor of dimension (batch, 1, time) 30 | Returns: 31 | torch.Tensor: Improved SI-SDR. Tensor of dimension (batch, ) 32 | References: 33 | - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation 34 | Luo, Yi and Mesgarani, Nima 35 | https://arxiv.org/abs/1809.07454 36 | """ 37 | with torch.no_grad(): 38 | estimate = estimate - estimate.mean(axis=2, keepdim=True) 39 | reference = reference - reference.mean(axis=2, keepdim=True) 40 | mix = mix - mix.mean(axis=2, keepdim=True) 41 | 42 | si_sdri = utils.sdri(estimate, reference, mix, mask=mask) 43 | 44 | return si_sdri.mean().item() 45 | 46 | 47 | def sdri_metric( 48 | estimate: torch.Tensor, 49 | reference: torch.Tensor, 50 | mix: torch.Tensor, 51 | mask: torch.Tensor, 52 | ) -> torch.Tensor: 53 | """Compute the improvement of SDR. (SDRi). 54 | Args: 55 | estimate (torch.Tensor): Estimated source signals. 56 | Tensor of dimension (batch, speakers, time) 57 | reference (torch.Tensor): Reference (original) source signals. 58 | Tensor of dimension (batch, speakers, time) 59 | mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated. 60 | Tensor of dimension (batch, speakers == 1, time) 61 | mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1). 62 | Tensor of dimension (batch, 1, time) 63 | Returns: 64 | torch.Tensor: Improved SDR. Tensor of dimension (batch, ) 65 | References: 66 | - Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation 67 | Luo, Yi and Mesgarani, Nima 68 | https://arxiv.org/abs/1809.07454 69 | """ 70 | with torch.no_grad(): 71 | sdri = utils.sdri(estimate, reference, mix, mask=mask) 72 | return sdri.mean().item() 73 | -------------------------------------------------------------------------------- /torchrecipes/audio/source_separation/tests/test_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # pyre-strict 9 | 10 | import unittest 11 | from typing import List, Tuple 12 | 13 | import hydra 14 | import torch 15 | import torchrecipes.audio.source_separation.conf # noqa 16 | from torch.utils.data import Dataset 17 | from torchrecipes.audio.source_separation.main import main 18 | from torchrecipes.utils.test import tempdir 19 | 20 | 21 | class TestDataset(Dataset): 22 | def __len__(self) -> int: 23 | return 10 24 | 25 | def __getitem__(self, key: int) -> Tuple[int, torch.Tensor, List[torch.Tensor]]: 26 | """Load the n-th sample from the dataset. 27 | Args: 28 | key (int): The index of the sample to be loaded 29 | Returns: 30 | (int, Tensor, List[Tensor]): ``(sample_rate, mix_waveform, list_of_source_waveforms)`` 31 | """ 32 | return 8000, torch.rand(1, 24000), [torch.rand(1, 24000), torch.rand(1, 24000)] 33 | 34 | 35 | class TestMain(unittest.TestCase): 36 | @tempdir 37 | def test_train_model(self, root_dir: str) -> None: 38 | with hydra.initialize_config_module( 39 | config_module="torchrecipes.audio.source_separation.conf" 40 | ): 41 | config = hydra.compose( 42 | config_name="default_config", 43 | overrides=[ 44 | f"datamodule.root_dir={root_dir}", 45 | "trainer.accelerator=cpu", 46 | "trainer.devices=null", 47 | "trainer.strategy=null", 48 | "trainer.max_epochs=2", 49 | "+trainer.fast_dev_run=true", 50 | ], 51 | ) 52 | with unittest.mock.patch( 53 | "torchrecipes.audio.source_separation.datamodule.librimix.LibriMixDataModule._get_dataset", 54 | return_value=TestDataset(), 55 | ): 56 | main(config) 57 | -------------------------------------------------------------------------------- /torchrecipes/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/lr_scheduler/exponential.yaml: -------------------------------------------------------------------------------- 1 | _target_: "torch.optim.lr_scheduler.ExponentialLR" 2 | gamma: 0.1 3 | last_epoch: -1 4 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/lr_scheduler/multi_step.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.MultiStepLR 2 | milestones: 3 | - 30 4 | - 60 5 | - 80 6 | gamma: 0.1 7 | last_epoch: -1 8 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/lr_scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.StepLR 2 | step_size: 10 3 | gamma: 0.9 4 | last_epoch: -1 5 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/module/loss/cross_entropy.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.CrossEntropyLoss 2 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/module/lr_scheduler/step_lr.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.StepLR 2 | step_size: 10 3 | gamma: 0.1 4 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/module/optim/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | lr: 0.001 3 | betas: [0.9, 0.999] 4 | eps: 1e-8 5 | weight_decay: 0 6 | amsgrad: False 7 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/module/optim/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: 0.001 3 | betas: [0.9, 0.999] 4 | eps: 1e-8 5 | weight_decay: 0.01 6 | amsgrad: False 7 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/module/optim/sgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.SGD 2 | lr: 0.1 3 | weight_decay: 1e-4 4 | momentum: 0.9 5 | nesterov: False 6 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/optim/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: 0.001 3 | betas: [0.9, 0.999] 4 | eps: 1e-8 5 | weight_decay: 0.01 6 | amsgrad: False 7 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/optim/sgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.SGD 2 | lr: 0.1 3 | weight_decay: 1e-4 4 | momentum: 0.9 5 | nesterov: True 6 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | max_epochs: 1 2 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/trainer/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | gpus: 8 2 | strategy: ddp 3 | max_epochs: 1 4 | num_sanity_val_steps: 0 5 | log_every_n_steps: 10 6 | -------------------------------------------------------------------------------- /torchrecipes/core/conf/trainer/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | max_epochs: 1 3 | -------------------------------------------------------------------------------- /torchrecipes/core/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from enum import auto, Enum, unique 8 | from typing import List 9 | 10 | 11 | class AutoName(Enum): 12 | @staticmethod 13 | def _generate_next_value_( 14 | name: str, start: int, count: int, last_values: List[auto] 15 | ) -> str: 16 | return name 17 | 18 | 19 | @unique 20 | class JobStatus(AutoName): 21 | """ 22 | Training run job state. 23 | """ 24 | 25 | RUNNING = auto() 26 | COMPLETED = auto() 27 | FAILED = auto() 28 | -------------------------------------------------------------------------------- /torchrecipes/core/test_utils/conf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from dataclasses import asdict 8 | from typing import Any, Dict 9 | 10 | 11 | # pyre-fixme[2]: Parameter must be annotated. 12 | def conf_asdict(datacls_obj) -> Dict[str, Any]: 13 | """ 14 | The dataclasses we provide may contain Hydra specific fields. 15 | Use this method instead of dataclasses.asdict to remove those. 16 | 17 | Args 18 | :param datacls_obj: a dataclass object 19 | :return: dict of the dataclass 20 | """ 21 | args = asdict(datacls_obj) 22 | if "_target_" in args: 23 | del args["_target_"] 24 | return args 25 | -------------------------------------------------------------------------------- /torchrecipes/core/test_utils/test_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | # pyre-strict 10 | 11 | from typing import Any, Callable, Dict, List, Optional 12 | 13 | import testslide 14 | from hydra import compose, initialize_config_module 15 | from hydra.utils import instantiate 16 | from omegaconf import OmegaConf 17 | from torchrecipes.core.base_train_app import BaseTrainApp, TrainOutput 18 | 19 | 20 | def get_mock_init_trainer_params( 21 | overrides: Optional[Dict[str, Any]] = None, 22 | ) -> Callable[..., Dict[str, Any]]: 23 | """ 24 | Order of trainer_params setting in unit test: 25 | - First call original function, which sets params from config 26 | - Then override some params to disable logger and checkpoint 27 | - Apply any test-specific overrides. 28 | """ 29 | 30 | def mock_init_trainer_params( 31 | original: Callable[..., Dict[str, Any]], 32 | ) -> Dict[str, Any]: 33 | trainer_params = original() 34 | 35 | trainer_params["logger"] = False 36 | trainer_params["enable_checkpointing"] = False 37 | trainer_params["fast_dev_run"] = True 38 | 39 | if overrides: 40 | trainer_params.update(overrides) 41 | 42 | return trainer_params 43 | 44 | return mock_init_trainer_params 45 | 46 | 47 | class BaseTrainAppTestCase(testslide.TestCase): 48 | """All Standard TrainApp unit tests should inherit from this class.""" 49 | 50 | def mock_trainer_params( 51 | self, app: BaseTrainApp, overrides: Optional[Dict[str, Any]] = None 52 | ) -> None: 53 | self.mock_callable( 54 | app, "_init_trainer_params", allow_private=True 55 | ).with_wrapper(get_mock_init_trainer_params(overrides)) 56 | 57 | def create_app_from_hydra( 58 | self, 59 | config_module: str, 60 | config_name: str, 61 | overrides: Optional[List[str]] = None, 62 | ) -> BaseTrainApp: 63 | with initialize_config_module(config_module=config_module): 64 | cfg = compose(config_name=config_name, overrides=overrides or []) 65 | print(OmegaConf.to_yaml(cfg)) 66 | return instantiate(cfg, _recursive_=False) 67 | 68 | def assert_train_output(self, output: TrainOutput) -> None: 69 | self.assertIsNotNone(output) 70 | # Ensure logger is set to False in test to avoid dependency on Manifold 71 | self.assertIsNone(output.tensorboard_log_dir) 72 | -------------------------------------------------------------------------------- /torchrecipes/core/tests/test_base_train_app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | # pyre-strict 10 | from pytorch_lightning.callbacks import ModelCheckpoint 11 | from torchrecipes.core.base_train_app import BaseTrainApp 12 | from torchrecipes.core.conf import TrainerConf 13 | from torchrecipes.core.test_utils.test_base import BaseTrainAppTestCase 14 | 15 | 16 | class TestTrainApp(BaseTrainAppTestCase): 17 | def test_ckpt_callback_fallback_to_default(self) -> None: 18 | app = BaseTrainApp(None, TrainerConf(), None) 19 | app._set_trainer_params(trainer_params={}) 20 | self.assertIsNotNone(app._checkpoint_callback) 21 | self.assertIsNone(app._checkpoint_callback.monitor) 22 | 23 | def test_ckpt_callback_user_provided(self) -> None: 24 | app = BaseTrainApp(None, TrainerConf(), None) 25 | self.mock_callable(app, "get_callbacks").to_return_value( 26 | [ModelCheckpoint(monitor="some_metrics")] 27 | ) 28 | app._set_trainer_params(trainer_params={}) 29 | self.assertIsNotNone(app._checkpoint_callback) 30 | self.assertEqual(app._checkpoint_callback.monitor, "some_metrics") 31 | -------------------------------------------------------------------------------- /torchrecipes/launcher/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/launcher/__init__.py -------------------------------------------------------------------------------- /torchrecipes/launcher/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | import importlib 10 | import logging 11 | import os 12 | from enum import auto, Enum, unique 13 | from typing import Optional, Union 14 | 15 | import hydra 16 | from omegaconf import OmegaConf 17 | from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT 18 | from torch.distributed.elastic.multiprocessing import errors 19 | from torchrecipes.core.base_train_app import BaseTrainApp, TrainOutput 20 | from torchrecipes.core.conf import TrainAppConf 21 | 22 | logger: logging.Logger = logging.getLogger(__name__) 23 | 24 | # Set default value of these environment variables in 25 | # fbcode/torchx/components/fb/stl_apps.py 26 | 27 | # Your TrainApp's hydra conf module. We need to import this module before calling hydra 28 | CONFIG_MODULE = "CONFIG_MODULE" 29 | # Which mode your App will run in. 30 | # - prod: (Default) train + test, return test result 31 | # - train: train only 32 | # - test: test only 33 | # - predict: train + predict 34 | MODE = "MODE" 35 | 36 | 37 | @unique 38 | class Mode(Enum): 39 | PROD = auto() 40 | TRAIN = auto() 41 | TEST = auto() 42 | PREDICT = auto() 43 | 44 | 45 | def _get_mode() -> Mode: 46 | """Fetch operating environment.""" 47 | mode_key = os.getenv(MODE, "").upper() 48 | try: 49 | return Mode[mode_key] 50 | except KeyError: 51 | logger.warning("Unknown MODE, run train and test by default") 52 | return Mode.PROD 53 | 54 | 55 | def run_in_certain_mode( 56 | app: BaseTrainApp, 57 | ) -> Union[TrainOutput, _EVALUATE_OUTPUT, Optional[_PREDICT_OUTPUT]]: 58 | mode = _get_mode() 59 | if mode == Mode.TRAIN: 60 | logger.info("MODE set to train, run train only.") 61 | return app.train() 62 | elif mode == Mode.TEST: 63 | logger.info("MODE set to test, run test only.") 64 | return app.test() 65 | elif mode == Mode.PREDICT: 66 | logger.info("MODE set to predict, run train and precit.") 67 | app.train() 68 | return app.predict() 69 | else: 70 | # By default, run train and test 71 | app.train() 72 | return app.test() 73 | 74 | 75 | @hydra.main() 76 | def run_with_hydra( 77 | cfg: TrainAppConf, 78 | ) -> Union[TrainOutput, _EVALUATE_OUTPUT, Optional[_PREDICT_OUTPUT]]: 79 | logger.info(OmegaConf.to_yaml(cfg)) 80 | app = hydra.utils.instantiate(cfg, _recursive_=False) 81 | return run_in_certain_mode(app) 82 | 83 | 84 | @errors.record 85 | def main() -> None: 86 | config_module = os.getenv(CONFIG_MODULE) 87 | logger.info(f"CONFIG_MODULE: {config_module}") 88 | # only needed for apps that use hydra config 89 | if config_module: 90 | importlib.import_module(config_module) 91 | run_with_hydra() 92 | else: 93 | # TODO: T93277666 add entry point for non-hydra apps 94 | raise NotImplementedError 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /torchrecipes/launcher/tests/test_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | import importlib 10 | import os 11 | from unittest import mock 12 | 13 | import testslide 14 | from testslide import StrictMock 15 | from torchrecipes.core.base_train_app import BaseTrainApp, TrainOutput 16 | from torchrecipes.launcher import run 17 | 18 | 19 | class TestLauncherMain(testslide.TestCase): 20 | @mock.patch.dict(os.environ, {}, clear=True) 21 | def test_no_env(self) -> None: 22 | self.assertRaises(NotImplementedError, run.main) 23 | 24 | @mock.patch.dict(os.environ, {"CONFIG_MODULE": "test_module"}) 25 | def test_import_module(self) -> None: 26 | self.mock_callable(importlib, "import_module").for_call( 27 | "test_module" 28 | ).to_return_value(None) 29 | self.mock_callable(run, "run_with_hydra").to_return_value( 30 | StrictMock(TrainOutput) 31 | ) 32 | run.main() 33 | 34 | 35 | class TestRunWithCertainEnv(testslide.TestCase): 36 | def assert_train(self, app: BaseTrainApp) -> None: 37 | mock_output = StrictMock(TrainOutput) 38 | self.mock_callable(app, "train").to_return_value( 39 | mock_output 40 | ).and_assert_called_once() 41 | 42 | def assert_test(self, app: BaseTrainApp) -> None: 43 | mock_output = [] 44 | self.mock_callable(app, "test").to_return_value( 45 | mock_output 46 | ).and_assert_called_once() 47 | 48 | @mock.patch.dict(os.environ, {"MODE": "prod"}) 49 | def test_prod_mode(self) -> None: 50 | app = StrictMock(template=BaseTrainApp) 51 | self.assert_train(app) 52 | self.assert_test(app) 53 | run.run_in_certain_mode(app) 54 | 55 | @mock.patch.dict(os.environ, {"MODE": "train"}) 56 | def test_train_only(self) -> None: 57 | app = StrictMock(template=BaseTrainApp) 58 | self.assert_train(app) 59 | run.run_in_certain_mode(app) 60 | 61 | @mock.patch.dict(os.environ, {"MODE": "test"}) 62 | def test_test_only(self) -> None: 63 | app = StrictMock(template=BaseTrainApp) 64 | self.assert_test(app) 65 | run.run_in_certain_mode(app) 66 | 67 | @mock.patch.dict(os.environ, {"MODE": "random"}) 68 | def test_wrong_input(self) -> None: 69 | app = StrictMock(template=BaseTrainApp) 70 | self.assert_train(app) 71 | self.assert_test(app) 72 | run.run_in_certain_mode(app) 73 | 74 | @mock.patch.dict(os.environ, {}, clear=True) 75 | def test_no_input(self) -> None: 76 | app = StrictMock(template=BaseTrainApp) 77 | self.assert_train(app) 78 | self.assert_test(app) 79 | run.run_in_certain_mode(app) 80 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/.gitignore: -------------------------------------------------------------------------------- 1 | # Hydra output directory 2 | .hydra 3 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/.lintrunner.toml: -------------------------------------------------------------------------------- 1 | # Black + usort 2 | [[linter]] 3 | code = 'UFMT' 4 | include_patterns = [ 5 | '**/*.py', 6 | ] 7 | command = [ 8 | 'python3', 9 | 'tools/linter/ufmt_linter.py', 10 | '--', 11 | '@{{PATHSFILE}}' 12 | ] 13 | init_command = [ 14 | 'python3', 15 | 'tools/linter/pip_init.py', 16 | '--dry-run={{DRYRUN}}', 17 | 'black==24.2.0', 18 | 'ufmt==2.5.1', 19 | 'usort==1.0.8', 20 | ] 21 | is_formatter = true 22 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/airflow/README.md: -------------------------------------------------------------------------------- 1 | # Airflow example 2 | 1. Install and start an Airflow server 3 | ```bash 4 | ./setup.sh 5 | ``` 6 | > **_NOTE:_**: Airflow UI can be accessed at http://0.0.0.0:8080 (replace the address with your EC2 instance address for public access). Learn more about airflow from [Quick Start](https://airflow.apache.org/docs/apache-airflow/stable/start/local.html) 7 | 8 | 2. Create a dag config. See the example in `charnn_dag.py` 9 | 10 | 3. Set `dag_folder` to folder containing the dag config in `~/airflow/airflow.cfg`. Such that Airflow can discover your dag configs. 11 | 12 | 3. Run a task instance 13 | ```bash 14 | airflow tasks run charnn_dag train 2022-08-01 15 | ``` 16 | > **_NOTE:_**: the instance can be monitored in UI: http://0.0.0.0:8080/taskinstance/list 17 | 18 | 4. Backfill the dag over 2 days 19 | ```bash 20 | airflow dags backfill charnn_dag --start-date 2022-08-01 --end-date 2022-08-02 --reset-dagruns 21 | ``` 22 | > **_NOTE:_**: the dag runs can be monitored in UI: http://0.0.0.0:8080/dagrun/list/ 23 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/airflow/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | install_airflow=true 4 | start_local_airflow=true 5 | 6 | if [ "$install_airflow" = true ] 7 | then 8 | pip3 install --upgrade pip 9 | sudo apt install libffi-dev 10 | pip3 install setuptools-rust 11 | pip3 install "apache-airflow[celery]==2.3.0" --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.3.0/constraints-3.8.txt" 12 | pip3 install apache-airflow-providers-amazon 13 | pip3 install boto3 14 | fi 15 | 16 | # https://airflow.apache.org/docs/apache-airflow/stable/start/local.html 17 | if [ "$start_local_airflow" = true ] 18 | then 19 | airflow standalone 20 | fi 21 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/char_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from typing import Tuple 9 | 10 | import torch 11 | 12 | from char_transform import CharTransform 13 | from torch.utils.data import Dataset 14 | 15 | 16 | class CharDataset(Dataset): 17 | """Character dataset""" 18 | 19 | def __init__(self, data_path: str, block_size: int) -> None: 20 | # self._init_data(data_path, block_size) 21 | self.block_size = block_size 22 | self.transform = CharTransform(data_path) 23 | self.data = self.transform.data 24 | self.vocab_size = self.transform.vocab_size 25 | 26 | def __len__(self) -> int: 27 | return self.transform.data_size - self.block_size 28 | 29 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 30 | # grab a chunk of (block_size + 1) characters from the data 31 | chunk = self.data[idx : idx + self.block_size + 1] 32 | # encode every character to an integer 33 | # dix = [self.stoi[s] for s in chunk] 34 | ids = self.transform(chunk) 35 | x = torch.tensor(ids[:-1], dtype=torch.long) 36 | y = torch.tensor(ids[1:], dtype=torch.long) 37 | return (x, y) 38 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/char_transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from typing import List 9 | 10 | import fsspec 11 | import torch 12 | import torch.nn as nn 13 | from torch import Tensor 14 | 15 | 16 | class CharTransform(nn.Module): 17 | def __init__(self, data_path: str): 18 | super().__init__() 19 | fs, path = fsspec.core.url_to_fs(data_path) 20 | with fs.open(path, "r") as f: 21 | self.data = f.read() 22 | self.data_path = data_path 23 | chars = sorted(set(self.data)) 24 | self.data_size = len(self.data) 25 | self.vocab_size = len(chars) 26 | print(f"Data has {self.data_size} characters, {self.vocab_size} unique.") 27 | self.stoi = {ch: i for i, ch in enumerate(chars)} 28 | self.itos = {i: ch for i, ch in enumerate(chars)} 29 | 30 | def forward(self, text: str) -> Tensor: 31 | return self.encode(text) 32 | 33 | def encode(self, text: str) -> Tensor: 34 | # Unsqueeze the input because GPT model expects a batch of inputs. 35 | ids = [self.stoi[s] for s in text] 36 | return torch.tensor(ids, dtype=torch.long) 37 | 38 | def decode(self, ids: Tensor) -> str: 39 | # Squeeze model output because GPT model returns a batch of ouputs. 40 | token_ids: List[int] = ids.tolist() 41 | return "".join([self.itos[i] for i in token_ids]) 42 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/combined_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | CombinedModule includes model and its corresponding transform. 10 | It's mainly used for inference from raw inputs, which will be converted 11 | to tensors by transform then pass to model. 12 | """ 13 | 14 | from typing import Optional 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch import Tensor 19 | from torch.nn import functional as F 20 | 21 | 22 | class CombinedModule(nn.Module): 23 | device: str 24 | 25 | def __init__(self, transform: nn.Module, model: nn.Module) -> None: 26 | super().__init__() 27 | self.transform = transform 28 | self.model = model 29 | self.device = "" 30 | 31 | def forward(self, text: str) -> str: 32 | tokens = self.transform(text) 33 | tokens = tokens.unsqueeze(0).to(self.device) 34 | generated_ids = self.generate(tokens).squeeze() 35 | return self.transform.decode(generated_ids) 36 | 37 | @torch.jit.export 38 | def set_device(self, device: str) -> None: 39 | self.device = device 40 | 41 | def top_k_logits(self, logits: Tensor, k: int) -> Tensor: 42 | v, ix = torch.topk(logits, k) 43 | out = logits.clone() 44 | out[out < v[:, [-1]]] = -float("Inf") 45 | return out 46 | 47 | def generate( 48 | self, 49 | x: Tensor, 50 | steps: int = 100, 51 | temperature: float = 1.0, 52 | sample: bool = True, 53 | top_k: Optional[int] = 10, 54 | ) -> Tensor: 55 | """ 56 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 57 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 58 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 59 | of block_size, unlike an RNN that has an infinite context window. 60 | """ 61 | block_size = 128 62 | for _ in range(steps): 63 | x_cond = ( 64 | x if x.size(1) <= block_size else x[:, -block_size:] 65 | ) # crop context if needed 66 | logits = self.model(x_cond) 67 | # pluck the logits at the final step and scale by temperature 68 | logits = logits[:, -1, :] / temperature 69 | # optionally crop probabilities to only the top k options 70 | if top_k is not None: 71 | logits = self.top_k_logits(logits, top_k) 72 | # apply softmax to convert to probabilities 73 | probs = F.softmax(logits, dim=-1) 74 | # sample from the distribution or take the most likely 75 | if sample: 76 | ix = torch.multinomial(probs, num_samples=1) 77 | else: 78 | _, ix = torch.topk(probs, k=1, dim=-1) 79 | # append to the sequence and continue 80 | x = torch.cat((x, ix), dim=1) 81 | 82 | return x 83 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/export.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Script to export a module to torchscript. Optionally, the module can 10 | be quantized before torchscripting. 11 | """ 12 | 13 | import argparse 14 | 15 | import fsspec 16 | import torch 17 | 18 | parser = argparse.ArgumentParser(description="Quantize the model from a snapshot") 19 | parser.add_argument( 20 | "-i", 21 | "--input_path", 22 | type=str, 23 | required=True, 24 | help="Snapshot path to load. It can be a local path, S3 or Google Cloud Storage URL", 25 | ) 26 | parser.add_argument( 27 | "-o", 28 | "--output_path", 29 | type=str, 30 | required=True, 31 | help="Snapshot path to save. It can be a local path, S3 or Google Cloud Storage URL", 32 | ) 33 | parser.add_argument("-q", "--quantize", help="quantize the model", action="store_true") 34 | parser.add_argument( 35 | "-t", "--torchscript", help="torchscript the model", action="store_true" 36 | ) 37 | 38 | 39 | def main() -> None: 40 | args = parser.parse_args() 41 | 42 | fs, intput_path = fsspec.core.url_to_fs(args.input_path) 43 | with fs.open(intput_path, "rb") as f: 44 | model = torch.load(f, map_location="cpu") 45 | 46 | # quantize the model. Note that dynamic Quantization currently only 47 | # supports nn.Linear and nn.LSTM in qconfig_spec 48 | if args.quantize: 49 | print("quantizing the model...") 50 | model = torch.quantization.quantize_dynamic( 51 | model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8 52 | ) 53 | 54 | fs, output_path = fsspec.core.url_to_fs(args.output_path) 55 | with fs.open(output_path, "wb") as f: 56 | if args.torchscript: 57 | print("torchscripting the model...") 58 | model_jit = torch.jit.script(model) 59 | torch.jit.save(model_jit, f) 60 | else: 61 | torch.save(model, f) 62 | 63 | print(f"exported the module to {args.output_path}") 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/serve/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to deploy model with torchserve 4 | 5 | S3_URL=$1 6 | LOCAL_MODULE_PATH="/tmp/charnn/exported.pt" 7 | 8 | # Set working directory to serve/ 9 | cd "$(dirname "$0")" || exit 10 | 11 | # Download the exported module file from s3 12 | aws s3 cp "$S3_URL" "$LOCAL_MODULE_PATH" 13 | 14 | # Archive the module file with its handler and save to "model_store/gpt.mar" 15 | mkdir -p model_store 16 | torch-model-archiver --model-name gpt --version 1.0 --serialized-file $LOCAL_MODULE_PATH --handler handler.py --export-path model_store --force 17 | 18 | # start torchserve with "model_store/gpt.mar" 19 | torchserve --start --model-store model_store --models gpt 20 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/serve/handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Module for text generation with a torchscript module containing both transform and model 10 | IT DOES NOT SUPPORT BATCH! 11 | """ 12 | 13 | import logging 14 | from abc import ABC 15 | 16 | import torch 17 | from ts.torch_handler.base_handler import BaseHandler 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class TextGenerator(BaseHandler, ABC): 23 | """ 24 | TextGenerator handler class. This handler takes a text (string) and 25 | as input and returns the generated text. 26 | """ 27 | 28 | def handle(self, data, context): 29 | """ 30 | Handle user's request, extract the text and return generated text. 31 | Batch processing is not supported. Only the first request in a batch will 32 | be handled. 33 | Example data: 34 | [ 35 | "body": "hello world" 36 | ] 37 | """ 38 | text = data[0].get("body") 39 | # Decode text if not a str but bytes or bytearray 40 | if isinstance(text, (bytes, bytearray)): 41 | text = text.decode("utf-8") 42 | # torchserve requires output to be a list 43 | return [self.model(text)] 44 | 45 | def _load_torchscript_model(self, model_pt_path): 46 | """Loads the PyTorch model and returns the NN model object. 47 | Args: 48 | model_pt_path (str): denotes the path of the model file. 49 | Returns: 50 | (NN Model Object) : Loads the model object. 51 | """ 52 | model = torch.jit.load(model_pt_path, map_location=self.device) 53 | model.set_device(str(self.device)) 54 | return model 55 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/tests/test_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import subprocess 10 | 11 | WORKING_DIR = os.path.dirname(os.path.dirname(__file__)) 12 | 13 | 14 | def run(command): 15 | """Run a command and return error code""" 16 | proc = subprocess.Popen(command) 17 | proc.communicate() 18 | return proc.returncode 19 | 20 | 21 | def test_main(): 22 | command = ["python3", os.path.join(WORKING_DIR, "main.py")] 23 | assert run(command) == 0 24 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/trainer_config.yaml: -------------------------------------------------------------------------------- 1 | opt: 2 | lr: 0.0006 3 | weight_decay: 0.1 4 | dataset: 5 | path: data/input.txt 6 | max_iter: 200 7 | trainer: 8 | work_dir: "/tmp/charnn" 9 | # each run's outputs will be saved under work_dir/job_name. If not specified, job_name will be auto-generated. 10 | job_name: "" 11 | max_epochs: 1 12 | lr: 0.0006 13 | batch_size: 128 14 | data_loader_workers: 1 15 | enable_profile: False 16 | snapshot_path: "" # specify your snapshot path to restore training state 17 | model: 18 | n_layer: 2 # 8 19 | n_head: 2 # 8 20 | n_embd: 32 # 512 21 | charnn: 22 | dist: ddp 23 | # train or generate 24 | task: train 25 | # start string 26 | phrase: Hello there 27 | 28 | hydra: 29 | run: 30 | dir: ./ 31 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/charnn/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | 11 | def get_realpath(path: str) -> str: 12 | if "://" in path or os.path.isabs(path): 13 | return path 14 | 15 | work_dir = os.path.dirname(__file__) 16 | return os.path.join(work_dir, path) 17 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime 2 | 3 | # Set working directory 4 | WORKDIR /workspace/paved_path 5 | 6 | COPY requirements.txt . 7 | 8 | # Install Requirements 9 | ARG aws 10 | RUN if [ "$aws" = true ]; then pip3 install boto3==1.21.21; fi 11 | RUN apt-get update && apt-get upgrade -y 12 | RUN pip3 install -r requirements.txt 13 | 14 | # Install awscli v2 15 | RUN apt-get install curl unzip -y 16 | RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" && unzip awscliv2.zip && ./aws/install 17 | 18 | # Install jdk-11. It's required by torchserve 19 | RUN apt-get install openjdk-11-jdk -y 20 | 21 | # Copy training script 22 | COPY ./charnn . 23 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/docker/README.md: -------------------------------------------------------------------------------- 1 | ## Build and push a docker image 2 | 3 | 1. Authenticate to your default registry 4 | ```bash 5 | export REGION=YOUR_AWS_REGION 6 | export ECR_URL=YOUR_AWS_ACCOUNT_ID.dkr.ecr.region.amazonaws.com 7 | aws ecr get-login-password --region $REGION | docker login --username AWS --password-stdin $ECR_URL 8 | ``` 9 | 2. Create a repository 10 | ```bash 11 | aws ecr create-repository \ 12 | --repository-name charnn \ 13 | --image-scanning-configuration scanOnPush=true \ 14 | --region $REGION 15 | ``` 16 | 3. Build and push the image to Amazon ECR 17 | ```bash 18 | cd recipes/torchrecipes/paved_path 19 | ./docker/build.sh 20 | ``` 21 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sudo docker build -f docker/Dockerfile -t charnn:latest ./ --build-arg aws=true 4 | sudo docker tag charnn:latest "$ECR_URL/charnn:latest" 5 | sudo docker push "$ECR_URL/charnn:latest" 6 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | tensorboard 3 | hydra-core 4 | --pre 5 | torchsnapshot-nightly 6 | fsspec[s3] 7 | torchserve 8 | torch-model-archiver 9 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/tools/linter/pip_init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import logging 10 | import os 11 | import subprocess 12 | import sys 13 | import time 14 | 15 | from typing import List 16 | 17 | 18 | def run_command(args: List[str]) -> "subprocess.CompletedProcess[bytes]": 19 | logging.debug("$ %s", " ".join(args)) 20 | start_time = time.monotonic() 21 | try: 22 | return subprocess.run(args, check=True) 23 | finally: 24 | end_time = time.monotonic() 25 | logging.debug("took %dms", (end_time - start_time) * 1000) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser(description="pip initializer") 30 | parser.add_argument( 31 | "packages", 32 | nargs="+", 33 | help="pip packages to install", 34 | ) 35 | parser.add_argument( 36 | "--verbose", 37 | action="store_true", 38 | help="verbose logging", 39 | ) 40 | parser.add_argument( 41 | "--dry-run", help="do not install anything, just print what would be done." 42 | ) 43 | 44 | args = parser.parse_args() 45 | 46 | logging.basicConfig( 47 | format="<%(threadName)s:%(levelname)s> %(message)s", 48 | level=logging.NOTSET if args.verbose else logging.DEBUG, 49 | stream=sys.stderr, 50 | ) 51 | 52 | for package in args.packages: 53 | package_name, _, version = package.partition("=") 54 | if version == "": 55 | raise RuntimeError( 56 | "Package {package_name} did not have a version specified. " 57 | "Please specify a version to product a consistent linting experience." 58 | ) 59 | pip_args = ["pip3", "install"] 60 | 61 | # If we are in a global install, use `--user` to install so that you do not 62 | # need root access in order to initialize linters. 63 | # 64 | # However, `pip install --user` interacts poorly with virtualenvs (see: 65 | # https://bit.ly/3vD4kvl) and conda (see: https://bit.ly/3KG7ZfU). So in 66 | # these cases perform a regular installation. 67 | in_conda = os.environ.get("CONDA_PREFIX") is not None 68 | in_virtualenv = os.environ.get("VIRTUAL_ENV") is not None 69 | if not in_conda and not in_virtualenv: 70 | pip_args.append("--user") 71 | 72 | pip_args.extend(args.packages) 73 | 74 | dry_run = args.dry_run == "1" 75 | if dry_run: 76 | print(f"Would have run: {pip_args}") 77 | sys.exit(0) 78 | 79 | run_command(pip_args) 80 | -------------------------------------------------------------------------------- /torchrecipes/paved_path/tools/linter/ufmt_linter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | from enum import Enum 8 | from pathlib import Path 9 | from typing import Any, List, NamedTuple, Optional 10 | 11 | from ufmt.core import make_black_config, ufmt_string 12 | from usort import Config as UsortConfig 13 | 14 | IS_WINDOWS: bool = os.name == "nt" 15 | 16 | 17 | def eprint(*args: Any, **kwargs: Any) -> None: 18 | print(*args, file=sys.stderr, flush=True, **kwargs) 19 | 20 | 21 | class LintSeverity(str, Enum): 22 | ERROR = "error" 23 | WARNING = "warning" 24 | ADVICE = "advice" 25 | DISABLED = "disabled" 26 | 27 | 28 | class LintMessage(NamedTuple): 29 | path: Optional[str] 30 | line: Optional[int] 31 | char: Optional[int] 32 | code: str 33 | severity: LintSeverity 34 | name: str 35 | original: Optional[str] 36 | replacement: Optional[str] 37 | description: Optional[str] 38 | 39 | 40 | def as_posix(name: str) -> str: 41 | return name.replace("\\", "/") if IS_WINDOWS else name 42 | 43 | 44 | def format_error_message(filename: str, err: Exception) -> LintMessage: 45 | return LintMessage( 46 | path=filename, 47 | line=None, 48 | char=None, 49 | code="UFMT", 50 | severity=LintSeverity.ADVICE, 51 | name="command-failed", 52 | original=None, 53 | replacement=None, 54 | description=(f"Failed due to {err.__class__.__name__}:\n{err}"), 55 | ) 56 | 57 | 58 | def check_file( 59 | filename: str, 60 | ) -> List[LintMessage]: 61 | with open(filename, "rb") as f: 62 | original = f.read().decode("utf-8") 63 | 64 | try: 65 | path = Path(filename) 66 | 67 | usort_config = UsortConfig.find(path) 68 | black_config = make_black_config(path) 69 | 70 | # Use UFMT API to call both usort and black 71 | replacement = ufmt_string( 72 | path=path, 73 | content=original, 74 | usort_config=usort_config, 75 | black_config=black_config, 76 | ) 77 | 78 | if original == replacement: 79 | return [] 80 | 81 | return [ 82 | LintMessage( 83 | path=filename, 84 | line=None, 85 | char=None, 86 | code="UFMT", 87 | severity=LintSeverity.WARNING, 88 | name="format", 89 | original=original, 90 | replacement=replacement, 91 | description="Run `lintrunner -a` to apply this patch.", 92 | ) 93 | ] 94 | except Exception as err: 95 | return [format_error_message(filename, err)] 96 | 97 | 98 | def main() -> None: 99 | parser = argparse.ArgumentParser( 100 | description="Format files with ufmt (black + usort).", 101 | fromfile_prefix_chars="@", 102 | ) 103 | parser.add_argument( 104 | "--verbose", 105 | action="store_true", 106 | help="verbose logging", 107 | ) 108 | parser.add_argument( 109 | "filenames", 110 | nargs="+", 111 | help="paths to lint", 112 | ) 113 | args = parser.parse_args() 114 | 115 | logging.basicConfig( 116 | format="<%(threadName)s:%(levelname)s> %(message)s", 117 | level=( 118 | logging.NOTSET 119 | if args.verbose 120 | else logging.DEBUG 121 | if len(args.filenames) < 1000 122 | else logging.INFO 123 | ), 124 | stream=sys.stderr, 125 | ) 126 | 127 | with concurrent.futures.ThreadPoolExecutor( 128 | max_workers=os.cpu_count(), 129 | thread_name_prefix="Thread", 130 | ) as executor: 131 | futures = {executor.submit(check_file, x): x for x in args.filenames} 132 | for future in concurrent.futures.as_completed(futures): 133 | try: 134 | for lint_message in future.result(): 135 | print(json.dumps(lint_message._asdict()), flush=True) 136 | except Exception: 137 | logging.critical('Failed at "%s".', futures[future]) 138 | raise 139 | 140 | 141 | if __name__ == "__main__": 142 | main() 143 | -------------------------------------------------------------------------------- /torchrecipes/rec/README.MD: -------------------------------------------------------------------------------- 1 | # Running 2 | 3 | ## Locally 4 | torchx run -s local_cwd dist.ddp -j 1x2 --script dlrm_main.py 5 | 6 | ## To specify arguments 7 | To specify arguments from the main function, add a '--' and then write arguments 8 | e.g: 9 | torchx run -s local_cwd dist.ddp -j 1x2 --script dlrm_main.py \ 10 | -- --num_embeddings_per_feature "45833188,36746,17245,7413,20243,3,7114,1441,62,29275261,1572176,345138,10,2209,11267,128,4,974,14,48937457,11316796,40094537,452104,12606,104,35" 11 | -------------------------------------------------------------------------------- /torchrecipes/rec/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /torchrecipes/rec/accelerators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/rec/accelerators/__init__.py -------------------------------------------------------------------------------- /torchrecipes/rec/accelerators/torchrec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import logging 8 | import os 9 | from typing import Any, Dict, Optional, Union 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from pytorch_lightning.strategies.parallel import ParallelStrategy 14 | from pytorch_lightning.utilities.types import _PATH 15 | 16 | logger: logging.Logger = logging.getLogger(__name__) 17 | 18 | 19 | class TorchrecStrategy(ParallelStrategy): 20 | """ 21 | Lightning Trainer takes care of the operations that are related to DDP. 22 | However, our models are parallelization aware, which are not fully compatible to the 23 | given accelerators and strategies provided by Lightning. 24 | 25 | The torchrec accelerator and strategies bypasses the corresponding logic in Lightning. 26 | """ 27 | 28 | def __init__(self) -> None: 29 | super().__init__() 30 | logger.info("Creating torchrec strategy") 31 | 32 | def broadcast(self, obj: object, src: int = 0) -> object: 33 | if dist.is_initialized: 34 | if isinstance(obj, torch.Tensor): 35 | dist.broadcast(obj, src) 36 | return obj 37 | else: 38 | object_list = [obj] 39 | dist.broadcast_object_list(object_list=object_list, src=src) 40 | return object_list[0] 41 | else: 42 | raise AssertionError( 43 | "Broadcast called in torchrec strategy w/o initializing distributed" 44 | ) 45 | 46 | @property 47 | def root_device(self) -> torch.device: 48 | rank = int(os.environ["LOCAL_RANK"]) 49 | if torch.cuda.is_available(): 50 | device = torch.device(f"cuda:{rank}") 51 | else: 52 | device = torch.device("cpu") 53 | return device 54 | 55 | def save_checkpoint( 56 | self, 57 | checkpoint: Dict[str, Any], 58 | filepath: _PATH, 59 | # pyre-ignore[2]: Parameter `storage_options` has type `None` but type `Any` is specified. 60 | storage_options: Optional[Any] = None, 61 | ) -> None: 62 | self.checkpoint_io.save_checkpoint( 63 | checkpoint=checkpoint, path=filepath, storage_options=storage_options 64 | ) 65 | 66 | # pyre-ignore[3] 67 | def batch_to_device( 68 | self, 69 | # pyre-ignore[2] 70 | batch: Any, 71 | device: Optional[torch.device] = None, 72 | dataloader_idx: Optional[int] = None, 73 | ) -> Any: 74 | if self.lightning_module: 75 | return batch.to(self.lightning_module.device) 76 | 77 | def barrier(self, name: Optional[str] = None) -> None: 78 | if dist.is_initialized: 79 | dist.barrier() 80 | else: 81 | raise AssertionError( 82 | "All gather called in torchrec strategy w/o initializing distributed" 83 | ) 84 | 85 | def all_gather( 86 | self, 87 | tensor: torch.Tensor, 88 | # pyre-ignore[2]: Parameter `group` has type `None` but type `Any` is specified. 89 | group: Optional[Any] = None, 90 | sync_grads: bool = False, 91 | ) -> torch.Tensor: 92 | if dist.is_initialized: 93 | dist.all_gather(tensor, group, sync_grads) 94 | return tensor 95 | else: 96 | raise AssertionError( 97 | "All gather called in torchrec strategy w/o initializing distributed" 98 | ) 99 | 100 | # pyre-ignore[3]: Return type must be specified as type other than `Any`. 101 | def reduce( 102 | self, 103 | # pyre-ignore[2]: Parameter `tensor` must have a type other than `Any`. 104 | tensor: Union[Any, torch.Tensor], 105 | *args: Any, 106 | **kwargs: Any, 107 | ) -> Union[Any, torch.Tensor]: 108 | if dist.is_initialized: 109 | dist.all_reduce(tensor) 110 | return tensor 111 | else: 112 | raise AssertionError( 113 | "Reduce called in torchrec strategy w/o initializing distributed" 114 | ) 115 | 116 | def model_to_device(self) -> None: 117 | pass 118 | 119 | def teardown(self) -> None: 120 | return None 121 | -------------------------------------------------------------------------------- /torchrecipes/rec/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/rec/datamodules/__init__.py -------------------------------------------------------------------------------- /torchrecipes/rec/datamodules/commons.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from dataclasses import dataclass 8 | 9 | import torch 10 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 11 | 12 | 13 | @dataclass 14 | class Batch: 15 | dense_features: torch.Tensor 16 | sparse_features: KeyedJaggedTensor 17 | labels: torch.Tensor 18 | 19 | def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": 20 | return Batch( 21 | dense_features=self.dense_features.to( 22 | device=device, non_blocking=non_blocking 23 | ), 24 | sparse_features=self.sparse_features.to( 25 | device=device, non_blocking=non_blocking 26 | ), 27 | labels=self.labels.to(device=device, non_blocking=non_blocking), 28 | ) 29 | 30 | def record_stream(self, stream: torch.cuda.streams.Stream) -> None: 31 | self.dense_features.record_stream(stream) 32 | self.sparse_features.record_stream(stream) 33 | self.labels.record_stream(stream) 34 | 35 | def pin_memory(self) -> "Batch": 36 | return Batch( 37 | dense_features=self.dense_features.pin_memory(), 38 | sparse_features=self.sparse_features.pin_memory(), 39 | labels=self.labels.pin_memory(), 40 | ) 41 | -------------------------------------------------------------------------------- /torchrecipes/rec/datamodules/random_rec_datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from dataclasses import dataclass 8 | from typing import List, Optional 9 | 10 | import pytorch_lightning as pl 11 | from hydra.core.config_store import ConfigStore 12 | from torch.utils.data import DataLoader 13 | from torchrec.datasets.random import RandomRecDataset 14 | from torchrecipes.core.conf import DataModuleConf 15 | from torchrecipes.utils.config_utils import get_class_name_str 16 | 17 | 18 | class RandomRecDataModule(pl.LightningDataModule): 19 | """ 20 | DataModule that wraps RandomRecDataset. This dataset generates _RandomRecBatch, or random 21 | batches of sparse_features in the form of KeyedJaggedTensor, dense_features and labels 22 | 23 | { 24 | "dense_features": torch.Tensor, 25 | "sparse_features": KeyedJaggedTensor, 26 | "labels": torch.Tensor, 27 | } 28 | """ 29 | 30 | def __init__( 31 | self, 32 | batch_size: int = 3, 33 | hash_size: Optional[int] = 100, 34 | hash_sizes: Optional[List[int]] = None, 35 | manual_seed: Optional[int] = None, 36 | pin_memory: bool = False, 37 | keys: Optional[List[str]] = None, 38 | ids_per_feature: int = 2, 39 | num_dense: int = 50, 40 | num_workers: int = 0, 41 | *, 42 | num_generated_batches: int = 10, 43 | min_ids_per_features: Optional[int] = None, 44 | ) -> None: 45 | super().__init__() 46 | self.keys: List[str] = keys if keys else ["f1", "f3", "f2"] 47 | self.batch_size = batch_size 48 | self.manual_seed = manual_seed 49 | self.pin_memory = pin_memory 50 | self.hash_size = hash_size 51 | self.hash_sizes = hash_sizes 52 | self.ids_per_feature = ids_per_feature 53 | self.min_ids_per_feature = min_ids_per_features 54 | self.num_generated_batches = num_generated_batches 55 | self.num_dense = num_dense 56 | self.num_workers = num_workers 57 | self.init_loader: DataLoader = DataLoader( 58 | RandomRecDataset( 59 | keys=self.keys, 60 | batch_size=self.batch_size, 61 | hash_size=self.hash_size, 62 | hash_sizes=self.hash_sizes, 63 | manual_seed=self.manual_seed, 64 | ids_per_feature=self.ids_per_feature, 65 | min_ids_per_feature=self.min_ids_per_feature, 66 | num_dense=self.num_dense, 67 | num_generated_batches=self.num_generated_batches, 68 | ), 69 | batch_size=None, 70 | batch_sampler=None, 71 | pin_memory=self.pin_memory, 72 | num_workers=self.num_workers, 73 | ) 74 | 75 | def train_dataloader(self) -> DataLoader: 76 | return self.init_loader 77 | 78 | def val_dataloader(self) -> DataLoader: 79 | return self.init_loader 80 | 81 | def test_dataloader(self) -> DataLoader: 82 | return self.init_loader 83 | 84 | 85 | @dataclass 86 | class RandomRecDataModuleConf(DataModuleConf): 87 | _target_: str = get_class_name_str(RandomRecDataModule) 88 | 89 | 90 | cs: ConfigStore = ConfigStore.instance() 91 | 92 | cs.store( 93 | group="schema/datamodule", 94 | name="random_rec_datamodule", 95 | node=RandomRecDataModuleConf, 96 | package="datamodule", 97 | ) 98 | -------------------------------------------------------------------------------- /torchrecipes/rec/datamodules/samplers/tests/test_undersampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | from typing import Iterable, Iterator, TypeVar 10 | 11 | import testslide 12 | from torch.utils.data import IterDataPipe 13 | from torchrecipes.rec.datamodules.samplers.undersampler import ( 14 | DistributionUnderSampler, 15 | ProportionUnderSampler, 16 | ) 17 | 18 | T = TypeVar("T") 19 | 20 | 21 | class IDP_NoLen(IterDataPipe[T]): 22 | def __init__(self, input_dp: Iterable[T]) -> None: 23 | super().__init__() 24 | self.input_dp = input_dp 25 | 26 | def __iter__(self) -> Iterator[T]: 27 | for i in self.input_dp: 28 | yield i 29 | 30 | 31 | class TestUnderSampler(testslide.TestCase): 32 | def test_proportion_undersampler(self) -> None: 33 | n = 20 34 | idp = IDP_NoLen(range(n)) 35 | self.assertTrue( 36 | all( 37 | i % 2 == 1 38 | for i in ProportionUnderSampler(idp, lambda x: x % 2, {0: 0.0, 1: 0.5}) 39 | ) 40 | ) 41 | 42 | def test_proportion_undersampler_errors(self) -> None: 43 | n = 20 44 | idp = IDP_NoLen(range(n)) 45 | with self.assertRaisesRegex( 46 | ValueError, "All proportions must be within 0 and 1." 47 | ): 48 | ProportionUnderSampler(idp, lambda x: x % 2, {0: 1.1}) 49 | 50 | def test_distribution_undersampler(self) -> None: 51 | n = 20 52 | idp = IDP_NoLen(range(n)) 53 | self.assertTrue( 54 | all( 55 | i % 2 == 1 56 | for i in DistributionUnderSampler( 57 | idp, lambda x: x % 2, {0: 0.0, 1: 1.0} 58 | ) 59 | ) 60 | ) 61 | 62 | def test_distribution_undersampler_known_input_dist(self) -> None: 63 | n = 20 64 | idp = IDP_NoLen(range(n)) 65 | self.assertTrue( 66 | all( 67 | i % 2 == 1 68 | for i in DistributionUnderSampler( 69 | idp, lambda x: x % 2, {0: 0.0, 1: 1.0}, {0: 0.5, 1: 0.5} 70 | ) 71 | ) 72 | ) 73 | 74 | def test_distribution_undersampler_errors(self) -> None: 75 | n = 20 76 | idp = IDP_NoLen(range(n)) 77 | with self.assertRaisesRegex( 78 | ValueError, "Only non-negative values are allowed in output_dist." 79 | ): 80 | DistributionUnderSampler(idp, lambda x: x % 2, {0: 0.0, 1: -1.0}) 81 | with self.assertRaisesRegex( 82 | ValueError, "Only positive values are allowed in input_dist." 83 | ): 84 | DistributionUnderSampler( 85 | idp, lambda x: x % 2, {0: 0.0, 1: 1.0}, {0: 0.5, 1: 0.5, 2: 0.0} 86 | ) 87 | with self.assertRaisesRegex( 88 | ValueError, "All keys in output_dist must be present in input_dist." 89 | ): 90 | DistributionUnderSampler( 91 | idp, lambda x: x % 2, {0: 0.0, 1: 0.9, 2: 0.1}, {0: 0.5, 1: 0.5} 92 | ) 93 | -------------------------------------------------------------------------------- /torchrecipes/rec/datamodules/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/rec/datamodules/tests/__init__.py -------------------------------------------------------------------------------- /torchrecipes/rec/datamodules/tests/test_random_rec_datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | import testslide 10 | import torch 11 | from torchrecipes.rec.datamodules.random_rec_datamodule import RandomRecDataModule 12 | 13 | 14 | class TestRandomRecDataModule(testslide.TestCase): 15 | def test_manual_seed_generator(self) -> None: 16 | dm1 = RandomRecDataModule(manual_seed=353434, min_ids_per_features=2) 17 | iterator1 = iter(dm1.init_loader) 18 | dm2 = RandomRecDataModule(manual_seed=353434, min_ids_per_features=2) 19 | iterator2 = iter(dm2.init_loader) 20 | 21 | for _ in range(10): 22 | batch1 = next(iterator1) 23 | batch2 = next(iterator2) 24 | self.assertTrue(torch.equal(batch1.dense_features, batch2.dense_features)) 25 | self.assertTrue( 26 | torch.equal( 27 | batch1.sparse_features.values(), batch2.sparse_features.values() 28 | ) 29 | ) 30 | self.assertTrue( 31 | torch.equal( 32 | batch1.sparse_features.offsets(), batch2.sparse_features.offsets() 33 | ) 34 | ) 35 | self.assertTrue(torch.equal(batch1.labels, batch2.labels)) 36 | 37 | def test_no_manual_seed_generator(self) -> None: 38 | dm1 = RandomRecDataModule(min_ids_per_features=2) 39 | iterator1 = iter(dm1.init_loader) 40 | dm2 = RandomRecDataModule(min_ids_per_features=2) 41 | iterator2 = iter(dm2.init_loader) 42 | 43 | for _ in range(10): 44 | batch1 = next(iterator1) 45 | batch2 = next(iterator2) 46 | self.assertFalse(torch.equal(batch1.dense_features, batch2.dense_features)) 47 | self.assertFalse( 48 | torch.equal( 49 | batch1.sparse_features.values(), batch2.sparse_features.values() 50 | ) 51 | ) 52 | # offsets not random 53 | self.assertTrue( 54 | torch.equal( 55 | batch1.sparse_features.offsets(), batch2.sparse_features.offsets() 56 | ) 57 | ) 58 | -------------------------------------------------------------------------------- /torchrecipes/rec/datamodules/tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import contextlib 8 | import csv 9 | import os 10 | import random 11 | from typing import Generator, List 12 | 13 | INT_FEATURE_COUNT = 13 14 | CAT_FEATURE_COUNT = 26 15 | 16 | 17 | @contextlib.contextmanager 18 | def create_dataset_tsv( 19 | num_rows: int = 10000, 20 | train: bool = True, 21 | # pyre-fixme[9]: dataset_path is declared to have type `str` but is used as type `None`. 22 | dataset_path: str = None, 23 | num_days: int = 1, 24 | num_days_test: int = 0, 25 | is_kaggle: bool = False, 26 | ) -> Generator[List[str], None, None]: 27 | """Util function to create the dataset tsv locally following the patern of the criteo dataset 28 | Args: 29 | num_rows: number of rows we want to create in this dataset 30 | train: if it's training dataset which will determine the generation of labels 31 | dataset_path: the path to create the dataset, required as an input 32 | num_days: number of days (out of 25) of data to use for train/validation 33 | only valid for criteo 1tb, as kaggle only have 1 train file 34 | num_days_test: number of days (out of 25) of data to use for testing 35 | only valid for criteo 1tb, the test data of kaggle does not label, thus not useable 36 | is_kaggle: if we generate Kaggle data or not 37 | 38 | Examples: 39 | >>> with create_dataset_tsv( 40 | num_days=num_days, num_days_test=num_days_test, dataset_path=dataset_path 41 | ) as _: 42 | >>> dm = CriteoDataModule( 43 | num_days=1, 44 | batch_size=3, 45 | num_days_test=0, 46 | num_workers=0, 47 | dataset_path=dataset_path, 48 | ) 49 | """ 50 | if is_kaggle is False: 51 | filenames = [f"day_{day}.tsv" for day in range(num_days + num_days_test)] 52 | else: 53 | filenames = ["train.txt", "test.txt"] 54 | paths = [os.path.join(dataset_path, filename) for filename in filenames] 55 | for path in paths: 56 | with open(path, "w") as f: 57 | rows = [] 58 | for _ in range(num_rows): 59 | row = [] 60 | if train: 61 | row.append(str(random.randint(0, 1))) 62 | row += [ 63 | *(str(random.randint(0, 100)) for _ in range(INT_FEATURE_COUNT)), 64 | *( 65 | ("%x" % abs(hash(str(random.randint(0, 1000))))).zfill(8)[:8] 66 | for _ in range(CAT_FEATURE_COUNT) 67 | ), 68 | ] 69 | 70 | rows.append(row) 71 | cf = csv.writer(f, delimiter="\t") 72 | cf.writerows(rows) 73 | yield paths 74 | -------------------------------------------------------------------------------- /torchrecipes/rec/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/rec/modules/__init__.py -------------------------------------------------------------------------------- /torchrecipes/rec/modules/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/rec/modules/tests/__init__.py -------------------------------------------------------------------------------- /torchrecipes/rec/modules/tests/test_unsharded_lightning_dlrm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | import unittest 10 | 11 | import pytorch_lightning as pl 12 | from torchrec import EmbeddingBagCollection 13 | from torchrec.modules.embedding_configs import EmbeddingBagConfig 14 | from torchrecipes.rec.datamodules.random_rec_datamodule import RandomRecDataModule 15 | from torchrecipes.rec.modules.unsharded_lightning_dlrm import UnshardedLightningDLRM 16 | 17 | 18 | class TestUnshardedLightningDLRM(unittest.TestCase): 19 | def test_train_model(self) -> None: 20 | num_embeddings = 100 21 | embedding_dim = 10 22 | num_dense = 50 23 | 24 | eb1_config = EmbeddingBagConfig( 25 | name="t1", 26 | embedding_dim=embedding_dim, 27 | num_embeddings=num_embeddings, 28 | feature_names=["f1", "f3"], 29 | ) 30 | eb2_config = EmbeddingBagConfig( 31 | name="t2", 32 | embedding_dim=embedding_dim, 33 | num_embeddings=num_embeddings, 34 | feature_names=["f2"], 35 | ) 36 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) 37 | 38 | model = UnshardedLightningDLRM( 39 | ebc, 40 | dense_in_features=num_dense, 41 | dense_arch_layer_sizes=[20, embedding_dim], 42 | over_arch_layer_sizes=[5, 1], 43 | ) 44 | datamodule = RandomRecDataModule(num_dense=num_dense) 45 | 46 | trainer = pl.Trainer( 47 | max_epochs=3, 48 | enable_checkpointing=False, 49 | limit_train_batches=100, 50 | limit_val_batches=100, 51 | limit_test_batches=100, 52 | logger=False, 53 | ) 54 | 55 | batch = next(iter(datamodule.init_loader)) 56 | model( 57 | dense_features=batch.dense_features, 58 | sparse_features=batch.sparse_features, 59 | ) 60 | trainer.fit(model, datamodule=datamodule) 61 | trainer.test(model, datamodule=datamodule) 62 | -------------------------------------------------------------------------------- /torchrecipes/rec/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/rec/tests/__init__.py -------------------------------------------------------------------------------- /torchrecipes/rec/tests/test_dlrm_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os 8 | import tempfile 9 | import unittest 10 | import uuid 11 | 12 | from torch.distributed.launcher.api import elastic_launch, LaunchConfig 13 | from torchrec import test_utils 14 | from torchrecipes.rec.datamodules.tests.utils import create_dataset_tsv 15 | from torchrecipes.rec.dlrm_main import main 16 | 17 | 18 | class MainTest(unittest.TestCase): 19 | @classmethod 20 | def _run_trainer(cls) -> None: 21 | num_days = 1 22 | num_days_test = 1 23 | dataset_path: str = tempfile.mkdtemp() 24 | with create_dataset_tsv( 25 | num_days=num_days, num_days_test=num_days_test, dataset_path=dataset_path 26 | ): 27 | tensorboard_save_dir: str = tempfile.mkdtemp() 28 | main( 29 | [ 30 | "--limit_train_batches", 31 | "5", 32 | "--limit_val_batches", 33 | "5", 34 | "--limit_test_batches", 35 | "5", 36 | "--over_arch_layer_sizes", 37 | "8,1", 38 | "--dense_arch_layer_sizes", 39 | "8,8", 40 | "--embedding_dim", 41 | "8", 42 | "--num_embeddings", 43 | "64", 44 | "--tensorboard_save_dir", 45 | tensorboard_save_dir, 46 | "--dataset_path", 47 | dataset_path, 48 | ] 49 | ) 50 | 51 | @test_utils.skip_if_asan 52 | def test_main_function(self) -> None: 53 | with tempfile.TemporaryDirectory() as tmpdir: 54 | lc = LaunchConfig( 55 | min_nodes=1, 56 | max_nodes=1, 57 | nproc_per_node=2, 58 | run_id=str(uuid.uuid4()), 59 | rdzv_backend="c10d", 60 | rdzv_endpoint=os.path.join(tmpdir, "rdzv"), 61 | rdzv_configs={"store_type": "file"}, 62 | start_method="spawn", 63 | monitor_interval=1, 64 | max_restarts=0, 65 | ) 66 | 67 | elastic_launch(config=lc, entrypoint=self._run_trainer)() 68 | -------------------------------------------------------------------------------- /torchrecipes/tests/test_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import testslide 7 | 8 | 9 | class VersionTest(testslide.TestCase): 10 | def test_can_get_version(self) -> None: 11 | import torchrecipes 12 | 13 | self.assertIsNotNone(torchrecipes.__version__) 14 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/.torchxconfig: -------------------------------------------------------------------------------- 1 | [flow] 2 | secure_group = oncall_pt_lightning 3 | entitlement = gpu_default_global 4 | proxy_workflow_image = None 5 | 6 | [cli:run] 7 | component = fb.dist.ddp 8 | workspace = //torchrecipes/text/doc_classification:penv 9 | 10 | [component:fb.dist.ddp] 11 | img = pytorch_recipes_text_doc_classification:latest 12 | m = main.py 13 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/text/doc_classification/__init__.py -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torchrecipes.text.doc_classification.transform.doc_classification_text_transform import ( # noqa 7 | DocClassificationTextTransform, 8 | ) 9 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/datamodule/dataset/sst2_dataset.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchtext.datasets.sst2.SST2 2 | root: ~/.torchtext/cache 3 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/datamodule/doc_classification_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.text.doc_classification.datamodule.doc_classification.DocClassificationDataModule.from_config 2 | columns: 3 | - text 4 | - label 5 | label_column: label 6 | batch_size: 16 7 | num_workers: 0 8 | drop_last: False 9 | pin_memory: False 10 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/default_config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - module/optim: adamw 3 | - module/model: xlmrbase_classifier_tiny 4 | - datamodule: doc_classification_datamodule 5 | - datamodule/dataset: sst2_dataset 6 | - trainer: cpu 7 | - transform: doc_classification_transform_tiny 8 | - _self_ 9 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/module/model/xlmrbase_classifier.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchtext.models.RobertaBundle.build_model 2 | encoder_conf: 3 | _target_: torchtext.models.RobertaEncoderConf 4 | vocab_size: 250002 5 | embedding_dim: 768 6 | ffn_dimension: 3072 7 | padding_idx: 1 8 | max_seq_len: 514 9 | num_attention_heads: 12 10 | num_encoder_layers: 12 11 | dropout: 0.1 12 | scaling: null 13 | normalize_before: False 14 | head: 15 | _target_: torchtext.models.RobertaClassificationHead 16 | num_classes: 2 17 | input_dim: 768 18 | inner_dim: 1024 19 | dropout: 0.4 20 | freeze_encoder: False 21 | checkpoint: https://download.pytorch.org/models/text/xlmr.base.encoder.pt 22 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/module/model/xlmrbase_classifier_tiny.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchtext.models.RobertaBundle.build_model 2 | encoder_conf: 3 | _target_: torchtext.models.RobertaEncoderConf 4 | vocab_size: 102 5 | embedding_dim: 8 6 | ffn_dimension: 8 7 | padding_idx: 1 8 | max_seq_len: 128 9 | num_attention_heads: 1 10 | num_encoder_layers: 1 11 | dropout: 0.1 12 | scaling: null 13 | normalize_before: False 14 | head: 15 | _target_: torchtext.models.RobertaClassificationHead 16 | num_classes: 2 17 | input_dim: 8 18 | inner_dim: 8 19 | dropout: 0.4 20 | freeze_encoder: True 21 | checkpoint: null 22 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/module/optim/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: 1.0e-05 3 | betas: 4 | - 0.9 5 | - 0.999 6 | eps: 1.0e-08 7 | weight_decay: 0 8 | amsgrad: false 9 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/tiny_model_full_config.yaml: -------------------------------------------------------------------------------- 1 | # This config demostrates specifying full config options in single yaml file 2 | # Optionally, users can add config groups to hold some frequently used configs and 3 | # refer them in `defaults` to make this config more concise as shown in 4 | # tiny_model_mixed_config.yaml 5 | 6 | module: 7 | # equivalent to setting `module/optim: adamw` under `defaults` 8 | optim: 9 | _target_: torch.optim.AdamW 10 | lr: 1.0e-05 11 | betas: 12 | - 0.9 13 | - 0.999 14 | eps: 1.0e-08 15 | weight_decay: 0 16 | amsgrad: false 17 | # equivalent to setting `module/model: xlmrbase_classifier_tiny` under `defaults` 18 | model: 19 | _target_: torchtext.models.RobertaBundle.build_model 20 | encoder_conf: 21 | _target_: torchtext.models.RobertaEncoderConf 22 | vocab_size: 102 23 | embedding_dim: 8 24 | ffn_dimension: 8 25 | padding_idx: 1 26 | max_seq_len: 128 27 | num_attention_heads: 1 28 | num_encoder_layers: 1 29 | dropout: 0.1 30 | scaling: null 31 | normalize_before: False 32 | head: 33 | _target_: torchtext.models.RobertaClassificationHead 34 | num_classes: 2 35 | input_dim: 8 36 | inner_dim: 8 37 | dropout: 0.4 38 | freeze_encoder: True 39 | checkpoint: null 40 | 41 | # equivalent to setting `datamodule: doc_classification_datamodule` under `defaults` 42 | datamodule: 43 | _target_: torchrecipes.text.doc_classification.datamodule.doc_classification.DocClassificationDataModule.from_config 44 | columns: 45 | - text 46 | - label 47 | label_column: label 48 | batch_size: 16 49 | num_workers: 0 50 | drop_last: False 51 | pin_memory: False 52 | dataset: 53 | _target_: torchtext.datasets.sst2.SST2 54 | root: ~/.torchtext/cache 55 | 56 | # equivalent to setting `transform: doc_classification_transform_tiny` under `defaults` 57 | transform: 58 | transform: 59 | _target_: torchrecipes.text.doc_classification.transform.doc_classification_text_transform.DocClassificationTextTransform 60 | vocab_path: https://download.pytorch.org/models/text/xlmr.vocab_example.pt 61 | spm_model_path: https://download.pytorch.org/models/text/xlmr.sentencepiece_example.bpe.model 62 | label_transform: 63 | _target_: torchtext.transforms.LabelToIndex 64 | label_names: 65 | - "0" 66 | - "1" 67 | num_labels: 2 68 | 69 | # equivalent to setting `trainer: cpu` under `defaults` 70 | trainer: 71 | _target_: pytorch_lightning.trainer.Trainer 72 | accelerator: cpu 73 | devices: null 74 | strategy: null 75 | max_epochs: 1 76 | default_root_dir: /tmp/doc_classification/torchrecipes 77 | enable_checkpointing: true 78 | fast_dev_run: false 79 | logger: 80 | _target_: pytorch_lightning.loggers.TensorBoardLogger 81 | save_dir: /tmp/torchrecipes/doc_classification/logs 82 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/tiny_model_mixed_config.yaml: -------------------------------------------------------------------------------- 1 | # This config demostrates specifying a mix of full options from the config group for 2 | # a component, referring to preset config groups and override options in config groups. 3 | 4 | # New recipe developers could start with specifying full options in a single yaml 5 | # file like `tiny_model_full_config.yaml`. When the best defaults for 6 | # some options are settled, they can be added as a config group(e.g. 7 | # module/model=my_model) and referred here to make this file concise. 8 | 9 | defaults: 10 | # this is optional as it will be overrided by the `model` section below 11 | - module/model: xlmrbase_classifier_tiny 12 | - module/optim: adamw 13 | - datamodule: doc_classification_datamodule 14 | - transform: doc_classification_transform_tiny 15 | - trainer: cpu 16 | - _self_ 17 | 18 | datamodule: 19 | dataset: 20 | _target_: torchtext.datasets.sst2.SST2 21 | root: ~/.torchtext/cache 22 | 23 | trainer: 24 | fast_dev_run: false 25 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.trainer.Trainer 2 | accelerator: cpu 3 | devices: null 4 | strategy: null 5 | max_epochs: 1 6 | default_root_dir: "" 7 | enable_checkpointing: true 8 | fast_dev_run: false 9 | logger: 10 | _target_: pytorch_lightning.loggers.TensorBoardLogger 11 | save_dir: /tmp/logs 12 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/trainer/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.trainer.Trainer 2 | accelerator: gpu 3 | devices: 8 4 | strategy: ddp 5 | max_epochs: 1 6 | num_sanity_val_steps: 0 7 | log_every_n_steps: 10 8 | default_root_dir: "" 9 | enable_checkpointing: true 10 | fast_dev_run: false 11 | logger: 12 | _target_: pytorch_lightning.loggers.TensorBoardLogger 13 | save_dir: /tmp/logs 14 | callbacks: 15 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 16 | dirpath: /tmp/checkpoints 17 | - _target_: pytorch_lightning.callbacks.LearningRateMonitor 18 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/trainer/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.trainer.Trainer 2 | accelerator: gpu 3 | devices: 1 4 | strategy: null 5 | max_epochs: 1 6 | default_root_dir: "" 7 | enable_checkpointing: true 8 | fast_dev_run: false 9 | logger: 10 | _target_: pytorch_lightning.loggers.TensorBoardLogger 11 | save_dir: /tmp/logs 12 | callbacks: 13 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 14 | dirpath: /tmp/checkpoints 15 | - _target_: pytorch_lightning.callbacks.LearningRateMonitor 16 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/transform/doc_classification_transform.yaml: -------------------------------------------------------------------------------- 1 | transform: 2 | _target_: torchrecipes.text.doc_classification.transform.doc_classification_text_transform.DocClassificationTextTransform 3 | vocab_path: https://download.pytorch.org/models/text/xlmr.vocab.pt 4 | spm_model_path: https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model 5 | label_transform: 6 | _target_: torchtext.transforms.LabelToIndex 7 | label_names: 8 | - "0" 9 | - "1" 10 | num_labels: 2 11 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/transform/doc_classification_transform_tiny.yaml: -------------------------------------------------------------------------------- 1 | transform: 2 | _target_: torchrecipes.text.doc_classification.transform.doc_classification_text_transform.DocClassificationTextTransform 3 | vocab_path: https://download.pytorch.org/models/text/xlmr.vocab_example.pt 4 | spm_model_path: https://download.pytorch.org/models/text/xlmr.sentencepiece_example.bpe.model 5 | label_transform: 6 | _target_: torchtext.transforms.LabelToIndex 7 | label_names: 8 | - "0" 9 | - "1" 10 | num_labels: 2 11 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/conf/xlmrbase_sst2_config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - module/optim: adamw 3 | - module/model: xlmrbase_classifier 4 | - datamodule: doc_classification_datamodule 5 | - datamodule/dataset: sst2_dataset 6 | - trainer: cpu 7 | - transform: doc_classification_transform 8 | - _self_ 9 | 10 | transform: 11 | num_labels: 2 12 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | # pyre-strict 10 | 11 | from dataclasses import dataclass 12 | from typing import Optional 13 | 14 | import hydra 15 | from omegaconf import DictConfig, OmegaConf 16 | from pytorch_lightning import seed_everything, Trainer 17 | from torchrecipes.text.doc_classification.datamodule.doc_classification import ( 18 | DocClassificationDataModule, 19 | ) 20 | from torchrecipes.text.doc_classification.module.doc_classification import ( 21 | DocClassificationModule, 22 | ) 23 | 24 | 25 | @dataclass 26 | class TrainOutput: 27 | best_model_path: Optional[str] = None 28 | tensorboard_log_dir: Optional[str] = None 29 | 30 | 31 | def train_and_test(cfg: DictConfig) -> TrainOutput: 32 | if cfg.get("random_seed") is not None: 33 | seed_everything(cfg.random_seed) 34 | 35 | module = DocClassificationModule.from_config( 36 | model=cfg.module.model, 37 | optim=cfg.module.optim, 38 | transform=cfg.transform, 39 | num_classes=cfg.transform.num_labels, 40 | ) 41 | datamodule = DocClassificationDataModule.from_config( 42 | transform=cfg.transform, 43 | dataset=cfg.datamodule.dataset, 44 | columns=cfg.datamodule.columns, 45 | label_column=cfg.datamodule.label_column, 46 | batch_size=cfg.datamodule.batch_size, 47 | num_workers=cfg.datamodule.num_workers, 48 | drop_last=cfg.datamodule.drop_last, 49 | pin_memory=cfg.datamodule.pin_memory, 50 | ) 51 | 52 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer) 53 | trainer.fit(module, datamodule=datamodule) 54 | trainer.test(module, datamodule=datamodule) 55 | return TrainOutput( 56 | best_model_path=getattr(trainer.checkpoint_callback, "best_model_path", None), 57 | tensorboard_log_dir=getattr(trainer.logger, "save_dir", None), 58 | ) 59 | 60 | 61 | @hydra.main(config_path="conf", config_name="default_config") 62 | def main(cfg: DictConfig) -> TrainOutput: 63 | print(f"config:\n{OmegaConf.to_yaml(cfg)}") 64 | return train_and_test(cfg) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/text/doc_classification/tests/common/__init__.py -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/common/assets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os.path 8 | import shutil 9 | from typing import Tuple, Union 10 | 11 | from torchtext.datasets import sst2 12 | 13 | _DATA_DIR_PATH: str = os.path.realpath( 14 | os.path.join(os.path.dirname(__file__), "..", "data") 15 | ) 16 | 17 | 18 | def get_asset_path(*paths: Union[str, Tuple[str]]) -> str: 19 | """Return full path of a test asset""" 20 | # pyre-fixme[6]: For 2nd argument expected `Union[PathLike[str], str]` but got 21 | # `Union[Tuple[str], str]`. 22 | return os.path.join(_DATA_DIR_PATH, *paths) 23 | 24 | 25 | def copy_asset(cur_path: str, new_path: str) -> None: 26 | new_path_dir = os.path.dirname(new_path) 27 | if not os.path.exists(new_path_dir): 28 | os.makedirs(new_path_dir) 29 | shutil.copy(cur_path, new_path) 30 | 31 | 32 | def copy_partial_sst2_dataset(root_dir: str) -> None: 33 | cur_path = get_asset_path(sst2.DATASET_NAME, sst2._PATH) 34 | new_path = os.path.join(root_dir, sst2.DATASET_NAME, sst2._PATH) 35 | copy_asset(cur_path, new_path) 36 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/data/SST2/SST-2.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/text/doc_classification/tests/data/SST2/SST-2.zip -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/data/spm_example.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/text/doc_classification/tests/data/spm_example.model -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/data/vocab_example.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/text/doc_classification/tests/data/vocab_example.pt -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/test_doc_classification_datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | # pyre-strict 10 | 11 | from unittest.mock import patch 12 | 13 | import testslide 14 | import torch 15 | from torchrecipes.text.doc_classification.datamodule.doc_classification import ( 16 | DocClassificationDataModule, 17 | ) 18 | from torchrecipes.text.doc_classification.tests.common.assets import ( 19 | _DATA_DIR_PATH, 20 | get_asset_path, 21 | ) 22 | from torchrecipes.text.doc_classification.transform.doc_classification_text_transform import ( 23 | DocClassificationTextTransform, 24 | ) 25 | from torchtext.datasets.sst2 import SST2 26 | from torchtext.transforms import LabelToIndex 27 | 28 | 29 | class TestDocClassificationDataModule(testslide.TestCase): 30 | def setUp(self) -> None: 31 | super().setUp() 32 | # patch the _hash_check() fn output to make it work with the dummy dataset 33 | self.patcher = patch( 34 | "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True 35 | ) 36 | self.patcher.start() 37 | 38 | def tearDown(self) -> None: 39 | self.patcher.stop() 40 | super().tearDown() 41 | 42 | def get_datamodule(self) -> DocClassificationDataModule: 43 | train_dataset, val_dataset, test_dataset = SST2(root=_DATA_DIR_PATH) 44 | text_transform = DocClassificationTextTransform( 45 | vocab_path=get_asset_path("vocab_example.pt"), 46 | spm_model_path=get_asset_path("spm_example.model"), 47 | ) 48 | label_transform = LabelToIndex(label_names=["0", "1"]) 49 | return DocClassificationDataModule( 50 | train_dataset=train_dataset, 51 | val_dataset=val_dataset, 52 | # TODO: Note that the following line should be replaced by 53 | # `test_dataset` once we update the lightning module to support 54 | # test data with and without labels 55 | test_dataset=val_dataset, 56 | transform=text_transform, 57 | label_transform=label_transform, 58 | columns=["text", "label"], 59 | label_column="label", 60 | batch_size=8, 61 | ) 62 | 63 | def test_doc_classification_datamodule(self) -> None: 64 | datamodule = self.get_datamodule() 65 | self.assertIsInstance(datamodule, DocClassificationDataModule) 66 | 67 | dataloader = datamodule.train_dataloader() 68 | batch = next(iter(dataloader)) 69 | 70 | self.assertTrue(torch.is_tensor(batch["label_ids"])) 71 | self.assertTrue(torch.is_tensor(batch["token_ids"])) 72 | 73 | self.assertEqual(batch["label_ids"].size(), torch.Size([8])) 74 | self.assertEqual(batch["token_ids"].size(), torch.Size([8, 35])) 75 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/test_doc_classification_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | # pyre-strict 10 | from unittest.mock import patch 11 | 12 | import torchrecipes.text.doc_classification.conf # noqa 13 | from hydra import compose, initialize_config_module 14 | from omegaconf import DictConfig 15 | from torchrecipes.core.test_utils.test_base import BaseTrainAppTestCase 16 | from torchrecipes.text.doc_classification.main import train_and_test 17 | 18 | 19 | class TestDocClassificationMain(BaseTrainAppTestCase): 20 | def setUp(self) -> None: 21 | super().setUp() 22 | # patch the _hash_check() fn output to make it work with the dummy dataset 23 | self.patcher = patch( 24 | "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True 25 | ) 26 | self.patcher.start() 27 | 28 | def tearDown(self) -> None: 29 | self.patcher.stop() 30 | super().tearDown() 31 | 32 | def _get_config(self, config_name: str) -> DictConfig: 33 | with initialize_config_module("torchrecipes.text.doc_classification.conf"): 34 | cfg = compose( 35 | config_name=config_name, 36 | overrides=[ 37 | # train with 1 batch of data and skip checkpointing 38 | "trainer.fast_dev_run=true", 39 | ], 40 | ) 41 | return cfg 42 | 43 | def test_default_config(self) -> None: 44 | cfg = self._get_config("default_config") 45 | output = train_and_test(cfg) 46 | self.assertIsNotNone(output) 47 | 48 | def test_tiny_model_full_config(self) -> None: 49 | cfg = self._get_config("tiny_model_full_config") 50 | output = train_and_test(cfg) 51 | self.assertIsNotNone(output) 52 | 53 | def test_tiny_model_mixed_config(self) -> None: 54 | cfg = self._get_config("tiny_model_mixed_config") 55 | output = train_and_test(cfg) 56 | self.assertIsNotNone(output) 57 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/test_doc_classification_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # pyre-strict 8 | import os 9 | from unittest.mock import patch 10 | 11 | import torchtext 12 | from pytorch_lightning.trainer import Trainer 13 | from torch.optim import AdamW 14 | from torchrecipes.text.doc_classification.datamodule.doc_classification import ( 15 | DocClassificationDataModule, 16 | ) 17 | from torchrecipes.text.doc_classification.module.doc_classification import ( 18 | DocClassificationModule, 19 | ) 20 | from torchrecipes.text.doc_classification.tests.common.assets import ( 21 | _DATA_DIR_PATH, 22 | get_asset_path, 23 | ) 24 | from torchrecipes.text.doc_classification.transform.doc_classification_text_transform import ( 25 | DocClassificationTextTransform, 26 | ) 27 | from torchrecipes.utils.task_test_base import TaskTestCaseBase 28 | from torchtext.datasets.sst2 import SST2 29 | from torchtext.transforms import LabelToIndex 30 | 31 | 32 | class TestDocClassificationModule(TaskTestCaseBase): 33 | def setUp(self) -> None: 34 | self.base_dir = os.path.join(os.path.dirname(__file__), "data") 35 | # patch the _hash_check() fn output to make it work with the dummy dataset 36 | self.patcher = patch( 37 | "torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True 38 | ) 39 | self.patcher.start() 40 | 41 | def tearDown(self) -> None: 42 | self.patcher.stop() 43 | super().tearDown() 44 | 45 | def get_transform(self) -> DocClassificationTextTransform: 46 | return DocClassificationTextTransform( 47 | vocab_path=get_asset_path("vocab_example.pt"), 48 | spm_model_path=get_asset_path("spm_example.model"), 49 | ) 50 | 51 | def get_standard_task(self) -> DocClassificationModule: 52 | model = torchtext.models.RobertaBundle.build_model( 53 | encoder_conf=torchtext.models.roberta.model.RobertaEncoderConf( 54 | vocab_size=102, 55 | embedding_dim=8, 56 | ffn_dimension=8, 57 | padding_idx=1, 58 | max_seq_len=64, 59 | num_attention_heads=1, 60 | num_encoder_layers=1, 61 | ), 62 | head=torchtext.models.roberta.model.RobertaClassificationHead( 63 | num_classes=2, 64 | input_dim=8, 65 | inner_dim=8, 66 | ), 67 | ) 68 | optim = AdamW(model.parameters()) 69 | return DocClassificationModule( 70 | transform=self.get_transform(), 71 | model=model, 72 | optim=optim, 73 | num_classes=2, 74 | ) 75 | 76 | def get_datamodule(self) -> DocClassificationDataModule: 77 | train_dataset, val_dataset, test_dataset = SST2(root=_DATA_DIR_PATH) 78 | label_transform = LabelToIndex(label_names=["0", "1"]) 79 | return DocClassificationDataModule( 80 | train_dataset=train_dataset, 81 | val_dataset=val_dataset, 82 | # TODO: Note that the following line should be replaced by 83 | # `test_dataset` once we update the lightning module to support 84 | # test data with and without labels 85 | test_dataset=val_dataset, 86 | transform=self.get_transform(), 87 | label_transform=label_transform, 88 | columns=["text", "label"], 89 | label_column="label", 90 | batch_size=8, 91 | ) 92 | 93 | def test_train(self) -> None: 94 | # pyre-fixme[16]: `TestDocClassificationModule` has no attribute `datamodule`. 95 | self.datamodule = self.get_datamodule() 96 | task = self.get_standard_task() 97 | trainer = Trainer(fast_dev_run=True) 98 | trainer.fit(task, datamodule=self.datamodule) 99 | 100 | pred1 = task.forward({"text": ["hello world", "how are you?"]}) 101 | pred2 = task.forward( 102 | {"text": ["hello world", "how are you?"], "label": ["1", "0"]} 103 | ) 104 | self.assertTrue(pred1 is not None) 105 | self.assertTrue(pred2 is not None) 106 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/tests/test_doc_classification_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import hydra 7 | import testslide 8 | import torch 9 | from omegaconf import OmegaConf 10 | from torchrecipes.text.doc_classification.tests.common.assets import get_asset_path 11 | from torchrecipes.text.doc_classification.transform.doc_classification_text_transform import ( 12 | DocClassificationTextTransform, 13 | ) 14 | from torchrecipes.utils.config_utils import get_class_name_str 15 | 16 | 17 | class TestDocClassificationTransform(testslide.TestCase): 18 | def test_doc_classification_transform(self) -> None: 19 | transform_conf = OmegaConf.create( 20 | { 21 | "_target_": get_class_name_str(DocClassificationTextTransform), 22 | "vocab_path": get_asset_path("vocab_example.pt"), 23 | "spm_model_path": get_asset_path("spm_example.model"), 24 | } 25 | ) 26 | transform = hydra.utils.instantiate(transform_conf, _recursive_=False) 27 | 28 | # check whether correct class is being instantiated by hydra 29 | self.assertIsInstance(transform, DocClassificationTextTransform) 30 | 31 | test_input = {"text": ["XLMR base Model Comparison"]} 32 | actual = transform(test_input) 33 | expected_token_ids = torch.tensor( 34 | [[0, 6, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2]], dtype=torch.long 35 | ) 36 | self.assertTrue(torch.all(actual["token_ids"].eq(expected_token_ids))) 37 | -------------------------------------------------------------------------------- /torchrecipes/text/doc_classification/transform/doc_classification_text_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os.path 8 | from typing import Any, Dict, List 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torchtext 13 | import torchtext.transforms as T 14 | from torchtext.functional import to_tensor 15 | 16 | 17 | class DocClassificationTextTransform(nn.Module): 18 | def __name__(self) -> str: 19 | return "DocClassificationTextTransform" 20 | 21 | def __init__( 22 | self, 23 | vocab_path: str, 24 | spm_model_path: str, 25 | text_column: str = "text", 26 | token_ids_column: str = "token_ids", 27 | pad_idx: int = 1, 28 | ) -> None: 29 | super().__init__() 30 | 31 | if os.path.exists(vocab_path): 32 | vocab = torch.load(vocab_path) 33 | else: 34 | vocab = torchtext._download_hooks.load_state_dict_from_url(vocab_path) 35 | 36 | self.xlmr_roberta_model_transform = T.Sequential( 37 | T.SentencePieceTokenizer(spm_model_path), 38 | T.VocabTransform(vocab), 39 | T.Truncate(254), 40 | T.AddToken(token=0, begin=True), 41 | T.AddToken(token=2, begin=False), 42 | ) 43 | self.text_column = text_column 44 | self.token_ids_column = token_ids_column 45 | self.pad_idx = pad_idx 46 | 47 | def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: 48 | texts = batch[self.text_column] 49 | assert torch.jit.isinstance(texts, List[str]) 50 | 51 | tokens_list = self.xlmr_roberta_model_transform(texts) 52 | tokens_tensor: torch.Tensor = to_tensor(tokens_list, self.pad_idx) 53 | batch[self.token_ids_column] = tokens_tensor 54 | return batch 55 | -------------------------------------------------------------------------------- /torchrecipes/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | -------------------------------------------------------------------------------- /torchrecipes/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os 8 | from typing import Optional 9 | 10 | from fsspec.core import url_to_fs 11 | from pytorch_lightning.callbacks import ModelCheckpoint 12 | 13 | 14 | def find_last_checkpoint_path(checkpoint_dir: Optional[str]) -> Optional[str]: 15 | """Takes in a checkpoint directory path, looks for a last.ckpt checkpoint inside, 16 | and returns the full path that we can use for resuming from that checkpoint. 17 | 18 | Args: 19 | checkpoint_dir: Path where the model file(s) are saved. 20 | 21 | Returns: 22 | Full path for the last model checkpoint from the given checkpoint directory. 23 | """ 24 | if checkpoint_dir is None: 25 | return None 26 | checkpoint_file_name = ( 27 | f"{ModelCheckpoint.CHECKPOINT_NAME_LAST}{ModelCheckpoint.FILE_EXTENSION}" 28 | ) 29 | last_checkpoint_filepath = os.path.join(checkpoint_dir, checkpoint_file_name) 30 | fs, _ = url_to_fs(last_checkpoint_filepath) 31 | if not fs.exists(last_checkpoint_filepath): 32 | return None 33 | 34 | return last_checkpoint_filepath 35 | -------------------------------------------------------------------------------- /torchrecipes/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # pyre-strict 8 | 9 | from typing import Type, TypeVar 10 | 11 | 12 | T = TypeVar("T") 13 | 14 | _CONFIG_ATTR = "_config_entry" 15 | 16 | 17 | def config_entry(fn: T) -> T: 18 | """Decorator used to mark an object as the config entry-point for a class. 19 | 20 | Below is an exameple usage for this decorator. Only a single method in a 21 | class should be annoted as the @config_entry. 22 | 23 | class MyObject: 24 | @config_entry 25 | @staticmethod 26 | def from_config(config: MyConf) -> 'MyObject': 27 | ... 28 | """ 29 | setattr(fn, _CONFIG_ATTR, None) 30 | return fn 31 | 32 | 33 | def get_class_config_method(klass: Type[T]) -> str: 34 | """ 35 | Args: 36 | klass: The class definition. 37 | 38 | Raises: 39 | ValueError if the klass does not define a single such method or the 40 | defined method is invalid. 41 | 42 | Returns: 43 | The fully qualified name of the method. 44 | """ 45 | class_name = get_class_name_str(klass) 46 | fns = [fn for name, fn in klass.__dict__.items() if hasattr(fn, _CONFIG_ATTR)] 47 | if len(fns) != 1: 48 | raise ValueError( 49 | f"{class_name} has no config entrypoint. Did you use @config_entry to annotate the method?" 50 | ) 51 | fn = fns[0] 52 | if not isinstance(fn, staticmethod): 53 | raise ValueError( 54 | f"{class_name}.{fn.__name__} is not a standalone function. Did you forget @staticmethod?" 55 | ) 56 | return f"{class_name}.{fn.__func__.__name__}" 57 | 58 | 59 | def get_class_name_str(klass: Type[T]) -> str: 60 | """ 61 | Args: 62 | klass: The class definition. 63 | 64 | Returns: 65 | The fully qualified name of the given class. 66 | """ 67 | return ".".join([klass.__module__, klass.__name__]) 68 | -------------------------------------------------------------------------------- /torchrecipes/utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | 9 | 10 | def get_rank() -> int: 11 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 12 | return torch.distributed.get_rank() 13 | return 0 14 | 15 | 16 | def get_world_size() -> int: 17 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 18 | return torch.distributed.get_world_size() 19 | return 1 20 | 21 | 22 | def barrier() -> None: 23 | """ 24 | Wrapper over torch.distributed.barrier, returns without waiting 25 | if the distributed process group is not initialized instead of throwing error. 26 | """ 27 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 28 | return 29 | torch.distributed.barrier() 30 | -------------------------------------------------------------------------------- /torchrecipes/utils/mixup_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from dataclasses import dataclass 7 | from enum import Enum 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | class MixupScheme(Enum): 14 | """Mixup scheme: Where to perform mixup within a model.""" 15 | 16 | 17 | @dataclass 18 | class MixupParams: 19 | """ 20 | alpha: float. Mixup ratio. Recommended values from (0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0) 21 | scheme: MixupScheme. The locations to perform mixup within a model. 22 | More details about the parameters and mixup can be found in FAIM WIKI: 23 | https://fburl.com/wiki/3f0qh0zr 24 | """ 25 | 26 | alpha: float 27 | scheme: MixupScheme 28 | 29 | 30 | class MixupUtil: 31 | @staticmethod 32 | def _get_lambda(alpha: float = 1.0) -> float: 33 | # Sample from a beta distribution 34 | # The result is used for linear interpolation 35 | if alpha > 0.0: 36 | lam = np.random.beta(alpha, alpha) 37 | else: 38 | lam = 1.0 39 | return lam 40 | 41 | def __init__(self, batch_size: int) -> None: 42 | self.indices: torch.Tensor = torch.randperm(batch_size) 43 | self.lam: float = self._get_lambda() 44 | 45 | def mixup(self, x: torch.Tensor) -> torch.Tensor: 46 | return x * self.lam + x[self.indices] * (1 - self.lam) 47 | 48 | def compute_loss( 49 | self, 50 | criterion: torch.nn.Module, 51 | pred: torch.Tensor, 52 | original_target: torch.Tensor, 53 | mixed_target: torch.Tensor, 54 | ) -> float: 55 | return self.lam * criterion(pred, original_target) + (1 - self.lam) * criterion( 56 | pred, mixed_target 57 | ) 58 | 59 | def mixup_labels(self, x: torch.Tensor) -> torch.Tensor: 60 | return x[self.indices] 61 | -------------------------------------------------------------------------------- /torchrecipes/utils/task_test_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # pyre-strict 8 | 9 | import pickle 10 | from abc import ABC, abstractmethod 11 | 12 | import testslide 13 | from pytorch_lightning import LightningModule 14 | 15 | 16 | class TaskTestCaseBase(ABC, testslide.TestCase): 17 | """All Standard Task unit tests should inherit from this class.""" 18 | 19 | @abstractmethod 20 | def get_standard_task(self) -> LightningModule: 21 | """Subclasses should implement a standard method of retrieving and instance of the Task to test.""" 22 | raise NotImplementedError 23 | 24 | def test_standard_task_is_torchscriptable(self) -> None: 25 | task = self.get_standard_task() 26 | _ = task.to_torchscript() 27 | 28 | def test_standard_task_is_pickleable(self) -> None: 29 | task = self.get_standard_task() 30 | pickle.dumps(task) 31 | -------------------------------------------------------------------------------- /torchrecipes/utils/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from functools import wraps 8 | from tempfile import TemporaryDirectory 9 | from typing import Callable, TypeVar 10 | from unittest import TestCase 11 | 12 | from pyre_extensions import ParameterSpecification 13 | 14 | 15 | TParams = ParameterSpecification("TParams") 16 | TReturn = TypeVar("TReturn") 17 | 18 | 19 | def tempdir(func: Callable[TParams, TReturn]) -> Callable[TParams, TReturn]: 20 | """A decorator for creating a tempory directory that is cleaned up after function execution.""" 21 | 22 | @wraps(func) 23 | def wrapper( 24 | self: TestCase, *args: TParams.args, **kwargs: TParams.kwargs 25 | ) -> TReturn: 26 | with TemporaryDirectory() as temp: 27 | return func(self, temp, *args, **kwargs) 28 | 29 | return wrapper 30 | -------------------------------------------------------------------------------- /torchrecipes/utils/tests/test_config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # pyre-strict 8 | 9 | import testslide 10 | from torchrecipes.utils.config_utils import ( 11 | config_entry, 12 | get_class_config_method, 13 | get_class_name_str, 14 | ) 15 | 16 | 17 | class TestConfigUtils(testslide.TestCase): 18 | def test_annotation_success(self) -> None: 19 | class TestClass: 20 | @config_entry 21 | @staticmethod 22 | def from_config(config: object) -> "TestClass": 23 | return TestClass() 24 | 25 | self.assertEqual( 26 | get_class_name_str(TestClass), 27 | "torchrecipes.utils.tests.test_config_utils.TestClass", 28 | ) 29 | self.assertEqual( 30 | get_class_config_method(TestClass), 31 | "torchrecipes.utils.tests.test_config_utils.TestClass.from_config", 32 | ) 33 | 34 | def test_annotation_failures(self) -> None: 35 | class NotStatic: 36 | @config_entry 37 | def from_config(self, config: object) -> "NotStatic": 38 | return NotStatic() 39 | 40 | class MultipleEntries: 41 | @config_entry 42 | @staticmethod 43 | def from_config(config: object) -> "MultipleEntries": 44 | return MultipleEntries() 45 | 46 | @config_entry 47 | @staticmethod 48 | def from_hydra(config: object) -> "MultipleEntries": 49 | return MultipleEntries() 50 | 51 | class NoEntry: 52 | @staticmethod 53 | def from_config(config: object) -> "NoEntry": 54 | return NoEntry() 55 | 56 | self.assertIn("NotStatic", get_class_name_str(NotStatic)) 57 | with self.assertRaises(ValueError): 58 | get_class_config_method(NotStatic) 59 | 60 | self.assertIn("MultipleEntries", get_class_name_str(MultipleEntries)) 61 | with self.assertRaises(ValueError): 62 | get_class_config_method(MultipleEntries) 63 | 64 | self.assertIn("NoEntry", get_class_name_str(NoEntry)) 65 | with self.assertRaises(ValueError): 66 | get_class_config_method(NoEntry) 67 | -------------------------------------------------------------------------------- /torchrecipes/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Follows PEP-0440 version scheme guidelines 7 | # https://www.python.org/dev/peps/pep-0440/#version-scheme 8 | # 9 | # Examples: 10 | # 0.1.0.devN # Developmental release 11 | # 0.1.0aN # Alpha release 12 | # 0.1.0bN # Beta release 13 | # 0.1.0rcN # Release Candidate 14 | # 0.1.0 # Final release 15 | __version__: str = "0.1.0.dev2" 16 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/vision/core/__init__.py -------------------------------------------------------------------------------- /torchrecipes/vision/core/conf/datamodule/mnist_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.core.datamodule.mnist_data_module.MNISTDataModule 2 | data_dir: null 3 | val_split: 0.2 4 | num_workers: 16 5 | normalize: false 6 | batch_size: 32 7 | seed: 42 8 | shuffle: false 9 | pin_memory: false 10 | drop_last: false 11 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/conf/datamodule/torchvision_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.core.datamodule.torchvision_data_module.TorchVisionDataModule 2 | datasets: ??? 3 | batch_size: 32 4 | drop_last: false 5 | normalize: false 6 | num_workers: 16 7 | pin_memory: false 8 | seed: 42 9 | val_split: null 10 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | from torchrecipes.vision.core.datamodule.mnist_data_module import MNISTDataModule 9 | from torchrecipes.vision.core.datamodule.torchvision_data_module import ( 10 | TorchVisionDataModule, 11 | ) 12 | 13 | __all__ = [ 14 | "MNISTDataModule", 15 | "TorchVisionDataModule", 16 | ] 17 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/datamodule/tests/test_mnist_data_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | import unittest 10 | from tempfile import TemporaryDirectory 11 | 12 | import hydra 13 | import torch 14 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 15 | from torchrecipes.vision.core.datamodule.mnist_data_module import MNISTDataModule 16 | from torchrecipes.vision.core.datamodule.transforms import build_transforms 17 | from torchvision.datasets import MNIST 18 | 19 | 20 | class TestMNISTDataModule(unittest.TestCase): 21 | data_path: str 22 | 23 | @classmethod 24 | def setUpClass(cls) -> None: 25 | data_path_ctx = TemporaryDirectory() 26 | cls.addClassCleanup(data_path_ctx.cleanup) 27 | cls.data_path = data_path_ctx.name 28 | 29 | # download the dataset 30 | MNIST(cls.data_path, train=True, download=True) 31 | MNIST(cls.data_path, train=False, download=True) 32 | 33 | def test_misconfiguration(self) -> None: 34 | """Tests init configuration validation.""" 35 | with self.assertRaises(MisconfigurationException): 36 | MNISTDataModule(val_split=-1) 37 | 38 | def test_dataloading(self) -> None: 39 | """Tests loading batches from the dataset.""" 40 | module = MNISTDataModule(data_dir=self.data_path, batch_size=1) 41 | module.prepare_data() 42 | module.setup() 43 | dataloder = module.train_dataloader() 44 | batch = next(iter(dataloder)) 45 | # batch contains images and labels 46 | self.assertEqual(len(batch), 2) 47 | self.assertEqual(len(batch[0]), 1) 48 | 49 | def test_split_dataset(self) -> None: 50 | """Tests splitting the full dataset into train and validation set.""" 51 | module = MNISTDataModule(data_dir=self.data_path, val_split=100) 52 | module.prepare_data() 53 | module.setup() 54 | # pyre-fixme[6]: For 1st param expected `Sized` but got `Dataset[typing.Any]`. 55 | self.assertEqual(len(module.datasets["train"]), 59900) 56 | # pyre-fixme[6]: For 1st param expected `Sized` but got `Dataset[typing.Any]`. 57 | self.assertEqual(len(module.datasets["val"]), 100) 58 | 59 | def test_transforms(self) -> None: 60 | """Tests images being transformed correctly.""" 61 | transform_config = [ 62 | { 63 | "_target_": "torchvision.transforms.Resize", 64 | "size": 64, 65 | }, 66 | { 67 | "_target_": "torchvision.transforms.ToTensor", 68 | }, 69 | ] 70 | transforms = build_transforms(transform_config) 71 | module = MNISTDataModule( 72 | data_dir=self.data_path, batch_size=1, train_transforms=transforms 73 | ) 74 | module.prepare_data() 75 | module.setup() 76 | dataloder = module.train_dataloader() 77 | image, _ = next(iter(dataloder)) 78 | self.assertEqual(image.size(), torch.Size([1, 1, 64, 64])) 79 | 80 | def test_init_with_hydra(self) -> None: 81 | """Tests creating module with Hydra.""" 82 | test_conf = { 83 | "_target_": "torchrecipes.vision.core.datamodule.mnist_data_module.MNISTDataModule", 84 | "data_dir": None, 85 | "val_split": 0.2, 86 | "num_workers": 16, 87 | "normalize": False, 88 | "batch_size": 32, 89 | "seed": 42, 90 | "shuffle": False, 91 | "pin_memory": False, 92 | "drop_last": False, 93 | } 94 | mnist_data_module = hydra.utils.instantiate(test_conf) 95 | self.assertIsInstance(mnist_data_module, MNISTDataModule) 96 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/datamodule/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | from torchrecipes.vision.core.datamodule.transforms.builder import build_transforms 10 | 11 | __all__ = ["build_transforms"] 12 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/datamodule/transforms/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Callable, Dict, Iterable, Mapping, Sequence 7 | 8 | import hydra 9 | from torchvision.transforms import Compose # @manual 10 | 11 | 12 | def build_transforms(transforms_config: Iterable[Mapping[str, Any]]) -> Compose: 13 | transform_list = [build_single_transform(config) for config in transforms_config] 14 | transform = Compose(transform_list) 15 | return transform 16 | 17 | 18 | def build_single_transform(config: Mapping[str, Any]) -> Callable[..., object]: 19 | config = dict(config) 20 | if "transform" in config: 21 | assert isinstance(config["transform"], Sequence) 22 | transform_list = [ 23 | build_single_transform(transform) for transform in config["transform"] 24 | ] 25 | transform = Compose(transform_list) 26 | config.pop("transform") 27 | return hydra.utils.instantiate(config, transform=transform) 28 | return hydra.utils.instantiate(config) 29 | 30 | 31 | def build_transforms_from_dataset_config( 32 | dataset_conf: Dict[str, Any], 33 | ) -> Dict[str, Any]: 34 | """ 35 | This function converts transform config to transform callable, 36 | then update the dataset config to use the generated callables. 37 | """ 38 | transform_conf = dataset_conf.get("transform", None) 39 | target_transform_conf = dataset_conf.get("target_transform", None) 40 | transforms_conf = dataset_conf.get("transforms", None) 41 | 42 | if transform_conf is not None: 43 | dataset_conf["transform"] = build_transforms(transform_conf) 44 | if target_transform_conf is not None: 45 | dataset_conf["target_transform"] = build_transforms(target_transform_conf) 46 | if transforms_conf is not None: 47 | dataset_conf["transforms"] = build_transforms(transforms_conf) 48 | 49 | return dataset_conf 50 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/recipes/daf2daac75dccb7810d8dda656177c2243ccf90a/torchrecipes/vision/core/ops/__init__.py -------------------------------------------------------------------------------- /torchrecipes/vision/core/ops/fine_tuning_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from torch import nn 9 | from torch.fx.graph_module import GraphModule 10 | from torchvision.models.feature_extraction import create_feature_extractor 11 | 12 | 13 | class FineTuningWrapper(nn.Module): 14 | """ 15 | A wrapper that creates the feature extractor from a pre-trained model 16 | and forward extracted features to layers to be fine-tuned. 17 | 18 | Args: 19 | trunk (nn.Module): model on which we will extract the features. 20 | feature_layer (str): the name of the node for which the activations 21 | will be returned. 22 | head (nn.Module): layers to be fine-tuned. 23 | freeze_trunk (bool): whether to freeze all parameters in the trunk. 24 | Default to True. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | trunk: nn.Module, 30 | feature_layer: str, 31 | head: nn.Module, 32 | freeze_trunk: bool = True, 33 | ) -> None: 34 | super().__init__() 35 | self.trunk: GraphModule = create_feature_extractor(trunk, [feature_layer]) 36 | self.head = head 37 | self.feature_layer = feature_layer 38 | if freeze_trunk: 39 | self.freeze_trunk() 40 | 41 | def freeze_trunk(self) -> None: 42 | for param in self.trunk.parameters(): 43 | param.requires_grad = False 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | features = self.trunk(x) 47 | return self.head(features[self.feature_layer]) 48 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from typing import Union 3 | 4 | import torch 5 | from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR 6 | 7 | 8 | class CosineWithWarmup(SequentialLR): 9 | r"""Cosine Decay Learning Rate Scheduler with Linear Warmup. 10 | 11 | Args: 12 | optimizer (Optimizer): Wrapped optimizer. 13 | max_iters (int): Max number of iterations. (This should be number of epochs/steps 14 | based on the unit of scheduler's step size.) 15 | warmup_iters (int or float): number or fraction of iterations where 16 | linear warmup happens. Approaching the end of the linear warmup 17 | period the linear warmup line will intersect with the cosine decay curve. 18 | Default: 0 19 | last_epoch (int): The index of last epoch. Default: -1. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | optimizer: torch.optim.Optimizer, 25 | max_iters: int, 26 | warmup_iters: Union[int, float] = 0, 27 | warmup_start_factor: float = 0.0, 28 | last_epoch: int = -1, 29 | ) -> None: 30 | if isinstance(warmup_iters, float): 31 | warmup_iters = int(warmup_iters * max_iters) 32 | linear_lr = LinearLR(optimizer, warmup_start_factor, total_iters=warmup_iters) 33 | cosine_lr = CosineAnnealingLR(optimizer, T_max=max_iters - warmup_iters) 34 | super().__init__(optimizer, [linear_lr, cosine_lr], [warmup_iters], last_epoch) 35 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/tests/test_fine_tuning_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | from typing import Callable, Optional 9 | 10 | import testslide 11 | import torch 12 | from torch import nn 13 | from torchrecipes.vision.core.ops.fine_tuning_wrapper import FineTuningWrapper 14 | from torchvision.models.resnet import resnet18 15 | from torchvision.ops.misc import FrozenBatchNorm2d 16 | 17 | 18 | class TestFineTuningWrapper(testslide.TestCase): 19 | def _get_model( 20 | self, freeze_trunk: bool, norm_layer: Optional[Callable[..., nn.Module]] = None 21 | ) -> FineTuningWrapper: 22 | trunk = resnet18(norm_layer=norm_layer) 23 | head = nn.Linear(in_features=512, out_features=2) 24 | return FineTuningWrapper(trunk, "flatten", head, freeze_trunk) 25 | 26 | def test_extraction(self) -> None: 27 | model = self._get_model(freeze_trunk=True, norm_layer=FrozenBatchNorm2d) 28 | inp = torch.randn(1, 3, 224, 224) 29 | out = model(inp) 30 | self.assertEqual(out.shape, torch.Size([1, 2])) 31 | 32 | def test_freeze_trunk(self) -> None: 33 | model = self._get_model(freeze_trunk=True, norm_layer=FrozenBatchNorm2d) 34 | # trunk should be frozon 35 | params = [x for x in model.trunk.parameters() if x.requires_grad] 36 | self.assertEqual(0, len(params)) 37 | 38 | # head should be trainable 39 | params = [x for x in model.head.parameters() if x.requires_grad] 40 | self.assertEqual(2, len(params)) 41 | 42 | def test_full_fine_tuning(self) -> None: 43 | model = self._get_model(freeze_trunk=False) 44 | params = [x for x in model.parameters() if x.requires_grad] 45 | self.assertEqual(len(list(model.parameters())), len(params)) 46 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/tests/test_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # pyre-strict 3 | 4 | import unittest 5 | from typing import List, Union 6 | 7 | import torch 8 | from parameterized import parameterized 9 | from torchrecipes.vision.core.optim.lr_scheduler import CosineWithWarmup 10 | 11 | 12 | class TestCosineWithWarmup(unittest.TestCase): 13 | def _get_target_schedule(self) -> List[float]: 14 | return [ 15 | 0.001, 16 | 0.0055, 17 | 0.01, 18 | 0.009619397662556433, 19 | 0.008535533905932736, 20 | 0.006913417161825449, 21 | 0.004999999999999999, 22 | 0.003086582838174551, 23 | 0.0014644660940672624, 24 | 0.00038060233744356627, 25 | ] 26 | 27 | # pyre-ignore[16]: Module parameterized.parameterized has attribute expand 28 | @parameterized.expand([(2,), (0.2,)]) 29 | def test_lr_schedule(self, warmup_iters: Union[int, float]) -> None: 30 | """Tests learning rate matches expected schedule during model training.""" 31 | test_parameter = torch.autograd.Variable( 32 | torch.randn([5, 5]), requires_grad=True 33 | ) 34 | optimizer = torch.optim.SGD([test_parameter], lr=0.01) 35 | lr_scheduler = CosineWithWarmup( 36 | optimizer, warmup_start_factor=0.1, max_iters=10, warmup_iters=warmup_iters 37 | ) 38 | 39 | target_schedule = self._get_target_schedule() 40 | 41 | for epoch in range(10): 42 | self.assertAlmostEqual( 43 | lr_scheduler.get_last_lr()[0], target_schedule[epoch] 44 | ) 45 | lr_scheduler.step() 46 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from typing import Callable, List 8 | 9 | from torch.optim import lr_scheduler, Optimizer 10 | 11 | 12 | def sequential_lr( 13 | optimizer: Optimizer, 14 | scheduler_fns: List[Callable[[Optimizer], lr_scheduler._LRScheduler]], 15 | milestones: List[int], 16 | ) -> lr_scheduler.SequentialLR: 17 | """Helper function to construct SequentialLR with scheduler callables. 18 | 19 | Args: 20 | optimizer (Optimizer): Wrapped optimizer. 21 | scheduler_fns (List[Callable]): List of chained scheduler callables. 22 | milestones (List[int]): List of integers that reflects milestone points. 23 | """ 24 | schedulers = [fn(optimizer) for fn in scheduler_fns] 25 | # pyre-fixme[6]: For 2nd param expected `List[LRScheduler]` but got 26 | # `List[_LRScheduler]`. 27 | return lr_scheduler.SequentialLR(optimizer, schedulers, milestones) 28 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/utils/model_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import logging 8 | from typing import Dict 9 | 10 | import torch 11 | from iopath.common.file_io import g_pathmgr 12 | from torch import nn 13 | 14 | logger: logging.Logger = logging.getLogger(__name__) 15 | 16 | 17 | def load_model_weights( 18 | module: nn.Module, weights_path: str, strict: bool = True 19 | ) -> nn.Module: 20 | """ 21 | Loads model weights from given model weights path in-place. 22 | 23 | Args: 24 | module (nn.Module): module to be operated on. 25 | weights_path (str): path to model weight file in state dict format. 26 | strict (bool): whether to load state dict in strict mode. 27 | """ 28 | with g_pathmgr.open(weights_path, "rb") as f: 29 | weights = torch.load(f, map_location="cpu") 30 | module.load_state_dict(weights, strict=strict) 31 | logger.info(f"Loaded model weights from {weights_path}.") 32 | return module 33 | 34 | 35 | def extract_model_weights_from_checkpoint( 36 | checkpoint_path: str, model_name: str = "" 37 | ) -> Dict[str, torch.Tensor]: 38 | """ 39 | Extracts model weights from given Lightning checkpoint. 40 | 41 | Args: 42 | checkpoint_path (str): path to the Lightning checkpoint. 43 | model_name (str): name of model attribute in the Lightning module. 44 | Set to empty if model is the Lightning module itself. 45 | """ 46 | with g_pathmgr.open(checkpoint_path, "rb") as f: 47 | ckpt = torch.load(f, map_location="cpu") 48 | if "state_dict" not in ckpt: 49 | raise ValueError( 50 | 'The checkpoint doesn\'t have key "state_dict",' 51 | " please make sure it's a valid Lightning checkpoint." 52 | ) 53 | state_dict = ckpt["state_dict"] 54 | logger.info(f"Loaded state dict from checkpoint {checkpoint_path}.") 55 | 56 | prefix_len = 0 if not model_name else len(model_name) + 1 # e.g. "model." 57 | weights: Dict[str, torch.Tensor] = {} 58 | for k, v in state_dict.items(): 59 | if model_name and k.startswith(model_name): 60 | weights[k[prefix_len:]] = v 61 | if not weights: 62 | raise ValueError( 63 | f"No model weights found with prefix '{model_name}' in provided state dict" 64 | ) 65 | return weights 66 | -------------------------------------------------------------------------------- /torchrecipes/vision/core/utils/model_weights_exporter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | import argparse 9 | import logging 10 | from typing import Dict 11 | 12 | import torch 13 | from iopath.common.file_io import g_pathmgr 14 | from torchrecipes.vision.core.utils.model_weights import ( 15 | extract_model_weights_from_checkpoint, 16 | ) 17 | 18 | 19 | def main() -> None: 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--checkpoint-path", 23 | type=str, 24 | required=True, 25 | help="Path to the Lightning checkpoint.", 26 | ) 27 | parser.add_argument( 28 | "--model-name", 29 | type=str, 30 | default="", 31 | help="Name of model attribute in Lightning module.", 32 | ) 33 | parser.add_argument( 34 | "--model-weights-path", type=str, help="Export model weights to path." 35 | ) 36 | 37 | args: argparse.Namespace = parser.parse_args() 38 | weights: Dict[str, torch.Tensor] = extract_model_weights_from_checkpoint( 39 | args.checkpoint_path, args.model_name 40 | ) 41 | 42 | if args.model_weights_path: 43 | with g_pathmgr.open(args.model_weights_path, "wb") as f: 44 | torch.save(weights, f) 45 | logging.info(f"Saved model weights to {args.model_weights_path}.") 46 | 47 | 48 | if __name__ == "__main__": 49 | main() # pragma: no cover 50 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/README.md: -------------------------------------------------------------------------------- 1 | # Image Classification Training Recipe 2 | 3 | Recipe for training image classification models from TorchVision. 4 | 5 | ## Training 6 | 7 | ### Launch Jobs without TorchX 8 | 9 | * We can easily launch a job to train a ResNet18 model on the CIFAR10 dataset with the following command: 10 | 11 | ```bash 12 | python torchrecipes/vision/image_classification/main.py 13 | ``` 14 | 15 | * Config overrides allow us to swap out different parts of the training job. For example, the following command launches a job to train a ResNet50 model on GPUs: 16 | 17 | ```bash 18 | python torchrecipes/vision/image_classification/main.py --config-name default_config module/model=resnet50 trainer=gpu 19 | ``` 20 | 21 | ### Launch Jobs with TorchX 22 | 23 | * We often use [TorchX](https://pytorch.org/torchx) to launch training jobs across different environments. You can install TorchX with 24 | 25 | ```bash 26 | pip install torchx 27 | ``` 28 | 29 | * Training jobs can then be launched with the following commands: 30 | 31 | ```bash 32 | torchx run --scheduler local_cwd utils.python --script torchrecipes/vision/image_classification/main.py 33 | ``` 34 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | def register_components() -> None: 8 | """ 9 | Imports all python files in the folder to trigger the 10 | code to register them to Hydra's ConfigStore. 11 | """ 12 | import torchrecipes.vision.image_classification.callbacks.mixup_transform # noqa 13 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/callbacks/mixup_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Dict, Optional 7 | 8 | import torch 9 | from pyre_extensions import none_throws 10 | from pytorch_lightning.callbacks import Callback 11 | from pytorch_lightning.core.lightning import LightningModule 12 | from pytorch_lightning.trainer.trainer import Trainer 13 | from torch.distributions.beta import Beta 14 | 15 | 16 | def convert_to_one_hot(targets: torch.Tensor, num_classes: int) -> torch.Tensor: 17 | """ 18 | This function converts target class indices to one-hot vectors, 19 | given the number of classes. 20 | 21 | """ 22 | assert ( 23 | torch.max(targets).item() < num_classes 24 | ), "Class Index must be less than number of classes" 25 | one_hot_targets = torch.zeros( 26 | (targets.shape[0], num_classes), dtype=torch.long, device=targets.device 27 | ) 28 | one_hot_targets.scatter_(1, targets.long(), 1) 29 | return one_hot_targets 30 | 31 | 32 | class MixupTransform(Callback): 33 | """ 34 | This implements the mixup data augmentation in the paper 35 | "mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412) 36 | """ 37 | 38 | def __init__(self, alpha: float, num_classes: Optional[int] = None) -> None: 39 | """ 40 | Args: 41 | alpha: the hyperparameter of Beta distribution used to sample mixup 42 | coefficient. 43 | num_classes: number of classes in the dataset. 44 | """ 45 | self.alpha = alpha 46 | self.num_classes = num_classes 47 | 48 | def on_train_batch_start( 49 | self, 50 | trainer: Trainer, 51 | pl_module: LightningModule, 52 | batch: Dict[str, Any], 53 | batch_idx: int, 54 | unused: Optional[int] = None, 55 | ) -> None: 56 | if batch["target"].ndim == 1: 57 | assert ( 58 | self.num_classes is not None 59 | ), f"num_classes is expected for 1D target: {batch['target']}" 60 | batch["target"] = convert_to_one_hot( 61 | batch["target"].view(-1, 1), none_throws(self.num_classes) 62 | ) 63 | else: 64 | assert batch["target"].ndim == 2, "target tensor shape must be 1D or 2D" 65 | 66 | c = ( 67 | Beta(self.alpha, self.alpha) 68 | .sample(sample_shape=torch.Size()) 69 | .to(device=batch["target"].device) 70 | ) 71 | permuted_indices = torch.randperm(batch["target"].shape[0]) 72 | for key in ["input", "target"]: 73 | batch[key] = c * batch[key] + (1.0 - c) * batch[key][permuted_indices, :] 74 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # flake8: noqa 9 | import torch 10 | import torchvision.models 11 | from pytorch_lightning.callbacks import LearningRateMonitor 12 | from torch.optim import SGD 13 | from torch.optim.lr_scheduler import StepLR 14 | from torchmetrics import AveragePrecision 15 | from torchrecipes.vision.core.datamodule.torchvision_data_module import ( 16 | TorchVisionDataModule, 17 | ) 18 | from torchrecipes.vision.image_classification.module.image_classification import ( 19 | ImageClassificationModule, 20 | ) 21 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/datamodule/datasets/cifar10.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | _target_: torchvision.datasets.CIFAR10 3 | train: true 4 | root: /tmp/resnet18/dataset/ 5 | download: true 6 | transform: 7 | _target_: torchvision.transforms.Compose 8 | transforms: 9 | - _target_: torchvision.transforms.Resize 10 | size: 64 11 | - _target_: torchvision.transforms.ToTensor 12 | val: null 13 | test: 14 | _target_: torchvision.datasets.CIFAR10 15 | train: false 16 | root: /tmp/resnet18/dataset/ 17 | download: true 18 | transform: 19 | _target_: torchvision.transforms.Compose 20 | transforms: 21 | - _target_: torchvision.transforms.Resize 22 | size: 64 23 | - _target_: torchvision.transforms.ToTensor 24 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/datamodule/datasets/fake_data.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | _target_: torchvision.datasets.FakeData 3 | transform: 4 | _target_: torchvision.transforms.ToTensor 5 | val: 6 | _target_: torchvision.datasets.FakeData 7 | transform: 8 | _target_: torchvision.transforms.ToTensor 9 | test: 10 | _target_: torchvision.datasets.FakeData 11 | transform: 12 | _target_: torchvision.transforms.ToTensor 13 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/datamodule/torchvision_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.core.datamodule.torchvision_data_module.TorchVisionDataModule 2 | datasets: ??? 3 | batch_size: 32 4 | drop_last: false 5 | normalize: false 6 | num_workers: 16 7 | pin_memory: false 8 | seed: 42 9 | val_split: 0.9 10 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/default_config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | # module 3 | - module: default_module 4 | - module/model: resnet18 5 | - module/loss: cross_entropy 6 | - module/optim: sgd 7 | - module/metrics: accuracy 8 | # datamodule 9 | - datamodule: torchvision_datamodule 10 | - datamodule/datasets: cifar10 11 | # trainer 12 | - trainer: cpu 13 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/default_module.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_classification.module.image_classification.ImageClassificationModule 2 | model: ??? 3 | loss: ??? 4 | optim: ??? 5 | metrics: ??? 6 | lr_scheduler: null 7 | apply_softmax: false 8 | norm_weight_decay: 0.0 9 | lr_scheduler_interval: epoch 10 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/loss/cross_entropy.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.CrossEntropyLoss 2 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/lr_scheduler/step_lr.yaml: -------------------------------------------------------------------------------- 1 | # use _partial_ to instantiate an LR scheduler callable, which returns 2 | # an LR scheduler in configure_optimizer. 3 | _target_: torch.optim.lr_scheduler.StepLR 4 | _partial_: true 5 | step_size: 10 6 | gamma: 0.1 7 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/metrics/accuracy.yaml: -------------------------------------------------------------------------------- 1 | accuracy: 2 | _target_: torchmetrics.Accuracy 3 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/metrics/accuracy_top1_top5.yaml: -------------------------------------------------------------------------------- 1 | accuracy_top1: 2 | _target_: torchmetrics.Accuracy 3 | top_k: 1 4 | accuracy_top5: 5 | _target_: torchmetrics.Accuracy 6 | top_k: 5 7 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/metrics/average_precision.yaml: -------------------------------------------------------------------------------- 1 | average_precision: 2 | _target_: torchmetrics.AveragePrecision 3 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/metrics/multilabel_accuracy.yaml: -------------------------------------------------------------------------------- 1 | accuracy_top1: 2 | _target_: torchrecipes.vision.image_classification.metrics.multilabel_accuracy.MultilabelAccuracy 3 | top_k: 1 4 | accuracy_top5: 5 | _target_: torchrecipes.vision.image_classification.metrics.multilabel_accuracy.MultilabelAccuracy 6 | top_k: 5 7 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/model/resnet18.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.models.resnet18 2 | pretrained: false 3 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/model/resnet50.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.models.resnet50 2 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/model/resnext101_32x4d.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchvision.models.ResNet 2 | block: '${get_method: torchvision.models.resnet.Bottleneck}' 3 | layers: [3, 4, 23, 3] 4 | groups: 32 5 | width_per_group: 4 6 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/module/optim/sgd.yaml: -------------------------------------------------------------------------------- 1 | # use _partial_ to instantiate an optimizer callable, which returns 2 | # an optimizer in configure_optimizer. 3 | _target_: torch.optim.SGD 4 | _partial_: true 5 | lr: 0.1 6 | weight_decay: 1e-4 7 | momentum: 0.9 8 | nesterov: False 9 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.trainer.Trainer 2 | default_root_dir: null 3 | accelerator: cpu 4 | devices: null 5 | strategy: null 6 | max_epochs: 1 7 | enable_checkpointing: false 8 | fast_dev_run: false 9 | logger: 10 | _target_: pytorch_lightning.loggers.TensorBoardLogger 11 | save_dir: /tmp/logs 12 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/conf/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.trainer.Trainer 2 | default_root_dir: null 3 | accelerator: gpu 4 | devices: null 5 | strategy: ddp 6 | max_epochs: 20 7 | fast_dev_run: false 8 | logger: 9 | _target_: pytorch_lightning.loggers.TensorBoardLogger 10 | save_dir: /tmp/logs 11 | callbacks: 12 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 13 | dirpath: /tmp/checkpoints 14 | - _target_: pytorch_lightning.callbacks.LearningRateMonitor 15 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/losses/soft_target_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | import torch 10 | 11 | 12 | def _convert_to_one_hot(targets: torch.Tensor, classes: int) -> torch.Tensor: 13 | """ 14 | This function converts target class indices to one-hot vectors, 15 | given the number of classes. 16 | 17 | """ 18 | if torch.max(targets).item() >= classes: 19 | raise ValueError("Class Index must be less than number of classes") 20 | one_hot_targets = torch.zeros( 21 | (targets.shape[0], classes), dtype=torch.long, device=targets.device 22 | ) 23 | one_hot_targets.scatter_(1, targets.long(), 1) 24 | return one_hot_targets 25 | 26 | 27 | class SoftTargetCrossEntropyLoss(torch.nn.CrossEntropyLoss): 28 | """This loss allows the targets for the cross entropy loss to be multi-label. 29 | 30 | Args: 31 | reduction (str): specifies reduction to apply to the output. 32 | normalize_targets (bool): whether the targets should be normalized to a sum of 1 33 | based on the total count of positive targets for a given sample. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | reduction: str = "mean", 39 | normalize_targets: bool = True, 40 | ) -> None: 41 | super().__init__(reduction=reduction) 42 | self.normalize_targets = normalize_targets 43 | self._eps: float = torch.finfo(torch.float32).eps 44 | 45 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 46 | target = target.detach().clone() 47 | # Check if targets are inputted as class integers 48 | if target.ndim == 1: 49 | if input.shape[0] != target.shape[0]: 50 | raise ValueError( 51 | "SoftTargetCrossEntropyLoss requires input and target to have same batch size!" 52 | ) 53 | target = _convert_to_one_hot(target.view(-1, 1), input.shape[1]) 54 | target = target.float() 55 | if self.normalize_targets: 56 | target /= self._eps + target.sum(dim=1, keepdim=True) 57 | 58 | return super().forward(input, target) 59 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/losses/tests/test_soft_target_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import copy 8 | 9 | import testslide 10 | import torch 11 | from torchrecipes.vision.image_classification.losses.soft_target_cross_entropy_loss import ( 12 | SoftTargetCrossEntropyLoss, 13 | ) 14 | 15 | 16 | class TestSoftTargetCrossEntropyLoss(testslide.TestCase): 17 | def _get_outputs(self) -> torch.Tensor: 18 | return torch.tensor([[1.0, 7.0, 0.0, 0.0, 2.0]]) 19 | 20 | def _get_targets(self) -> torch.Tensor: 21 | return torch.tensor([[1, 0, 0, 0, 1]]) 22 | 23 | def _get_loss(self) -> float: 24 | return 5.51097965 25 | 26 | def test_soft_target_cross_entropy(self) -> None: 27 | crit = SoftTargetCrossEntropyLoss(reduction="mean") 28 | outputs = self._get_outputs() 29 | targets = self._get_targets() 30 | self.assertAlmostEqual(crit(outputs, targets).item(), self._get_loss()) 31 | 32 | def test_soft_target_cross_entropy_none_reduction(self) -> None: 33 | crit = SoftTargetCrossEntropyLoss(reduction="none") 34 | 35 | outputs = torch.tensor([[1.0, 7.0, 0.0, 0.0, 2.0], [4.0, 2.0, 1.0, 6.0, 0.5]]) 36 | targets = torch.tensor([[1, 0, 0, 0, 1], [0, 1, 0, 1, 0]]) 37 | loss = crit(outputs, targets) 38 | self.assertEqual(loss.numel(), outputs.size(0)) 39 | 40 | def test_soft_target_cross_entropy_integer_label(self) -> None: 41 | crit = SoftTargetCrossEntropyLoss(reduction="mean") 42 | outputs = self._get_outputs() 43 | targets = torch.tensor([4]) 44 | self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918) 45 | 46 | def test_unnormalized_soft_target_cross_entropy(self) -> None: 47 | crit = SoftTargetCrossEntropyLoss(reduction="none", normalize_targets=False) 48 | outputs = self._get_outputs() 49 | targets = self._get_targets() 50 | self.assertAlmostEqual(crit(outputs, targets).item(), 11.0219593) 51 | 52 | def test_deep_copy(self) -> None: 53 | crit = SoftTargetCrossEntropyLoss(reduction="mean") 54 | outputs = self._get_outputs() 55 | targets = self._get_targets() 56 | crit(outputs, targets) 57 | 58 | crit2 = copy.deepcopy(crit) 59 | self.assertAlmostEqual(crit2(outputs, targets).item(), self._get_loss()) 60 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | import logging 9 | from dataclasses import dataclass 10 | from typing import Optional 11 | 12 | import hydra 13 | import torchrecipes.vision.image_classification.conf # noqa 14 | from omegaconf import DictConfig, OmegaConf 15 | from pytorch_lightning import seed_everything 16 | 17 | log: logging.Logger = logging.getLogger(__name__) 18 | 19 | 20 | @dataclass 21 | class TrainOutput: 22 | log_dir: Optional[str] = None 23 | best_model_path: Optional[str] = None 24 | 25 | 26 | @hydra.main(config_path="conf", config_name="default_config") 27 | def main(config: DictConfig) -> TrainOutput: 28 | seed = config.get("seed", 0) 29 | seed_everything(seed, workers=True) 30 | log.info(f"Config:\n{OmegaConf.to_yaml(config)}") 31 | 32 | log.info("Instantiating a datamodule, a module, and a trainer") 33 | datamodule = hydra.utils.instantiate(config.datamodule) 34 | trainer = hydra.utils.instantiate(config.trainer) 35 | module = hydra.utils.instantiate(config.module) 36 | 37 | if getattr(config, "pretrained_checkpoint_path", None): 38 | log.info(f"Loading module from checkpoint {config.pretrained_checkpoint_path}") 39 | module = module.load_from_checkpoint( 40 | checkpoint_path=config.pretrained_checkpoint_path 41 | ) 42 | 43 | log.info("Training started") 44 | trainer.fit(module, datamodule=datamodule) 45 | logging.info("Testing started") 46 | trainer.test(module, datamodule=datamodule) 47 | 48 | train_output = TrainOutput( 49 | best_model_path=getattr(trainer.checkpoint_callback, "best_model_path", None), 50 | log_dir=getattr(trainer.logger, "save_dir", None), 51 | ) 52 | log.info(f"Training output: {train_output}") 53 | return train_output 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/metrics/multilabel_accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | from torchmetrics.metric import Metric 13 | 14 | 15 | class MultilabelAccuracy(Metric): 16 | """Computes top-k accuracy for multilabel targets. A sample is considered 17 | correctly classified if the top-k predictions contain any of the labels. 18 | 19 | Args: 20 | top_k: Number of highest score predictions considered to find the 21 | correct label. 22 | dist_sync_on_step: Synchronize metric state across processes at each 23 | forward() before returning the value at the step. 24 | """ 25 | 26 | def __init__( 27 | self, top_k: int, compute_on_step: bool = True, dist_sync_on_step: bool = False 28 | ) -> None: 29 | super().__init__( 30 | compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step 31 | ) 32 | 33 | self._top_k = top_k 34 | self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") 35 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 36 | 37 | # pyre-fixme[14]: `update` overrides method defined in `Metric` inconsistently. 38 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: 39 | """Updates the state with predictions and target. 40 | Args: 41 | preds: tensor of shape (B, C) where each value is either logit or 42 | class probability. 43 | target: tensor of shape (B, C), which is one-hot / multi-label 44 | encoded. 45 | """ 46 | assert preds.shape == target.shape, ( 47 | "predictions and target must be of the same shape. " 48 | f"Got preds({preds.shape}) vs target({target.shape})." 49 | ) 50 | num_classes = target.shape[1] 51 | assert ( 52 | num_classes >= self._top_k 53 | ), f"top-k({self._top_k}) is greater than the number of classes({num_classes})" 54 | preds, target = self._format_inputs(preds, target) 55 | 56 | _, top_idx = preds.topk(self._top_k, dim=1, largest=True, sorted=True) 57 | 58 | # pyre-ignore[16]: Accuracy has attribute correct 59 | # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__i... 60 | self.correct += ( 61 | torch.gather(target, dim=1, index=top_idx[:, : self._top_k]) 62 | .max(dim=1) 63 | .values.sum() 64 | .item() 65 | ) 66 | # pyre-ignore[16]: Accuracy has attribute total 67 | # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__i... 68 | self.total += preds.shape[0] 69 | 70 | def compute(self) -> torch.Tensor: 71 | # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Tensor, 72 | # Module]`. 73 | if torch.is_nonzero(self.total): 74 | # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase... 75 | return self.correct / self.total 76 | return torch.tensor(0.0) 77 | 78 | @staticmethod 79 | def _format_inputs( 80 | preds: torch.Tensor, target: torch.Tensor 81 | ) -> Tuple[torch.Tensor, torch.Tensor]: 82 | # Since .topk() is not compatible with fp16, we promote the predictions to full precision 83 | return (preds.float(), target.int()) 84 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/metrics/tests/test_multilabel_accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import testslide 8 | import torch 9 | from torchrecipes.vision.image_classification.metrics.multilabel_accuracy import ( 10 | MultilabelAccuracy, 11 | ) 12 | 13 | 14 | class TestMultilabelAccuracy(testslide.TestCase): 15 | def test_top_k(self) -> None: 16 | metric = MultilabelAccuracy(top_k=2) 17 | self.assertTrue(torch.equal(metric.compute(), torch.tensor(0.0))) 18 | 19 | preds = torch.tensor([[1.0, -1.0, 2.0]]) 20 | target = torch.tensor([[0, 1.0, 0]]) 21 | metric.update(preds, target) 22 | # none of the labels is in the top_k prediction 23 | self.assertTrue(torch.equal(metric.compute(), torch.tensor(0.0))) 24 | 25 | preds = torch.tensor([[1.0, -0.5, 2.0]]) 26 | target = torch.tensor([[1, 1, 0]]) 27 | metric.update(preds, target) 28 | # one of the labels is in the top_k prediction 29 | # one out of the two samples is correctly classified 30 | self.assertTrue(torch.equal(metric.compute(), torch.tensor(0.5))) 31 | 32 | def test_invalid_top_k(self) -> None: 33 | metric = MultilabelAccuracy(top_k=10) 34 | 35 | preds = torch.tensor([[1.0]]) 36 | target = torch.tensor([[0]]) 37 | with self.assertRaises(AssertionError): 38 | metric.update(preds, target) 39 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/tests/test_image_classification_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | # pyre-strict 8 | from functools import partial 9 | 10 | import testslide 11 | import torch 12 | from torchrecipes.vision.image_classification.module.image_classification import ( 13 | ImageClassificationModule, 14 | ) 15 | from torchvision.models.resnet import resnet18 16 | 17 | 18 | class TestImageClassificationModule(testslide.TestCase): 19 | def test_custom_norm_weight_decay(self) -> None: 20 | module = ImageClassificationModule( 21 | model=resnet18(), 22 | loss=torch.nn.CrossEntropyLoss(), 23 | optim=partial(torch.optim.SGD, lr=0.1), 24 | metrics={}, 25 | norm_weight_decay=0.1, 26 | ) 27 | 28 | param_groups = module.get_optimizer_param_groups() 29 | self.assertEqual(2, len(param_groups)) 30 | self.assertEqual(0.1, param_groups[1]["weight_decay"]) 31 | 32 | def test_custom_optimizer_interval(self) -> None: 33 | module = ImageClassificationModule( 34 | model=resnet18(), 35 | loss=torch.nn.CrossEntropyLoss(), 36 | optim=partial(torch.optim.SGD, lr=0.1), 37 | # pyre-fixme[6]: For 4th param expected 38 | # `Optional[typing.Callable[[Optimizer], _LRScheduler]]` but got 39 | # `partial[StepLR]`. 40 | lr_scheduler=partial(torch.optim.lr_scheduler.StepLR, step_size=10), 41 | metrics={}, 42 | lr_scheduler_interval="step", 43 | ) 44 | optim = module.configure_optimizers() 45 | # pyre-ignore[16]: optim["lr_scheduler"] has key "interval" 46 | self.assertEqual("step", optim["lr_scheduler"]["interval"]) 47 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_classification/tests/test_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | # pyre-strict 10 | from copy import deepcopy 11 | 12 | import hydra 13 | import testslide 14 | import torch 15 | import torchrecipes.vision.image_classification.conf # noqa 16 | from omegaconf import DictConfig 17 | from torch import nn 18 | from torchrecipes.utils.test import tempdir 19 | from torchrecipes.vision.core.ops.fine_tuning_wrapper import FineTuningWrapper 20 | from torchrecipes.vision.image_classification.main import main 21 | from torchvision.models.resnet import resnet18 22 | from torchvision.ops.misc import FrozenBatchNorm2d 23 | 24 | 25 | class TestMain(testslide.TestCase): 26 | def _get_config(self, tb_save_dir: str) -> DictConfig: 27 | with hydra.initialize_config_module( 28 | config_module="torchrecipes.vision.image_classification.conf" 29 | ): 30 | config = hydra.compose( 31 | config_name="default_config", 32 | overrides=[ 33 | "datamodule/datasets=fake_data", 34 | "+module.model.num_classes=10", 35 | "trainer.enable_checkpointing=false", 36 | "trainer.fast_dev_run=true", 37 | f"trainer.logger.save_dir={tb_save_dir}", 38 | ], 39 | ) 40 | return config 41 | 42 | @tempdir 43 | def test_train_model(self, root_dir: str) -> None: 44 | config = self._get_config(tb_save_dir=root_dir) 45 | output = main(config) 46 | self.assertIsNotNone(output) 47 | 48 | @tempdir 49 | def test_fine_tuning(self, root_dir: str) -> None: 50 | trunk = resnet18(norm_layer=FrozenBatchNorm2d) 51 | head = nn.Linear(in_features=512, out_features=10) 52 | fine_tune_model = FineTuningWrapper(trunk, "flatten", head) 53 | origin_trunk = deepcopy(fine_tune_model.trunk) 54 | 55 | config = self._get_config(tb_save_dir=root_dir) 56 | datamodule = hydra.utils.instantiate(config.datamodule) 57 | trainer = hydra.utils.instantiate(config.trainer) 58 | module = hydra.utils.instantiate(config.module, model=fine_tune_model) 59 | trainer.fit(module, datamodule=datamodule) 60 | 61 | with torch.no_grad(): 62 | inp = torch.randn(1, 3, 28, 28) 63 | origin_out = origin_trunk(inp) 64 | tuned_out = module.model.trunk(inp) 65 | self.assertTrue(torch.equal(origin_out["flatten"], tuned_out["flatten"])) 66 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | 10 | def register_components() -> None: 11 | """ 12 | Calls register_components() for all subfolders so we can register 13 | subcomponents to Hydra's ConfigStore. 14 | """ 15 | import torchrecipes.vision.image_generation.module.gan # noqa 16 | import torchrecipes.vision.image_generation.module.infogan # noqa 17 | import torchrecipes.vision.image_generation.train_app # noqa 18 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | from torchrecipes.vision.image_generation.callbacks.image_generation import ( 10 | TensorboardGenerativeModelImageSampler, 11 | ) 12 | 13 | __all__ = [ 14 | "TensorboardGenerativeModelImageSampler", 15 | ] 16 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | import torchrecipes.core.conf # noqa 10 | import torchrecipes.vision.core.datamodule # noqa 11 | 12 | # Components to register with this config 13 | from torchrecipes.vision.image_generation import register_components 14 | 15 | register_components() 16 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/datamodule/datasets/fake_data.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | _target_: torchvision.datasets.FakeData 3 | image_size: 4 | - 1 5 | - 28 6 | - 28 7 | transform: 8 | _target_: torchvision.transforms.ToTensor 9 | val: 10 | _target_: torchvision.datasets.FakeData 11 | image_size: 12 | - 1 13 | - 28 14 | - 28 15 | transform: 16 | _target_: torchvision.transforms.ToTensor 17 | test: 18 | _target_: torchvision.datasets.FakeData 19 | image_size: 20 | - 1 21 | - 28 22 | - 28 23 | transform: 24 | _target_: torchvision.transforms.ToTensor 25 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/datamodule/mnist.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.core.datamodule.mnist_data_module.MNISTDataModule 2 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/datamodule/torchvision_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.core.datamodule.torchvision_data_module.TorchVisionDataModule 2 | datasets: ??? 3 | batch_size: 32 4 | drop_last: false 5 | normalize: false 6 | num_workers: 16 7 | pin_memory: false 8 | seed: 42 9 | val_split: null 10 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/datamodule/transforms/resize.yaml: -------------------------------------------------------------------------------- 1 | resize: 64 2 | train: 3 | transforms_config: 4 | - _target_: torchvision.transforms.Resize 5 | size: ${datamodule.transforms.resize} 6 | - _target_: torchvision.transforms.ToTensor 7 | val: 8 | transforms_config: 9 | - _target_: torchvision.transforms.Resize 10 | size: ${datamodule.transforms.resize} 11 | - _target_: torchvision.transforms.ToTensor 12 | test: 13 | transforms_config: 14 | - _target_: torchvision.transforms.Resize 15 | size: ${datamodule.transforms.resize} 16 | - _target_: torchvision.transforms.ToTensor 17 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/gan_train_app.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_generation.train_app.GANTrainApp 2 | 3 | defaults: 4 | - _self_ 5 | # module 6 | - schema/module: gan_module_conf 7 | - module/generator: gan 8 | - module/discriminator: gan 9 | - module/criterion: bce_loss 10 | - module/optim: default 11 | # datamodule 12 | - datamodule: mnist 13 | # trainer 14 | - schema/trainer: trainer 15 | - trainer: cpu 16 | 17 | hydra: 18 | searchpath: 19 | - pkg://torchrecipes.core.conf 20 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/infogan_train_app.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_generation.train_app.GANTrainApp 2 | 3 | defaults: 4 | - _self_ 5 | # module 6 | - schema/module: infogan_module_conf 7 | - module/generator: infogan 8 | - module/discriminator: infogan 9 | - module/optim: infogan_default 10 | # datamodule 11 | - datamodule: mnist 12 | # trainer 13 | - schema/trainer: trainer 14 | - trainer: cpu 15 | 16 | hydra: 17 | searchpath: 18 | - pkg://torchrecipes.core.conf 19 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/criterion/bce_loss.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.BCELoss 2 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/discriminator/dcgan.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_generation.models.dcgan.Discriminator 2 | feature_maps: 28 3 | image_channels: 1 4 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/discriminator/gan.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_generation.models.gan.Discriminator 2 | img_shape: 3 | - 1 4 | - 28 5 | - 28 6 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/discriminator/infogan.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_generation.models.infogan.Discriminator 2 | n_classes: ${module.generator.n_classes} 3 | code_dim: ${module.generator.code_dim} 4 | img_size: ${module.generator.img_size} 5 | channels: ${module.generator.channels} 6 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/generator/dcgan.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_generation.models.dcgan.Generator 2 | latent_dim: 32 3 | feature_maps: 28 4 | image_channels: 1 5 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/generator/gan.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_generation.models.gan.Generator 2 | latent_dim: 32 3 | img_shape: 4 | - 1 5 | - 28 6 | - 28 7 | hidden_dim: 256 8 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/generator/infogan.yaml: -------------------------------------------------------------------------------- 1 | _target_: torchrecipes.vision.image_generation.models.infogan.Generator 2 | latent_dim: 62 3 | n_classes: 10 4 | code_dim: 2 5 | img_size: 32 6 | channels: 1 7 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/optim/default.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.0002 2 | beta1: 0.9 3 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/optim/infogan_default.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.0002 2 | b1: 0.5 3 | b2: 0.999 4 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/conf/module/tests/test_gan_module_conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | import unittest 10 | 11 | import hydra 12 | from hydra.experimental import compose, initialize_config_module 13 | from torchrecipes.vision.image_generation.module.gan import GAN 14 | 15 | 16 | class TestGANModuleConf(unittest.TestCase): 17 | def test_init_with_hydra(self) -> None: 18 | with initialize_config_module( 19 | config_module="torchrecipes.vision.image_generation.conf" 20 | ): 21 | test_conf = compose( 22 | config_name="gan_train_app", 23 | ) 24 | test_module = hydra.utils.instantiate(test_conf.module, _recursive_=False) 25 | self.assertIsInstance(test_module, GAN) 26 | self.assertIsNotNone(test_module.generator) 27 | self.assertIsNotNone(test_module.discriminator) 28 | self.assertIsNotNone(test_module.criterion) 29 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/models/gan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | # based on https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/gans/basic/basic_gan_module.py 9 | from typing import Tuple 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | 16 | 17 | class Generator(nn.Module): 18 | """Generator model from the 19 | `"Generative Adversarial Networks" `_ paper. 20 | 21 | Args: 22 | latent_dim (int): dimension of latent 23 | img_shape (tuple): shape of image tensor 24 | hidden_dim (int): dimension of hidden layer 25 | """ 26 | 27 | def __init__( 28 | self, latent_dim: int, img_shape: Tuple[int, int, int], hidden_dim: int = 256 29 | ) -> None: 30 | super().__init__() 31 | feats = int(np.prod(img_shape)) 32 | self.img_shape = img_shape 33 | self.fc1 = nn.Linear(latent_dim, hidden_dim) 34 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features * 2) 35 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features * 2) 36 | self.fc4 = nn.Linear(self.fc3.out_features, feats) 37 | 38 | # forward method 39 | def forward(self, z: torch.Tensor) -> torch.Tensor: 40 | z = F.leaky_relu(self.fc1(z), 0.2) 41 | z = F.leaky_relu(self.fc2(z), 0.2) 42 | z = F.leaky_relu(self.fc3(z), 0.2) 43 | img = torch.tanh(self.fc4(z)) 44 | img = img.view(img.size(0), *self.img_shape) 45 | return img 46 | 47 | 48 | class Discriminator(nn.Module): 49 | """Discriminator model from the 50 | `"Generative Adversarial Networks" `_ paper. 51 | 52 | Args: 53 | img_shape (tuple): shape of image tensor 54 | hidden_dim (int): dimension of hidden layer 55 | """ 56 | 57 | def __init__(self, img_shape: Tuple[int, int, int], hidden_dim: int = 1024) -> None: 58 | super().__init__() 59 | in_dim = int(np.prod(img_shape)) 60 | self.fc1 = nn.Linear(in_dim, hidden_dim) 61 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2) 62 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2) 63 | self.fc4 = nn.Linear(self.fc3.out_features, 1) 64 | 65 | # forward method 66 | def forward(self, img: torch.Tensor) -> torch.Tensor: 67 | x = img.view(img.size(0), -1) 68 | x = F.leaky_relu(self.fc1(x), 0.2) 69 | x = F.dropout(x, 0.3) 70 | x = F.leaky_relu(self.fc2(x), 0.2) 71 | x = F.dropout(x, 0.3) 72 | x = F.leaky_relu(self.fc3(x), 0.2) 73 | x = F.dropout(x, 0.3) 74 | return torch.sigmoid(self.fc4(x)) 75 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/tests/test_image_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | import unittest 9 | from unittest.mock import patch 10 | 11 | import torch 12 | import torchvision 13 | from pytorch_lightning import Trainer 14 | from pytorch_lightning.loggers import CSVLogger 15 | from torchrecipes.utils.test import tempdir 16 | from torchrecipes.vision.core.utils.test_module import TestModule 17 | from torchrecipes.vision.image_generation.callbacks import ( 18 | TensorboardGenerativeModelImageSampler, 19 | ) 20 | 21 | 22 | class TestGANModule(unittest.TestCase): 23 | @tempdir 24 | def test_module_without_dimension(self, tmp_dir: str) -> None: 25 | """tests using the callback with a module that doesn't define image and 26 | latent dimension. 27 | """ 28 | module = TestModule() 29 | trainer = Trainer( 30 | default_root_dir=tmp_dir, 31 | fast_dev_run=True, 32 | callbacks=[TensorboardGenerativeModelImageSampler()], 33 | ) 34 | 35 | with self.assertRaises(AssertionError): 36 | trainer.fit(module) 37 | 38 | @tempdir 39 | def test_logger_without_add_image(self, tmp_dir: str) -> None: 40 | """tests using the callback with an unsupported logger.""" 41 | module = TestModule() 42 | trainer = Trainer( 43 | default_root_dir=tmp_dir, 44 | fast_dev_run=True, 45 | logger=CSVLogger(tmp_dir), 46 | callbacks=[TensorboardGenerativeModelImageSampler()], 47 | ) 48 | 49 | with self.assertRaises(AssertionError): 50 | trainer.fit(module) 51 | 52 | @tempdir 53 | def test_callback_triggered(self, tmp_dir: str) -> None: 54 | """tests image generation is triggered by end of an epoch.""" 55 | 56 | class MyModule(TestModule): 57 | def __init__(self) -> None: 58 | super().__init__() 59 | self.latent_dim = 32 60 | self.img_dim = (1, 1, 2) 61 | 62 | def forward(self, x) -> torch.Tensor: 63 | assert x.size() == torch.Size([1, 32]) 64 | return super().forward(x) 65 | 66 | module = MyModule() 67 | trainer = Trainer( 68 | default_root_dir=tmp_dir, 69 | fast_dev_run=True, 70 | callbacks=[TensorboardGenerativeModelImageSampler(num_samples=1)], 71 | ) 72 | 73 | with patch.object( 74 | torchvision.utils, "make_grid", return_value=torch.randn(1, 1, 1) 75 | ) as mock_call: 76 | trainer.fit(module) 77 | # called twice: once for training, once for validation 78 | self.assertEqual(mock_call.call_count, 2) 79 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/tests/test_train_app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | from typing import cast 10 | 11 | import torchrecipes.vision.image_generation.conf # noqa 12 | from torchrecipes.core.test_utils.test_base import BaseTrainAppTestCase 13 | from torchrecipes.vision.image_generation.train_app import GANTrainApp 14 | 15 | 16 | class TestGANTrainApp(BaseTrainAppTestCase): 17 | def get_train_app(self, config_name: str) -> GANTrainApp: 18 | app = self.create_app_from_hydra( 19 | config_module="torchrecipes.vision.image_generation.conf", 20 | config_name=config_name, 21 | overrides=[ 22 | "datamodule=torchvision_datamodule", 23 | "+datamodule/datasets=fake_data", 24 | ], 25 | ) 26 | self.mock_trainer_params(app) 27 | return cast(GANTrainApp, app) 28 | 29 | def test_gan_train_app(self) -> None: 30 | train_app = self.get_train_app("gan_train_app") 31 | output = train_app.train() 32 | self.assertIsNotNone(output) 33 | 34 | def test_infogan_train_app(self) -> None: 35 | train_app = self.get_train_app("infogan_train_app") 36 | output = train_app.train() 37 | self.assertIsNotNone(output) 38 | -------------------------------------------------------------------------------- /torchrecipes/vision/image_generation/train_app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the BSD-style license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | #!/usr/bin/env python3 8 | 9 | from typing import List, Optional 10 | 11 | import hydra 12 | 13 | # @manual "fbsource//third-party/pypi/omegaconf:omegaconf" 14 | 15 | from pytorch_lightning import LightningDataModule 16 | from pytorch_lightning.callbacks import Callback 17 | from torchrecipes.core.base_train_app import BaseTrainApp 18 | from torchrecipes.core.conf import DataModuleConf, TrainerConf 19 | from torchrecipes.vision.image_generation.callbacks import ( 20 | TensorboardGenerativeModelImageSampler, 21 | ) 22 | from torchrecipes.vision.image_generation.module.gan import GANModuleConf 23 | 24 | 25 | class GANTrainApp(BaseTrainApp): 26 | def __init__( 27 | self, 28 | module: GANModuleConf, 29 | trainer: TrainerConf, 30 | datamodule: DataModuleConf, 31 | ) -> None: 32 | super().__init__(module, trainer, datamodule) 33 | 34 | def get_data_module(self) -> Optional[LightningDataModule]: 35 | """ 36 | Instantiate a LightningDataModule. 37 | """ 38 | return hydra.utils.instantiate(self.datamodule_conf) 39 | 40 | def get_callbacks(self) -> List[Callback]: 41 | # TODO(kaizh): make callback configurable 42 | return [TensorboardGenerativeModelImageSampler()] 43 | --------------------------------------------------------------------------------