├── .coveragerc ├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── documentation.yml │ ├── feature-request.yml │ └── help-support.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── build_and_publish_docs.yaml │ ├── build_docs.yaml │ ├── nightly_build_cpu.yaml │ ├── pre_commit.yaml │ ├── release_build.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── conftest.py ├── dev-requirements.txt ├── docs ├── .gitignore ├── Makefile ├── README.md ├── license_header.txt ├── requirements.txt └── source │ ├── _static │ ├── css │ │ └── torchtnt.css │ └── js │ │ ├── jquery.js │ │ └── torchtnt.js │ ├── assets │ └── TNTDiagram.png │ ├── checkpointing.rst │ ├── conf.py │ ├── distributed.rst │ ├── docutils.conf │ ├── examples.rst │ ├── ext │ └── fbcode.py │ ├── framework │ ├── auto_unit.rst │ ├── callbacks.rst │ ├── eval.rst │ ├── fit.rst │ ├── predict.rst │ ├── state.rst │ ├── train.rst │ └── unit.rst │ ├── index.rst │ ├── overview.rst │ ├── templates │ ├── class_template.rst │ └── layout.html │ └── utils │ └── utils.rst ├── examples ├── auto_unit_example.py ├── mingpt │ ├── char_dataset.py │ ├── data │ │ └── input.txt │ ├── main.py │ └── model.py ├── mnist │ ├── README.md │ ├── main.py │ └── requirements.txt ├── torchrec │ ├── README.md │ ├── main.py │ ├── requirements.txt │ └── tests │ │ └── torchrec_example_test.py └── train_unit_example.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── tests ├── framework │ ├── __init__.py │ ├── callbacks │ │ ├── test_base_checkpointer.py │ │ ├── test_checkpoint_utils.py │ │ ├── test_csv_writer.py │ │ ├── test_dcp_saver.py │ │ ├── test_dcp_saver_gpu.py │ │ ├── test_early_stopping.py │ │ ├── test_empty_cuda_cache.py │ │ ├── test_garbage_collector.py │ │ ├── test_iteration_time_logger.py │ │ ├── test_lambda.py │ │ ├── test_learning_rate_monitor.py │ │ ├── test_memory_snapshot.py │ │ ├── test_module_summary.py │ │ ├── test_periodic_distributed_sync.py │ │ ├── test_progress_reporter.py │ │ ├── test_pytorch_profiler.py │ │ ├── test_slow_rank_detector.py │ │ ├── test_system_resources_monitor.py │ │ ├── test_tensorboard_parameter_monitor.py │ │ ├── test_tensorfloat32.py │ │ ├── test_throughput_logger.py │ │ ├── test_time_limit_interrupter.py │ │ ├── test_time_wait_for_batch_logger.py │ │ ├── test_torchsnapshot_saver.py │ │ ├── test_torchsnapshot_saver_gpu.py │ │ ├── test_tqdm_progress_bar.py │ │ └── test_train_progress_monitor.py │ ├── test_app_state_mixin.py │ ├── test_auto_unit.py │ ├── test_auto_unit_gpu.py │ ├── test_callback_handler.py │ ├── test_evaluate.py │ ├── test_fit.py │ ├── test_loop_utils.py │ ├── test_predict.py │ ├── test_state.py │ ├── test_train.py │ ├── test_unit.py │ ├── test_unit_utils.py │ ├── test_unit_utils_gpu.py │ └── test_utils.py └── utils │ ├── __init__.py │ ├── data │ ├── test_data_prefetcher.py │ ├── test_data_prefetcher_gpu.py │ ├── test_iterators.py │ ├── test_multi_dataloader.py │ └── test_profile_dataloader.py │ ├── loggers │ ├── __init__.py │ ├── test_anomaly_logger.py │ ├── test_csv.py │ ├── test_in_memory.py │ ├── test_json.py │ ├── test_stdout.py │ ├── test_tensorboard.py │ └── test_utils.py │ ├── test_anomaly_evaluation.py │ ├── test_checkpoint.py │ ├── test_checkpoint_gpu.py │ ├── test_device.py │ ├── test_device_gpu.py │ ├── test_device_mesh.py │ ├── test_distributed.py │ ├── test_distributed_gpu.py │ ├── test_early_stop_checker.py │ ├── test_early_stop_checker_gpu.py │ ├── test_env.py │ ├── test_flops.py │ ├── test_fsspec.py │ ├── test_memory.py │ ├── test_memory_snapshot_profiler.py │ ├── test_memory_snapshot_profiler_gpu.py │ ├── test_misc.py │ ├── test_module_summary.py │ ├── test_nan.py │ ├── test_oom.py │ ├── test_oom_gpu.py │ ├── test_optimizer.py │ ├── test_precision.py │ ├── test_prepare_module.py │ ├── test_prepare_module_gpu.py │ ├── test_progress.py │ ├── test_rank_zero_log.py │ ├── test_swa.py │ ├── test_timer.py │ ├── test_timer_gpu.py │ ├── test_tqdm.py │ └── test_version.py ├── torchtnt ├── __init__.py ├── framework │ ├── __init__.py │ ├── _callback_handler.py │ ├── _loop_utils.py │ ├── _test_utils.py │ ├── _unit_utils.py │ ├── auto_unit.py │ ├── callback.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── _checkpoint_utils.py │ │ ├── base_checkpointer.py │ │ ├── base_csv_writer.py │ │ ├── checkpointer_types.py │ │ ├── dcp_saver.py │ │ ├── early_stopping.py │ │ ├── empty_cuda_cache.py │ │ ├── garbage_collector.py │ │ ├── iteration_time_logger.py │ │ ├── lambda_callback.py │ │ ├── learning_rate_monitor.py │ │ ├── memory_snapshot.py │ │ ├── module_summary.py │ │ ├── periodic_distributed_sync.py │ │ ├── progress_reporter.py │ │ ├── pytorch_profiler.py │ │ ├── slow_rank_detector.py │ │ ├── system_resources_monitor.py │ │ ├── tensorboard_parameter_monitor.py │ │ ├── tensorfloat32.py │ │ ├── throughput_logger.py │ │ ├── time_limit_interrupter.py │ │ ├── time_wait_for_batch_logger.py │ │ ├── torch_compile.py │ │ ├── torchsnapshot_saver.py │ │ ├── tqdm_progress_bar.py │ │ └── train_progress_monitor.py │ ├── evaluate.py │ ├── fit.py │ ├── predict.py │ ├── state.py │ ├── train.py │ ├── unit.py │ └── utils.py ├── py.typed └── utils │ ├── __init__.py │ ├── anomaly_evaluation.py │ ├── checkpoint.py │ ├── data │ ├── __init__.py │ ├── data_prefetcher.py │ ├── iterators.py │ ├── multi_dataloader.py │ ├── profile_dataloader.py │ └── synthetic_data.py │ ├── device.py │ ├── device_mesh.py │ ├── distributed.py │ ├── early_stop_checker.py │ ├── env.py │ ├── event.py │ ├── event_handlers.py │ ├── flops.py │ ├── fsdp_utils.py │ ├── fsspec.py │ ├── loggers │ ├── __init__.py │ ├── anomaly_logger.py │ ├── csv.py │ ├── file.py │ ├── in_memory.py │ ├── json.py │ ├── logger.py │ ├── stdout.py │ ├── tensorboard.py │ └── utils.py │ ├── lr_scheduler.py │ ├── memory.py │ ├── memory_snapshot_profiler.py │ ├── misc.py │ ├── module_summary.py │ ├── nan.py │ ├── oom.py │ ├── optimizer.py │ ├── precision.py │ ├── prepare_module.py │ ├── progress.py │ ├── rank_zero_log.py │ ├── stateful.py │ ├── swa.py │ ├── test_utils.py │ ├── timer.py │ ├── tqdm.py │ └── version.py └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | tests/* 4 | 5 | [report] 6 | omit = 7 | tests/* 8 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Suggested config from pytorch that we can adopt 3 | select = B,C,E,F,P,T4,W,B9,TOR0,TOR1,TOR2 4 | max-line-length = 120 5 | # C408 ignored because we like the dict keyword argument syntax 6 | # E501 is not flexible enough, we're using B950 instead 7 | ignore = 8 | E203,E305,E402,E501,E704,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, 9 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 10 | # to line this up with executable bit 11 | EXE001, 12 | optional-ascii-coding = True 13 | exclude = 14 | ./.git, 15 | ./docs 16 | ./build 17 | ./scripts, 18 | ./venv, 19 | *.pyi 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the 9 | existing and past issues](https://github.com/pytorch/tnt/issues?q=is%3Aissue+sort%3Acreated-desc+). 10 | - type: textarea 11 | attributes: 12 | label: 🐛 Describe the bug 13 | description: | 14 | Please provide a clear and concise description of what the bug is. 15 | 16 | If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: 17 | 18 | ```python 19 | # All necessary imports at the beginning 20 | import torch 21 | import torchtnt 22 | 23 | # A succinct reproducing example trimmed down to the essential parts 24 | 25 | ``` 26 | 27 | If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. 28 | 29 | Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. 30 | placeholder: | 31 | A clear and concise description of what the bug is. 32 | 33 | ```python 34 | Sample code to reproduce the problem 35 | ``` 36 | 37 | ``` 38 | The error message you got, with the full traceback. 39 | ``` 40 | validations: 41 | required: true 42 | - type: textarea 43 | attributes: 44 | label: Versions 45 | description: | 46 | Please run the following and paste the output below. Make sure the version numbers of all relevant packages (e.g. torch, torchtnt, other domain packages) are included. 47 | ```sh 48 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py 49 | # For security purposes, please check the contents of collect_env.py before running it. 50 | python collect_env.py 51 | ``` 52 | validations: 53 | required: true 54 | 55 | - type: markdown 56 | attributes: 57 | value: > 58 | Thanks for contributing 🎉! 59 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Documentation 2 | description: Report an issue related to inline documnetation 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 📚 The doc issue 8 | description: > 9 | A clear and concise description of what content is an issue. 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Suggest a potential alternative/fix 15 | description: > 16 | Tell us how we could improve the documentation in this regard. 17 | - type: markdown 18 | attributes: 19 | value: > 20 | Thanks for contributing 🎉! 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new feature 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 🚀 The feature 8 | description: > 9 | A clear and concise description of the feature proposal 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Motivation, pitch 15 | description: > 16 | Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., 17 | *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link 18 | here too. 19 | validations: 20 | required: true 21 | - type: textarea 22 | attributes: 23 | label: Alternatives 24 | description: > 25 | A description of any alternative solutions or features you've considered, if any. 26 | - type: textarea 27 | attributes: 28 | label: Additional context 29 | description: > 30 | Add any other context or screenshots about the feature request. 31 | - type: markdown 32 | attributes: 33 | value: > 34 | Thanks for contributing 🎉! 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/help-support.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Help Support 2 | description: Do you need help/support? Send us your questions. 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 📚 Question 8 | description: > 9 | Description of your question or what you need support with. 10 | validations: 11 | required: true 12 | - type: markdown 13 | attributes: 14 | value: > 15 | Thanks for contributing 🎉! 16 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Please read through our [contribution guide](https://github.com/pytorch/tnt/blob/main/CONTRIBUTING.md) prior to creating your pull request. 2 | 3 | Summary: 4 | 5 | 6 | Test plan: 7 | 8 | 9 | Fixes #{issue number} 10 | 11 | -------------------------------------------------------------------------------- /.github/workflows/build_and_publish_docs.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Update Docs 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | # Allow one concurrent deployment 8 | concurrency: 9 | group: "pages" 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | build_and_publish_docs: 14 | runs-on: ubuntu-latest 15 | permissions: 16 | # Grant write permission here so that the doc can be pushed to gh-pages branch 17 | contents: write 18 | steps: 19 | - name: Check out repo 20 | uses: actions/checkout@v2 21 | - name: Setup conda env 22 | uses: conda-incubator/setup-miniconda@v2 23 | with: 24 | miniconda-version: "latest" 25 | activate-environment: test 26 | python-version: "3.10" 27 | - name: Install dependencies 28 | shell: bash -l {0} 29 | run: | 30 | set -eux 31 | conda activate test 32 | pip install -r requirements.txt 33 | pip install -r dev-requirements.txt 34 | conda install pytorch cpuonly -c pytorch-nightly 35 | python setup.py sdist bdist_wheel 36 | pip install dist/*.whl 37 | - name: Build docs 38 | shell: bash -l {0} 39 | run: | 40 | set -eux 41 | conda activate test 42 | cd docs 43 | pip install -r requirements.txt 44 | make html 45 | touch build/html/.nojekyll 46 | cd .. 47 | - name: Deploy docs to Github pages 48 | uses: JamesIves/github-pages-deploy-action@v4.4.1 49 | with: 50 | branch: gh-pages # The branch the action should deploy to. 51 | folder: docs/build/html # The folder the action should deploy. 52 | target-folder: master 53 | -------------------------------------------------------------------------------- /.github/workflows/build_docs.yaml: -------------------------------------------------------------------------------- 1 | name: Build Docs 2 | 3 | on: 4 | pull_request: 5 | 6 | # Allow one concurrent deployment 7 | concurrency: 8 | group: "pages" 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | build_docs: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out repo 16 | uses: actions/checkout@v2 17 | - name: Setup conda env 18 | uses: conda-incubator/setup-miniconda@v2 19 | with: 20 | miniconda-version: "latest" 21 | activate-environment: test 22 | python-version: "3.10" 23 | - name: Install dependencies 24 | shell: bash -l {0} 25 | run: | 26 | set -eux 27 | conda activate test 28 | pip install -r requirements.txt 29 | pip install -r dev-requirements.txt 30 | conda install pytorch cpuonly -c pytorch-nightly 31 | python setup.py sdist bdist_wheel 32 | pip install dist/*.whl 33 | - name: Build docs 34 | shell: bash -l {0} 35 | run: | 36 | set -eux 37 | conda activate test 38 | cd docs 39 | pip install -r requirements.txt 40 | make html 41 | touch build/html/.nojekyll 42 | cd .. 43 | -------------------------------------------------------------------------------- /.github/workflows/nightly_build_cpu.yaml: -------------------------------------------------------------------------------- 1 | name: Push CPU Binary Nightly 2 | 3 | on: 4 | # run every day at 11:15am 5 | schedule: 6 | - cron: '15 11 * * *' 7 | # or manually trigger it 8 | workflow_dispatch: 9 | inputs: 10 | append_to_version: 11 | description: "Optional value to append to version string" 12 | 13 | 14 | jobs: 15 | unit_tests: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: [3.8, 3.9] 20 | steps: 21 | - name: Check out repo 22 | uses: actions/checkout@v2 23 | - name: Setup conda env 24 | uses: conda-incubator/setup-miniconda@v2 25 | with: 26 | miniconda-version: "latest" 27 | activate-environment: test 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install dependencies 30 | shell: bash -l {0} 31 | run: | 32 | set -eux 33 | conda activate test 34 | pip install -r requirements.txt 35 | python setup.py sdist bdist_wheel 36 | pip install dist/*.whl 37 | pip install -r dev-requirements.txt 38 | conda install pytorch cpuonly -c pytorch-nightly 39 | - name: Run unit tests 40 | shell: bash -l {0} 41 | run: | 42 | set -eux 43 | conda activate test 44 | pytest tests -vv 45 | # TODO figure out how to deduplicate steps 46 | upload_to_pypi: 47 | needs: unit_tests 48 | runs-on: ubuntu-latest 49 | steps: 50 | - name: Check out repo 51 | uses: actions/checkout@v2 52 | - name: Setup conda env 53 | uses: conda-incubator/setup-miniconda@v2 54 | with: 55 | miniconda-version: "latest" 56 | activate-environment: test 57 | python-version: 3.8 58 | - name: Install dependencies 59 | shell: bash -l {0} 60 | run: | 61 | set -eux 62 | conda activate test 63 | conda install pytorch cpuonly -c pytorch-nightly 64 | pip install -r requirements.txt 65 | pip install --no-build-isolation -e ".[dev]" 66 | - name: Upload to PyPI 67 | shell: bash -l {0} 68 | env: 69 | PYPI_USER: ${{ secrets.PYPI_USER }} 70 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 71 | run: | 72 | set -eux 73 | conda activate test 74 | pip install twine 75 | python setup.py --nightly --append-to-version=${{ github.event.inputs.append_to_version }} sdist bdist_wheel 76 | twine upload --username "$PYPI_USER" --password "$PYPI_TOKEN" dist/* --verbose 77 | -------------------------------------------------------------------------------- /.github/workflows/pre_commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | with: 15 | python-version: 3.11 16 | - uses: pre-commit/action@v3.0.0 17 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: unit test 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | 8 | jobs: 9 | unit_tests_nightly_pytorch: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [3.8, 3.9] 14 | steps: 15 | - name: Check out repo 16 | uses: actions/checkout@v2 17 | - name: Setup conda env 18 | uses: conda-incubator/setup-miniconda@v2 19 | with: 20 | miniconda-version: "latest" 21 | activate-environment: test 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | shell: bash -l {0} 25 | run: | 26 | set -eux 27 | conda activate test 28 | pip install -r requirements.txt 29 | pip install -r dev-requirements.txt 30 | pip install --no-build-isolation -e . 31 | conda install pytorch cpuonly -c pytorch-nightly 32 | - name: Run unit tests with coverage 33 | shell: bash -l {0} 34 | run: | 35 | set -eux 36 | conda activate test 37 | pytest --cov=. --cov-report xml tests -vv 38 | - name: Upload Coverage to Codecov 39 | uses: codecov/codecov-action@v2 40 | unit_tests_stable_pytorch: 41 | runs-on: ubuntu-latest 42 | strategy: 43 | matrix: 44 | python-version: [3.8] 45 | steps: 46 | - name: Check out repo 47 | uses: actions/checkout@v2 48 | - name: Setup conda env 49 | uses: conda-incubator/setup-miniconda@v2 50 | with: 51 | miniconda-version: "latest" 52 | activate-environment: test 53 | python-version: ${{ matrix.python-version }} 54 | - name: Install dependencies 55 | shell: bash -l {0} 56 | run: | 57 | set -eux 58 | conda activate test 59 | pip install -r requirements.txt 60 | pip install -r dev-requirements.txt 61 | pip install --no-build-isolation -e . 62 | - name: Run unit tests with coverage 63 | shell: bash -l {0} 64 | run: | 65 | set -eux 66 | conda activate test 67 | pytest --cov=. --cov-report xml tests -vv 68 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.1.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: check-ast 10 | - id: check-merge-conflict 11 | - id: check-added-large-files 12 | args: ['--maxkb=500'] 13 | - id: end-of-file-fixer 14 | 15 | - repo: https://github.com/Lucas-C/pre-commit-hooks 16 | rev: v1.1.7 17 | hooks: 18 | - id: insert-license 19 | files: \.py$ 20 | args: 21 | - --license-filepath 22 | - docs/license_header.txt 23 | 24 | - repo: https://github.com/pycqa/flake8 25 | rev: 4.0.1 26 | hooks: 27 | - id: flake8 28 | args: 29 | - --config=.flake8 30 | additional_dependencies: 31 | - torchfix==0.1.1 32 | 33 | - repo: https://github.com/omnilib/ufmt 34 | rev: v2.5.1 35 | hooks: 36 | - id: ufmt 37 | additional_dependencies: 38 | - black == 24.2.0 39 | - usort == 1.0.2 40 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to tnt 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to tnt, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For torchtnt software 4 | 5 | Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Meta nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | TNT 2 | ========== 3 | 4 | **TNT** is a library for PyTorch **t**rai**n**ing **t**ools and utilities. 5 | 6 |

7 | build status 8 | pypi version 9 | pypi version 10 | pypi nightly version 11 | codecov 12 | bsd license 13 | documentation status 14 | 15 | 16 | 17 | ## Installation 18 | 19 | TNT can be installed with pip: 20 | 21 | ```buildoutcfg 22 | pip install torchtnt 23 | ``` 24 | Or, alternatively, via conda: 25 | 26 | ```buildoutcfg 27 | conda install -c conda-forge torchtnt 28 | ``` 29 | 30 | If you run into issues, make sure that Pytorch is installed first. 31 | 32 | You can also install the latest version from master. Just run: 33 | 34 | ```buildoutcfg 35 | pip install git+https://github.com/pytorch/tnt.git@master 36 | ``` 37 | 38 | To update to the latest version from master: 39 | 40 | ```buildoutcfg 41 | pip install --upgrade git+https://github.com/pytorch/tnt.git@master 42 | ``` 43 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | parameterized 2 | pytest 3 | pytest-cov 4 | torchsnapshot-nightly 5 | pyre-check 6 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | src/pytorch-sphinx-theme/ 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | Docs 2 | ========== 3 | 4 | 5 | ## Building the docs 6 | 7 | To build and preview the docs run the following commands: 8 | 9 | ```buildoutcfg 10 | cd docs 11 | pip3 install -r requirements.txt 12 | make html 13 | python3 -m http.server 8082 --bind :: 14 | ``` 15 | 16 | Now you should be able to view the docs in your browser at the link provided in your terminal. 17 | 18 | To reload the preview after making changes, rerun: 19 | 20 | ``` 21 | make html 22 | python3 -m http.server 8082 --bind :: 23 | ``` 24 | -------------------------------------------------------------------------------- /docs/license_header.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==5.0.1 2 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 3 | -------------------------------------------------------------------------------- /docs/source/_static/css/torchtnt.css: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #redirect-banner { 10 | border-bottom: 1px solid #e2e2e2; 11 | } 12 | #redirect-banner > p { 13 | margin: 0.8rem; 14 | text-align: center; 15 | font-weight: bold; 16 | color: red; 17 | } 18 | -------------------------------------------------------------------------------- /docs/source/_static/js/torchtnt.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | const NETWORK_TEST_URL = 'https://staticdocs.thefacebook.com/ping'; 10 | fetch(NETWORK_TEST_URL).then(() => { 11 | $("#redirect-banner").prependTo("body").show(); 12 | }); 13 | -------------------------------------------------------------------------------- /docs/source/assets/TNTDiagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tnt/cb31137b0928acc24f4d341faa8c7d88b8ed4696/docs/source/assets/TNTDiagram.png -------------------------------------------------------------------------------- /docs/source/docutils.conf: -------------------------------------------------------------------------------- 1 | [html writers] 2 | table_style: colwidths-auto # Necessary for the table generated by autosummary to look decent 3 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ============ 3 | 4 | TrainUnit Example 5 | ~~~~~~~~~~~~~~~~~~~~~ 6 | The TrainUnit example shows how to use the :class:`~torchtnt.framework.unit.TrainUnit` to train a basic model. 7 | 8 | It can be found here: https://github.com/pytorch/tnt/blob/master/examples/train_unit_example.py 9 | 10 | 11 | AutoUnit Example 12 | ~~~~~~~~~~~~~~~~~~~~~ 13 | The AutoUnit example shows how to use the :class:`~torchtnt.framework.auto_unit.AutoUnit` to train a basic model with 14 | less code, and more training features enabled out of the box. 15 | 16 | It can be found here: https://github.com/pytorch/tnt/blob/master/examples/auto_unit_example.py 17 | 18 | 19 | TorchData Train Example 20 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 21 | The TorchData Train example shows how to use TorchData's `DataPipe `_ and `Dataloader2 `_. when training a basic model with TNT. 22 | 23 | It can be found here: https://github.com/pytorch/tnt/blob/master/examples/torchdata_train_example.py 24 | 25 | 26 | MNIST Example 27 | ~~~~~~~~~~~~~~~~~~~~~ 28 | The MNIST example shows how to use TNT to train a convnet model on the MNIST dataset. 29 | 30 | It can be found here: https://github.com/pytorch/tnt/blob/master/examples/mnist/main.py 31 | 32 | 33 | MinGPT Example 34 | ~~~~~~~~~~~~~~~~~~~~~ 35 | The MinGPT example shows how to use the :class:`~torchtnt.framework.auto_unit.AutoUnit` to train a MinGPT model. 36 | 37 | It can be found here: https://github.com/pytorch/tnt/blob/master/examples/mingpt/main.py 38 | 39 | 40 | TorchRec Example 41 | ~~~~~~~~~~~~~~~~~~~~~ 42 | The TorchRec example shows how to use :class:`~torchtnt.framework.auto_unit.TrainUnit` with `TorchRec `_. 43 | 44 | It can be found here: https://github.com/pytorch/tnt/blob/master/examples/torchrec/main.py 45 | -------------------------------------------------------------------------------- /docs/source/ext/fbcode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import os 10 | 11 | from docutils import nodes 12 | from sphinx.util.docutils import SphinxDirective 13 | from sphinx.util.nodes import nested_parse_with_titles 14 | 15 | 16 | class FbcodeDirective(SphinxDirective): 17 | # this enables content in the directive 18 | has_content = True 19 | 20 | def run(self): 21 | if "fbcode" not in os.getcwd(): 22 | return [] 23 | node = nodes.section() 24 | node.document = self.state.document 25 | nested_parse_with_titles(self.state, self.content, node) 26 | return node.children 27 | 28 | 29 | def setup(app): 30 | app.add_directive("fbcode", FbcodeDirective) 31 | 32 | return { 33 | "version": "0.1", 34 | "parallel_read_safe": True, 35 | "parallel_write_safe": True, 36 | } 37 | -------------------------------------------------------------------------------- /docs/source/framework/auto_unit.rst: -------------------------------------------------------------------------------- 1 | AutoUnit 2 | =========== 3 | 4 | .. autoclass:: torchtnt.framework.auto_unit.AutoUnit 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: torchtnt.framework.auto_unit.AutoPredictUnit 9 | :members: 10 | :undoc-members: 11 | -------------------------------------------------------------------------------- /docs/source/framework/callbacks.rst: -------------------------------------------------------------------------------- 1 | Callbacks 2 | ======================= 3 | 4 | .. automodule:: torchtnt.framework.callback 5 | :members: 6 | :undoc-members: 7 | 8 | 9 | Built-in callbacks 10 | ~~~~~~~~~~~~~~~~~~~~~ 11 | 12 | We offer several pre-written callbacks which are ready to be used out of the box: 13 | 14 | 15 | .. currentmodule:: torchtnt.framework.callbacks 16 | 17 | .. autosummary:: 18 | :nosignatures: 19 | :toctree: generated/ 20 | :template: class_template.rst 21 | 22 | BaseCSVWriter 23 | EarlyStopping 24 | GarbageCollector 25 | IterationTimeLogger 26 | Lambda 27 | LearningRateMonitor 28 | MemorySnapshot 29 | ModuleSummary 30 | PeriodicDistributedSync 31 | ProgressReporter 32 | PyTorchProfiler 33 | SlowRankDetector 34 | SystemResourcesMonitor 35 | TensorBoardParameterMonitor 36 | TimeLimitInterrupter 37 | TimeWaitForBatchLogger 38 | ThroughputLogger 39 | TorchSnapshotSaver 40 | TQDMProgressBar 41 | TrainProgressMonitor 42 | -------------------------------------------------------------------------------- /docs/source/framework/eval.rst: -------------------------------------------------------------------------------- 1 | Evaluate 2 | ======================= 3 | 4 | Evaluate Entry Point 5 | ~~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autofunction:: torchtnt.framework.evaluate.evaluate 8 | -------------------------------------------------------------------------------- /docs/source/framework/fit.rst: -------------------------------------------------------------------------------- 1 | Fit 2 | ======================= 3 | 4 | Fit Entry Point 5 | ~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autofunction:: torchtnt.framework.fit.fit 8 | -------------------------------------------------------------------------------- /docs/source/framework/predict.rst: -------------------------------------------------------------------------------- 1 | Predict 2 | ======================= 3 | 4 | Predict Entry Point 5 | ~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autofunction:: torchtnt.framework.predict.predict 8 | -------------------------------------------------------------------------------- /docs/source/framework/state.rst: -------------------------------------------------------------------------------- 1 | State 2 | ========= 3 | 4 | .. autoclass:: torchtnt.framework.state.State 5 | :members: 6 | :undoc-members: 7 | 8 | PhaseState 9 | ~~~~~~~~~~~~~~~~~ 10 | .. autoclass:: torchtnt.framework.state.PhaseState 11 | :members: 12 | :undoc-members: 13 | -------------------------------------------------------------------------------- /docs/source/framework/train.rst: -------------------------------------------------------------------------------- 1 | Train 2 | ======================= 3 | 4 | Train Entry Point 5 | ~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autofunction:: torchtnt.framework.train.train 8 | -------------------------------------------------------------------------------- /docs/source/framework/unit.rst: -------------------------------------------------------------------------------- 1 | Unit 2 | ========= 3 | 4 | The Unit concept represents the primary place to organize your model code in TorchTNT. TorchTNT offers three different types of Unit classes for training, evaluation, and prediction. These interfaces are mutually exclusive and can be combined as needed, e.g. in the case of fitting (interleaving training and evaluation). 5 | 6 | TrainUnit 7 | ~~~~~~~~~~~~~~~~~ 8 | .. autoclass:: torchtnt.framework.unit.TrainUnit 9 | :members: 10 | :undoc-members: 11 | 12 | EvalUnit 13 | ~~~~~~~~~~~~~~~~~ 14 | .. autoclass:: torchtnt.framework.unit.EvalUnit 15 | :members: 16 | :undoc-members: 17 | 18 | PredictUnit 19 | ~~~~~~~~~~~~~~~~~ 20 | .. autoclass:: torchtnt.framework.unit.PredictUnit 21 | :members: 22 | :undoc-members: 23 | 24 | Combining Multiple Units 25 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 26 | In some cases, it is convenient to implement multiple Unit interfaces under the same class, e.g. if you plan to use your class to run several different phases; 27 | for example, running training and then prediction, or running training and evaluation interleaved (referred to as fitting). 28 | Here is an example of a unit which extends TrainUnit, EvalUnit, and PredictUnit. 29 | 30 | .. code-block:: python 31 | 32 | from torchtnt.framework.unit import TrainUnit, EvalUnit, PredictUnit 33 | 34 | Batch = Tuple[torch.tensor, torch.tensor] 35 | 36 | class MyUnit(TrainUnit[Batch], EvalUnit[Batch], PredictUnit[Batch]): 37 | def __init__( 38 | self, 39 | module: torch.nn.Module, 40 | optimizer: torch.optim.Optimizer, 41 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 42 | ): 43 | super().__init__() 44 | self.module = module 45 | self.optimizer = optimizer 46 | self.lr_scheduler = lr_scheduler 47 | 48 | def train_step(self, state: State, data: Batch) -> None: 49 | inputs, targets = data 50 | outputs = self.module(inputs) 51 | loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets) 52 | loss.backward() 53 | 54 | self.optimizer.step() 55 | self.optimizer.zero_grad() 56 | 57 | def eval_step(self, state: State, data: Batch) -> None: 58 | inputs, targets = data 59 | outputs = self.module(inputs) 60 | loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, targets) 61 | 62 | def predict_step(self, state: State, data: Batch) -> torch.tensor: 63 | inputs, targets = data 64 | outputs = self.module(inputs) 65 | return outputs 66 | 67 | def on_train_epoch_end(self, state: State) -> None: 68 | # step the learning rate scheduler 69 | self.lr_scheduler.step() 70 | 71 | my_unit = MyUnit(module=..., optimizer=..., lr_scheduler=...) 72 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to the TorchTNT documentation! 2 | =========================================== 3 | 4 | TNT is a library for PyTorch training tools and utilities. It has two main components, which are the top-level modules of the repo: 5 | 6 | 1. :mod:`torchtnt.framework`: contains a lightweight training framework to simplify maintaining training, evaluation, and prediction loops. 7 | 2. :mod:`torchtnt.utils`: contains a grab-bag of various independent, training-related utilities, including data related abstractions and wrappers around different publishers to simplify logging metrics. 8 | 9 | Installation 10 | -------------- 11 | 12 | TNT can be installed with pip. To do so, run: 13 | 14 | .. code-block:: shell 15 | 16 | pip install torchtnt 17 | 18 | If you run into issues, make sure that Pytorch is installed first. 19 | 20 | You can also install the latest version from master. Just run: 21 | 22 | .. code-block:: shell 23 | 24 | pip install git+https://github.com/pytorch/tnt.git@master 25 | 26 | To update to the latest version from master: 27 | 28 | .. code-block:: shell 29 | 30 | pip install --upgrade git+https://github.com/pytorch/tnt.git@master 31 | 32 | 33 | Documentation 34 | --------------- 35 | .. toctree:: 36 | :maxdepth: 1 37 | :caption: Overview 38 | :glob: 39 | 40 | overview 41 | 42 | .. fbcode:: 43 | 44 | .. toctree:: 45 | :maxdepth: 2 46 | :caption: Getting Started (Meta) 47 | :glob: 48 | 49 | meta/getting_started 50 | meta/migrating 51 | meta/migrating_example 52 | meta/tss_to_dcp 53 | 54 | .. toctree:: 55 | :maxdepth: 1 56 | :caption: Examples 57 | 58 | examples 59 | 60 | .. fbcode:: 61 | 62 | .. toctree:: 63 | :maxdepth: 2 64 | :caption: Examples (Meta) 65 | :glob: 66 | 67 | meta/examples 68 | 69 | .. fbcode:: 70 | 71 | .. toctree:: 72 | :maxdepth: 2 73 | :caption: Debugging FAQ (Meta) 74 | :glob: 75 | 76 | meta/checkpointing_FAQ 77 | meta/mem_debug 78 | 79 | .. toctree:: 80 | :maxdepth: 1 81 | :caption: Core Concepts 82 | 83 | distributed 84 | checkpointing 85 | 86 | .. toctree:: 87 | :maxdepth: 1 88 | :caption: Framework 89 | 90 | framework/unit 91 | framework/auto_unit 92 | framework/train 93 | framework/eval 94 | framework/predict 95 | framework/fit 96 | framework/state 97 | framework/callbacks 98 | 99 | .. toctree:: 100 | :maxdepth: 2 101 | :caption: Utils 102 | 103 | utils/utils 104 | 105 | .. fbcode:: 106 | 107 | .. toctree:: 108 | :maxdepth: 2 109 | :caption: Framework (Meta) 110 | :glob: 111 | 112 | meta/framework/callbacks 113 | 114 | .. fbcode:: 115 | 116 | .. toctree:: 117 | :maxdepth: 2 118 | :caption: Utils (Meta) 119 | :glob: 120 | 121 | meta/utils 122 | -------------------------------------------------------------------------------- /docs/source/templates/class_template.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline }} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | -------------------------------------------------------------------------------- /docs/source/templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {%- block extrabody %} 4 | {% if not fbcode %} 5 |

11 | {% elif fbcode %} 12 | 17 | {% endif %} 18 | {%- endblock %} 19 | -------------------------------------------------------------------------------- /examples/mingpt/char_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from dataclasses import dataclass 10 | from typing import Dict, Tuple 11 | 12 | import fsspec 13 | import torch 14 | from torch.utils.data import Dataset 15 | 16 | """ 17 | Adapted from https://github.com/karpathy/minGPT/blob/master/projects/chargpt/chargpt.py 18 | """ 19 | 20 | 21 | @dataclass 22 | class DataConfig: 23 | path: str 24 | block_size: int 25 | train_split: float 26 | truncate: float = 1.0 27 | 28 | 29 | class CharDataset(Dataset): 30 | def __init__(self, data_cfg: DataConfig) -> None: 31 | print(data_cfg.path) 32 | data = fsspec.open(data_cfg.path).open().read().decode("utf-8") 33 | data = data[: int(len(data) * data_cfg.truncate)] 34 | 35 | chars = sorted(set(data)) 36 | data_size, vocab_size = len(data), len(chars) 37 | print("Data has %d characters, %d unique." % (data_size, vocab_size)) 38 | 39 | self.stoi: Dict[str, int] = {ch: i for i, ch in enumerate(chars)} 40 | self.itos: Dict[int, str] = {i: ch for i, ch in enumerate(chars)} 41 | self.block_size: int = data_cfg.block_size 42 | self.vocab_size: int = vocab_size 43 | self.data: str = data 44 | 45 | def __len__(self) -> int: 46 | return len(self.data) - self.block_size 47 | 48 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 49 | # grab a chunk of (block_size + 1) characters from the data 50 | chunk = self.data[idx : idx + self.block_size + 1] 51 | # encode every character to an integer 52 | dix = [self.stoi[s] for s in chunk] 53 | x = torch.tensor(dix[:-1], dtype=torch.long) 54 | y = torch.tensor(dix[1:], dtype=torch.long) 55 | return x, y 56 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # Basic MNIST Example 2 | 3 | ```bash 4 | pip install -r requirements.txt 5 | python main.py 6 | ``` 7 | -------------------------------------------------------------------------------- /examples/mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torcheval 3 | torchtnt 4 | torchvision 5 | -------------------------------------------------------------------------------- /examples/torchrec/README.md: -------------------------------------------------------------------------------- 1 | # Basic TorchRec Example 2 | 3 | ```bash 4 | pip install -r requirements.txt 5 | python main.py 6 | ``` 7 | -------------------------------------------------------------------------------- /examples/torchrec/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torcheval 3 | torchrec 4 | torchtnt 5 | -------------------------------------------------------------------------------- /examples/torchrec/tests/torchrec_example_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | from torchtnt.utils.distributed import spawn_multi_process 13 | from torchtnt.utils.test_utils import skip_if_asan, skip_if_not_gpu 14 | 15 | from ..main import main 16 | 17 | 18 | class TorchrecExampleTest(unittest.TestCase): 19 | @skip_if_asan 20 | @skip_if_not_gpu 21 | def test_torchrec_example(self) -> None: 22 | spawn_multi_process(2, "nccl", main, []) 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | 3 | first_party_detection = false 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.3.0 2 | numpy==1.24.4 3 | fsspec 4 | tensorboard 5 | packaging 6 | psutil 7 | pyre_extensions 8 | typing_extensions 9 | setuptools 10 | tqdm 11 | tabulate 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import sys 10 | 11 | from datetime import date 12 | from typing import List 13 | 14 | from setuptools import find_packages, setup 15 | from torchtnt import __version__ 16 | 17 | 18 | def current_path(file_name: str) -> str: 19 | return os.path.abspath(os.path.join(__file__, os.path.pardir, file_name)) 20 | 21 | 22 | def read_requirements(file_name: str) -> List[str]: 23 | with open(current_path(file_name), encoding="utf8") as f: 24 | return [r for r in f.read().strip().split() if not r.startswith("-")] 25 | 26 | 27 | def get_nightly_version() -> str: 28 | return date.today().strftime("%Y.%m.%d") 29 | 30 | 31 | def parse_args() -> argparse.Namespace: 32 | parser = argparse.ArgumentParser(description="torchtnt setup") 33 | parser.add_argument( 34 | "--nightly", 35 | dest="nightly", 36 | action="store_true", 37 | help="enable settings for nightly package build", 38 | default=False, 39 | ) 40 | parser.add_argument( 41 | "--append-to-version", 42 | dest="append_version", 43 | help="append string to end of version number (e.g. a1)", 44 | ) 45 | return parser.parse_known_args() 46 | 47 | 48 | if __name__ == "__main__": 49 | with open(current_path("README.md"), encoding="utf8") as f: 50 | readme: str = f.read() 51 | 52 | custom_args, setup_args = parse_args() 53 | package_name = "torchtnt" if not custom_args.nightly else "torchtnt-nightly" 54 | version = __version__ if not custom_args.nightly else get_nightly_version() 55 | if custom_args.append_version: 56 | version = f"{version}{custom_args.append_version}" 57 | 58 | print(f"using package_name={package_name}, version={version}") 59 | 60 | sys.argv = [sys.argv[0]] + setup_args 61 | 62 | setup( 63 | name=package_name, 64 | version=version, 65 | author="PyTorch", 66 | author_email="daniellepintz@fb.com", 67 | description="A lightweight library for PyTorch training tools and utilities", 68 | long_description=readme, 69 | long_description_content_type="text/markdown", 70 | url="https://github.com/pytorch/tnt", 71 | license="BSD-3", 72 | keywords=["pytorch", "torch", "training", "tools", "utilities"], 73 | python_requires=">=3.7", 74 | install_requires=read_requirements("requirements.txt"), 75 | packages=find_packages(), 76 | package_data={"torchtnt": ["py.typed"]}, 77 | zip_safe=True, 78 | classifiers=[ 79 | "Development Status :: 2 - Pre-Alpha", 80 | "Intended Audience :: Developers", 81 | "Intended Audience :: Science/Research", 82 | "License :: OSI Approved :: BSD License", 83 | "Programming Language :: Python :: 3", 84 | "Programming Language :: Python :: 3.7", 85 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 86 | ], 87 | extras_require={"dev": read_requirements("dev-requirements.txt")}, 88 | ) 89 | -------------------------------------------------------------------------------- /tests/framework/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_empty_cuda_cache.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from unittest import mock 12 | 13 | from torchtnt.framework._test_utils import DummyFitUnit, generate_random_dataloader 14 | from torchtnt.framework.callbacks.empty_cuda_cache import EmptyCudaCache 15 | from torchtnt.framework.fit import fit 16 | 17 | 18 | class EmptyCudaCacheTest(unittest.TestCase): 19 | def test_empty_cuda_cache_call_count_fit(self) -> None: 20 | """ 21 | Test EmptyCudaCache callback was called correct number of times (with fit entry point) 22 | """ 23 | input_dim = 2 24 | train_dataset_len = 10 25 | eval_dataset_len = 6 26 | batch_size = 2 27 | max_epochs = 2 28 | evaluate_every_n_epochs = 1 29 | expected_num_total_steps = ( 30 | train_dataset_len / batch_size * max_epochs 31 | + eval_dataset_len / batch_size * max_epochs 32 | ) 33 | step_interval = 4 34 | 35 | my_unit = DummyFitUnit(2) 36 | ecc_callback = EmptyCudaCache(step_interval) 37 | 38 | train_dataloader = generate_random_dataloader( 39 | train_dataset_len, input_dim, batch_size 40 | ) 41 | eval_dataloader = generate_random_dataloader( 42 | eval_dataset_len, input_dim, batch_size 43 | ) 44 | 45 | expected_num_calls_to_cuda_empty = expected_num_total_steps / step_interval 46 | with mock.patch( 47 | "torchtnt.framework.callbacks.empty_cuda_cache.torch.cuda.empty_cache" 48 | ) as empty_mock: 49 | fit( 50 | my_unit, 51 | train_dataloader=train_dataloader, 52 | eval_dataloader=eval_dataloader, 53 | max_epochs=max_epochs, 54 | evaluate_every_n_epochs=evaluate_every_n_epochs, 55 | callbacks=[ecc_callback], 56 | ) 57 | self.assertEqual(empty_mock.call_count, expected_num_calls_to_cuda_empty) 58 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_learning_rate_monitor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from unittest.mock import MagicMock 12 | 13 | from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader 14 | from torchtnt.framework.callbacks.learning_rate_monitor import LearningRateMonitor 15 | from torchtnt.framework.train import train 16 | 17 | from torchtnt.utils.loggers.logger import MetricLogger 18 | 19 | 20 | class LearningRateMonitorTest(unittest.TestCase): 21 | def test_learning_rate_monitor_epoch(self) -> None: 22 | """ 23 | Test LearningRateMonitor callback with 'epoch' logging interval 24 | """ 25 | input_dim = 2 26 | dataset_len = 10 27 | batch_size = 2 28 | max_epochs = 2 29 | 30 | my_unit = DummyTrainUnit(input_dim=input_dim) 31 | log_writer = MagicMock(spec=MetricLogger) 32 | monitor = LearningRateMonitor(loggers=log_writer) 33 | 34 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 35 | train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[monitor]) 36 | self.assertEqual(log_writer.log_dict.call_count, 2) 37 | 38 | def test_learning_rate_monitor_step(self) -> None: 39 | """ 40 | Test LearningRateMonitor callback with 'step' logging interval 41 | """ 42 | input_dim = 2 43 | dataset_len = 10 44 | batch_size = 2 45 | max_epochs = 2 46 | 47 | my_unit = DummyTrainUnit(input_dim=input_dim) 48 | log_writer = MagicMock(spec=MetricLogger) 49 | monitor = LearningRateMonitor(loggers=log_writer, logging_interval="step") 50 | 51 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 52 | 53 | total_steps = (dataset_len / batch_size) * max_epochs 54 | 55 | train( 56 | my_unit, 57 | dataloader, 58 | max_epochs=max_epochs, 59 | # pyre-fixme[6]: For 4th argument expected `Optional[int]` but got `float`. 60 | max_steps=total_steps, 61 | callbacks=[monitor], 62 | ) 63 | self.assertEqual(log_writer.log_dict.call_count, total_steps) 64 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_memory_snapshot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import tempfile 10 | import unittest 11 | from unittest.mock import MagicMock, Mock 12 | 13 | from torchtnt.framework.callbacks.memory_snapshot import MemorySnapshot 14 | from torchtnt.framework.state import EntryPoint 15 | from torchtnt.utils.memory_snapshot_profiler import MemorySnapshotProfiler 16 | 17 | 18 | class TestMemorySnapshot(unittest.TestCase): 19 | def test_on_train_step_end(self) -> None: 20 | with tempfile.TemporaryDirectory() as temp_dir: 21 | memory_snapshot = MemorySnapshot( 22 | memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir), 23 | ) 24 | memory_snapshot.memory_snapshot_profiler = Mock() 25 | 26 | mock_state, mock_unit = MagicMock(), MagicMock() 27 | memory_snapshot.on_train_step_end(mock_state, mock_unit) 28 | 29 | memory_snapshot.memory_snapshot_profiler.step.assert_called_once() 30 | 31 | def test_on_eval_step_end(self) -> None: 32 | with tempfile.TemporaryDirectory() as temp_dir: 33 | memory_snapshot = MemorySnapshot( 34 | memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir), 35 | ) 36 | memory_snapshot.memory_snapshot_profiler = Mock() 37 | 38 | mock_state, mock_unit = MagicMock(), MagicMock() 39 | mock_state.entry_point = EntryPoint.EVALUATE 40 | memory_snapshot.on_eval_step_end(mock_state, mock_unit) 41 | 42 | memory_snapshot.memory_snapshot_profiler.step.assert_called_once() 43 | 44 | def test_on_predict_step_end(self) -> None: 45 | with tempfile.TemporaryDirectory() as temp_dir: 46 | memory_snapshot = MemorySnapshot( 47 | memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir), 48 | ) 49 | memory_snapshot.memory_snapshot_profiler = Mock() 50 | 51 | mock_state, mock_unit = MagicMock(), MagicMock() 52 | memory_snapshot.on_predict_step_end(mock_state, mock_unit) 53 | 54 | memory_snapshot.memory_snapshot_profiler.step.assert_called_once() 55 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_periodic_distributed_sync.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from unittest.mock import MagicMock, patch 12 | 13 | from torchtnt.framework._test_utils import DummyEvalUnit, DummyPredictUnit 14 | 15 | from torchtnt.framework.callbacks.periodic_distributed_sync import ( 16 | PeriodicDistributedSync, 17 | ) 18 | from torchtnt.framework.state import EntryPoint, State 19 | 20 | 21 | class PeriodicDistributedSyncTest(unittest.TestCase): 22 | @patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier") 23 | def test_frequency_predict(self, barrier_mock: MagicMock) -> None: 24 | pds = PeriodicDistributedSync(sync_every_n_steps=2) 25 | unit = DummyPredictUnit(2) 26 | state = State(entry_point=EntryPoint.PREDICT) 27 | unit.predict_progress.increment_step() # 1 step completed 28 | pds.on_predict_step_end(state, unit) 29 | barrier_mock.assert_not_called() 30 | 31 | unit.predict_progress.increment_step() # 2 steps completed 32 | pds.on_predict_step_end(state, unit) 33 | barrier_mock.assert_called_once() 34 | 35 | @patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier") 36 | def test_frequency_evaluate(self, barrier_mock: MagicMock) -> None: 37 | pds = PeriodicDistributedSync(sync_every_n_steps=2) 38 | unit = DummyEvalUnit(2) 39 | state = State(entry_point=EntryPoint.EVALUATE) 40 | unit.eval_progress.increment_step() # 1 step completed 41 | pds.on_eval_step_end(state, unit) 42 | barrier_mock.assert_not_called() 43 | 44 | unit.eval_progress.increment_step() # 2 steps completed 45 | pds.on_eval_step_end(state, unit) 46 | barrier_mock.assert_called_once() 47 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_progress_reporter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | from torchtnt.framework._test_utils import DummyAutoUnit 14 | from torchtnt.framework.callbacks.progress_reporter import ProgressReporter 15 | from torchtnt.framework.state import EntryPoint, State 16 | from torchtnt.utils.distributed import get_global_rank, spawn_multi_process 17 | from torchtnt.utils.progress import Progress 18 | 19 | 20 | class ProgressReporterTest(unittest.TestCase): 21 | def test_log_with_rank(self) -> None: 22 | spawn_multi_process(2, "gloo", self._test_log_with_rank) 23 | 24 | @staticmethod 25 | def _test_log_with_rank() -> None: 26 | progress_reporter = ProgressReporter() 27 | unit = DummyAutoUnit(module=torch.nn.Linear(2, 2)) 28 | unit.train_progress = Progress( 29 | num_epochs_completed=1, 30 | num_steps_completed=5, 31 | num_steps_completed_in_epoch=3, 32 | ) 33 | unit.eval_progress = Progress( 34 | num_epochs_completed=2, 35 | num_steps_completed=15, 36 | num_steps_completed_in_epoch=7, 37 | ) 38 | state = State(entry_point=EntryPoint.FIT) 39 | tc = unittest.TestCase() 40 | with tc.assertLogs(level="INFO") as log: 41 | progress_reporter.on_train_end(state, unit) 42 | tc.assertEqual( 43 | log.output, 44 | [ 45 | f"INFO:torchtnt.framework.callbacks.progress_reporter:Progress Reporter: rank {get_global_rank()} at on_train_end. " 46 | "Train progress: completed epochs: 1, completed steps: 5, completed steps in current epoch: 3. " 47 | "Eval progress: completed epochs: 2, completed steps: 15, completed steps in current epoch: 7." 48 | ], 49 | ) 50 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_pytorch_profiler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from unittest.mock import MagicMock 12 | 13 | import torch 14 | from torchtnt.framework._test_utils import ( 15 | DummyEvalUnit, 16 | DummyPredictUnit, 17 | DummyTrainUnit, 18 | generate_random_dataloader, 19 | ) 20 | from torchtnt.framework.callbacks.pytorch_profiler import PyTorchProfiler 21 | from torchtnt.framework.evaluate import evaluate 22 | from torchtnt.framework.predict import predict 23 | from torchtnt.framework.train import train 24 | 25 | 26 | class PyTorchProfilerTest(unittest.TestCase): 27 | def test_profiler_train(self) -> None: 28 | """ 29 | Test PytorchProfiler callback with train entry point 30 | """ 31 | input_dim = 2 32 | dataset_len = 10 33 | batch_size = 2 34 | max_epochs = 2 35 | expected_num_total_steps = dataset_len / batch_size * max_epochs 36 | 37 | my_unit = DummyTrainUnit(input_dim) 38 | profiler_mock = MagicMock(spec=torch.profiler.profile) 39 | 40 | profiler = PyTorchProfiler(profiler=profiler_mock) 41 | 42 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 43 | train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[profiler]) 44 | self.assertEqual(profiler_mock.start.call_count, 1) 45 | self.assertEqual(profiler_mock.step.call_count, expected_num_total_steps) 46 | self.assertEqual(profiler_mock.stop.call_count, 1) 47 | 48 | def test_profiler_evaluate(self) -> None: 49 | """ 50 | Test PytorchProfiler callback with evaluate entry point 51 | """ 52 | input_dim = 2 53 | dataset_len = 10 54 | batch_size = 2 55 | expected_num_total_steps = dataset_len / batch_size 56 | 57 | my_unit = DummyEvalUnit(2) 58 | profiler_mock = MagicMock(spec=torch.profiler.profile) 59 | 60 | profiler = PyTorchProfiler(profiler=profiler_mock) 61 | 62 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 63 | 64 | evaluate(my_unit, dataloader, callbacks=[profiler]) 65 | self.assertEqual(profiler_mock.start.call_count, 1) 66 | self.assertEqual(profiler_mock.step.call_count, expected_num_total_steps) 67 | self.assertEqual(profiler_mock.stop.call_count, 1) 68 | 69 | def test_profiler_predict(self) -> None: 70 | """ 71 | Test PytorchProfiler callback with predict entry point 72 | """ 73 | input_dim = 2 74 | dataset_len = 10 75 | batch_size = 2 76 | expected_num_total_steps = dataset_len / batch_size 77 | 78 | my_unit = DummyPredictUnit(2) 79 | profiler_mock = MagicMock(spec=torch.profiler.profile) 80 | 81 | profiler = PyTorchProfiler(profiler=profiler_mock) 82 | 83 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 84 | 85 | predict(my_unit, dataloader, callbacks=[profiler]) 86 | self.assertEqual(profiler_mock.start.call_count, 1) 87 | self.assertEqual(profiler_mock.step.call_count, expected_num_total_steps) 88 | self.assertEqual(profiler_mock.stop.call_count, 1) 89 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_system_resources_monitor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from unittest.mock import MagicMock 12 | 13 | from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader 14 | from torchtnt.framework.callbacks.system_resources_monitor import SystemResourcesMonitor 15 | from torchtnt.framework.train import train 16 | 17 | from torchtnt.utils.loggers.logger import MetricLogger 18 | 19 | 20 | class SystemResourcesMonitorTest(unittest.TestCase): 21 | def test_system_resources_monitor_epoch(self) -> None: 22 | """ 23 | Test SystemResourcesMonitor callback with 'epoch' logging interval 24 | """ 25 | input_dim = 2 26 | dataset_len = 10 27 | batch_size = 2 28 | max_epochs = 2 29 | 30 | my_unit = DummyTrainUnit(input_dim=input_dim) 31 | log_writer = MagicMock(spec=MetricLogger) 32 | monitor = SystemResourcesMonitor( 33 | loggers=log_writer, 34 | logging_interval="epoch", 35 | ) 36 | 37 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 38 | train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[monitor]) 39 | self.assertEqual(log_writer.log_dict.call_count, 2) 40 | 41 | def test_system_resources_monitor_step(self) -> None: 42 | """ 43 | Test SystemResourcesMonitor callback with 'step' logging interval 44 | """ 45 | input_dim = 2 46 | dataset_len = 10 47 | batch_size = 2 48 | max_epochs = 2 49 | 50 | my_unit = DummyTrainUnit(input_dim=input_dim) 51 | log_writer = MagicMock(spec=MetricLogger) 52 | monitor = SystemResourcesMonitor( 53 | loggers=log_writer, 54 | logging_interval="step", 55 | ) 56 | 57 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 58 | 59 | total_steps = (dataset_len / batch_size) * max_epochs 60 | 61 | train( 62 | my_unit, 63 | dataloader, 64 | max_epochs=max_epochs, 65 | # pyre-fixme[6]: For 4th argument expected `Optional[int]` but got `float`. 66 | max_steps=total_steps, 67 | callbacks=[monitor], 68 | ) 69 | self.assertEqual(log_writer.log_dict.call_count, total_steps) 70 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_tensorboard_parameter_monitor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from unittest.mock import MagicMock 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader 15 | from torchtnt.framework.callbacks.tensorboard_parameter_monitor import ( 16 | TensorBoardParameterMonitor, 17 | ) 18 | from torchtnt.framework.train import train 19 | 20 | 21 | class TensorBoardParameterMonitorTest(unittest.TestCase): 22 | def test_monitor_train(self) -> None: 23 | """ 24 | Test TensorBoardParameterMonitor callback with train entry point 25 | """ 26 | input_dim = 2 27 | dataset_len = 10 28 | batch_size = 2 29 | max_epochs = 2 30 | 31 | my_unit = DummyTrainUnit(input_dim=input_dim) 32 | summary_writer = MagicMock(spec=SummaryWriter) 33 | monitor = TensorBoardParameterMonitor(logger=summary_writer) 34 | 35 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 36 | train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[monitor]) 37 | # pyre-fixme[6]: For 2nd argument expected `SupportsDunderLT[Variable[_T]]` 38 | # but got `int`. 39 | self.assertGreater(summary_writer.add_histogram.call_count, 0) 40 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_torchsnapshot_saver_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | import shutil 12 | import tempfile 13 | import unittest 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader 18 | from torchtnt.framework.callbacks.torchsnapshot_saver import TorchSnapshotSaver 19 | from torchtnt.framework.train import train 20 | from torchtnt.utils.distributed import get_global_rank, spawn_multi_process 21 | from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu 22 | 23 | 24 | class TorchSnapshotSaverGPUTest(unittest.TestCase): 25 | @skip_if_not_distributed 26 | @skip_if_not_gpu 27 | def test_save_restore_fsdp(self) -> None: 28 | spawn_multi_process( 29 | 2, 30 | "nccl", 31 | self._save_restore_fsdp, 32 | ) 33 | 34 | @staticmethod 35 | def _save_restore_fsdp() -> None: 36 | input_dim = 2 37 | dataset_len = 10 38 | batch_size = 2 39 | max_epochs = 2 40 | save_every_n_epochs = 1 41 | 42 | my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp") 43 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 44 | if get_global_rank() == 0: 45 | temp_dir = tempfile.mkdtemp() 46 | else: 47 | temp_dir = "" 48 | 49 | snapshot_cb = TorchSnapshotSaver( 50 | temp_dir, 51 | save_every_n_epochs=save_every_n_epochs, 52 | replicated=["**"], 53 | ) 54 | temp_dir = snapshot_cb.dirpath 55 | train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]) 56 | 57 | tc = unittest.TestCase() 58 | try: 59 | my_new_unit = DummyAutoUnit( 60 | module=torch.nn.Linear(input_dim, 2), strategy="fsdp" 61 | ) 62 | tc.assertNotEqual( 63 | my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict() 64 | ) 65 | # get latest checkpoint 66 | ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_train_step_10") 67 | snapshot_cb.restore(ckpt_path, my_new_unit) 68 | tc.assertEqual( 69 | my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict() 70 | ) 71 | finally: 72 | dist.barrier() # avoid race condition 73 | if get_global_rank() == 0: 74 | shutil.rmtree(temp_dir) # delete temp directory 75 | -------------------------------------------------------------------------------- /tests/framework/callbacks/test_train_progress_monitor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader 13 | from torchtnt.framework.callbacks.train_progress_monitor import TrainProgressMonitor 14 | from torchtnt.framework.train import train 15 | 16 | from torchtnt.utils.loggers import InMemoryLogger 17 | 18 | 19 | class TrainProgressMonitorTest(unittest.TestCase): 20 | def test_train_progress_monitor(self) -> None: 21 | """ 22 | Test TrainProgressMonitor callback 23 | """ 24 | input_dim = 2 25 | dataset_len = 10 26 | batch_size = 2 27 | max_epochs = 3 28 | num_train_steps_per_epoch = dataset_len / batch_size 29 | 30 | my_unit = DummyTrainUnit(input_dim=input_dim) 31 | logger = InMemoryLogger() 32 | monitor = TrainProgressMonitor(loggers=logger) 33 | 34 | dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) 35 | train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[monitor]) 36 | 37 | buf = logger.log_buffer 38 | self.assertEqual( 39 | len(buf), max_epochs + 1 40 | ) # +1 since we also log on_train_start 41 | self.assertEqual( 42 | buf[0]["Training steps completed vs epochs"], num_train_steps_per_epoch * 0 43 | ) 44 | self.assertEqual( 45 | buf[1]["Training steps completed vs epochs"], num_train_steps_per_epoch * 1 46 | ) 47 | self.assertEqual( 48 | buf[2]["Training steps completed vs epochs"], num_train_steps_per_epoch * 2 49 | ) 50 | self.assertEqual( 51 | buf[3]["Training steps completed vs epochs"], num_train_steps_per_epoch * 3 52 | ) 53 | -------------------------------------------------------------------------------- /tests/framework/test_state.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | from torchtnt.framework import ActivePhase 13 | 14 | from torchtnt.framework.state import _check_loop_condition, PhaseState 15 | from torchtnt.utils.checkpoint import Phase 16 | 17 | 18 | class StateTest(unittest.TestCase): 19 | def test_check_loop_condition(self) -> None: 20 | var = "foo" 21 | _check_loop_condition(var, None) 22 | _check_loop_condition(var, 100) 23 | with self.assertRaisesRegex(ValueError, f"Invalid value provided for {var}"): 24 | _check_loop_condition(var, -1) 25 | 26 | def test_phase_state_validation(self) -> None: 27 | with self.assertRaisesRegex( 28 | ValueError, "Invalid value provided for max_epochs" 29 | ): 30 | PhaseState(dataloader=[], max_epochs=-2) 31 | with self.assertRaisesRegex(ValueError, "Invalid value provided for max_steps"): 32 | PhaseState(dataloader=[], max_steps=-2) 33 | with self.assertRaisesRegex( 34 | ValueError, "Invalid value provided for max_steps_per_epoch" 35 | ): 36 | PhaseState(dataloader=[], max_steps_per_epoch=-2) 37 | with self.assertRaisesRegex( 38 | ValueError, "Invalid value provided for evaluate_every_n_steps" 39 | ): 40 | PhaseState(dataloader=[], evaluate_every_n_steps=-2) 41 | with self.assertRaisesRegex( 42 | ValueError, "Invalid value provided for evaluate_every_n_epochs" 43 | ): 44 | PhaseState(dataloader=[], evaluate_every_n_epochs=-2) 45 | 46 | def test_active_phase_into_phase(self) -> None: 47 | active_phase = ActivePhase.TRAIN 48 | self.assertEqual(active_phase.into_phase(), Phase.TRAIN) 49 | 50 | eval_phase = ActivePhase.EVALUATE 51 | self.assertEqual(eval_phase.into_phase(), Phase.EVALUATE) 52 | 53 | predict_phase = ActivePhase.PREDICT 54 | self.assertEqual(predict_phase.into_phase(), Phase.PREDICT) 55 | 56 | def test_active_phase_str(self) -> None: 57 | active_phase = ActivePhase.TRAIN 58 | self.assertEqual(str(active_phase), "train") 59 | 60 | eval_phase = ActivePhase.EVALUATE 61 | self.assertEqual(str(eval_phase), "eval") 62 | 63 | predict_phase = ActivePhase.PREDICT 64 | self.assertEqual(str(predict_phase), "predict") 65 | 66 | def test_set_evaluate_every_n_steps_or_epochs(self) -> None: 67 | state = PhaseState(dataloader=[], evaluate_every_n_steps=2) 68 | state.evaluate_every_n_steps = None 69 | state.evaluate_every_n_steps = 100 70 | with self.assertRaisesRegex( 71 | ValueError, "Invalid value provided for evaluate_every_n_steps" 72 | ): 73 | state.evaluate_every_n_steps = -2 74 | 75 | state = PhaseState(dataloader=[], evaluate_every_n_epochs=2) 76 | state.evaluate_every_n_epochs = None 77 | state.evaluate_every_n_epochs = 100 78 | with self.assertRaisesRegex( 79 | ValueError, "Invalid value provided for evaluate_every_n_epochs" 80 | ): 81 | state.evaluate_every_n_epochs = -2 82 | -------------------------------------------------------------------------------- /tests/framework/test_unit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | from typing import Iterator 13 | 14 | import torch 15 | from torchtnt.framework._test_utils import get_dummy_train_state 16 | from torchtnt.framework.state import State 17 | from torchtnt.framework.unit import EvalUnit, PredictUnit, TrainUnit 18 | 19 | 20 | class TestUnit( 21 | EvalUnit[Iterator[torch.Tensor]], PredictUnit[torch.Tensor], TrainUnit[torch.Tensor] 22 | ): 23 | def __init__(self) -> None: 24 | super().__init__() 25 | 26 | def train_step(self, state: State, data: torch.Tensor) -> None: 27 | return 28 | 29 | def eval_step(self, state: State, data: Iterator[torch.Tensor]) -> None: 30 | return 31 | 32 | def predict_step(self, state: State, data: torch.Tensor) -> None: 33 | return 34 | 35 | 36 | class UnitTest(unittest.TestCase): 37 | def test_initialization_and_get_next_batch(self) -> None: 38 | unit = TestUnit() 39 | self.assertIsNotNone(unit.train_progress) 40 | self.assertIsNotNone(unit.eval_progress) 41 | self.assertIsNotNone(unit.predict_progress) 42 | 43 | tensor_1 = torch.ones(1) 44 | tensor_2 = torch.zeros(1) 45 | state = get_dummy_train_state() 46 | 47 | # test train next batch - exepct to return the elements within the iterable 48 | train_data_iter = iter([tensor_1, tensor_2]) 49 | self.assertEqual(unit.get_next_train_batch(state, train_data_iter), tensor_1) 50 | self.assertEqual(unit.get_next_train_batch(state, train_data_iter), tensor_2) 51 | 52 | # test predict next batch - exepct to return the elements within the iterable 53 | self.assertEqual( 54 | unit.get_next_predict_batch(state, iter([tensor_1, tensor_2])), tensor_1 55 | ) 56 | 57 | # test eval next batch - exepct to return the iterable 58 | data_iter = iter([tensor_1, tensor_2]) 59 | next_eval_batch = unit.get_next_eval_batch(state, data_iter) 60 | self.assertEqual(next_eval_batch, data_iter) 61 | -------------------------------------------------------------------------------- /tests/framework/test_unit_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from typing import Dict, Iterator 12 | 13 | import torch 14 | from torch.optim import Optimizer 15 | from torchtnt.framework._unit_utils import ( 16 | _find_optimizers_for_module, 17 | _step_requires_iterator, 18 | ) 19 | from torchtnt.framework.state import State 20 | 21 | 22 | class UnitUtilsTest(unittest.TestCase): 23 | def test_step_func_requires_iterator(self) -> None: 24 | class Foo: 25 | def bar(self, state: State, data: object) -> object: 26 | return data 27 | 28 | def baz(self, state: State, data: Iterator[torch.Tensor]) -> object: 29 | pass 30 | 31 | def dummy(a: int, b: str, data: Iterator[str]) -> None: 32 | pass 33 | 34 | foo = Foo() 35 | 36 | self.assertFalse(_step_requires_iterator(foo.bar)) 37 | self.assertTrue(_step_requires_iterator(foo.baz)) 38 | self.assertTrue(_step_requires_iterator(dummy)) 39 | 40 | def test_find_optimizers_for_module(self) -> None: 41 | module1 = torch.nn.Linear(10, 10) 42 | module2 = torch.nn.Linear(10, 10) 43 | optim1 = torch.optim.Adam(module1.parameters()) 44 | optim2 = torch.optim.Adagrad(module2.parameters()) 45 | 46 | opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2} 47 | optimizers = _find_optimizers_for_module(module1, opts) 48 | optim_name, _ = optimizers[0] 49 | self.assertEqual(optim_name, "optim1") 50 | optimizers = _find_optimizers_for_module(module2, opts) 51 | optim_name, _ = optimizers[0] 52 | self.assertEqual(optim_name, "optim2") 53 | -------------------------------------------------------------------------------- /tests/framework/test_unit_utils_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from typing import Dict 12 | 13 | import torch 14 | 15 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 16 | from torch.optim import Optimizer 17 | from torchtnt.framework._unit_utils import _find_optimizers_for_module 18 | from torchtnt.utils.distributed import spawn_multi_process 19 | from torchtnt.utils.env import init_from_env 20 | from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu 21 | 22 | 23 | class UnitUtilsGPUTest(unittest.TestCase): 24 | @skip_if_not_distributed 25 | @skip_if_not_gpu 26 | def test_find_optimizers_for_FSDP_module(self) -> None: 27 | spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module) 28 | 29 | @staticmethod 30 | def _find_optimizers_for_FSDP_module() -> None: 31 | device = init_from_env() 32 | module1 = FSDP(torch.nn.Linear(10, 10).to(device)) 33 | module2 = torch.nn.Linear(10, 10) 34 | optim1 = torch.optim.Adam(module1.parameters()) 35 | optim2 = torch.optim.Adagrad(module2.parameters()) 36 | 37 | opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2} 38 | optim_list = _find_optimizers_for_module(module1, opts) 39 | optim_name, _ = optim_list[0] 40 | 41 | tc = unittest.TestCase() 42 | tc.assertEqual(optim_name, "optim1") 43 | optim_list = _find_optimizers_for_module(module2, opts) 44 | optim_name, _ = optim_list[0] 45 | tc.assertEqual(optim_name, "optim2") 46 | -------------------------------------------------------------------------------- /tests/framework/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import time 11 | import unittest 12 | from unittest.mock import MagicMock, patch 13 | 14 | from torchtnt.framework.utils import get_timing_context 15 | 16 | from torchtnt.utils.timer import Timer 17 | 18 | 19 | class UtilsTest(unittest.TestCase): 20 | @patch("torchtnt.framework.utils.record_function") 21 | def test_get_timing_context(self, mock_record_function: MagicMock) -> None: 22 | state = MagicMock() 23 | state.timer = None 24 | 25 | ctx = get_timing_context(state, "a") 26 | with ctx: 27 | time.sleep(1) 28 | mock_record_function.assert_called_with("a") 29 | 30 | state.timer = Timer() 31 | ctx = get_timing_context(state, "b") 32 | with ctx: 33 | time.sleep(1) 34 | self.assertTrue("b" in state.timer.recorded_durations.keys()) 35 | mock_record_function.assert_called_with("b") 36 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/utils/data/test_data_prefetcher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from typing import Tuple 12 | from unittest.mock import MagicMock, patch 13 | 14 | import torch 15 | from torch.utils.data.dataset import Dataset, TensorDataset 16 | from torchtnt.utils.data.data_prefetcher import CudaDataPrefetcher 17 | 18 | Batch = Tuple[torch.Tensor, torch.Tensor] 19 | 20 | 21 | class DataPrefetcherTest(unittest.TestCase): 22 | def _generate_dataset(self, num_samples: int, input_dim: int) -> Dataset[Batch]: 23 | """Returns a dataset of random inputs and labels for binary classification.""" 24 | data = torch.randn(num_samples, input_dim) 25 | labels = torch.randint(low=0, high=2, size=(num_samples,)) 26 | return TensorDataset(data, labels) 27 | 28 | def test_device_data_prefetcher(self) -> None: 29 | device = torch.device("cpu") 30 | 31 | num_samples = 12 32 | batch_size = 4 33 | dataloader = torch.utils.data.DataLoader( 34 | self._generate_dataset(num_samples, 2), batch_size=batch_size 35 | ) 36 | 37 | num_prefetch_batches = 2 38 | with self.assertRaisesRegex(ValueError, "expects a CUDA device"): 39 | _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches) 40 | 41 | @patch("torch.cuda.Stream") 42 | def test_num_prefetch_batches_data_prefetcher(self, mock_stream: MagicMock) -> None: 43 | device = torch.device("cuda:0") 44 | 45 | num_samples = 12 46 | batch_size = 4 47 | dataloader = torch.utils.data.DataLoader( 48 | self._generate_dataset(num_samples, 2), batch_size=batch_size 49 | ) 50 | 51 | with self.assertRaisesRegex( 52 | ValueError, "`num_prefetch_batches` must be greater than 0" 53 | ): 54 | _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=-1) 55 | 56 | with self.assertRaisesRegex( 57 | ValueError, "`num_prefetch_batches` must be greater than 0" 58 | ): 59 | _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=0) 60 | 61 | # no exceptions raised 62 | _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=1) 63 | _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=2) 64 | 65 | # Check that CUDA streams were created 66 | self.assertEqual(mock_stream.call_count, 2) 67 | -------------------------------------------------------------------------------- /tests/utils/data/test_data_prefetcher_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | import unittest 10 | from typing import Tuple 11 | 12 | import torch 13 | 14 | from torch.utils.data import Dataset, TensorDataset 15 | from torchtnt.utils.data.data_prefetcher import CudaDataPrefetcher 16 | from torchtnt.utils.test_utils import skip_if_not_gpu 17 | 18 | Batch = Tuple[torch.Tensor, torch.Tensor] 19 | 20 | 21 | class DataPrefetcherGPUTest(unittest.TestCase): 22 | def _generate_dataset(self, num_samples: int, input_dim: int) -> Dataset[Batch]: 23 | """Returns a dataset of random inputs and labels for binary classification.""" 24 | data = torch.randn(num_samples, input_dim) 25 | labels = torch.randint(low=0, high=2, size=(num_samples,)) 26 | return TensorDataset(data, labels) 27 | 28 | @skip_if_not_gpu 29 | def test_cuda_data_prefetcher(self) -> None: 30 | device = torch.device("cuda:0") 31 | 32 | num_samples = 12 33 | batch_size = 4 34 | dataloader = torch.utils.data.DataLoader( 35 | self._generate_dataset(num_samples, 2), batch_size=batch_size 36 | ) 37 | 38 | num_prefetch_batches = 2 39 | data_prefetcher = CudaDataPrefetcher(dataloader, device, num_prefetch_batches) 40 | self.assertEqual(num_prefetch_batches, data_prefetcher.num_prefetch_batches) 41 | 42 | # make sure data_prefetcher has same number of samples as original dataloader 43 | num_batches_in_data_prefetcher = 0 44 | for inputs, targets in data_prefetcher: 45 | num_batches_in_data_prefetcher += 1 46 | # len(inputs) should equal the batch size 47 | self.assertEqual(len(inputs), batch_size) 48 | self.assertEqual(len(targets), batch_size) 49 | # make sure batch is on correct device 50 | self.assertEqual(inputs.device, device) 51 | self.assertEqual(targets.device, device) 52 | 53 | self.assertEqual(num_batches_in_data_prefetcher, num_samples / batch_size) 54 | -------------------------------------------------------------------------------- /tests/utils/data/test_iterators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | from torchtnt.utils.data.iterators import StoppingMechanism 12 | 13 | 14 | class TestIterators(unittest.TestCase): 15 | 16 | def test_stopping_mechanism_comparison(self) -> None: 17 | self.assertTrue( 18 | StoppingMechanism.ALL_DATASETS_EXHAUSTED == "ALL_DATASETS_EXHAUSTED" 19 | ) 20 | self.assertTrue( 21 | StoppingMechanism.ALL_DATASETS_EXHAUSTED 22 | == StoppingMechanism.ALL_DATASETS_EXHAUSTED 23 | ) 24 | self.assertFalse( 25 | StoppingMechanism.ALL_DATASETS_EXHAUSTED == "SMALLEST_DATASET_EXHAUSTED" 26 | ) 27 | self.assertFalse( 28 | StoppingMechanism.ALL_DATASETS_EXHAUSTED 29 | == StoppingMechanism.SMALLEST_DATASET_EXHAUSTED 30 | ) 31 | -------------------------------------------------------------------------------- /tests/utils/data/test_profile_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from typing import Iterator 12 | 13 | import torch 14 | 15 | # pyre-fixme[21]: Could not find name `ProfilerActivity` in `torch.profiler`. 16 | from torch.profiler import ProfilerActivity 17 | from torchtnt.utils.data.profile_dataloader import profile_dataloader 18 | from torchtnt.utils.env import init_from_env 19 | 20 | 21 | class DummyIterable: 22 | def __init__(self, count: int) -> None: 23 | self.count: int = count 24 | 25 | def __iter__(self) -> Iterator[int]: 26 | for i in range(self.count): 27 | yield i 28 | 29 | 30 | class ProfileDataLoaderTest(unittest.TestCase): 31 | def test_profile_dataloader(self) -> None: 32 | max_length = 10 33 | iterable = DummyIterable(max_length) 34 | with _get_torch_profiler() as p: 35 | timer = profile_dataloader(iterable, p) 36 | self.assertEqual(len(timer.recorded_durations["next(iter)"]), max_length) 37 | 38 | def test_profile_dataloader_max_steps(self) -> None: 39 | max_length = 10 40 | max_steps = 5 41 | iterable = DummyIterable(max_length) 42 | with _get_torch_profiler() as p: 43 | timer = profile_dataloader(iterable, p, max_steps=max_steps) 44 | self.assertEqual(len(timer.recorded_durations["next(iter)"]), max_steps) 45 | 46 | def test_profile_dataloader_profiler(self) -> None: 47 | max_length = 10 48 | iterable = DummyIterable(max_length) 49 | with _get_torch_profiler() as p: 50 | timer = profile_dataloader(iterable, p) 51 | self.assertEqual(len(timer.recorded_durations["next(iter)"]), max_length) 52 | 53 | def test_profile_dataloader_device(self) -> None: 54 | device = init_from_env() 55 | max_length = 10 56 | iterable = DummyIterable(max_length) 57 | with _get_torch_profiler() as p: 58 | timer = profile_dataloader(iterable, p, device=device) 59 | self.assertEqual(len(timer.recorded_durations["next(iter)"]), max_length) 60 | self.assertEqual( 61 | len(timer.recorded_durations["copy_data_to_device"]), max_length 62 | ) 63 | 64 | 65 | def _get_torch_profiler() -> torch.profiler.profile: 66 | profiler_schedule = torch.profiler.schedule( 67 | wait=0, 68 | warmup=1, 69 | active=1, 70 | ) 71 | return torch.profiler.profile( 72 | # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. 73 | activities=[ProfilerActivity.CPU], 74 | schedule=profiler_schedule, 75 | ) 76 | -------------------------------------------------------------------------------- /tests/utils/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/utils/loggers/test_csv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import csv 11 | import unittest 12 | from pathlib import Path 13 | from tempfile import TemporaryDirectory 14 | 15 | from torchtnt.utils.loggers.csv import CSVLogger 16 | 17 | 18 | class CSVLoggerTest(unittest.TestCase): 19 | def test_csv_log(self) -> None: 20 | with TemporaryDirectory() as tmpdir: 21 | csv_path = Path(tmpdir, "test.csv").as_posix() 22 | logger = CSVLogger(csv_path, steps_before_flushing=1) 23 | log_name = "asdf" 24 | log_value = 123.0 25 | log_step = 10 26 | logger.log(log_name, log_value, log_step) 27 | logger.close() 28 | 29 | with open(csv_path) as f: 30 | output = list(csv.DictReader(f)) 31 | self.assertEqual(float(output[0][log_name]), log_value) 32 | self.assertEqual(int(output[0]["step"]), log_step) 33 | 34 | def test_csv_log_async(self) -> None: 35 | with TemporaryDirectory() as tmpdir: 36 | csv_path = Path(tmpdir, "test.csv").as_posix() 37 | logger = CSVLogger(csv_path, steps_before_flushing=1, async_write=True) 38 | log_name = "asdf" 39 | log_value = 123.0 40 | log_step = 10 41 | logger.log(log_name, log_value, log_step) 42 | logger.close() 43 | 44 | with open(csv_path) as f: 45 | output = list(csv.DictReader(f)) 46 | self.assertEqual(float(output[0][log_name]), log_value) 47 | self.assertEqual(int(output[0]["step"]), log_step) 48 | -------------------------------------------------------------------------------- /tests/utils/loggers/test_in_memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from collections import OrderedDict 12 | from io import StringIO 13 | from typing import cast 14 | 15 | from torchtnt.utils.loggers.in_memory import InMemoryLogger 16 | from torchtnt.utils.test_utils import captured_output 17 | 18 | 19 | class InMemoryLoggerTest(unittest.TestCase): 20 | def test_in_memory_log(self) -> None: 21 | logger = InMemoryLogger() 22 | logger.log(name="metric1", data=123.0, step=0) 23 | logger.log(name="metric1", data=456.0, step=1) 24 | logger.log(name="metric1", data=789.0, step=2) 25 | # Test flushing. 26 | with captured_output() as (out, err): 27 | logger.flush() 28 | out = cast(StringIO, out) 29 | err = cast(StringIO, err) 30 | self.assertTrue(out.getvalue().startswith("OrderedDict([")) 31 | self.assertEqual(err.getvalue(), "") 32 | logger.log_dict(payload={"metric2": 1.0, "metric3": 2.0}, step=3) 33 | # Check the buffer directly. 34 | buf = logger.log_buffer 35 | self.assertEqual(len(buf), 4) 36 | self.assertEqual(buf[0]["metric1"], 123.0) 37 | self.assertEqual(buf[0]["step"], 0) 38 | self.assertEqual(buf[1]["metric1"], 456.0) 39 | self.assertEqual(buf[1]["step"], 1) 40 | self.assertEqual(buf[2]["metric1"], 789.0) 41 | self.assertEqual(buf[2]["step"], 2) 42 | self.assertEqual(buf[3]["metric2"], 1.0) 43 | self.assertEqual(buf[3]["metric3"], 2.0) 44 | self.assertEqual(buf[3]["step"], 3) 45 | # Test flushing. 46 | with captured_output() as (out, err): 47 | logger.flush() 48 | out = cast(StringIO, out) 49 | err = cast(StringIO, err) 50 | self.assertTrue(out.getvalue().startswith("OrderedDict([")) 51 | self.assertEqual(err.getvalue(), "") 52 | # Closing the log clears the buffer. 53 | logger.close() 54 | self.assertEqual(logger.log_buffer, OrderedDict([])) 55 | -------------------------------------------------------------------------------- /tests/utils/loggers/test_json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import json 11 | import unittest 12 | from pathlib import Path 13 | from tempfile import TemporaryDirectory 14 | 15 | from torchtnt.utils.loggers.json import JSONLogger 16 | 17 | 18 | class JSONLoggerTest(unittest.TestCase): 19 | def test_json_log(self) -> None: 20 | with TemporaryDirectory() as tmpdir: 21 | json_path = Path(tmpdir, "test.json").as_posix() 22 | logger = JSONLogger(json_path, steps_before_flushing=1) 23 | log_name = "asdf" 24 | log_value = 123.0 25 | log_step = 10 26 | logger.log(log_name, log_value, log_step) 27 | logger.close() 28 | 29 | with open(json_path) as f: 30 | d = json.load(f) 31 | print(d) 32 | self.assertTrue(len(d)) 33 | self.assertEqual(d[0][log_name], log_value) 34 | self.assertEqual(d[0]["step"], log_step) 35 | -------------------------------------------------------------------------------- /tests/utils/loggers/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import unittest 10 | 11 | import numpy as np 12 | import torch 13 | from torchtnt.utils.loggers.utils import scalar_to_float 14 | 15 | 16 | class TestUtilities(unittest.TestCase): 17 | def test_scalar_to_float(self) -> None: 18 | invalid_tensor = torch.Tensor([1, 2, 3]) 19 | with self.assertRaises(ValueError): 20 | scalar_to_float(invalid_tensor) 21 | 22 | float_x = 3.45 23 | valid_tensor = torch.Tensor([float_x]) 24 | self.assertAlmostEqual(scalar_to_float(valid_tensor), float_x) 25 | 26 | invalid_ndarray = np.array([23.45, 15.21]) 27 | with self.assertRaises(ValueError): 28 | scalar_to_float(invalid_ndarray) 29 | 30 | valid_ndarray = np.array([[[float_x]]]) 31 | self.assertAlmostEqual(scalar_to_float(valid_ndarray), float_x) 32 | 33 | def test_scalar_to_float_bf16(self) -> None: 34 | float_x = 3.45 35 | valid_tensor = torch.Tensor([float_x]).to(torch.bfloat16) 36 | self.assertAlmostEqual(scalar_to_float(valid_tensor), float_x, delta=0.01) 37 | -------------------------------------------------------------------------------- /tests/utils/test_anomaly_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import math 11 | import unittest 12 | 13 | from torchtnt.utils.anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator 14 | 15 | 16 | class TestAnomalyLogger(unittest.TestCase): 17 | 18 | def test_threshold(self) -> None: 19 | threshold = ThresholdEvaluator(min_val=0.5, max_val=0.9) 20 | self.assertFalse(threshold.is_anomaly()) 21 | 22 | threshold.update(0.4) 23 | self.assertTrue(threshold.is_anomaly()) 24 | 25 | threshold.update(0.6) 26 | self.assertFalse(threshold.is_anomaly()) 27 | 28 | threshold.update(0.95) 29 | self.assertTrue(threshold.is_anomaly()) 30 | 31 | threshold = ThresholdEvaluator(max_val=1) 32 | 33 | threshold.update(100.0) 34 | self.assertTrue(threshold.is_anomaly()) 35 | 36 | threshold.update(-500.0) 37 | self.assertFalse(threshold.is_anomaly()) 38 | 39 | def test_isnan(self) -> None: 40 | isnan = IsNaNEvaluator() 41 | self.assertFalse(isnan.is_anomaly()) 42 | 43 | isnan.update(0.4) 44 | self.assertFalse(isnan.is_anomaly()) 45 | 46 | isnan.update(math.nan) 47 | self.assertTrue(isnan.is_anomaly()) 48 | -------------------------------------------------------------------------------- /tests/utils/test_checkpoint_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import os 10 | import shutil 11 | import tempfile 12 | import unittest 13 | 14 | import torch.distributed as dist 15 | from torchtnt.utils import init_from_env 16 | from torchtnt.utils.checkpoint import get_checkpoint_dirpaths 17 | from torchtnt.utils.distributed import get_global_rank, spawn_multi_process 18 | from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu 19 | 20 | 21 | class TestCheckpointUtilsGPU(unittest.TestCase): 22 | 23 | @skip_if_not_distributed 24 | @skip_if_not_gpu 25 | def test_get_checkpoint_dirpaths_distributed(self) -> None: 26 | spawn_multi_process( 27 | 2, "nccl", self._test_get_checkpoint_dirpaths, timeout_s=180 28 | ) 29 | 30 | @staticmethod 31 | def _test_get_checkpoint_dirpaths() -> None: 32 | """ 33 | Tests retrieving checkpoint directories from a given root directory 34 | using NCCL on GPUs with custom state for pickling. 35 | """ 36 | init_from_env() 37 | paths = [ 38 | "epoch_0_step_10", 39 | "epoch_1_step_10_val_loss=10.5", 40 | "epoch_2_step_10", 41 | "epoch_0_step_5", 42 | "epoch_0_step_6_acc=0.03", 43 | "epoch_0_step_3", 44 | ] 45 | 46 | if get_global_rank() == 0: 47 | temp_dir = tempfile.mkdtemp() 48 | for path in paths: 49 | os.mkdir(os.path.join(temp_dir, path)) 50 | else: 51 | temp_dir = None 52 | 53 | tc = unittest.TestCase() 54 | # Only rank 0 will know about temp_dir 55 | if get_global_rank() != 0: 56 | tc.assertIsNone(temp_dir) 57 | 58 | ckpt_dirpaths = get_checkpoint_dirpaths( 59 | # pyre-fixme[6]: For 1st argument expected `str` but got `Optional[str]`. 60 | temp_dir, 61 | process_group=dist.group.WORLD, 62 | ) 63 | 64 | # Broadcast temp_dir to verify successful execution 65 | temp_dir = [temp_dir] if get_global_rank() == 0 else [None] 66 | dist.broadcast_object_list(temp_dir, src=0, group=dist.group.WORLD) 67 | temp_dir = temp_dir[0] 68 | tc.assertIsNotNone(temp_dir) 69 | 70 | tc.assertEqual( 71 | {str(x) for x in ckpt_dirpaths}, 72 | {os.path.join(temp_dir, path) for path in paths}, 73 | ) 74 | 75 | if get_global_rank() == 0: 76 | shutil.rmtree(temp_dir) 77 | -------------------------------------------------------------------------------- /tests/utils/test_device.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from unittest import mock 12 | 13 | import torch 14 | from torchtnt.utils.device import ( 15 | get_device_from_env, 16 | get_nvidia_smi_gpu_stats, 17 | get_psutil_cpu_stats, 18 | ) 19 | 20 | 21 | class DeviceTest(unittest.TestCase): 22 | def test_get_cpu_device(self) -> None: 23 | device = get_device_from_env() 24 | self.assertEqual(device.type, "cpu") 25 | self.assertEqual(device.index, None) 26 | 27 | def test_get_cpu_stats(self) -> None: 28 | """Get CPU stats, check that values are populated.""" 29 | cpu_stats = get_psutil_cpu_stats() 30 | # Check that percentages are between 0 and 100 31 | self.assertGreaterEqual(cpu_stats["cpu_vm_percent"], 0) 32 | self.assertLessEqual(cpu_stats["cpu_vm_percent"], 100) 33 | self.assertGreaterEqual(cpu_stats["cpu_percent"], 0) 34 | self.assertLessEqual(cpu_stats["cpu_percent"], 100) 35 | self.assertGreaterEqual(cpu_stats["cpu_swap_percent"], 0) 36 | self.assertLessEqual(cpu_stats["cpu_swap_percent"], 100) 37 | 38 | def test_get_gpu_stats(self) -> None: 39 | """Get Nvidia GPU stats, check that values are populated.""" 40 | device = torch.device("cuda:0") 41 | 42 | with mock.patch("shutil.which"), mock.patch( 43 | "torchtnt.utils.device.subprocess.run" 44 | ) as subprocess_run_mock: 45 | subprocess_run_mock.return_value.stdout = "0, 0, 0, 2, 16273, 38, 15" 46 | gpu_stats = get_nvidia_smi_gpu_stats(device) 47 | 48 | # Check that percentages are between 0 and 100 49 | self.assertGreaterEqual(gpu_stats["utilization_gpu_percent"], 0) 50 | self.assertLessEqual(gpu_stats["utilization_gpu_percent"], 100) 51 | self.assertGreaterEqual(gpu_stats["utilization_memory_percent"], 0) 52 | self.assertLessEqual(gpu_stats["utilization_memory_percent"], 100) 53 | self.assertGreaterEqual(gpu_stats["fan_speed_percent"], 0) 54 | self.assertLessEqual(gpu_stats["fan_speed_percent"], 100) 55 | 56 | # Check that values are greater than zero 57 | self.assertGreaterEqual(gpu_stats["memory_used_mb"], 0) 58 | self.assertGreaterEqual(gpu_stats["memory_free_mb"], 0) 59 | self.assertGreaterEqual(gpu_stats["temperature_gpu_celsius"], 0) 60 | self.assertGreaterEqual(gpu_stats["temperature_memory_celsius"], 0) 61 | -------------------------------------------------------------------------------- /tests/utils/test_distributed_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | import torch.distributed as dist 14 | from torchtnt.utils.device import get_device_from_env 15 | from torchtnt.utils.distributed import ( 16 | all_gather_tensors, 17 | broadcast_str, 18 | get_global_rank, 19 | get_local_rank, 20 | PGWrapper, 21 | spawn_multi_process, 22 | ) 23 | from torchtnt.utils.env import init_from_env 24 | from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu 25 | 26 | 27 | class DistributedGPUTest(unittest.TestCase): 28 | @skip_if_not_gpu 29 | @skip_if_not_distributed 30 | def test_gather_uneven_multidim_nccl(self) -> None: 31 | spawn_multi_process( 32 | 2, 33 | "nccl", 34 | self._test_ddp_gather_uneven_tensors_multidim_nccl, 35 | ) 36 | 37 | @staticmethod 38 | def _test_ddp_gather_uneven_tensors_multidim_nccl() -> None: 39 | rank = dist.get_rank() 40 | world_size = dist.get_world_size() 41 | tensor = torch.ones(rank + 1, 4 - rank, device=get_device_from_env()) 42 | result = all_gather_tensors(tensor) 43 | assert len(result) == world_size 44 | for idx in range(world_size): 45 | val = result[idx] 46 | assert val.shape == (idx + 1, 4 - idx) 47 | assert (val == 1).all() 48 | 49 | @skip_if_not_gpu 50 | @skip_if_not_distributed 51 | def test_pg_wrapper_scatter_object_list_nccl(self) -> None: 52 | spawn_multi_process( 53 | 2, 54 | "nccl", 55 | self._test_pg_wrapper_scatter_object_list, 56 | timeout_s=180, 57 | ) 58 | 59 | @classmethod 60 | def _test_pg_wrapper_scatter_object_list( 61 | cls, 62 | ) -> None: 63 | init_from_env() 64 | pg_wrapper = PGWrapper(dist.group.WORLD) 65 | output_list = [None] * 2 66 | pg_wrapper.scatter_object_list( 67 | output_list=output_list, 68 | input_list=[1, 2] if get_local_rank() == 0 else [None] * 2, 69 | src=0, 70 | ) 71 | tc = unittest.TestCase() 72 | tc.assertEqual(output_list[0], get_local_rank() + 1) 73 | 74 | @staticmethod 75 | def _test_method(offset_arg: int, offset_kwarg: int) -> int: 76 | return get_global_rank() + offset_arg - offset_kwarg 77 | 78 | @skip_if_not_gpu 79 | @skip_if_not_distributed 80 | def test_spawn_multi_process(self) -> None: 81 | mp_list = spawn_multi_process(2, "nccl", self._test_method, 3, offset_kwarg=2) 82 | self.assertEqual(mp_list, [1, 2]) 83 | 84 | @skip_if_not_gpu 85 | @skip_if_not_distributed 86 | def test_broadcast_str(self) -> None: 87 | spawn_multi_process(2, "gloo", self._test_broadcast_str) 88 | 89 | @staticmethod 90 | def _test_broadcast_str() -> None: 91 | """ 92 | Tests that test_broadcast_strworks as expected 93 | """ 94 | 95 | val = None 96 | if dist.get_rank() == 0: 97 | val = "foo" 98 | 99 | broadcasted_val = broadcast_str(val) 100 | 101 | tc = unittest.TestCase() 102 | tc.assertEqual(broadcasted_val, "foo") 103 | -------------------------------------------------------------------------------- /tests/utils/test_early_stop_checker_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | from torchtnt.utils.early_stop_checker import EarlyStopChecker 14 | from torchtnt.utils.test_utils import skip_if_not_gpu 15 | 16 | 17 | class EarlyStopCheckerGPUTest(unittest.TestCase): 18 | @skip_if_not_gpu 19 | def test_early_stop_min_delta_on_gpu(self) -> None: 20 | device = torch.device("cuda:0") 21 | 22 | # Loss decreases beyond 0.25 but not more than min_delta 23 | losses = [ 24 | torch.tensor([0.4], device=device), 25 | torch.tensor([0.38], device=device), 26 | torch.tensor([0.31], device=device), 27 | torch.tensor([0.25], device=device), 28 | torch.tensor([0.27], device=device), 29 | torch.tensor([0.24], device=device), 30 | ] 31 | es1 = EarlyStopChecker("min", 3, min_delta=0.05) 32 | es2 = EarlyStopChecker("min", 4, min_delta=0.05) 33 | 34 | for loss in losses: 35 | should_stop = es1.check(torch.tensor(loss)) 36 | self.assertFalse(should_stop) 37 | should_stop = es2.check(torch.tensor(loss)) 38 | self.assertFalse(should_stop) 39 | 40 | # Patience should run out 41 | should_stop = es1.check(torch.tensor(0.25)) 42 | self.assertTrue(should_stop) 43 | 44 | # es2 has more patience than es1 45 | should_stop = es2.check(torch.tensor(0.25)) 46 | self.assertFalse(should_stop) 47 | should_stop = es2.check(torch.tensor(0.26)) 48 | self.assertTrue(should_stop) 49 | -------------------------------------------------------------------------------- /tests/utils/test_fsspec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | import tempfile 12 | import unittest 13 | from typing import Any 14 | 15 | import fsspec 16 | 17 | from torchtnt.utils.fsspec import get_filesystem 18 | 19 | 20 | class FsTest(unittest.TestCase): 21 | def _test_operations( 22 | self, 23 | fs: fsspec.AbstractFileSystem, 24 | directory: str, 25 | filename: str, 26 | **kwargs: Any, 27 | ) -> None: 28 | """Tests normal filsystem operations on the given directory and file. 29 | 30 | Args: 31 | fs: The filesystem to use when testing. 32 | directory: The directory containing the file. 33 | filename: The name of the file. 34 | kwargs: Passed to any write operations as additional arguments. 35 | """ 36 | filepath = os.path.join(directory, filename) 37 | 38 | with fs.open(filepath, mode="w", **kwargs) as f: 39 | f.write("blob") 40 | self.assertTrue(fs.exists(filepath)) 41 | self.assertTrue(fs.isfile(filepath)) 42 | 43 | def test_get_filesystem(self) -> None: 44 | with tempfile.TemporaryDirectory() as temp_dir: 45 | self._test_operations( 46 | fs=get_filesystem(temp_dir), 47 | directory=temp_dir, 48 | filename="test_fs.txt", 49 | ) 50 | -------------------------------------------------------------------------------- /tests/utils/test_memory_snapshot_profiler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import tempfile 11 | import unittest 12 | 13 | from torchtnt.utils.memory_snapshot_profiler import ( 14 | MemorySnapshotParams, 15 | MemorySnapshotProfiler, 16 | ) 17 | 18 | 19 | class MemorySnapshotProfilerTest(unittest.TestCase): 20 | def test_validation(self) -> None: 21 | """Test parameter validation.""" 22 | with tempfile.TemporaryDirectory() as temp_dir: 23 | with self.assertRaisesRegex(ValueError, "start_step must be nonnegative."): 24 | _ = MemorySnapshotProfiler( 25 | output_dir=temp_dir, 26 | memory_snapshot_params=MemorySnapshotParams( 27 | start_step=-1, stop_step=0 28 | ), 29 | ) 30 | with self.assertRaisesRegex( 31 | ValueError, "stop_step must be specified when start_step is set." 32 | ): 33 | _ = MemorySnapshotProfiler( 34 | output_dir=temp_dir, 35 | memory_snapshot_params=MemorySnapshotParams( 36 | start_step=2, stop_step=None 37 | ), 38 | ) 39 | with self.assertRaisesRegex(ValueError, "start_step must be < stop_step."): 40 | _ = MemorySnapshotProfiler( 41 | output_dir=temp_dir, 42 | memory_snapshot_params=MemorySnapshotParams( 43 | start_step=2, stop_step=0 44 | ), 45 | ) 46 | with self.assertRaisesRegex(ValueError, "stop_step must be positive."): 47 | _ = MemorySnapshotProfiler( 48 | output_dir=temp_dir, 49 | memory_snapshot_params=MemorySnapshotParams(stop_step=0), 50 | ) 51 | with self.assertRaisesRegex( 52 | ValueError, 53 | "stop_step must be enabled with either start_step or enable_oom_observer.", 54 | ): 55 | _ = MemorySnapshotProfiler( 56 | output_dir=temp_dir, 57 | memory_snapshot_params=MemorySnapshotParams( 58 | stop_step=2, enable_oom_observer=False 59 | ), 60 | ) 61 | with self.assertRaisesRegex( 62 | ValueError, 63 | "At least one of start_step/stop_step or enable_oom_observer must be set.", 64 | ): 65 | _ = MemorySnapshotProfiler( 66 | output_dir=temp_dir, 67 | memory_snapshot_params=MemorySnapshotParams( 68 | start_step=None, stop_step=None, enable_oom_observer=False 69 | ), 70 | ) 71 | -------------------------------------------------------------------------------- /tests/utils/test_memory_snapshot_profiler_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | import tempfile 12 | import unittest 13 | 14 | import torch 15 | from torchtnt.utils.device import get_device_from_env 16 | from torchtnt.utils.memory_snapshot_profiler import ( 17 | MemorySnapshotParams, 18 | MemorySnapshotProfiler, 19 | ) 20 | from torchtnt.utils.test_utils import skip_if_not_gpu 21 | 22 | 23 | class MemorySnapshotProfilerGPUTest(unittest.TestCase): 24 | @skip_if_not_gpu 25 | def test_stop_step(self) -> None: 26 | """Test that a memory snapshot is saved when stop_step is reached.""" 27 | with tempfile.TemporaryDirectory() as temp_dir: 28 | memory_snapshot_profiler = MemorySnapshotProfiler( 29 | output_dir=temp_dir, 30 | memory_snapshot_params=MemorySnapshotParams(start_step=0, stop_step=2), 31 | ) 32 | 33 | # initialize device & allocate memory for tensors 34 | device = get_device_from_env() 35 | a = torch.rand((1024, 1024), device=device) 36 | b = torch.rand((1024, 1024), device=device) 37 | _ = (a + b) * (a - b) 38 | 39 | memory_snapshot_profiler.step() 40 | 41 | # Check if the corresponding files exist 42 | save_dir = os.path.join(temp_dir, "step_2_rank0") 43 | 44 | pickle_dump_path = os.path.join(save_dir, "snapshot.pickle") 45 | trace_path = os.path.join(save_dir, "trace_plot.html") 46 | segment_plot_path = os.path.join(save_dir, "segment_plot.html") 47 | 48 | # after first step files do not exist 49 | self.assertFalse(os.path.exists(pickle_dump_path)) 50 | self.assertFalse(os.path.exists(trace_path)) 51 | self.assertFalse(os.path.exists(segment_plot_path)) 52 | 53 | # after second step stop_step is reached and files should exist 54 | memory_snapshot_profiler.step() 55 | self.assertTrue(os.path.exists(pickle_dump_path)) 56 | self.assertTrue(os.path.exists(trace_path)) 57 | self.assertTrue(os.path.exists(segment_plot_path)) 58 | -------------------------------------------------------------------------------- /tests/utils/test_nan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | 14 | from torchtnt.utils.nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph 15 | 16 | 17 | class NaNFunction(torch.autograd.Function): 18 | @staticmethod 19 | # pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently 20 | def forward(ctx, input): 21 | return input.clone() 22 | 23 | @staticmethod 24 | # pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently 25 | def backward(ctx, grad_output): 26 | return torch.tensor([float("nan")], device="cpu") 27 | 28 | 29 | class NanHookTest(unittest.TestCase): 30 | def test_register_nan_hooks_on_whole_graph(self) -> None: 31 | x = torch.tensor([1.0], device="cpu", requires_grad=True) 32 | out = NaNFunction.apply(x) 33 | 34 | # no error is thrown 35 | out.backward() 36 | 37 | _ = register_nan_hooks_on_whole_graph([out]) 38 | with self.assertRaisesRegex(RuntimeError, "Detected NaN"): 39 | out.backward() 40 | 41 | def test_check_for_nan_or_inf(self) -> None: 42 | tensor = torch.tensor([float("nan")], device="cpu") 43 | 44 | with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"): 45 | check_for_nan_or_inf(tensor) 46 | 47 | tensor = torch.tensor([float("inf")], device="cpu") 48 | with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"): 49 | check_for_nan_or_inf(tensor) 50 | -------------------------------------------------------------------------------- /tests/utils/test_oom.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | from torchtnt.utils.oom import ( 13 | _bytes_to_mb_gb, 14 | is_out_of_cpu_memory, 15 | is_out_of_cuda_memory, 16 | is_out_of_memory_error, 17 | ) 18 | 19 | 20 | class OomTest(unittest.TestCase): 21 | def test_is_out_of_cpu_memory(self) -> None: 22 | """Test CPU OOM error detection.""" 23 | cpu_oom_error = RuntimeError("DefaultCPUAllocator: can't allocate memory") 24 | self.assertTrue(is_out_of_cpu_memory(cpu_oom_error)) 25 | not_cpu_oom_error = RuntimeError("RuntimeError: blah") 26 | self.assertFalse(is_out_of_cpu_memory(not_cpu_oom_error)) 27 | 28 | def test_is_out_of_cuda_memory(self) -> None: 29 | """Test cuda OOM error detection.""" 30 | cuda_oom_error_1 = RuntimeError("CUDA out of memory. Tried to allocate ...") 31 | self.assertTrue(is_out_of_cuda_memory(cuda_oom_error_1)) 32 | cuda_oom_error_2 = RuntimeError( 33 | "RuntimeError: cuda runtime error (2) : out of memory" 34 | ) 35 | self.assertTrue(is_out_of_cuda_memory(cuda_oom_error_2)) 36 | not_cuda_oom_error = RuntimeError("RuntimeError: blah") 37 | self.assertFalse(is_out_of_cuda_memory(not_cuda_oom_error)) 38 | 39 | def test_is_out_of_memory_error(self) -> None: 40 | """Test general OOM error detection.""" 41 | cpu_oom_error = RuntimeError("DefaultCPUAllocator: can't allocate memory") 42 | self.assertTrue(is_out_of_memory_error(cpu_oom_error)) 43 | cuda_oom_error_1 = RuntimeError("CUDA out of memory. Tried to allocate ...") 44 | self.assertTrue(is_out_of_memory_error(cuda_oom_error_1)) 45 | cuda_oom_error_2 = RuntimeError( 46 | "RuntimeError: cuda runtime error (2) : out of memory" 47 | ) 48 | self.assertTrue(is_out_of_memory_error(cuda_oom_error_2)) 49 | not_oom_error = RuntimeError("RuntimeError: blah") 50 | self.assertFalse(is_out_of_memory_error(not_oom_error)) 51 | 52 | def test_bytes_to_mb_gb(self) -> None: 53 | bytes_to_mb_test_cases = [ 54 | (0, "0.0 MB"), 55 | (100000, "0.1 MB"), 56 | (1000000, "0.95 MB"), 57 | (1000000000, "0.93 GB"), 58 | (1000000000000, "931.32 GB"), 59 | ] 60 | for inp, expected in bytes_to_mb_test_cases: 61 | self.assertEqual(expected, _bytes_to_mb_gb(inp)) 62 | -------------------------------------------------------------------------------- /tests/utils/test_oom_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | import tempfile 12 | import unittest 13 | 14 | import torch 15 | from torchtnt.utils.device import get_device_from_env 16 | from torchtnt.utils.oom import log_memory_snapshot 17 | 18 | from torchtnt.utils.test_utils import skip_if_not_gpu 19 | 20 | 21 | class OomGPUTest(unittest.TestCase): 22 | @skip_if_not_gpu 23 | def test_log_memory_snapshot(self) -> None: 24 | with tempfile.TemporaryDirectory() as temp_dir: 25 | # Record history 26 | torch.cuda.memory._record_memory_history(enabled="all", max_entries=10000) 27 | 28 | # initialize device & allocate memory for tensors 29 | device = get_device_from_env() 30 | a = torch.rand((1024, 1024), device=device) 31 | b = torch.rand((1024, 1024), device=device) 32 | _ = (a + b) * (a - b) 33 | 34 | # save a snapshot 35 | log_memory_snapshot(temp_dir, "foo") 36 | 37 | # Check if the corresponding files exist 38 | save_dir = os.path.join(temp_dir, "foo_rank0") 39 | 40 | pickle_dump_path = os.path.join(save_dir, "snapshot.pickle") 41 | self.assertTrue(os.path.exists(pickle_dump_path)) 42 | 43 | trace_path = os.path.join(save_dir, "trace_plot.html") 44 | self.assertTrue(os.path.exists(trace_path)) 45 | 46 | segment_plot_path = os.path.join(save_dir, "segment_plot.html") 47 | self.assertTrue(os.path.exists(segment_plot_path)) 48 | -------------------------------------------------------------------------------- /tests/utils/test_optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | from torchtnt.utils.env import init_from_env 14 | from torchtnt.utils.optimizer import init_optim_state 15 | 16 | 17 | class OptimizerTest(unittest.TestCase): 18 | def test_init_optim_state(self) -> None: 19 | """Test optimizer skeleton state initialization.""" 20 | device = init_from_env() 21 | module = torch.nn.Linear(1, 1, device=device) 22 | original_state_dict = module.state_dict().copy() 23 | optimizer = torch.optim.AdamW(module.parameters(), lr=0.01) 24 | self.assertEqual(optimizer.state, {}) 25 | 26 | init_optim_state(optimizer) 27 | 28 | # check that optimizer state has been initialized 29 | self.assertNotEqual(optimizer.state, {}) 30 | 31 | # check that parameters have not changed 32 | self.assertTrue( 33 | torch.allclose( 34 | original_state_dict["weight"], 35 | module.state_dict()["weight"], 36 | ) 37 | ) 38 | self.assertTrue( 39 | torch.allclose( 40 | original_state_dict["bias"], 41 | module.state_dict()["bias"], 42 | ) 43 | ) 44 | -------------------------------------------------------------------------------- /tests/utils/test_precision.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | from torch.amp.grad_scaler import GradScaler 14 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 15 | 16 | from torchtnt.utils.precision import ( 17 | convert_precision_str_to_dtype, 18 | get_grad_scaler_from_precision, 19 | ) 20 | 21 | 22 | class PrecisionTest(unittest.TestCase): 23 | def test_convert_precision_str_to_dtype_success(self) -> None: 24 | for precision_str, expected_dtype in [ 25 | ("fp16", torch.float16), 26 | ("bf16", torch.bfloat16), 27 | ("fp32", None), 28 | ]: 29 | with self.subTest( 30 | precision_str=precision_str, expected_dtype=expected_dtype 31 | ): 32 | self.assertEqual( 33 | convert_precision_str_to_dtype(precision_str), expected_dtype 34 | ) 35 | 36 | def test_convert_precision_str_to_dtype_throws(self) -> None: 37 | with self.assertRaisesRegex( 38 | ValueError, 39 | "Precision foo not supported. Please use one of .*", 40 | ): 41 | convert_precision_str_to_dtype("foo") 42 | 43 | def test_get_grad_scaler_from_precision(self) -> None: 44 | grad_scaler = get_grad_scaler_from_precision( 45 | torch.float32, is_fsdp1_module=False 46 | ) 47 | self.assertIsNone(grad_scaler) 48 | 49 | grad_scaler = get_grad_scaler_from_precision( 50 | torch.float16, is_fsdp1_module=False 51 | ) 52 | self.assertIsInstance(grad_scaler, GradScaler) 53 | 54 | grad_scaler = get_grad_scaler_from_precision( 55 | torch.float16, is_fsdp1_module=True 56 | ) 57 | self.assertIsInstance(grad_scaler, ShardedGradScaler) 58 | -------------------------------------------------------------------------------- /tests/utils/test_rank_zero_log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from unittest.mock import MagicMock, patch 12 | 13 | from torchtnt.utils.rank_zero_log import ( 14 | _supports_stacklevel, 15 | rank_zero_critical, 16 | rank_zero_debug, 17 | rank_zero_error, 18 | rank_zero_info, 19 | rank_zero_warn, 20 | ) 21 | 22 | 23 | class RankZeroLogTest(unittest.TestCase): 24 | @patch.dict("os.environ", {"RANK": "0"}, clear=True) 25 | def test_rank_zero_fn_rank_zero(self) -> None: 26 | 27 | logger = MagicMock() 28 | supports_stacklevel = _supports_stacklevel() 29 | 30 | rank_zero_debug("foo", logger=logger) 31 | if supports_stacklevel: 32 | logger.debug.assert_called_once_with("foo", stacklevel=2) 33 | else: 34 | logger.debug.assert_called_once_with("foo") 35 | 36 | rank_zero_info("foo", logger=logger) 37 | if supports_stacklevel: 38 | logger.info.assert_called_once_with("foo", stacklevel=2) 39 | else: 40 | logger.info.assert_called_once_with("foo") 41 | 42 | rank_zero_warn("foo", logger=logger) 43 | if supports_stacklevel: 44 | logger.warning.assert_called_once_with("foo", stacklevel=2) 45 | else: 46 | logger.warning.assert_called_once_with("foo") 47 | 48 | rank_zero_error("foo", logger=logger) 49 | if supports_stacklevel: 50 | logger.error.assert_called_once_with("foo", stacklevel=2) 51 | else: 52 | logger.error.assert_called_once_with("foo") 53 | 54 | rank_zero_critical("foo", logger=logger) 55 | if supports_stacklevel: 56 | logger.critical.assert_called_once_with("foo", stacklevel=2) 57 | else: 58 | logger.critical.assert_called_once_with("foo") 59 | 60 | @patch.dict("os.environ", {"RANK": "1"}, clear=True) 61 | def test_rank_zero_fn_rank_non_zero(self) -> None: 62 | 63 | logger = MagicMock() 64 | 65 | rank_zero_debug("foo", logger=logger) 66 | logger.debug.assert_not_called() 67 | 68 | rank_zero_info("foo", logger=logger) 69 | logger.info.assert_not_called() 70 | 71 | rank_zero_warn("foo", logger=logger) 72 | logger.warning.assert_not_called() 73 | 74 | rank_zero_error("foo", logger=logger) 75 | logger.error.assert_not_called() 76 | 77 | rank_zero_critical("foo", logger=logger) 78 | logger.critical.assert_not_called() 79 | -------------------------------------------------------------------------------- /tests/utils/test_timer_gpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import time 11 | import unittest 12 | from unittest.mock import MagicMock, patch 13 | 14 | import torch 15 | from torchtnt.utils.test_utils import skip_if_not_gpu 16 | from torchtnt.utils.timer import Timer 17 | 18 | 19 | class TimerGPUTest(unittest.TestCase): 20 | @skip_if_not_gpu 21 | @patch("torch.cuda.synchronize") 22 | def test_timer_synchronize(self, mock_synchronize: MagicMock) -> None: 23 | """Make sure that torch.cuda.synchronize() is called when GPU is present.""" 24 | 25 | start_event = torch.cuda.Event(enable_timing=True) 26 | end_event = torch.cuda.Event(enable_timing=True) 27 | timer = Timer() 28 | 29 | # Do not explicitly call synchronize, timer must call it for test to pass. 30 | 31 | with timer.time("action_1"): 32 | start_event.record() 33 | time.sleep(0.5) 34 | end_event.record() 35 | 36 | self.assertEqual(mock_synchronize.call_count, 2) 37 | -------------------------------------------------------------------------------- /tests/utils/test_tqdm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import sys 11 | import unittest 12 | from io import StringIO 13 | from unittest.mock import MagicMock, patch 14 | 15 | from torchtnt.utils.tqdm import create_progress_bar 16 | 17 | 18 | class TQDMTest(unittest.TestCase): 19 | @patch("sys.stdout", new_callable=StringIO) 20 | @patch("sys.stderr", new_callable=StringIO) 21 | def test_tqdm_file(self, mock_stderr: MagicMock, mock_stdout: MagicMock) -> None: 22 | """ 23 | Test the file argument to create_progress_bar 24 | """ 25 | 26 | create_progress_bar( 27 | dataloader=["foo", "bar"], 28 | desc="foo", 29 | num_epochs_completed=0, 30 | num_steps_completed=0, 31 | max_steps=None, 32 | max_steps_per_epoch=None, 33 | file=None, 34 | ) 35 | self.assertIn( 36 | "foo 0: 0%| | 0/2 [00:00 None: 22 | mock_system.return_value = "Linux" 23 | self.assertFalse(version.is_windows()) 24 | 25 | mock_system.return_value = "Darwin" 26 | self.assertFalse(version.is_windows()) 27 | 28 | mock_system.return_value = "Windows" 29 | self.assertTrue(version.is_windows()) 30 | 31 | @patch("platform.python_version") 32 | def test_get_python_version(self, mock_python_version: MagicMock) -> None: 33 | mock_python_version.return_value = "3.8.0" 34 | self.assertEqual(version.get_python_version(), Version("3.8.0")) 35 | self.assertNotEqual(version.get_python_version(), Version("3.10.5")) 36 | 37 | mock_python_version.return_value = "3.10.5" 38 | self.assertNotEqual(version.get_python_version(), Version("3.8.0")) 39 | self.assertEqual(version.get_python_version(), Version("3.10.5")) 40 | 41 | def test_get_torch_version(self) -> None: 42 | with patch.object(torch, "__version__", "1.8.3"): 43 | self.assertEqual(version.get_torch_version(), Version("1.8.3")) 44 | self.assertNotEqual(version.get_torch_version(), Version("1.12.0")) 45 | 46 | with patch.object(torch, "__version__", "1.12.0"): 47 | self.assertNotEqual(version.get_torch_version(), Version("1.8.3")) 48 | self.assertEqual(version.get_torch_version(), Version("1.12.0")) 49 | 50 | def test_torch_version_comparators(self) -> None: 51 | with patch.object(torch, "__version__", "2.0.0a0"): 52 | self.assertFalse(version.is_torch_version_geq("2.1.0")) 53 | -------------------------------------------------------------------------------- /torchtnt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | __version__ = "0.2.4" 10 | -------------------------------------------------------------------------------- /torchtnt/framework/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from .auto_unit import AutoPredictUnit, AutoUnit 10 | from .callback import Callback 11 | from .evaluate import evaluate 12 | from .fit import fit 13 | from .predict import predict 14 | from .state import ActivePhase, EntryPoint, PhaseState, State 15 | from .train import train 16 | from .unit import EvalUnit, PredictUnit, TEvalUnit, TPredictUnit, TrainUnit, TTrainUnit 17 | 18 | __all__ = [ 19 | "AutoPredictUnit", 20 | "AutoUnit", 21 | "Callback", 22 | "evaluate", 23 | "fit", 24 | "predict", 25 | "ActivePhase", 26 | "EntryPoint", 27 | "PhaseState", 28 | "State", 29 | "train", 30 | "EvalUnit", 31 | "PredictUnit", 32 | "TEvalUnit", 33 | "TPredictUnit", 34 | "TrainUnit", 35 | "TTrainUnit", 36 | ] 37 | -------------------------------------------------------------------------------- /torchtnt/framework/_unit_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import collections 10 | import inspect 11 | import logging 12 | from typing import Callable, Dict, List, Tuple, TypeVar 13 | 14 | import torch 15 | import typing_extensions 16 | from torchtnt.framework.state import State 17 | 18 | _logger: logging.Logger = logging.getLogger(__name__) 19 | T = TypeVar("T") 20 | 21 | 22 | def _step_requires_iterator(step_func: Callable[[State, T], object]) -> bool: 23 | """ 24 | Helper function to evaluate whether the get_next_X_batch method should pass the data iterator to the `X_step` 25 | functions, or whether get_next_X_batch should call `next(data_iter)` and pass a single batch to the step method. 26 | 27 | This is closely tied to the Unit's corresponding step function signature. 28 | """ 29 | argspec = inspect.getfullargspec(step_func) 30 | annotations = argspec.annotations 31 | if "data" not in annotations: 32 | _logger.warning( 33 | f"Expected step function to have an annotated argument named ``data``. Found {annotations}." 34 | ) 35 | return False 36 | annotated_type = annotations["data"] 37 | return typing_extensions.get_origin(annotated_type) is collections.abc.Iterator 38 | 39 | 40 | def _find_optimizers_for_module( 41 | module: torch.nn.Module, optimizers: Dict[str, torch.optim.Optimizer] 42 | ) -> List[Tuple[str, torch.optim.Optimizer]]: 43 | """ 44 | Given a module, returns a list of optimizers that are associated with it. 45 | """ 46 | optimizer_list = [] 47 | module_params = [param.data_ptr() for param in module.parameters()] 48 | for optim_name, optimizer in optimizers.items(): 49 | optimizer_params = [ 50 | param.data_ptr() for param in optimizer.param_groups[0]["params"] 51 | ] 52 | if all(module_param in optimizer_params for module_param in module_params): 53 | optimizer_list.append((optim_name, optimizer)) 54 | return optimizer_list 55 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from .base_csv_writer import BaseCSVWriter 10 | from .dcp_saver import DistributedCheckpointSaver 11 | from .early_stopping import EarlyStopping 12 | from .empty_cuda_cache import EmptyCudaCache 13 | from .garbage_collector import GarbageCollector 14 | from .iteration_time_logger import IterationTimeLogger 15 | from .lambda_callback import Lambda 16 | from .learning_rate_monitor import LearningRateMonitor 17 | from .memory_snapshot import MemorySnapshot 18 | from .module_summary import ModuleSummary 19 | from .periodic_distributed_sync import PeriodicDistributedSync 20 | from .progress_reporter import ProgressReporter 21 | from .pytorch_profiler import PyTorchProfiler 22 | from .slow_rank_detector import SlowRankDetector 23 | from .system_resources_monitor import SystemResourcesMonitor 24 | from .tensorboard_parameter_monitor import TensorBoardParameterMonitor 25 | from .tensorfloat32 import EnableTensorFloat32 26 | from .throughput_logger import ThroughputLogger 27 | from .time_limit_interrupter import TimeLimitInterrupter 28 | from .time_wait_for_batch_logger import TimeWaitForBatchLogger 29 | from .torch_compile import TorchCompile 30 | from .torchsnapshot_saver import TorchSnapshotSaver 31 | from .tqdm_progress_bar import TQDMProgressBar 32 | from .train_progress_monitor import TrainProgressMonitor 33 | 34 | __all__ = [ 35 | "BaseCSVWriter", 36 | "EarlyStopping", 37 | "EmptyCudaCache", 38 | "EnableTensorFloat32", 39 | "GarbageCollector", 40 | "IterationTimeLogger", 41 | "Lambda", 42 | "LearningRateMonitor", 43 | "MemorySnapshot", 44 | "ModuleSummary", 45 | "PeriodicDistributedSync", 46 | "ProgressReporter", 47 | "PyTorchProfiler", 48 | "SlowRankDetector", 49 | "SystemResourcesMonitor", 50 | "TensorBoardParameterMonitor", 51 | "ThroughputLogger", 52 | "TimeLimitInterrupter", 53 | "TimeWaitForBatchLogger", 54 | "TorchCompile", 55 | "TorchSnapshotSaver", 56 | "TQDMProgressBar", 57 | "TrainProgressMonitor", 58 | "DistributedCheckpointSaver", 59 | ] 60 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/checkpointer_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from dataclasses import dataclass 10 | from typing import Optional 11 | 12 | 13 | # TODO: eventually support overriding all knobs 14 | @dataclass 15 | class KnobOptions: 16 | """ 17 | Controls the knobs for Checkpoints. 18 | 19 | Args: 20 | max_per_rank_io_concurrency: Maximum number of concurrent IO operations per rank in checkpointing. 21 | Defaults to 16. 22 | enable_storage_optimization: Enable storage efficiency optimizations for Distributed Checkpointing. 23 | """ 24 | 25 | # use a more conservative number of concurrent IO operations per rank in Checkpointing 26 | # the default value of 16 is too bandwidth hungry for most users 27 | max_per_rank_io_concurrency: Optional[int] = None 28 | # This would enable storage efficiency optimizations (model store): 29 | # e.g. Compression, Batching, Quantization etc. 30 | enable_storage_optimization: bool = True 31 | 32 | 33 | @dataclass 34 | class RestoreOptions: 35 | """ 36 | Options when restoring a snapshot. 37 | 38 | Args: 39 | restore_modules: Whether to restore the module state dict. 40 | restore_train_progress: Whether to restore the training progress state. 41 | restore_eval_progress: Whether to restore the evaluation progress state. 42 | restore_predict_progress: Whether to restore the prediction progress state. 43 | restore_optimizers: Whether to restore the optimizer states. 44 | restore_lr_schedulers: Whether to restore the lr scheduler states. 45 | strict: Whether to strictly restore app state and the module state dict. 46 | init_optim_states: Whether to initialize the optimizer state. Defaults to True. Toggle off 47 | if running into issues with loading optimizer state. This will reset optimizer state, 48 | which may affect training in some cases. 49 | """ 50 | 51 | restore_modules: bool = True 52 | restore_train_progress: bool = True 53 | restore_eval_progress: bool = True 54 | restore_predict_progress: bool = True 55 | restore_optimizers: bool = True 56 | restore_lr_schedulers: bool = True 57 | strict: bool = True 58 | init_optim_states: bool = True 59 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/empty_cuda_cache.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import cast 10 | 11 | import torch 12 | 13 | from torchtnt.framework.callback import Callback 14 | from torchtnt.framework.state import EntryPoint, State 15 | from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit 16 | 17 | 18 | class EmptyCudaCache(Callback): 19 | """ 20 | A callback that performs periodic emptying of cuda cache using `torch.cuda.empty_cache _. 21 | 22 | On different ranks, reserved memory and fragmentation might diverge after several iterations. 23 | If different ranks trigger de-fragmentation (i.e. cudaFree and redo cudaMalloc later) 24 | at different times, there will be different stragglers in different iterations, which will 25 | hurt the performance and will get worse with larger clusters. To avoid this, this callback 26 | calls empty_cache() at the same cadence across all ranks. 27 | 28 | Args: 29 | step_interval: number of steps to run before emptying cuda cache 30 | """ 31 | 32 | def __init__(self, step_interval: int) -> None: 33 | self._step_interval = step_interval 34 | 35 | def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: 36 | total_num_steps_completed = unit.train_progress.num_steps_completed 37 | if state.entry_point == EntryPoint.FIT: 38 | # if fitting, unit should also subclass EvalUnit 39 | unit_as_eval_unit = cast(TEvalUnit, unit) 40 | # if fitting, include the num eval steps completed in the total steps completed 41 | total_num_steps_completed += ( 42 | unit_as_eval_unit.eval_progress.num_steps_completed 43 | ) 44 | 45 | if total_num_steps_completed % self._step_interval == 0: 46 | torch.cuda.empty_cache() 47 | 48 | def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: 49 | total_num_steps_completed = unit.eval_progress.num_steps_completed 50 | if state.entry_point == EntryPoint.FIT: 51 | # if fitting, unit should also subclass TrainUnit 52 | unit_as_train_unit = cast(TTrainUnit, unit) 53 | # if fitting, include the num train steps completed in the total steps completed 54 | total_num_steps_completed += ( 55 | unit_as_train_unit.train_progress.num_steps_completed 56 | ) 57 | 58 | if total_num_steps_completed % self._step_interval == 0: 59 | torch.cuda.empty_cache() 60 | 61 | def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: 62 | if unit.predict_progress.num_steps_completed % self._step_interval == 0: 63 | torch.cuda.empty_cache() 64 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/learning_rate_monitor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Dict, List, Union 10 | 11 | from torchtnt.framework.callback import Callback 12 | from torchtnt.framework.state import State 13 | from torchtnt.framework.unit import TTrainUnit 14 | from torchtnt.utils.loggers.logger import MetricLogger 15 | from torchtnt.utils.optimizer import extract_lr_from_optimizer 16 | 17 | 18 | def _write_stats( 19 | writers: List[MetricLogger], 20 | lr_stats: Dict[str, float], 21 | step: int, 22 | ) -> None: 23 | 24 | for writer in writers: 25 | writer.log_dict(lr_stats, step) 26 | 27 | 28 | class LearningRateMonitor(Callback): 29 | """ 30 | A callback which logs learning rate of tracked optimizers and learning rate schedulers. 31 | Logs learning rate for each parameter group associated with an optimizer. 32 | 33 | Args: 34 | loggers: Either a :class:`torchtnt.loggers.logger.MetricLogger` or 35 | list of :class:`torchtnt.loggers.logger.MetricLogger` 36 | """ 37 | 38 | def __init__( 39 | self, 40 | loggers: Union[MetricLogger, List[MetricLogger]], 41 | *, 42 | logging_interval: str = "epoch", 43 | ) -> None: 44 | if not isinstance(loggers, list): 45 | loggers = [loggers] 46 | 47 | expected_intervals = ("epoch", "step") 48 | if logging_interval not in expected_intervals: 49 | raise ValueError( 50 | f"Invalid value '{logging_interval}' for argument logging_interval. Accepted values are {expected_intervals}." 51 | ) 52 | 53 | self._loggers: List[MetricLogger] = loggers 54 | self.logging_interval = logging_interval 55 | 56 | def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None: 57 | if not self._loggers: 58 | return 59 | 60 | if self.logging_interval != "epoch": 61 | return 62 | 63 | lr_stats = self._extract_lr(unit) 64 | 65 | step = unit.train_progress.num_steps_completed 66 | _write_stats(self._loggers, lr_stats, step) 67 | 68 | def on_train_step_start(self, state: State, unit: TTrainUnit) -> None: 69 | if not self._loggers: 70 | return 71 | 72 | if self.logging_interval != "step": 73 | return 74 | 75 | lr_stats = self._extract_lr(unit) 76 | 77 | step = unit.train_progress.num_steps_completed 78 | _write_stats(self._loggers, lr_stats, step) 79 | 80 | @classmethod 81 | def _extract_lr(cls, unit: TTrainUnit) -> Dict[str, float]: 82 | """ 83 | Extracts learning rates from optimizers and LR schedulers and returns them as a dictionary. 84 | """ 85 | lr_stats: Dict[str, float] = {} 86 | 87 | # go through tracked optimizers 88 | optimizers = unit.tracked_optimizers() 89 | for name, optim in optimizers.items(): 90 | lr_stats.update(extract_lr_from_optimizer(optim, f"optimizers/{name}")) 91 | 92 | # go through tracked LR schedulers 93 | lr_schedulers = unit.tracked_lr_schedulers() 94 | for name, lr_scheduler in lr_schedulers.items(): 95 | lr_stats.update( 96 | extract_lr_from_optimizer( 97 | lr_scheduler.optimizer, f"lr_schedulers/{name}" 98 | ) 99 | ) 100 | 101 | return lr_stats 102 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/memory_snapshot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import logging 10 | 11 | from torchtnt.framework.callback import Callback 12 | from torchtnt.framework.state import State 13 | from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit 14 | from torchtnt.utils.memory_snapshot_profiler import MemorySnapshotProfilerBase 15 | 16 | logger: logging.Logger = logging.getLogger(__name__) 17 | 18 | 19 | class MemorySnapshot(Callback): 20 | """ 21 | A callback for memory snapshot collection during training, saving pickle files to the user-specified directory. 22 | Uses `Memory Snapshots `. 23 | 24 | Args: 25 | memory_snapshot_profiler: Instance of MemorySnapshotProfilerBase, controls when and where to save the memory snapshots. 26 | 27 | Note: It is recommended to instantiate this callback **as early as possible** in your training/eval/prediction script, 28 | ideally before model initialization, to make sure all memory allocation is captured. 29 | 30 | """ 31 | 32 | def __init__( 33 | self, 34 | *, 35 | memory_snapshot_profiler: MemorySnapshotProfilerBase, 36 | ) -> None: 37 | self.memory_snapshot_profiler = memory_snapshot_profiler 38 | 39 | def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: 40 | self.memory_snapshot_profiler.step() 41 | 42 | def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: 43 | self.memory_snapshot_profiler.step() 44 | 45 | def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: 46 | self.memory_snapshot_profiler.step() 47 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/module_summary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple 10 | 11 | from torchtnt.framework.callback import Callback 12 | from torchtnt.framework.state import EntryPoint, State 13 | from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit 14 | from torchtnt.utils.module_summary import ( 15 | get_module_summary, 16 | get_summary_table, 17 | ModuleSummary as ModuleSummaryObj, 18 | prune_module_summary, 19 | ) 20 | from torchtnt.utils.rank_zero_log import rank_zero_info 21 | 22 | 23 | def _log_module_summary_tables(module_summaries: List[ModuleSummaryObj]) -> None: 24 | for ms in module_summaries: 25 | rank_zero_info("\n" + get_summary_table(ms)) 26 | 27 | 28 | class ModuleSummary(Callback): 29 | """ 30 | A callback which generates and logs a summary of the modules. 31 | 32 | Args: 33 | max_depth: The maximum depth of module summaries to keep. 34 | process_fn: Function to print the module summaries. Default is to log all module summary tables. 35 | module_inputs: A mapping from module name to (args, kwargs) for that module. Useful when wanting FLOPS, activation sizes, etc. 36 | 37 | Raises: 38 | RuntimeError: 39 | If torcheval is not installed. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | max_depth: Optional[int] = None, 45 | process_fn: Callable[ 46 | [List[ModuleSummaryObj]], None 47 | ] = _log_module_summary_tables, 48 | # pyre-fixme 49 | module_inputs: Optional[ 50 | MutableMapping[str, Tuple[Tuple[Any, ...], Dict[str, Any]]] 51 | ] = None, 52 | ) -> None: 53 | self._max_depth = max_depth 54 | self._process_fn = process_fn 55 | self._module_inputs = module_inputs 56 | 57 | def on_train_start(self, state: State, unit: TTrainUnit) -> None: 58 | self._get_and_process_summaries(unit) 59 | 60 | def on_eval_start(self, state: State, unit: TEvalUnit) -> None: 61 | if state.entry_point != EntryPoint.EVALUATE: 62 | return 63 | self._get_and_process_summaries(unit) 64 | 65 | def on_predict_start(self, state: State, unit: TPredictUnit) -> None: 66 | self._get_and_process_summaries(unit) 67 | 68 | def _retrieve_module_summaries(self, unit: AppStateMixin) -> List[ModuleSummaryObj]: 69 | module_summaries = [] 70 | for module_name, module in unit.tracked_modules().items(): 71 | args, kwargs = (), {} 72 | if self._module_inputs and module_name in self._module_inputs: 73 | args, kwargs = self._module_inputs[module_name] 74 | module_summary = get_module_summary( 75 | module, module_args=args, module_kwargs=kwargs 76 | ) 77 | module_summary._module_name = module_name 78 | if self._max_depth: 79 | prune_module_summary(module_summary, max_depth=self._max_depth) 80 | module_summaries.append(module_summary) 81 | return module_summaries 82 | 83 | def _get_and_process_summaries(self, unit: AppStateMixin) -> None: 84 | module_summaries = self._retrieve_module_summaries(unit) 85 | self._process_fn(module_summaries) 86 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/periodic_distributed_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import logging 10 | 11 | from torchtnt.framework.callback import Callback 12 | from torchtnt.framework.state import State 13 | from torchtnt.framework.unit import TEvalUnit, TPredictUnit 14 | from torchtnt.utils.distributed import barrier, get_global_rank 15 | 16 | logger: logging.Logger = logging.getLogger(__name__) 17 | 18 | 19 | class PeriodicDistributedSync(Callback): 20 | """ 21 | A callback to sync all distributed workers at a given frequency. 22 | Helpful when using distributed without DDP/FSDP but would still like to ensure that the workers are in sync with each other, for example large predict jobs. 23 | Both predict and evaluate are supported. 24 | 25 | Args: 26 | sync_every_n_steps: the frequency at which to sync the workers. 27 | """ 28 | 29 | def __init__(self, sync_every_n_steps: int = 1000) -> None: 30 | self.sync_every_n_steps = sync_every_n_steps 31 | self._global_rank: int = get_global_rank() 32 | 33 | def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: 34 | num_steps = unit.predict_progress.num_steps_completed 35 | if num_steps % self.sync_every_n_steps == 0: 36 | logger.info(f"Barrier at step {num_steps} on rank {self._global_rank}") 37 | barrier() 38 | 39 | def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: 40 | num_steps = unit.eval_progress.num_steps_completed 41 | if num_steps % self.sync_every_n_steps == 0: 42 | logger.info(f"Barrier at step {num_steps} on rank {self._global_rank}") 43 | barrier() 44 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/pytorch_profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch 10 | 11 | from torchtnt.framework.callback import Callback 12 | from torchtnt.framework.state import EntryPoint, State 13 | from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit 14 | 15 | 16 | class PyTorchProfiler(Callback): 17 | """ 18 | A callback which profiles user code using `PyTorch Profiler `_. 19 | 20 | Args: 21 | profiler: a torch.profiler.profile context manager which will be used 22 | 23 | """ 24 | 25 | def __init__( 26 | self, 27 | profiler: torch.profiler.profile, 28 | ) -> None: 29 | self.profiler: torch.profiler.profile = profiler 30 | 31 | def on_train_start(self, state: State, unit: TTrainUnit) -> None: 32 | self.profiler.start() 33 | 34 | def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: 35 | self.profiler.step() 36 | 37 | def on_train_end(self, state: State, unit: TTrainUnit) -> None: 38 | self.profiler.stop() 39 | 40 | def on_eval_start(self, state: State, unit: TEvalUnit) -> None: 41 | # if in fit do nothing since the profiler was already started in on_train_start 42 | if state.entry_point == EntryPoint.EVALUATE: 43 | self.profiler.start() 44 | 45 | def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: 46 | self.profiler.step() 47 | 48 | def on_eval_end(self, state: State, unit: TEvalUnit) -> None: 49 | # if in fit do nothing since the profiler will be stopped in on_train_end 50 | if state.entry_point == EntryPoint.EVALUATE: 51 | self.profiler.stop() 52 | 53 | def on_predict_start(self, state: State, unit: TPredictUnit) -> None: 54 | self.profiler.start() 55 | 56 | def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: 57 | self.profiler.step() 58 | 59 | def on_predict_end(self, state: State, unit: TPredictUnit) -> None: 60 | self.profiler.stop() 61 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/tensorboard_parameter_monitor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Dict, Optional, Union 10 | 11 | import torch 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torchtnt.framework.callback import Callback 15 | from torchtnt.framework.state import State 16 | from torchtnt.framework.unit import TTrainUnit 17 | from torchtnt.utils.loggers.tensorboard import TensorBoardLogger 18 | 19 | 20 | def _write_histogram_parameters( 21 | summary_writer: SummaryWriter, modules: Dict[str, torch.nn.Module], step: int 22 | ) -> None: 23 | for module_name, module in modules.items(): 24 | for param_name, parameter in module.named_parameters(): 25 | summary_writer.add_histogram( 26 | f"Parameters/{module_name}/{param_name}", 27 | parameter, 28 | global_step=step, 29 | ) 30 | 31 | 32 | class TensorBoardParameterMonitor(Callback): 33 | """ 34 | A callback which logs module parameters as histograms to TensorBoard. 35 | https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter 36 | 37 | Args: 38 | logger: Either a :class:`torchtnt.loggers.tensorboard.TensorBoardLogger` 39 | or a :class:`torch.utils.tensorboard.SummaryWriter` instance. 40 | """ 41 | 42 | def __init__(self, logger: Union[TensorBoardLogger, SummaryWriter]) -> None: 43 | if isinstance(logger, TensorBoardLogger): 44 | logger = logger.writer 45 | self._writer: Optional[SummaryWriter] = logger 46 | 47 | def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: 48 | writer = self._writer 49 | if not writer: 50 | return 51 | 52 | step = unit.train_progress.num_steps_completed 53 | modules = unit.tracked_modules() 54 | _write_histogram_parameters(writer, modules, step) 55 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/tensorfloat32.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import logging 10 | from typing import Optional 11 | 12 | import torch 13 | from torchtnt.framework.callback import Callback 14 | from torchtnt.framework.state import EntryPoint, State 15 | from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit 16 | from torchtnt.utils.rank_zero_log import rank_zero_info 17 | 18 | logger: logging.Logger = logging.getLogger(__name__) 19 | 20 | 21 | class EnableTensorFloat32(Callback): 22 | """ 23 | A callback that enables TensorFloat32 operations on CUDA. 24 | 25 | Args: 26 | float32_matmul_precision: precision to use for float32 matmul operations. 27 | See `torch.set_float32_matmul_precision` for details. 28 | """ 29 | 30 | def __init__(self, float32_matmul_precision: str = "high") -> None: 31 | self.float32_matmul_precision = float32_matmul_precision 32 | 33 | self.original_float32_matmul_precision: Optional[str] = None 34 | self.original_cuda_matmul: Optional[bool] = None 35 | self.original_cudnn: Optional[bool] = None 36 | 37 | def _enable(self) -> None: 38 | rank_zero_info("Enabling TensorFloat32 operations on CUDA", logger=logger) 39 | assert self.original_float32_matmul_precision is None 40 | assert self.original_cuda_matmul is None 41 | assert self.original_cudnn is None 42 | 43 | self.original_float32_matmul_precision = torch.get_float32_matmul_precision() 44 | self.original_cuda_matmul = torch.backends.cuda.matmul.allow_tf32 45 | self.original_cudnn = torch.backends.cudnn.allow_tf32 46 | 47 | torch.set_float32_matmul_precision(self.float32_matmul_precision) 48 | torch.backends.cuda.matmul.allow_tf32 = True 49 | torch.backends.cudnn.allow_tf32 = True 50 | 51 | def _reset(self) -> None: 52 | rank_zero_info( 53 | "Restoring original TensorFloat32 permissions on CUDA", logger=logger 54 | ) 55 | if self.original_float32_matmul_precision is not None: 56 | torch.set_float32_matmul_precision(self.original_float32_matmul_precision) 57 | self.original_float32_matmul_precision = None 58 | 59 | if self.original_cuda_matmul is not None: 60 | torch.backends.cuda.matmul.allow_tf32 = self.original_cuda_matmul 61 | self.original_cuda_matmul = None 62 | 63 | if self.original_cudnn is not None: 64 | torch.backends.cudnn.allow_tf32 = self.original_cudnn 65 | self.original_cudnn = None 66 | 67 | def on_train_start(self, state: State, unit: TTrainUnit) -> None: 68 | self._enable() 69 | 70 | def on_train_end(self, state: State, unit: TTrainUnit) -> None: 71 | self._reset() 72 | 73 | def on_eval_start(self, state: State, unit: TEvalUnit) -> None: 74 | if state.entry_point == EntryPoint.FIT: 75 | return # if fitting, this is already handled in on_train_start 76 | self._enable() 77 | 78 | def on_eval_end(self, state: State, unit: TEvalUnit) -> None: 79 | if state.entry_point == EntryPoint.FIT: 80 | return # if fitting, this is already handled in on_train_end 81 | self._reset() 82 | 83 | def on_predict_start(self, state: State, unit: TPredictUnit) -> None: 84 | self._enable() 85 | 86 | def on_predict_end(self, state: State, unit: TPredictUnit) -> None: 87 | self._reset() 88 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/torch_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import logging 10 | 11 | try: 12 | from torch._inductor.async_compile import shutdown_compile_workers 13 | except ImportError: 14 | 15 | def shutdown_compile_workers() -> None: 16 | logging.warning( 17 | "shutdown_compile_workers is not available in your version of PyTorch. \ 18 | Please use nightly version to enable this feature." 19 | ) 20 | 21 | 22 | from torchtnt.framework.callback import Callback 23 | from torchtnt.framework.state import State 24 | from torchtnt.framework.unit import TTrainUnit 25 | 26 | logger: logging.Logger = logging.getLogger(__name__) 27 | 28 | 29 | class TorchCompile(Callback): 30 | """ 31 | A callback for using torch.compile. 32 | 33 | Args: 34 | step_shutdown_compile_workers: step after which compiler workers 35 | will be shut down. 36 | """ 37 | 38 | def __init__(self, step_shutdown_compile_workers: int) -> None: 39 | self._step_shutdown_compile_workers = step_shutdown_compile_workers 40 | 41 | def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: 42 | total_num_steps_completed = unit.train_progress.num_steps_completed 43 | if total_num_steps_completed == self._step_shutdown_compile_workers: 44 | logger.info( 45 | f"Shutdown compile workers after step {total_num_steps_completed}" 46 | ) 47 | shutdown_compile_workers() 48 | -------------------------------------------------------------------------------- /torchtnt/framework/callbacks/train_progress_monitor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import List, Union 10 | 11 | from torchtnt.framework.callback import Callback 12 | from torchtnt.framework.state import State 13 | from torchtnt.framework.unit import TTrainUnit 14 | from torchtnt.utils.loggers.logger import MetricLogger 15 | from torchtnt.utils.progress import Progress 16 | 17 | 18 | def _write_training_progress( 19 | train_progress: Progress, loggers: List[MetricLogger] 20 | ) -> None: 21 | if not loggers: 22 | return 23 | 24 | step = train_progress.num_steps_completed 25 | epoch = train_progress.num_epochs_completed 26 | for logger in loggers: 27 | logger.log("Training steps completed vs epochs", step, epoch) 28 | 29 | 30 | class TrainProgressMonitor(Callback): 31 | """ 32 | A callback which logs training progress in terms of steps vs epochs. This is helpful to visualize when the end of data occurs across epochs, especially for iterable datasets. 33 | This callback writes to the logger at the beginning of training, and at the end of every epoch. 34 | 35 | Args: 36 | loggers: Either a :class:`torchtnt.loggers.logger.MetricLogger` or 37 | list of :class:`torchtnt.loggers.logger.MetricLogger` 38 | """ 39 | 40 | def __init__( 41 | self, 42 | loggers: Union[MetricLogger, List[MetricLogger]], 43 | ) -> None: 44 | if not isinstance(loggers, list): 45 | loggers = [loggers] 46 | self._loggers: List[MetricLogger] = loggers 47 | 48 | def on_train_start(self, state: State, unit: TTrainUnit) -> None: 49 | _write_training_progress(unit.train_progress, self._loggers) 50 | 51 | def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: 52 | _write_training_progress(unit.train_progress, self._loggers) 53 | -------------------------------------------------------------------------------- /torchtnt/framework/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import logging 10 | from contextlib import contextmanager, nullcontext 11 | from typing import ContextManager, Generator, Tuple, TypeVar 12 | 13 | from torch.profiler import record_function 14 | from torchtnt.framework.state import State 15 | 16 | _logger: logging.Logger = logging.getLogger(__name__) 17 | T = TypeVar("T") 18 | 19 | 20 | @contextmanager 21 | def get_timing_context( 22 | state: State, 23 | event_name: str, 24 | # pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter. 25 | ) -> Generator[Tuple[ContextManager, ContextManager], None, None]: 26 | """ 27 | Returns a context manager that records an event to a :class:`~torchtnt.utils.timer.Timer` and to PyTorch Profiler. 28 | 29 | Args: 30 | state: an instance of :class:`~torchtnt.framework.state.State` 31 | event_name: string identifier to use for timing 32 | """ 33 | timer_context = ( 34 | state.timer.time(event_name) if state.timer is not None else nullcontext() 35 | ) 36 | profiler_context = record_function(event_name) 37 | with timer_context, profiler_context: 38 | yield (timer_context, profiler_context) 39 | -------------------------------------------------------------------------------- /torchtnt/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/tnt/cb31137b0928acc24f4d341faa8c7d88b8ed4696/torchtnt/py.typed -------------------------------------------------------------------------------- /torchtnt/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from .data_prefetcher import CudaDataPrefetcher 10 | from .iterators import ( 11 | AllDatasetBatchesIterator, 12 | DataIterationStrategy, 13 | DataIterationStrategyRegistry, 14 | InOrderIterator, 15 | MultiIterator, 16 | RandomizedBatchSamplerIterator, 17 | RoundRobinIterator, 18 | ) 19 | from .multi_dataloader import MultiDataLoader 20 | from .profile_dataloader import profile_dataloader 21 | from .synthetic_data import AbstractRandomDataset 22 | 23 | __all__ = [ 24 | "AbstractRandomDataset", 25 | "AllDatasetBatchesIterator", 26 | "CudaDataPrefetcher", 27 | "DataIterationStrategy", 28 | "DataIterationStrategyRegistry", 29 | "InOrderIterator", 30 | "MultiDataLoader", 31 | "MultiIterator", 32 | "RandomizedBatchSamplerIterator", 33 | "RoundRobinIterator", 34 | "profile_dataloader", 35 | ] 36 | -------------------------------------------------------------------------------- /torchtnt/utils/data/profile_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import logging 11 | from typing import Iterable, Optional 12 | 13 | import torch 14 | from torch.profiler import record_function 15 | from torchtnt.utils.device import copy_data_to_device 16 | from torchtnt.utils.timer import Timer, TimerProtocol 17 | 18 | _log: logging.Logger = logging.getLogger(__name__) 19 | 20 | 21 | def profile_dataloader( 22 | dataloader: Iterable[object], 23 | profiler: torch.profiler.profile, 24 | *, 25 | max_steps: Optional[int] = None, 26 | timer: Optional[TimerProtocol] = None, 27 | device: Optional[torch.device] = None, 28 | ) -> TimerProtocol: 29 | """ 30 | A helper function that profiles the dataloader iterations. 31 | 32 | Args: 33 | dataloader: dataloader to be profiled. 34 | profiler: PyTorch profiler to be used. The profiler is only stepped, so it is the responsibility of the caller to start/stop the profiler. 35 | max_steps (optional): maximum number of steps to run for. If not set, the dataloader will run until its iterator is exhausted. 36 | timer (optional): timer to be used to track duration. 37 | device (optional): device to copy the data to. If set, this function will profile copying data to device. 38 | """ 39 | timer = timer if timer is not None else Timer(cuda_sync=False) 40 | with timer.time("iter(dataloader)"), record_function("iter(dataloader)"): 41 | data_iter = iter(dataloader) 42 | 43 | # If max_steps is not set, run until the dataloader is exhausted 44 | steps_completed = 0 45 | 46 | while max_steps is None or (steps_completed < max_steps): 47 | try: 48 | with timer.time("next(iter)"), record_function("next(iter)"): 49 | data = next(data_iter) 50 | 51 | if device is not None: 52 | with timer.time("copy_data_to_device"), record_function( 53 | "copy_data_to_device" 54 | ): 55 | data = copy_data_to_device(data, device) 56 | 57 | steps_completed += 1 58 | if profiler: 59 | profiler.step() 60 | except StopIteration: 61 | break 62 | 63 | return timer 64 | -------------------------------------------------------------------------------- /torchtnt/utils/event.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from dataclasses import dataclass, field 10 | from typing import Dict, Union 11 | 12 | EventMetadataValue = Union[str, int, float, bool, None] 13 | 14 | 15 | @dataclass 16 | class Event: 17 | """ 18 | The class represents the generic event that occurs during a TorchTNT 19 | loop. The event can be any kind of meaningful action. 20 | 21 | Args: 22 | name: event name. 23 | metadata: additional data that is associated with the event. 24 | """ 25 | 26 | name: str 27 | metadata: Dict[str, EventMetadataValue] = field(default_factory=dict) 28 | -------------------------------------------------------------------------------- /torchtnt/utils/event_handlers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import logging 11 | import random 12 | from contextlib import contextmanager 13 | from functools import lru_cache 14 | from typing import Dict, Generator, List, Optional 15 | 16 | import importlib_metadata 17 | from typing_extensions import Protocol, runtime_checkable 18 | 19 | from .event import Event 20 | 21 | logger: logging.Logger = logging.getLogger(__name__) 22 | 23 | 24 | @runtime_checkable 25 | class EventHandler(Protocol): 26 | def handle_event(self, event: Event) -> None: ... 27 | 28 | 29 | _log_handlers: List[EventHandler] = [] 30 | 31 | 32 | @lru_cache(maxsize=None) 33 | def get_event_handlers() -> List[EventHandler]: 34 | global _log_handlers 35 | 36 | # Registered event handlers through entry points 37 | eps = importlib_metadata.entry_points(group="tnt_event_handlers") 38 | for entry in eps: 39 | logger.debug( 40 | f"Attempting to register event handler {entry.name}: {entry.value}" 41 | ) 42 | factory = entry.load() 43 | handler = factory() 44 | 45 | if not isinstance(handler, EventHandler): 46 | raise RuntimeError( 47 | f"The factory function for {({entry.value})} " 48 | "did not return a EventHandler object." 49 | ) 50 | _log_handlers.append(handler) 51 | return _log_handlers 52 | 53 | 54 | def log_event(event: Event) -> None: 55 | """ 56 | Handle an event. 57 | 58 | Args: 59 | event: The event to handle. 60 | """ 61 | 62 | for handler in get_event_handlers(): 63 | handler.handle_event(event) 64 | 65 | 66 | @contextmanager 67 | def log_interval( 68 | name: str, metadata: Optional[Dict[str, str]] = None 69 | ) -> Generator[None, None, None]: 70 | unique_id = _generate_random_int64() 71 | if metadata is None: 72 | metadata = {} 73 | metadata.update({"action": "start", "unique_id": unique_id}) 74 | start_event = Event(name=name, metadata=metadata) 75 | log_event(start_event) 76 | 77 | yield 78 | 79 | metadata["action"] = "end" 80 | end_event = Event(name=name, metadata=metadata) 81 | log_event(end_event) 82 | 83 | 84 | def _generate_random_int64() -> int: 85 | # avoid being influenced by externally set seed 86 | local_random = random.Random() 87 | return local_random.randint(0, 2**63 - 1) 88 | -------------------------------------------------------------------------------- /torchtnt/utils/fsspec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Any 11 | 12 | import fsspec 13 | from fsspec.core import url_to_fs 14 | 15 | 16 | def get_filesystem(path: str, **kwargs: Any) -> fsspec.AbstractFileSystem: 17 | """Returns the appropriate filesystem to use when handling the given path.""" 18 | fs, _ = url_to_fs(path, **kwargs) 19 | return fs 20 | -------------------------------------------------------------------------------- /torchtnt/utils/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from .anomaly_logger import AnomalyLogger, TrackedMetric 10 | from .csv import CSVLogger 11 | from .file import FileLogger 12 | from .in_memory import InMemoryLogger 13 | from .json import JSONLogger 14 | from .logger import MetricLogger, Scalar 15 | from .stdout import StdoutLogger 16 | from .tensorboard import TensorBoardLogger 17 | from .utils import scalar_to_float 18 | 19 | 20 | __all__ = [ 21 | "AnomalyLogger", 22 | "TrackedMetric", 23 | "CSVLogger", 24 | "FileLogger", 25 | "InMemoryLogger", 26 | "JSONLogger", 27 | "MetricLogger", 28 | "Scalar", 29 | "StdoutLogger", 30 | "TensorBoardLogger", 31 | "scalar_to_float", 32 | ] 33 | -------------------------------------------------------------------------------- /torchtnt/utils/loggers/csv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import csv 11 | import logging 12 | from threading import Thread 13 | from typing import Dict, List, Optional 14 | 15 | from fsspec import open as fs_open 16 | from torchtnt.utils.loggers.file import FileLogger 17 | from torchtnt.utils.loggers.logger import MetricLogger 18 | 19 | logger: logging.Logger = logging.getLogger(__name__) 20 | 21 | 22 | class CSVLogger(FileLogger, MetricLogger): 23 | """ 24 | CSV file logger. CSV headers are time, step, and names passed to `log`. 25 | 26 | Args: 27 | path (str): path to write logs to 28 | steps_before_flushing: (int, optional): Number of steps to buffer in logger before flushing 29 | log_all_ranks: (bool, optional): Log all ranks if true, else log only on rank 0. 30 | async_write: (bool, optional): Whether to write asynchronously or not. Defaults to False. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | path: str, 36 | steps_before_flushing: int = 100, 37 | log_all_ranks: bool = False, 38 | async_write: bool = False, 39 | ) -> None: 40 | super().__init__(path, steps_before_flushing, log_all_ranks) 41 | 42 | self._async_write = async_write 43 | self._thread: Optional[Thread] = None 44 | 45 | def flush(self) -> None: 46 | if self._rank == 0 or self._log_all_ranks: 47 | buffer = self._log_buffer 48 | if not buffer: 49 | logger.debug("No logs to write.") 50 | return 51 | 52 | if self._thread: 53 | # ensure previous thread is completed before next write 54 | self._thread.join() 55 | 56 | data_list = list(buffer.values()) 57 | if not self._async_write: 58 | _write_csv(self.path, data_list) 59 | return 60 | 61 | self._thread = Thread(target=_write_csv, args=(self.path, data_list)) 62 | self._thread.start() 63 | 64 | def close(self) -> None: 65 | # toggle off async writing for final flush 66 | self._async_write = False 67 | self.flush() 68 | 69 | 70 | def _write_csv(path: str, data_list: List[Dict[str, float]]) -> None: 71 | with fs_open(path, "w") as f: 72 | w = csv.DictWriter(f, data_list[0].keys()) 73 | w.writeheader() 74 | w.writerows(data_list) 75 | -------------------------------------------------------------------------------- /torchtnt/utils/loggers/file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import atexit 11 | import logging 12 | from abc import ABC, abstractmethod 13 | from collections import OrderedDict 14 | from time import monotonic 15 | from typing import Dict, Mapping 16 | 17 | from torchtnt.utils.distributed import get_global_rank 18 | 19 | from torchtnt.utils.loggers.logger import Scalar 20 | from torchtnt.utils.loggers.utils import scalar_to_float 21 | 22 | 23 | logger: logging.Logger = logging.getLogger(__name__) 24 | 25 | 26 | class FileLogger(ABC): 27 | """ 28 | Abstract file logger. 29 | 30 | Args: 31 | path (str): path to write logs to 32 | steps_before_flushing: (int): Number of steps to store in log before flushing 33 | log_all_ranks: (bool): Log all ranks if true, else log only on rank 0. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | path: str, 39 | steps_before_flushing: int, 40 | log_all_ranks: bool, 41 | ) -> None: 42 | self._path: str = path 43 | self._rank: int = get_global_rank() 44 | self._log_all_ranks = log_all_ranks 45 | self._log_buffer: OrderedDict[int, Dict[str, float]] = OrderedDict() 46 | self._len_before_flush: int = 0 47 | self._steps_before_flushing: int = steps_before_flushing 48 | 49 | if self._rank == 0 or log_all_ranks: 50 | logger.info(f"Logging metrics to path: {path}") 51 | else: 52 | logger.debug( 53 | f"Not logging metrics on this host because host rank is {self._rank} != 0" 54 | ) 55 | atexit.register(self.close) 56 | 57 | @property 58 | def path(self) -> str: 59 | return self._path 60 | 61 | def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: 62 | """Add multiple scalar values. 63 | 64 | Args: 65 | payload (dict): dictionary of tag name and scalar value 66 | step (int): step value to record 67 | """ 68 | 69 | for k, v in payload.items(): 70 | self.log(k, v, step) 71 | 72 | def log(self, name: str, data: Scalar, step: int) -> None: 73 | """Log scalar data to file. 74 | 75 | Args: 76 | name (string): a unique name to group scalars 77 | data (float/int/Tensor): scalar data to log 78 | step (int): step value to record 79 | """ 80 | 81 | if self._rank == 0 or self._log_all_ranks: 82 | self._log_buffer.setdefault(step, {})[name] = scalar_to_float(data) 83 | self._log_buffer[step]["step"] = step 84 | self._log_buffer[step]["time"] = monotonic() 85 | 86 | if ( 87 | len(self._log_buffer) - self._len_before_flush 88 | >= self._steps_before_flushing 89 | ): 90 | self.flush() 91 | self._len_before_flush = len(self._log_buffer) 92 | 93 | @abstractmethod 94 | def flush(self) -> None: ... 95 | 96 | @abstractmethod 97 | def close(self) -> None: ... 98 | -------------------------------------------------------------------------------- /torchtnt/utils/loggers/in_memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import atexit 11 | import logging 12 | from collections import OrderedDict 13 | from time import monotonic 14 | from typing import Dict, Mapping 15 | 16 | from torchtnt.utils.loggers.logger import MetricLogger, Scalar 17 | from torchtnt.utils.loggers.utils import scalar_to_float 18 | 19 | logger: logging.Logger = logging.getLogger(__name__) 20 | 21 | 22 | class InMemoryLogger(MetricLogger): 23 | """ 24 | Simple logger that buffers data in-memory. 25 | 26 | Example: 27 | from torchtnt.utils.loggers import InMemoryLogger 28 | logger = InMemoryLogger() 29 | logger.log("accuracy", 23.56, 10) 30 | logger.close() 31 | """ 32 | 33 | def __init__(self) -> None: 34 | self._log_buffer: OrderedDict[int, Dict[str, float]] = OrderedDict() 35 | logger.info("Logging metrics in-memory") 36 | atexit.register(self.close) 37 | 38 | @property 39 | def log_buffer(self) -> Dict[int, Dict[str, float]]: 40 | """Directly access the buffer.""" 41 | 42 | return self._log_buffer 43 | 44 | def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: 45 | """Add multiple scalar values. 46 | 47 | Args: 48 | payload (dict): dictionary of tag name and scalar value 49 | step (int): step value to record 50 | """ 51 | 52 | for k, v in payload.items(): 53 | self.log(k, v, step) 54 | 55 | def log(self, name: str, data: Scalar, step: int) -> None: 56 | """Log scalar data to the in-memory buffer. 57 | 58 | Args: 59 | name (string): a unique name to group scalars 60 | data (float/int/Tensor): scalar data to log 61 | step (int): step value to record 62 | """ 63 | 64 | self._log_buffer.setdefault(step, {})[name] = scalar_to_float(data) 65 | self._log_buffer[step]["step"] = step 66 | self._log_buffer[step]["time"] = monotonic() 67 | 68 | def flush(self) -> None: 69 | print(self._log_buffer) 70 | 71 | def close(self) -> None: 72 | self._log_buffer.clear() 73 | -------------------------------------------------------------------------------- /torchtnt/utils/loggers/json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import json 11 | import logging 12 | 13 | from fsspec import open as fs_open 14 | from torchtnt.utils.loggers.file import FileLogger 15 | from torchtnt.utils.loggers.logger import MetricLogger 16 | 17 | logger: logging.Logger = logging.getLogger(__name__) 18 | 19 | 20 | class JSONLogger(FileLogger, MetricLogger): 21 | """ 22 | JSON file logger. 23 | 24 | Args: 25 | path (str): path to write logs to 26 | steps_before_flushing: (int, optional): Number of steps to store in log before flushing 27 | log_all_ranks: (bool, optional): Log all ranks if true, else log only on rank 0. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | path: str, 33 | steps_before_flushing: int = 100, 34 | log_all_ranks: bool = False, 35 | ) -> None: 36 | super().__init__(path, steps_before_flushing, log_all_ranks) 37 | 38 | def flush(self) -> None: 39 | if self._rank == 0 or self._log_all_ranks: 40 | data = self._log_buffer 41 | if not data: 42 | logger.debug("No logs to write.") 43 | return 44 | with fs_open(self.path, "w") as f: 45 | json.dump(list(data.values()), f) 46 | 47 | def close(self) -> None: 48 | self.flush() 49 | -------------------------------------------------------------------------------- /torchtnt/utils/loggers/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Mapping, Union 11 | 12 | from numpy import ndarray 13 | from torch import Tensor 14 | from typing_extensions import Protocol 15 | 16 | Scalar = Union[Tensor, ndarray, int, float] 17 | 18 | 19 | class MetricLogger(Protocol): 20 | """ 21 | Abstract metric logger. 22 | """ 23 | 24 | def log( 25 | self, 26 | name: str, 27 | data: Scalar, 28 | step: int, 29 | ) -> None: 30 | """Log scalar data. 31 | 32 | Args: 33 | name (string): tag name used to group scalars 34 | data (float/int/Tensor): scalar data to log 35 | step (int): step value to record 36 | """ 37 | pass 38 | 39 | def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: 40 | """Log multiple scalar values. 41 | 42 | Args: 43 | payload (dict): dictionary of tag name and scalar value 44 | step (int): step value to record 45 | """ 46 | pass 47 | 48 | def close(self) -> None: 49 | """ 50 | Close log resource, flushing if necessary. 51 | Logs should not be written after `close` is called. 52 | """ 53 | pass 54 | -------------------------------------------------------------------------------- /torchtnt/utils/loggers/stdout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import atexit 10 | import logging 11 | import sys 12 | from typing import Mapping, Optional 13 | 14 | from torchtnt.utils.distributed import rank_zero_fn 15 | 16 | from torchtnt.utils.loggers.logger import MetricLogger, Scalar 17 | from torchtnt.utils.loggers.utils import scalar_to_float 18 | 19 | logger: logging.Logger = logging.getLogger(__name__) 20 | 21 | 22 | class StdoutLogger(MetricLogger): 23 | """ 24 | Logger that prints metrics to stdout on rank 0. Each step is logged on a different line. 25 | Metrics belonging to the same step will be printed in the order they were logged on 26 | the same line. Step number is treated as an opaque identifier, successive steps do 27 | not have to be consecutive, but it is generally good practice to make them so. 28 | 29 | Args: 30 | precision (int): The number of digits to print after the decimal point. The default value is 31 | set to 4. The output will be rounded per the usual rounding rules. 32 | 33 | Example: 34 | from torchtnt.utils.loggers import StdoutLogger 35 | 36 | logger = StdoutLogger() 37 | logger.log(step=1, name="accuracy", data=0.982378) 38 | logger.log(step=1, name="loss", data=0.23112) 39 | logger.log_dict(step=2, payload={"accuracy": 0.99123, "loss": 0.18787}) 40 | 41 | This will print the following to stdout in order: 42 | [Step 1] accuracy=0.9824, loss=0.2311 43 | [Step 2] accuracy=0.9912, loss=0.1879 44 | """ 45 | 46 | def __init__(self, precision: int = 4) -> None: 47 | self._current_step: Optional[int] = None 48 | self._precision = precision 49 | logger.info("Logging metrics to stdout") 50 | atexit.register(self.close) 51 | 52 | def _start_new_step_if_needed(self, step: int) -> None: 53 | if self._current_step is None or step != self._current_step: 54 | self._current_step = step 55 | print(f"\n[Step {step}]", end="") 56 | 57 | def _log_metric(self, metric_name: str, metric_value: Scalar) -> None: 58 | metric_value = scalar_to_float(metric_value) 59 | print(f" {metric_name}={metric_value:.{self._precision}f}", end="", flush=True) 60 | 61 | @rank_zero_fn 62 | def log(self, name: str, data: Scalar, step: int) -> None: 63 | self._start_new_step_if_needed(step) 64 | self._log_metric(name, data) 65 | 66 | @rank_zero_fn 67 | def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None: 68 | self._start_new_step_if_needed(step) 69 | for k, v in payload.items(): 70 | self._log_metric(k, v) 71 | 72 | def close(self) -> None: 73 | print("\n") 74 | sys.stdout.flush() 75 | -------------------------------------------------------------------------------- /torchtnt/utils/loggers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from numpy import ndarray 10 | from torch import Tensor 11 | from torchtnt.utils.loggers.logger import Scalar 12 | 13 | 14 | def scalar_to_float(scalar: Scalar) -> float: 15 | if isinstance(scalar, Tensor): 16 | scalar = scalar.squeeze() 17 | numel = scalar.numel() 18 | if numel != 1: 19 | raise ValueError( 20 | f"Scalar tensor must contain a single item, {numel} given." 21 | ) 22 | 23 | return float(scalar.cpu().detach().float().numpy().item()) 24 | elif isinstance(scalar, ndarray): 25 | numel = scalar.size 26 | if numel != 1: 27 | raise ValueError( 28 | f"Scalar ndarray must contain a single item, {numel} given." 29 | ) 30 | return float(scalar.item()) 31 | 32 | return float(scalar) 33 | -------------------------------------------------------------------------------- /torchtnt/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | import torch.optim.lr_scheduler 10 | 11 | # This PR exposes LRScheduler as a public class 12 | # https://github.com/pytorch/pytorch/pull/88503 13 | try: 14 | TLRScheduler = torch.optim.lr_scheduler.LRScheduler 15 | except AttributeError: 16 | TLRScheduler = torch.optim.lr_scheduler._LRScheduler 17 | 18 | __all__ = ["TLRScheduler"] 19 | -------------------------------------------------------------------------------- /torchtnt/utils/misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Optional 11 | 12 | import torch 13 | 14 | _SEC_IN_DAY: int = 60 * 60 * 24 15 | 16 | 17 | def days_to_secs(days: Optional[int]) -> Optional[int]: 18 | """Convert time from days to seconds""" 19 | if days is None: 20 | return None 21 | if days < 0: 22 | raise ValueError(f"days must be non-negative, but was given {days}") 23 | return days * _SEC_IN_DAY 24 | 25 | 26 | def transfer_weights(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: 27 | for src_param, dst_param in zip(src_module.parameters(), dst_module.parameters()): 28 | dst_param.detach().copy_(src_param.to(dst_param.device)) 29 | 30 | 31 | def transfer_batch_norm_stats( 32 | src_module: torch.nn.Module, dst_module: torch.nn.Module 33 | ) -> None: 34 | """ 35 | Transfer batch norm statistics between two same models 36 | """ 37 | src_batch_norm_modules = [] 38 | dst_batch_norm_modules = [] 39 | 40 | # fetch all batch norm modules for both 41 | for module in src_module.modules(): 42 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 43 | src_batch_norm_modules.append(module) 44 | 45 | for module in dst_module.modules(): 46 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 47 | dst_batch_norm_modules.append(module) 48 | 49 | if len(src_batch_norm_modules) != len(dst_batch_norm_modules): 50 | raise ValueError( 51 | "Modules must have same number of batch norm layers" 52 | f"Src module has {len(src_batch_norm_modules)}" 53 | f"Dst module has {len(dst_batch_norm_modules)}" 54 | ) 55 | 56 | # copy batch norm statistics 57 | for src_batch_norm_module, dst_batch_norm_module in zip( 58 | src_batch_norm_modules, dst_batch_norm_modules 59 | ): 60 | dst_batch_norm_module.running_mean.detach().copy_( 61 | src_batch_norm_module.running_mean.to( 62 | dst_batch_norm_module.running_mean.device 63 | ) 64 | ) 65 | dst_batch_norm_module.running_var.detach().copy_( 66 | src_batch_norm_module.running_var.to( 67 | dst_batch_norm_module.running_var.device 68 | ) 69 | ) 70 | dst_batch_norm_module.num_batches_tracked.detach().copy_( 71 | src_batch_norm_module.num_batches_tracked.to( 72 | dst_batch_norm_module.num_batches_tracked.device 73 | ) 74 | ) 75 | dst_batch_norm_module.momentum = src_batch_norm_module.momentum 76 | -------------------------------------------------------------------------------- /torchtnt/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Dict 11 | 12 | import torch 13 | 14 | 15 | def init_optim_state(optimizer: torch.optim.Optimizer) -> None: 16 | """ 17 | Initialize optimizer states by calling step() with zero grads. This is necessary because some optimizers like AdamW 18 | initialize some states in their state_dicts lazily, only after calling step() for the first time. Certain checkpointing 19 | solutions may rely on in-place loading, re-using existing tensor allocated memory from the optimizer state dict. This 20 | optimization does not work with optimizers that lazily initialize their states, as certain states will not be restored. 21 | Calling this function ensures that these states are available in the state dict for in place loading. 22 | 23 | Args: 24 | optimizer: A PyTorch optimizer. 25 | """ 26 | if optimizer.state: 27 | # The optimizer state is initialized. 28 | return 29 | 30 | for param_group in optimizer.param_groups: 31 | for param in param_group["params"]: 32 | if param.grad is not None: 33 | raise RuntimeError( 34 | "Initializing the optimizer states requires that no existing gradients for parameters are found." 35 | ) 36 | if param.requires_grad: 37 | param.grad = torch.zeros_like(param) 38 | optimizer.step(closure=None) 39 | optimizer.zero_grad(set_to_none=True) 40 | 41 | 42 | def extract_lr_from_optimizer( 43 | optim: torch.optim.Optimizer, prefix: str 44 | ) -> Dict[str, float]: 45 | """ 46 | Retrieves the learning rate values from an optimizer and returns them as a dictionary. 47 | """ 48 | lr_stats = {} 49 | seen_pg_keys = {} 50 | for pg in optim.param_groups: 51 | lr = pg["lr"] 52 | name = _get_deduped_name(seen_pg_keys, pg.get("name", "pg")) 53 | key = f"{prefix}/{name}" 54 | assert key not in lr_stats 55 | lr_stats[key] = lr 56 | return lr_stats 57 | 58 | 59 | def _get_deduped_name(seen_keys: Dict[str, int], name: str) -> str: 60 | if name not in seen_keys: 61 | seen_keys[name] = 0 62 | 63 | seen_keys[name] += 1 64 | return name + f":{seen_keys[name]-1}" 65 | -------------------------------------------------------------------------------- /torchtnt/utils/precision.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Mapping, Optional 11 | 12 | import torch 13 | from torch.amp.grad_scaler import GradScaler 14 | 15 | _DTYPE_STRING_TO_DTYPE_MAPPING: Mapping[str, Optional[torch.dtype]] = { 16 | "fp16": torch.float16, 17 | "bf16": torch.bfloat16, 18 | "fp32": None, 19 | } 20 | 21 | 22 | def convert_precision_str_to_dtype(precision: str) -> Optional[torch.dtype]: 23 | """ 24 | Converts precision as a string to a torch.dtype 25 | 26 | Args: 27 | precision: string containing the precision 28 | 29 | Raises: 30 | ValueError if an invalid precision string is passed. 31 | 32 | """ 33 | if precision not in _DTYPE_STRING_TO_DTYPE_MAPPING: 34 | raise ValueError( 35 | f"Precision {precision} not supported. Please use one of {list(_DTYPE_STRING_TO_DTYPE_MAPPING.keys())}" 36 | ) 37 | return _DTYPE_STRING_TO_DTYPE_MAPPING[precision] 38 | 39 | 40 | def get_grad_scaler_from_precision( 41 | precision: torch.dtype, *, is_fsdp1_module: Optional[bool] = False 42 | ) -> Optional[GradScaler]: 43 | """ 44 | Returns the correct grad scaler to use based on the precision and whether 45 | or not the model is FSDP. FSDP required it's own sharded grad scaler. FSDP2 uses 46 | the original grad scaler (amp.grad_scaler). See https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md 47 | 48 | Args: 49 | precision: the precision being used 50 | is_fsdp1_module: whether the grad scaler is for an FSDP1 module 51 | 52 | Returns: 53 | The appropriate grad scaler to use, ``None`` if no grad scaler should be used. 54 | """ 55 | 56 | if precision == torch.float16: 57 | if is_fsdp1_module: 58 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 59 | 60 | return ShardedGradScaler() 61 | else: 62 | return GradScaler("cuda") 63 | return None 64 | -------------------------------------------------------------------------------- /torchtnt/utils/rank_zero_log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import logging 11 | from typing import Any, Optional 12 | 13 | from packaging.version import Version 14 | 15 | from torchtnt.utils.distributed import get_global_rank 16 | from torchtnt.utils.version import get_python_version 17 | 18 | _LOGGER: logging.Logger = logging.getLogger(__name__) 19 | 20 | 21 | def rank_zero_print(*args: Any, **kwargs: Any) -> None: 22 | """Call print function only from rank 0.""" 23 | if get_global_rank() != 0: 24 | return 25 | print(*args, **kwargs) 26 | 27 | 28 | def rank_zero_debug( 29 | *args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any 30 | ) -> None: 31 | """Log debug message only from rank 0.""" 32 | if get_global_rank() != 0: 33 | return 34 | logger = logger or _LOGGER 35 | if _supports_stacklevel(): 36 | kwargs["stacklevel"] = 2 37 | logger.debug(*args, **kwargs) 38 | 39 | 40 | def rank_zero_info( 41 | *args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any 42 | ) -> None: 43 | """Log info message only from rank 0.""" 44 | if get_global_rank() != 0: 45 | return 46 | logger = logger or _LOGGER 47 | if _supports_stacklevel(): 48 | kwargs["stacklevel"] = 2 49 | logger.info(*args, **kwargs) 50 | 51 | 52 | def rank_zero_warn( 53 | *args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any 54 | ) -> None: 55 | """Log warn message only from rank 0.""" 56 | if get_global_rank() != 0: 57 | return 58 | logger = logger or _LOGGER 59 | if _supports_stacklevel(): 60 | kwargs["stacklevel"] = 2 61 | logger.warning(*args, **kwargs) 62 | 63 | 64 | def rank_zero_error( 65 | *args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any 66 | ) -> None: 67 | """Log error message only from rank 0.""" 68 | if get_global_rank() != 0: 69 | return 70 | logger = logger or _LOGGER 71 | if _supports_stacklevel(): 72 | kwargs["stacklevel"] = 2 73 | logger.error(*args, **kwargs) 74 | 75 | 76 | def rank_zero_critical( 77 | *args: Any, logger: Optional[logging.Logger] = None, **kwargs: Any 78 | ) -> None: 79 | """Log critical message only from rank 0.""" 80 | if get_global_rank() != 0: 81 | return 82 | logger = logger or _LOGGER 83 | if _supports_stacklevel(): 84 | kwargs["stacklevel"] = 2 85 | logger.critical(*args, **kwargs) 86 | 87 | 88 | def _supports_stacklevel() -> bool: 89 | return get_python_version() >= Version("3.8.0") 90 | -------------------------------------------------------------------------------- /torchtnt/utils/stateful.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from typing import Any, Dict, Union 10 | 11 | import torch 12 | from torchtnt.utils.lr_scheduler import TLRScheduler 13 | from torchtnt.utils.prepare_module import FSDP2OptimizerWrapper, FSDPOptimizerWrapper 14 | from torchtnt.utils.progress import Progress 15 | 16 | from typing_extensions import Protocol, runtime_checkable 17 | 18 | 19 | @runtime_checkable 20 | class Stateful(Protocol): 21 | """Defines the interface for checkpoint saving and loading.""" 22 | 23 | def state_dict(self) -> Dict[str, Any]: ... 24 | 25 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... 26 | 27 | 28 | StatefulDict = Dict[str, Stateful] 29 | ModuleDict = Dict[str, torch.nn.Module] 30 | OptimizerAndLRSchedulerDict = Dict[ 31 | str, 32 | Union[ 33 | TLRScheduler, torch.optim.Optimizer, FSDPOptimizerWrapper, FSDP2OptimizerWrapper 34 | ], 35 | ] 36 | ProgressDict = Dict[str, Progress] 37 | 38 | 39 | class MultiStateful: 40 | """ 41 | Wrapper for multiple stateful objects. Necessary because we might have multiple nn.Modules or multiple optimizers, 42 | but save/load_checkpoint APIs may only accept one stateful object. 43 | 44 | Stores state_dict as a dict of state_dicts. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | stateful_objs: Union[ 50 | StatefulDict, ModuleDict, OptimizerAndLRSchedulerDict, ProgressDict 51 | ], 52 | ) -> None: 53 | self.stateful_objs = stateful_objs 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {k: v.state_dict() for k, v in self.stateful_objs.items()} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | for k in state_dict: 60 | self.stateful_objs[k].load_state_dict(state_dict[k]) 61 | 62 | 63 | @runtime_checkable 64 | class MetricStateful(Protocol): 65 | """ 66 | Defines the interfaces for metric objects that can be saved and loaded from checkpoints. 67 | This conforms to the API exposed by major metric libraries like torcheval. 68 | """ 69 | 70 | def update(self, *_: Any, **__: Any) -> None: ... 71 | 72 | # pyre-ignore[3]: Metric computation may return any type depending on the implementation 73 | def compute(self) -> Any: ... 74 | 75 | def state_dict(self) -> Dict[str, Any]: ... 76 | 77 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... 78 | 79 | 80 | class DictStateful(Stateful, Dict[str, Any]): 81 | """A dictionary that implements the stateful interface that can be saved and loaded from checkpoints.""" 82 | 83 | def state_dict(self) -> Dict[str, Any]: 84 | return self 85 | 86 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 87 | self.clear() 88 | self.update(state_dict) 89 | -------------------------------------------------------------------------------- /torchtnt/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import ctypes 11 | import sys 12 | import unittest 13 | import uuid 14 | from contextlib import contextmanager 15 | from functools import wraps 16 | from io import StringIO 17 | from typing import Callable, Generator, Optional, TextIO, Tuple, TypeVar 18 | 19 | import torch 20 | import torch.distributed.launcher as pet 21 | from pyre_extensions import ParameterSpecification 22 | 23 | 24 | TParams = ParameterSpecification("TParams") 25 | TReturn = TypeVar("TReturn") 26 | 27 | 28 | def get_pet_launch_config(nproc: int) -> pet.LaunchConfig: 29 | """ 30 | Initialize pet.LaunchConfig for single-node, multi-rank functions. 31 | 32 | Args: 33 | nproc: The number of processes to launch. 34 | 35 | Returns: 36 | An instance of pet.LaunchConfig for single-node, multi-rank functions. 37 | 38 | Example: 39 | >>> from torch.distributed import launcher 40 | >>> launch_config = get_pet_launch_config(nproc=8) 41 | >>> launcher.elastic_launch(config=launch_config, entrypoint=train)() 42 | """ 43 | return pet.LaunchConfig( 44 | min_nodes=1, 45 | max_nodes=1, 46 | nproc_per_node=nproc, 47 | run_id=str(uuid.uuid4()), 48 | rdzv_backend="c10d", 49 | rdzv_endpoint="localhost:0", 50 | max_restarts=0, 51 | monitor_interval=1, 52 | ) 53 | 54 | 55 | def is_asan() -> bool: 56 | """Determines if the Python interpreter is running with ASAN""" 57 | return hasattr(ctypes.CDLL(""), "__asan_init") 58 | 59 | 60 | def is_tsan() -> bool: 61 | """Determines if the Python interpreter is running with TSAN""" 62 | return hasattr(ctypes.CDLL(""), "__tsan_init") 63 | 64 | 65 | def is_asan_or_tsan() -> bool: 66 | return is_asan() or is_tsan() 67 | 68 | 69 | def skip_if_asan( 70 | func: Callable[TParams, TReturn] 71 | ) -> Callable[TParams, Optional[TReturn]]: 72 | """Skip test run if we are in ASAN mode.""" 73 | 74 | @wraps(func) 75 | def wrapper(*args: TParams.args, **kwargs: TParams.kwargs) -> Optional[TReturn]: 76 | if is_asan_or_tsan(): 77 | print("Skipping test run since we are in ASAN mode.") 78 | return 79 | return func(*args, **kwargs) 80 | 81 | return wrapper 82 | 83 | 84 | @contextmanager 85 | def captured_output() -> Generator[Tuple[TextIO, TextIO], None, None]: 86 | new_out, new_err = StringIO(), StringIO() 87 | old_out, old_err = sys.stdout, sys.stderr 88 | try: 89 | sys.stdout, sys.stderr = new_out, new_err 90 | yield sys.stdout, sys.stderr 91 | finally: 92 | sys.stdout, sys.stderr = old_out, old_err 93 | 94 | 95 | """Decorator for tests to ensure running on a GPU.""" 96 | skip_if_not_gpu: Callable[..., Callable[..., object]] = unittest.skipUnless( 97 | torch.cuda.is_available(), "Skipping test since GPU is not available" 98 | ) 99 | 100 | """Decorator for tests to ensure running when distributed is available.""" 101 | skip_if_not_distributed: Callable[..., Callable[..., object]] = unittest.skipUnless( 102 | torch.distributed.is_available(), "Skipping test since distributed is not available" 103 | ) 104 | -------------------------------------------------------------------------------- /torchtnt/utils/tqdm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import io 11 | import logging 12 | from typing import Iterable, Optional, TextIO, Union 13 | 14 | from torchtnt.utils.progress import estimated_steps_in_epoch 15 | from tqdm.auto import tqdm 16 | 17 | logger: logging.Logger = logging.getLogger(__name__) 18 | 19 | 20 | def create_progress_bar( 21 | dataloader: Iterable[object], 22 | *, 23 | desc: str, 24 | num_epochs_completed: int, 25 | num_steps_completed: int, 26 | max_steps: Optional[int], 27 | max_steps_per_epoch: Optional[int], 28 | mininterval: float | None = None, 29 | file: Optional[Union[TextIO, io.StringIO]] = None, 30 | ) -> tqdm: 31 | """Constructs a :func:`tqdm` progress bar. The number of steps in an epoch is inferred from the dataloader, num_steps_completed, max_steps and max_steps_per_epoch. 32 | 33 | Args: 34 | dataloader: an iterable of data, used to infer number of steps in an epoch. 35 | desc: a description for the progress bar. 36 | num_epochs_completed: an integer for the number of epochs completed so far int he loop. 37 | num_steps_completed: an integer for the number of steps completed so far in the loop. 38 | max_steps: an optional integer for the number of max steps in the loop. 39 | max_steps_per_epoch: an optional integer for the number of max steps per epoch. 40 | mininterval: Minimum display update interval (in seconds). If None, use TQDM's default. 41 | file: specifies where to output the progress messages (default: sys.stderr) 42 | """ 43 | current_epoch = num_epochs_completed 44 | total = estimated_steps_in_epoch( 45 | dataloader, 46 | num_steps_completed=num_steps_completed, 47 | max_steps=max_steps, 48 | max_steps_per_epoch=max_steps_per_epoch, 49 | ) 50 | kwargs = {} 51 | if mininterval is not None: 52 | kwargs["mininterval"] = mininterval 53 | return tqdm( 54 | desc=f"{desc} {current_epoch}", 55 | total=total, 56 | initial=num_steps_completed, 57 | bar_format="{l_bar}{bar}{r_bar}\n", 58 | file=file, 59 | **kwargs, 60 | ) 61 | 62 | 63 | def update_progress_bar( 64 | progress_bar: tqdm, num_steps_completed: int, refresh_rate: int 65 | ) -> None: 66 | """Updates a progress bar to reflect the number of steps completed.""" 67 | if num_steps_completed % refresh_rate == 0: 68 | progress_bar.update(refresh_rate) 69 | 70 | 71 | def close_progress_bar( 72 | progress_bar: tqdm, num_steps_completed: int, refresh_rate: int 73 | ) -> None: 74 | """Updates and closes a progress bar.""" 75 | # complete remaining progress in bar 76 | progress_bar.update(num_steps_completed % refresh_rate) 77 | progress_bar.close() 78 | -------------------------------------------------------------------------------- /torchtnt/utils/version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import platform 11 | 12 | import pkg_resources 13 | import torch 14 | from packaging.version import Version 15 | 16 | 17 | def is_windows() -> bool: 18 | """ 19 | Is the current program running in the Windows operating system? 20 | """ 21 | return platform.system() == "Windows" 22 | 23 | 24 | def get_python_version() -> Version: 25 | """ 26 | Get the current runtime Python version as a Version. 27 | 28 | Example:: 29 | 30 | # if running in Python 3.8.0 31 | >>> get_python_version() 32 | '3.8.0' 33 | """ 34 | return Version(platform.python_version()) 35 | 36 | 37 | def get_torch_version() -> Version: 38 | """ 39 | Get the PyTorch version for the current runtime environment as a Version. 40 | 41 | Example:: 42 | 43 | # if running PyTorch 1.12.0 44 | >>> get_torch_version() 45 | '1.12.0' 46 | """ 47 | try: 48 | if hasattr(torch, "__version__"): 49 | pkg_version = Version(torch.__version__) 50 | else: 51 | # try pkg_resources to infer version 52 | pkg_version = Version(pkg_resources.get_distribution("torch").version) 53 | except TypeError as e: 54 | raise TypeError("PyTorch version could not be detected automatically.") from e 55 | 56 | return pkg_version 57 | 58 | 59 | def is_torch_version_geq(version: str) -> bool: 60 | return get_torch_version() >= Version(version) 61 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W504 4 | exclude = build 5 | --------------------------------------------------------------------------------