├── .github ├── ISSUE_TEMPLATE │ ├── 01_bug.yml │ ├── 02_enchancement.yml │ ├── 03_docs.yml │ └── config.yml ├── dependabot.yml ├── pull_request_template.md └── workflows │ ├── cache.yml │ ├── integration.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── examples ├── README.md ├── __init__.py ├── language │ ├── __init__.py │ ├── dataset.py │ ├── engine.py │ └── transformer.py ├── py.typed ├── requirements.txt ├── torch_cifar10_resnet.py ├── torch_imagenet_resnet.py ├── torch_language_model.py ├── utils.py └── vision │ ├── __init__.py │ ├── cifar_resnet.py │ ├── datasets.py │ ├── engine.py │ └── optimizers.py ├── kfac ├── __init__.py ├── assignment.py ├── base_preconditioner.py ├── distributed.py ├── enums.py ├── gpt_neox │ ├── __init__.py │ ├── assignment.py │ ├── layer.py │ ├── modules.py │ ├── mpu.py │ └── preconditioner.py ├── hyperparams.py ├── layers │ ├── __init__.py │ ├── base.py │ ├── eigen.py │ ├── inverse.py │ ├── modules.py │ ├── register.py │ └── utils.py ├── preconditioner.py ├── py.typed ├── scheduler.py ├── tracing.py └── warnings.py ├── pyproject.toml ├── scripts ├── README.md ├── copy_and_extract.sh ├── kill_python_procs.sh └── run_imagenet.sh ├── testing ├── __init__.py ├── assignment.py ├── distributed.py ├── gpt_neox.py ├── models.py └── utils.py ├── tests ├── __init__.py ├── assignment_test.py ├── base_preconditioner_test.py ├── distributed_test.py ├── gpt_neox │ ├── __init__.py │ ├── gpt_assignment_test.py │ ├── gpt_modules_test.py │ ├── gpt_mpu_test.py │ └── gpt_preconditioner_test.py ├── hyperparams_test.py ├── integration │ ├── __init__.py │ └── mnist_integration_test.py ├── layers │ ├── __init__.py │ ├── layers_test.py │ ├── modules_test.py │ ├── register_test.py │ └── utils_test.py ├── preconditioner_test.py ├── scheduler_test.py ├── testing │ ├── __init__.py │ └── distributed_wrapper_test.py ├── tracing_test.py └── training_test.py └── tox.ini /.github/ISSUE_TEMPLATE/01_bug.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Report errors or unexpected results. 3 | labels: ["bug"] 4 | assignees: 5 | - gpauloski 6 | body: 7 | - type: textarea 8 | id: install 9 | attributes: 10 | label: How did you install K-FAC and PyTorch? 11 | description: > 12 | E.g., install via pip, install from source, etc. **Note:** this will 13 | be rendered as console text automatically. 14 | placeholder: | 15 | $ pip install torch 16 | $ pip install -e . 17 | ... 18 | render: console 19 | validations: 20 | required: true 21 | 22 | - type: input 23 | id: version 24 | attributes: 25 | label: What commit are you using? 26 | description: > 27 | Current commit hash. 28 | placeholder: Output of $(git rev-parse HEAD) 29 | validations: 30 | required: true 31 | 32 | - type: textarea 33 | id: freeform 34 | attributes: 35 | label: Describe the problem. 36 | description: > 37 | Please provide sample code and directions for reproducing 38 | your problem and what you expected to happen. 39 | validations: 40 | required: true 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/02_enchancement.yml: -------------------------------------------------------------------------------- 1 | name: Enhancement Request 2 | description: Request a new feature. 3 | labels: ["enhancement"] 4 | assignees: 5 | - gpauloski 6 | body: 7 | - type: textarea 8 | id: request 9 | attributes: 10 | label: Describe the Request 11 | description: > 12 | Please describe your use case and why the current feature set does 13 | not satisfy your needs. 14 | validations: 15 | required: true 16 | 17 | - type: textarea 18 | id: example 19 | attributes: 20 | label: Sample Code 21 | description: > 22 | If relevant, please provide sample code such as the proposed 23 | interface, usage, or results. **Note:** this will be rendered as 24 | Python code automatically. 25 | render: python 26 | validations: 27 | required: false 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/03_docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation Improvements 2 | description: Suggest improvements to the documentation. 3 | labels: ["documentation"] 4 | assignees: 5 | - gpauloski 6 | body: 7 | - type: textarea 8 | id: freeform 9 | attributes: 10 | label: Describe the Request 11 | description: > 12 | Please describe limitations of the current documentation (either 13 | the docstrings in the code or the GitHub wiki and READMEs) or 14 | suggested improvements. 15 | validations: 16 | required: true 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | # Check for updates to GitHub Actions every week 8 | interval: "weekly" 9 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | 4 | 5 | ### Fixes 6 | 7 | 8 | - Fixes #XX 9 | - Fixes #XX 10 | 11 | ### Type of Change 12 | 13 | 14 | - [ ] Bug fix (non-breaking change which fixes an issue) 15 | - [ ] New feature (non-breaking change which adds functionality) 16 | - [ ] Refactoring (internal implementation changes) 17 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 18 | - [ ] Documentation update (no changes to the code) 19 | - [ ] CI change (changes to CI workflows, packages, templates, etc.) 20 | - [ ] Version changes (changes to the package or dependency versions) 21 | 22 | ## Testing 23 | 24 | 25 | N/A 26 | 27 | ## Pull Request Checklist 28 | 29 | Please confirm the PR meets the following requirements. 30 | - [ ] Code changes pass `pre-commit` (e.g., mypy, ruff, etc.). 31 | - [ ] Tests have been added to show the fix is effective or that the new feature works. 32 | - [ ] New and existing unit tests pass locally with the changes. 33 | - [ ] Docs have been updated and reviewed if relevant. 34 | -------------------------------------------------------------------------------- /.github/workflows/cache.yml: -------------------------------------------------------------------------------- 1 | name: cache-cleanup 2 | 3 | on: 4 | pull_request: 5 | types: 6 | - closed 7 | 8 | jobs: 9 | cleanup: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Check out code 13 | uses: actions/checkout@v4 14 | 15 | - name: Cleanup 16 | run: | 17 | gh extension install actions/gh-actions-cache 18 | REPO=${{ github.repository }} 19 | BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge" 20 | echo "Fetching list of cache key" 21 | cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 ) 22 | ## Setting this to not fail the workflow while deleting cache keys. 23 | set +e 24 | echo "Deleting caches..." 25 | for cacheKey in $cacheKeysForPR 26 | do 27 | gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm 28 | done 29 | echo "Done" 30 | env: 31 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 32 | -------------------------------------------------------------------------------- /.github/workflows/integration.yml: -------------------------------------------------------------------------------- 1 | name: integration 2 | 3 | on: 4 | push: 5 | branches: [main, test-me-*] 6 | tags: 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | jobs: 11 | integration: 12 | timeout-minutes: 15 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Python 3.12 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.12' 24 | 25 | - name: Get pip cache dir 26 | id: pip-cache-dir 27 | run: echo "PIP_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV 28 | 29 | - name: Use pip cache 30 | id: pip-cache 31 | uses: actions/cache@v4 32 | with: 33 | path: ${{ env.PIP_CACHE_DIR }} 34 | key: integration-ubuntu-latest-pip-3.12-${{ hashFiles('pyproject.toml') }} 35 | restore-keys: | 36 | integration-ubuntu-latest-pip-3.12 37 | 38 | - name: Install KFAC 39 | run: python -m pip install . --extra-index-url https://download.pytorch.org/whl/cpu 40 | 41 | - name: Install Dependencies 42 | run: python -m pip install --upgrade torchvision --extra-index-url https://download.pytorch.org/whl/cpu 43 | 44 | - name: Run MNIST Integration Test 45 | run: python tests/integration/mnist_integration_test.py 46 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [main, test-me-*] 6 | tags: 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | jobs: 11 | tests: 12 | timeout-minutes: 15 13 | 14 | strategy: 15 | matrix: 16 | include: 17 | - os: ubuntu-latest 18 | python: 3.9 19 | toxenv: py39 20 | - os: ubuntu-latest 21 | python: '3.10' 22 | toxenv: py310 23 | - os: ubuntu-latest 24 | python: '3.11' 25 | toxenv: py311 26 | - os: ubuntu-latest 27 | python: '3.12' 28 | toxenv: py312 29 | - os: ubuntu-latest 30 | python: '3.13' 31 | toxenv: py313 32 | runs-on: ${{ matrix.os }} 33 | 34 | steps: 35 | - name: Checkout 36 | uses: actions/checkout@v4 37 | 38 | - name: Set up Python ${{matrix.python}} 39 | uses: actions/setup-python@v5 40 | with: 41 | python-version: ${{ matrix.python }} 42 | 43 | - name: Get pip cache dir 44 | id: pip-cache-dir 45 | run: echo "PIP_CACHE_DIR=$(pip cache dir)" >> $GITHUB_ENV 46 | 47 | - name: Use pip cache 48 | id: pip-cache 49 | uses: actions/cache@v4 50 | with: 51 | path: ${{ env.PIP_CACHE_DIR }} 52 | key: ${{ matrix.os }}-pip-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} 53 | restore-keys: | 54 | ${{ matrix.os }}-pip-${{ matrix.python }}- 55 | 56 | - name: Install Packages 57 | run: python -mpip install --upgrade pip tox 58 | 59 | - name: Run Tox 60 | run: tox -e ${{ matrix.toxenv }} 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | sbatch_logs/ 3 | saved_logs/ 4 | 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | repos: 4 | - repo: 'https://github.com/pre-commit/pre-commit-hooks' 5 | rev: v5.0.0 6 | hooks: 7 | - id: mixed-line-ending 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-added-large-files 11 | - id: check-docstring-first 12 | - id: check-json 13 | - id: check-yaml 14 | - id: check-merge-conflict 15 | - id: name-tests-test 16 | - repo: 'https://github.com/codespell-project/codespell' 17 | rev: v2.4.1 18 | hooks: 19 | - id: codespell 20 | - repo: 'https://github.com/charliermarsh/ruff-pre-commit' 21 | rev: v0.11.12 22 | hooks: 23 | - id: ruff 24 | args: 25 | - '--fix' 26 | - id: ruff-format 27 | - repo: 'https://github.com/pre-commit/mirrors-mypy' 28 | rev: v1.16.0 29 | hooks: 30 | - id: mypy 31 | additional_dependencies: [types-tqdm] 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Greg Pauloski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed K-FAC Preconditioner for PyTorch 2 | 3 | [![DOI](https://zenodo.org/badge/240976400.svg)](https://zenodo.org/badge/latestdoi/240976400) 4 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/gpauloski/kfac_pytorch/main.svg)](https://results.pre-commit.ci/latest/github/gpauloski/kfac_pytorch/main) 5 | [![Tests](https://github.com/gpauloski/kfac_pytorch/actions/workflows/tests.yml/badge.svg)](https://github.com/gpauloski/kfac_pytorch/actions) 6 | [![Integration](https://github.com/gpauloski/kfac_pytorch/actions/workflows/integration.yml/badge.svg)](https://github.com/gpauloski/kfac_pytorch/actions) 7 | 8 | K-FAC, Kronecker-factored Approximate Curvature, is a second-order optimization method based on an efficient approximation of the Fisher information matrix (see the [original paper](https://arxiv.org/abs/1503.05671)). 9 | This repository provides a PyTorch implementation of K-FAC as a preconditioner to standard PyTorch optimizers with support for single-device or distributed training. 10 | The distributed strategy is implemented using KAISA, a K-FAC-enabled, Adaptable, Improved, and Scalable second-order optimizer framework, where the placement of the second-order computations and gradient preconditioning is controlled by the *gradient worker fraction* parameter (see the [paper](https://arxiv.org/abs/2107.01739) for more details). 11 | KAISA has been shown to reduce time-to-convergence in [PyTorch distributed training](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) applications such as ResNet-50, Mask R-CNN, and BERT. 12 | 13 | ## Publications 14 | 15 | - J. G. Pauloski, L. Huang, W. Xu, K. Chard, I. T. Foster and Z. Zhang, "[Deep Neural Network Training With Distributed K-FAC](https://ieeexplore.ieee.org/document/9739867)," in IEEE Transactions on Parallel and Distributed Systems, vol. 33, no. 12, pp. 3616-3627, 1 Dec. 2022, doi: 10.1109/TPDS.2022.3161187. 16 | - J. Gregory Pauloski, Qi Huang, Lei Huang, Shivaram Venkataraman, Kyle Chard, Ian Foster, and Zhao Zhang. 2021. [KAISA: An Adaptive Second-order Optimizer Framework for Deep Neural Networks](https://dl.acm.org/doi/10.1145/3458817.3476152). International Conference for High Performance Computing, Networking, Storage and Analysis (SC '21). Association for Computing Machinery, New York, NY, USA, Article 13, 1–14. 17 | - J. Gregory Pauloski, Zhao Zhang, Lei Huang, Weijia Xu, and Ian T. Foster. 2020. [Convolutional Neural Network Training with Distributed K-FAC](https://dl.acm.org/doi/10.5555/3433701.3433826). International Conference for High Performance Computing, Networking, Storage and Analysis (SC ‘20). IEEE Press, Article 94, 1–14. 18 | 19 | ## Table of Contents 20 | 21 | - [Install](#install) 22 | - [Usage](#usage) 23 | - [Examples](#examples) 24 | - [Developing](#developing) 25 | - [Citations and References](#citations-and-references) 26 | 27 | ## Install 28 | 29 | ### Requirements 30 | 31 | K-FAC only requires PyTorch 1.8 or later. 32 | The example scripts have additional requirements defined in [examples/requirements.txt](examples/requirements.txt). 33 | 34 | ### Installation 35 | 36 | ``` 37 | $ git clone https://github.com/gpauloski/kfac_pytorch.git 38 | $ cd kfac_pytorch 39 | $ pip install . # Use -e to install in development mode 40 | ``` 41 | 42 | If [NVIDIA Apex](https://github.com/NVIDIA/apex) is installed with C extensions, the optimized `flatten` and `unflatten` operations will be used during collective communication operations. 43 | 44 | ## Usage 45 | 46 | K-FAC requires minimal code to incorporate with existing training scripts. 47 | See the [K-FAC docstring](kfac/preconditioner.py) for a detailed list of K-FAC parameters. 48 | 49 | ```Python 50 | from kfac.preconditioner import KFACPreconditioner 51 | 52 | ... 53 | 54 | model = torch.nn.parallel.DistributedDataParallel(...) 55 | optimizer = optim.SGD(model.parameters(), ...) 56 | preconditioner = KFACPreconditioner(model, ...) 57 | 58 | ... 59 | 60 | for data, target in train_loader: 61 | optimizer.zero_grad() 62 | output = model(data) 63 | loss = criterion(output, target) 64 | loss.backward() 65 | preconditioner.step() 66 | optimizer.step() 67 | 68 | ... 69 | ``` 70 | 71 | See the [wiki](https://github.com/gpauloski/kfac_pytorch/wiki) for more details on K-FAC's features. 72 | 73 | ## Examples 74 | 75 | Example scripts for training ResNet models on Cifar10 and ImageNet-1k are provided in [examples/](examples/). 76 | 77 | ## Developing 78 | 79 | [tox](https://tox.wiki/en/latest/index.html) and [pre-commit](https://pre-commit.com) are used for development. 80 | Pre-commit enforces the code formatting, linting, and type-checking in this repository. 81 | 82 | To get started with local development (note: Python 3.11 is supported but some testing dependencies are not available): 83 | ``` 84 | $ tox --devenv venv -e py310 85 | $ . venv/bin/activate 86 | $ pre-commit install 87 | ``` 88 | Note that the `tox` recipes install CPU-only PyTorch as GPUs are not available in CI. 89 | 90 | To verify code passes pre-commit, run: 91 | ``` 92 | $ pre-commit run --all-files 93 | ``` 94 | 95 | Tox can also be used to run the test suite: 96 | ``` 97 | $ tox -e py39 # run all tests in Python 3.9 98 | ``` 99 | 100 | ## Citations and References 101 | 102 | The K-FAC code is based on Chaoqi Wang's [KFAC-PyTorch](https://github.com/alecwangcq/KFAC-Pytorch). 103 | The ResNet models for Cifar10 are from Yerlan Idelbayev's [pytorch_resnet_cifar10](https://github.com/akamaster/pytorch_resnet_cifar10). 104 | The CIFAR-10 and ImageNet-1k training scripts are modeled after Horovod's example PyTorch training scripts. 105 | 106 | The code used in "[Convolutional Neural Network Training with Distributed K-FAC](https://dl.acm.org/doi/10.5555/3433701.3433826)" is frozen in the `kfac-lw` and `kfac-opt` branches. 107 | The code used in "[KAISA: An Adaptive Second-order Optimizer Framework for Deep Neural Networks](https://dl.acm.org/doi/10.1145/3458817.3476152)" is frozen in the `hybrid-opt` branch. 108 | 109 | If you use this code in your work, please cite the SC '20 and '21 papers. 110 | 111 | ``` 112 | @inproceedings{pauloski2020kfac, 113 | author = {Pauloski, J. Gregory and Zhang, Zhao and Huang, Lei and Xu, Weijia and Foster, Ian T.}, 114 | title = {Convolutional {N}eural {N}etwork {T}raining with {D}istributed {K}-{FAC}}, 115 | year = {2020}, 116 | isbn = {9781728199986}, 117 | publisher = {IEEE Press}, 118 | booktitle = {Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis}, 119 | articleno = {94}, 120 | numpages = {14}, 121 | location = {Atlanta, Georgia}, 122 | series = {SC '20}, 123 | doi = {10.5555/3433701.3433826} 124 | } 125 | 126 | @inproceedings{pauloski2021kaisa, 127 | author = {Pauloski, J. Gregory and Huang, Qi and Huang, Lei and Venkataraman, Shivaram and Chard, Kyle and Foster, Ian and Zhang, Zhao}, 128 | title = {KAISA: {A}n {A}daptive {S}econd-{O}rder {O}ptimizer {F}ramework for {D}eep {N}eural {N}etworks}, 129 | year = {2021}, 130 | isbn = {9781450384421}, 131 | publisher = {Association for Computing Machinery}, 132 | address = {New York, NY, USA}, 133 | url = {https://doi.org/10.1145/3458817.3476152}, 134 | doi = {10.1145/3458817.3476152}, 135 | booktitle = {Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis}, 136 | articleno = {13}, 137 | numpages = {14}, 138 | location = {St. Louis, Missouri}, 139 | series = {SC '21} 140 | } 141 | ``` 142 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # KFAC Training Examples 2 | 3 | Distributed training with K-FAC examples for computer vision (ImageNet and 4 | CIFAR-10) and language modeling (PennTreebank and WikiText) tasks. 5 | 6 | ## Requirements 7 | 8 | The provided example training scripts require KFAC to be installed and the 9 | additional requirements defined in `examples/requirements.txt`. 10 | 11 | ``` 12 | $ pip install -e . 13 | $ pip install -r examples/requirements.txt 14 | ``` 15 | 16 | Python >=3.9 is recommended. 17 | 18 | ## Usage 19 | 20 | Note: these examples use the `torchrun` launcher which is only available in 21 | PyTorch 1.10 and later. For PyTorch 1.9, use `python -m torch.distributed.run` 22 | and for PyTorch 1.8, use `python -m torch.distributed.launch`. 23 | 24 | #### Single Node, Multi-GPU 25 | ``` 26 | $ torchrun --standalone --nnodes 1 --nproc_per_node=[NGPUS] \ 27 | examples/torch_{...}.py [ARGS] 28 | ``` 29 | 30 | #### Multi-Node, Multi-GPU 31 | On each node, run: 32 | ``` 33 | $ torchrun --nnodes=[NNODES] --nproc_per_node=[NGPUS] --rdzv_backend=c10d --rdzv_endpoint=[HOSTADDR] \ 34 | examples/torch_{...}.py [ARGS] 35 | ``` 36 | 37 | The full list of arguments can be found with the `--help` argument. 38 | E.g., `python examples/torch_cifar10_resnet.py --help`. 39 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/examples/__init__.py -------------------------------------------------------------------------------- /examples/language/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/examples/language/__init__.py -------------------------------------------------------------------------------- /examples/language/dataset.py: -------------------------------------------------------------------------------- 1 | """Language modeling datasets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import sys 6 | from typing import Callable 7 | from typing import List 8 | from typing import NamedTuple 9 | from typing import Tuple 10 | from typing import Union 11 | 12 | if sys.version_info >= (3, 9): # pragma: >=3.9 cover 13 | from typing import Literal 14 | else: # pragma: <3.9 cover 15 | from typing_extensions import Literal 16 | 17 | import torch 18 | from torch.utils.data import DataLoader 19 | from torch.utils.data.dataset import Dataset 20 | from torch.utils.data.dataset import IterableDataset 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torchtext.data.utils import get_tokenizer 23 | from torchtext.datasets import PennTreebank 24 | from torchtext.datasets import WikiText2 25 | from torchtext.datasets import WikiText103 26 | from torchtext.vocab import build_vocab_from_iterator 27 | from torchtext.vocab import Vocab 28 | 29 | DType = Tuple[torch.Tensor, torch.Tensor] 30 | IndicesType = Union[List[int], torch.Tensor] 31 | 32 | 33 | class LoaderSampler(NamedTuple): 34 | """Tuple of a dataloader and the corresponding datasampler.""" 35 | 36 | loader: DataLoader[DType] 37 | sampler: DistributedSampler[IndicesType] 38 | 39 | 40 | class Datasets(NamedTuple): 41 | """Train/val/test tuple of LoaderSamplers.""" 42 | 43 | train: LoaderSampler 44 | val: LoaderSampler 45 | test: LoaderSampler 46 | 47 | 48 | class _Dataset(Dataset[DType]): 49 | def __init__(self, data: torch.Tensor, seq_len: int) -> None: 50 | self._data = data 51 | self._seq_len = seq_len 52 | 53 | def __len__(self) -> int: 54 | return len(self._data) // self._seq_len 55 | 56 | def __getitem__(self, idx: int) -> DType: 57 | start = self._seq_len * idx 58 | end = self._seq_len * (idx + 1) 59 | data = self._data[start:end] 60 | target = self._data[start + 1 : end + 1] 61 | return data, target 62 | 63 | 64 | def download_dataset( 65 | dataset: Literal['penntreebank', 'wikitext2', 'wikitext103'], 66 | data_dir: str, 67 | ) -> tuple[IterableDataset[str], IterableDataset[str], IterableDataset[str]]: 68 | """Get a torchtext language modeling dataset. 69 | 70 | Args: 71 | dataset (str): one of 'penntreebank', 'wikitext2', or 'wikitext102'. 72 | data_dir (str): directory to download datasets to. 73 | 74 | Returns: 75 | tuple of train, validation, and testing sets for the specified 76 | dataset. 77 | """ 78 | datasets = {} 79 | for split in ('train', 'valid', 'test'): 80 | if dataset.lower() == 'penntreebank': 81 | datasets[split] = PennTreebank(root=data_dir, split=split) 82 | elif dataset.lower() == 'wikitext2': 83 | datasets[split] = WikiText2(root=data_dir, split=split) 84 | elif dataset.lower() == 'wikitext103': 85 | datasets[split] = WikiText103(root=data_dir, split=split) 86 | else: 87 | raise AssertionError(f'Unsupported dataset {dataset}.') 88 | 89 | return (datasets['train'], datasets['valid'], datasets['test']) 90 | 91 | 92 | def encode_and_flatten( 93 | raw_text_iter: IterableDataset[str], 94 | tokenizer: Callable[[str], list[str]], 95 | vocab: Vocab, 96 | ) -> torch.Tensor: 97 | """Tokenizes, encodes, and flattens a dataset.""" 98 | data = [ 99 | torch.tensor(vocab(tokenizer(item)), dtype=torch.long) 100 | for item in raw_text_iter 101 | ] 102 | return torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) 103 | 104 | 105 | def get_dataset( 106 | dataset: Literal['penntreebank', 'wikitext2', 'wikitext103'], 107 | data_dir: str, 108 | seq_len: int, 109 | batch_size: int, 110 | *, 111 | cuda: bool = False, 112 | rank: int | None = None, 113 | world_size: int | None = None, 114 | ) -> tuple[Datasets, Vocab]: 115 | """Get language modeling datasets. 116 | 117 | Args: 118 | dataset (str): one of 'penntreebank', 'wikitext2', or 'wikitext102'. 119 | data_dir (str): directory to download datasets to. 120 | seq_len (int): number of tokens in a training sequence. 121 | batch_size (int): batch size. 122 | cuda (bool): set as True if training with CUDA. 123 | rank (int): optional rank of this worker for initializing the 124 | distributed sampler. 125 | world_size (int): optional world size if using distributed training. 126 | 127 | Returns: 128 | Datasets, a named tuple with attributes train, val, and test, each 129 | corresponding to another tuple with the dataloader and sampler for 130 | that training data split. Also returns the vocab used to encode 131 | the datasets. 132 | """ 133 | train_iter, val_iter, test_iter = download_dataset(dataset, data_dir) 134 | 135 | tokenizer = get_tokenizer('basic_english') 136 | vocab = build_vocab_from_iterator( 137 | map(tokenizer, train_iter), 138 | specials=[''], 139 | ) 140 | vocab.set_default_index(vocab['']) 141 | 142 | train_data = encode_and_flatten(train_iter, tokenizer, vocab) 143 | val_data = encode_and_flatten(val_iter, tokenizer, vocab) 144 | test_data = encode_and_flatten(test_iter, tokenizer, vocab) 145 | 146 | datasets = [ 147 | _Dataset(data, seq_len=seq_len) 148 | for data in (train_data, val_data, test_data) 149 | ] 150 | 151 | num_replicas = ( 152 | torch.distributed.get_world_size() 153 | if world_size is None 154 | else world_size 155 | ) 156 | rank = torch.distributed.get_rank() if rank is None else rank 157 | 158 | samplers: list[DistributedSampler[IndicesType]] = [ 159 | DistributedSampler(dataset, num_replicas=num_replicas, rank=rank) 160 | for dataset in datasets 161 | ] 162 | 163 | loaders: list[DataLoader[DType]] = [ 164 | DataLoader( 165 | dataset, 166 | batch_size=batch_size, 167 | drop_last=True, 168 | sampler=sampler, 169 | num_workers=4 if cuda else 0, 170 | pin_memory=cuda, 171 | ) 172 | for dataset, sampler in zip(datasets, samplers) 173 | ] 174 | 175 | return ( 176 | Datasets( 177 | train=LoaderSampler(loaders[0], samplers[0]), 178 | val=LoaderSampler(loaders[1], samplers[1]), 179 | test=LoaderSampler(loaders[2], samplers[2]), 180 | ), 181 | vocab, 182 | ) 183 | 184 | 185 | if __name__ == '__main__': 186 | datasets, vocab = get_dataset( 187 | 'penntreebank', 188 | '/tmp/torchtext-data', 189 | 12, 190 | 4, 191 | world_size=1, 192 | rank=0, 193 | ) 194 | 195 | datasets.train.sampler.set_epoch(0) 196 | for batch, (data, target) in enumerate(datasets.train.loader): 197 | if batch > 2: 198 | break 199 | print(f'BATCH {batch}') 200 | for sample in range(len(data)): 201 | print(f'SAMPLE {sample}') 202 | print(vocab.lookup_tokens(list(data[sample]))) 203 | print(vocab.lookup_tokens(list(target[sample]))) 204 | -------------------------------------------------------------------------------- /examples/language/engine.py: -------------------------------------------------------------------------------- 1 | """Training and eval functions for the language modeling example.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Tuple 6 | 7 | import torch 8 | from tqdm import tqdm 9 | 10 | import kfac 11 | from examples.language.transformer import gen_square_subsequent_mask 12 | from examples.utils import Metric 13 | 14 | DType = Tuple[torch.Tensor, torch.Tensor] 15 | 16 | 17 | def train( 18 | model: torch.nn.Module, 19 | *, 20 | criterion: torch.nn.Module, 21 | optimizer: torch.optim.Optimizer, 22 | preconditioner: kfac.base_preconditioner.BaseKFACPreconditioner | None, 23 | dataloader: torch.utils.data.DataLoader[DType], 24 | epoch: int, 25 | epochs: int, 26 | ) -> float: 27 | """Perform one training epoch.""" 28 | model.train() 29 | train_loss = Metric('train_loss') 30 | src_mask: torch.Tensor | None = None 31 | 32 | with tqdm( 33 | total=len(dataloader), 34 | bar_format='{l_bar}{bar:8}{r_bar}', 35 | desc=f'Epoch {epoch:2d}/{epochs:2d}', 36 | disable=torch.distributed.get_rank() > 0, 37 | ) as t: 38 | for data, target in dataloader: 39 | if src_mask is None: 40 | seq_len = data.size(0) 41 | device = next(model.parameters()).device 42 | src_mask = gen_square_subsequent_mask(seq_len).to(device) 43 | 44 | optimizer.zero_grad() 45 | 46 | data = data.to(model.device) 47 | target = target.to(model.device).reshape(-1) 48 | 49 | output = model(data, src_mask) 50 | ntokens = output.size(-1) 51 | output_flat = output.view(-1, ntokens) 52 | 53 | loss = criterion(output_flat, target) 54 | loss.backward() 55 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 56 | 57 | if preconditioner is not None: 58 | preconditioner.step() 59 | optimizer.step() 60 | 61 | loss = loss.detach() 62 | train_loss.update(loss) 63 | 64 | t.set_postfix_str( 65 | 'loss: {:.2f}, ppl: {:.2f}, lr: {:.1E}'.format( 66 | train_loss.avg, 67 | torch.exp(train_loss.avg), 68 | optimizer.param_groups[0]['lr'], 69 | ), 70 | ) 71 | t.update(1) 72 | 73 | return train_loss.avg.item() 74 | 75 | 76 | def evaluate( 77 | model: torch.nn.Module, 78 | *, 79 | criterion: torch.nn.Module, 80 | dataloader: torch.utils.data.DataLoader[DType], 81 | prefix: str, 82 | ) -> float: 83 | """Evaluate model.""" 84 | model.eval() 85 | eval_loss = Metric('eval_loss') 86 | src_mask: torch.Tensor | None = None 87 | 88 | with ( 89 | torch.no_grad(), 90 | tqdm( 91 | total=len(dataloader), 92 | bar_format='{l_bar}{bar:8}{r_bar}', 93 | desc=prefix[:11].ljust(11, ' '), 94 | disable=torch.distributed.get_rank() > 0, 95 | ) as t, 96 | ): 97 | for data, target in dataloader: 98 | if src_mask is None: 99 | seq_len = data.size(0) 100 | device = next(model.parameters()).device 101 | src_mask = gen_square_subsequent_mask(seq_len).to(device) 102 | 103 | data = data.to(model.device) 104 | target = target.to(model.device).reshape(-1) 105 | 106 | output = model(data, src_mask) 107 | ntokens = output.size(-1) 108 | output_flat = output.view(-1, ntokens) 109 | 110 | loss = criterion(output_flat, target) 111 | loss = loss.detach() 112 | eval_loss.update(loss) 113 | 114 | t.update(1) 115 | t.set_postfix_str( 116 | 'loss: {:.2f}, ppl: {:.2f}'.format( 117 | eval_loss.avg, 118 | torch.exp(eval_loss.avg), 119 | ), 120 | refresh=False, 121 | ) 122 | 123 | return eval_loss.avg.item() 124 | -------------------------------------------------------------------------------- /examples/language/transformer.py: -------------------------------------------------------------------------------- 1 | """Simple Transformer Model. 2 | 3 | Based on Attention is All You Need and 4 | https://pytorch.org/tutorials/beginner/transformer_tutorial.html. 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import math 10 | 11 | import torch 12 | from torch import nn 13 | 14 | 15 | class TransformerModel(nn.Module): 16 | """Transformer Model.""" 17 | 18 | def __init__( 19 | self, 20 | ntoken: int, 21 | d_model: int, 22 | nhead: int, 23 | d_hid: int, 24 | nlayers: int, 25 | dropout: float = 0.5, 26 | ) -> None: 27 | """Init TransformerModel. 28 | 29 | Args: 30 | ntoken (int): number of tokens in vocabulary. 31 | d_model (int): number of expected features in encoder/decoder 32 | inputs. 33 | nhead (int): number of attention heads. 34 | d_hid (int): hidden dimension size. 35 | nlayers (int): number of encoder layers in the model. 36 | dropout (float): dropout layer probability. 37 | """ 38 | super().__init__() 39 | self.model_type = 'Transformer' 40 | self.pos_encoder = PositionalEncoding(d_model, dropout) 41 | encoder_layers = nn.TransformerEncoderLayer( 42 | d_model, 43 | nhead, 44 | d_hid, 45 | dropout, 46 | ) 47 | self.transformer_encoder = nn.TransformerEncoder( 48 | encoder_layers, 49 | nlayers, 50 | ) 51 | self.encoder = nn.Embedding(ntoken, d_model) 52 | self.d_model = d_model 53 | self.decoder = nn.Linear(d_model, ntoken) 54 | 55 | self.init_weights() 56 | 57 | def init_weights(self) -> None: 58 | """Initialize weights.""" 59 | initrange = 0.1 60 | self.encoder.weight.data.uniform_(-initrange, initrange) 61 | self.decoder.bias.data.zero_() 62 | self.decoder.weight.data.uniform_(-initrange, initrange) 63 | 64 | def forward( 65 | self, 66 | src: torch.Tensor, 67 | src_mask: torch.Tensor, 68 | ) -> torch.Tensor: 69 | """Transformer forward pass. 70 | 71 | Args: 72 | src (Tensor): tensor with shape [seq_len, batch_size]. 73 | src_mask (Tensor): tensor with shape [seq_len, seq_len]. 74 | 75 | Returns: 76 | output tensor with shape [seq_len, batch_size, ntoken]. 77 | """ 78 | src = self.encoder(src) * math.sqrt(self.d_model) 79 | src = self.pos_encoder(src) 80 | output = self.transformer_encoder(src, src_mask) 81 | output = self.decoder(output) 82 | return output 83 | 84 | 85 | class PositionalEncoding(nn.Module): 86 | """Positional Encoder.""" 87 | 88 | def __init__( 89 | self, 90 | d_model: int, 91 | dropout: float = 0.1, 92 | max_len: int = 5000, 93 | ) -> None: 94 | """Init PositionalEncoding. 95 | 96 | Args: 97 | d_model (int): number of expected features in encoder/decoder 98 | inputs. 99 | dropout (float): dropout layer probability. 100 | max_len (int): max vocabulary size (I think). 101 | """ 102 | super().__init__() 103 | self.dropout = nn.Dropout(p=dropout) 104 | 105 | position = torch.arange(max_len).unsqueeze(1) 106 | div_term = torch.exp( 107 | torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model), 108 | ) 109 | self.pe: torch.Tensor 110 | pe = torch.zeros(max_len, 1, d_model) 111 | pe[:, 0, 0::2] = torch.sin(position * div_term) 112 | pe[:, 0, 1::2] = torch.cos(position * div_term) 113 | self.register_buffer('pe', pe) 114 | 115 | def forward(self, x: torch.Tensor) -> torch.Tensor: 116 | """Positional encoder forward pass. 117 | 118 | Args: 119 | x (Tensor): tensor with shape [seq_len, batch_size, embedding_dim]. 120 | 121 | Returns: 122 | tensor with same shape as input injected with some information 123 | about the relative or absolute position of the tokens in the 124 | sequence. 125 | """ 126 | x = x + self.pe[: x.size(0)] 127 | return self.dropout(x) 128 | 129 | 130 | def gen_square_subsequent_mask(sz: int) -> torch.Tensor: 131 | """Generates an upper-triangular matrix of -inf, with zeros on diag.""" 132 | return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) 133 | -------------------------------------------------------------------------------- /examples/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/examples/py.typed -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | torchdata 3 | torchinfo==1.5.2 4 | torchtext 5 | torchvision 6 | tqdm 7 | -------------------------------------------------------------------------------- /examples/torch_language_model.py: -------------------------------------------------------------------------------- 1 | """Language modeling with Transformers Example. 2 | 3 | Based on the PyTorch example and modified for distributed training with KFAC: 4 | https://pytorch.org/tutorials/beginner/transformer_tutorial.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import argparse 10 | import logging 11 | import os 12 | import sys 13 | import time 14 | from typing import Sequence 15 | 16 | import torch 17 | from torch.utils import collect_env 18 | 19 | import kfac 20 | from examples.language.dataset import get_dataset 21 | from examples.language.engine import evaluate 22 | from examples.language.engine import train 23 | from examples.language.transformer import TransformerModel 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: 29 | """Parse command line arguments.""" 30 | argv = argv if argv is not None else sys.argv[1:] 31 | parser = argparse.ArgumentParser( 32 | description='Language Modeling Example', 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 34 | ) 35 | 36 | model_group = parser.add_argument_group('Model Parameters') 37 | model_group.add_argument( 38 | '--embedding-dim', 39 | default=256, 40 | type=int, 41 | help='embedding dimension size', 42 | ) 43 | model_group.add_argument( 44 | '--hidden-dim', 45 | default=256, 46 | type=int, 47 | help='hidden dimension size', 48 | ) 49 | model_group.add_argument( 50 | '--attention-heads', 51 | default=4, 52 | type=int, 53 | help='number of attention heads', 54 | ) 55 | model_group.add_argument( 56 | '--layers', 57 | default=2, 58 | type=int, 59 | help='number of layers', 60 | ) 61 | model_group.add_argument( 62 | '--dropout', 63 | default=0.2, 64 | type=float, 65 | help='dropout probability', 66 | ) 67 | 68 | data_group = parser.add_argument_group('Data Parameters') 69 | data_group.add_argument( 70 | '--dataset', 71 | default='penntreebank', 72 | choices=['penntreebank', 'wikitext2', 'wikitext103'], 73 | help='dataset to train language model on', 74 | ) 75 | data_group.add_argument( 76 | '--download-dir', 77 | default='/tmp/torchtext-data', 78 | help='directory to download dataset to', 79 | ) 80 | data_group.add_argument( 81 | '--seq-len', 82 | default=64, 83 | type=int, 84 | help='number of tokens in each training sample', 85 | ) 86 | data_group.add_argument( 87 | '--batch-size', 88 | default=20, 89 | type=int, 90 | help='batch size', 91 | ) 92 | 93 | training_group = parser.add_argument_group('Training Parameters') 94 | training_group.add_argument( 95 | '--epochs', 96 | default=20, 97 | type=int, 98 | help='training epochs', 99 | ) 100 | training_group.add_argument( 101 | '--lr', 102 | default=1.0, 103 | type=float, 104 | help='initial learning rate', 105 | ) 106 | training_group.add_argument( 107 | '--backend', 108 | choices=['gloo', 'mpi', 'nccl'], 109 | default='nccl', 110 | help='distributed training backend', 111 | ) 112 | training_group.add_argument( 113 | '--no-cuda', 114 | action='store_true', 115 | default=False, 116 | help='disable CUDA training', 117 | ) 118 | training_group.add_argument( 119 | '--seed', 120 | default=42, 121 | type=int, 122 | help='training seed', 123 | ) 124 | 125 | kfac_group = parser.add_argument_group('KFAC Parameters') 126 | kfac_group.add_argument( 127 | '--kfac', 128 | action='store_true', 129 | default=False, 130 | help='enable KFAC preconditioning', 131 | ) 132 | kfac_group.add_argument( 133 | '--inv-update-steps', 134 | type=int, 135 | default=10, 136 | help='iters between updating second-order information', 137 | ) 138 | kfac_group.add_argument( 139 | '--factor-update-steps', 140 | type=int, 141 | default=1, 142 | help='iters between update kronecker factors', 143 | ) 144 | kfac_group.add_argument( 145 | '--factor-decay', 146 | type=float, 147 | default=0.95, 148 | help='alpha value for factor accumulation', 149 | ) 150 | kfac_group.add_argument( 151 | '--damping', 152 | type=float, 153 | default=0.003, 154 | help='damping factor', 155 | ) 156 | kfac_group.add_argument( 157 | '--kl-clip', 158 | type=float, 159 | default=0.001, 160 | help='KL clip', 161 | ) 162 | kfac_group.add_argument( 163 | '--skip-layers', 164 | nargs='+', 165 | type=str, 166 | default=['embedding', 'decoder', 'self_attn'], 167 | help='layers to skip KFAC registration for', 168 | ) 169 | kfac_group.add_argument( 170 | '--strategy', 171 | choices=['MEM_OPT', 'HYBRID_OPT', 'COMM_OPT'], 172 | default='COMM_OPT', 173 | help='distribution strategy for KFAC computations', 174 | ) 175 | 176 | args = parser.parse_args(argv) 177 | 178 | args.cuda = not args.no_cuda and torch.cuda.is_available() 179 | args.device = 'cuda' if args.cuda else 'cpu' 180 | args.local_rank = int(os.environ['LOCAL_RANK']) 181 | 182 | return args 183 | 184 | 185 | def main(argv: Sequence[str] | None = None) -> int: 186 | """Train and validate a language model.""" 187 | args = parse_args(argv) 188 | 189 | torch.distributed.init_process_group( 190 | backend=args.backend, 191 | init_method='env://', 192 | ) 193 | 194 | logging.basicConfig( 195 | format='[%(asctime)s] %(levelname)-5s (%(name)s): %(message)s', 196 | datefmt='%Y-%m-%d %H:%M:%S', 197 | level=logging.INFO 198 | if torch.distributed.get_rank() == 0 199 | else logging.ERROR, 200 | stream=sys.stdout, 201 | ) 202 | 203 | if args.cuda: 204 | torch.cuda.set_device(args.local_rank) 205 | torch.cuda.manual_seed(args.seed) 206 | 207 | if torch.distributed.get_rank() == 0: 208 | logger.info('Collecting env info...') 209 | logger.info(collect_env.get_pretty_env_info()) 210 | logger.info(f'Training arguments:\n{args}') 211 | 212 | datasets, vocab = get_dataset( 213 | args.dataset, 214 | args.download_dir, 215 | seq_len=args.seq_len, 216 | batch_size=args.batch_size, 217 | cuda=args.cuda, 218 | rank=torch.distributed.get_rank(), 219 | world_size=torch.distributed.get_world_size(), 220 | ) 221 | 222 | model: torch.nn.Module = TransformerModel( 223 | ntoken=len(vocab), 224 | d_model=args.embedding_dim, 225 | nhead=args.attention_heads, 226 | d_hid=args.hidden_dim, 227 | nlayers=args.layers, 228 | dropout=args.dropout, 229 | ) 230 | model.to(args.device) 231 | model = torch.nn.parallel.DistributedDataParallel( 232 | model, 233 | device_ids=[args.local_rank] if args.cuda else None, 234 | ) 235 | 236 | criterion = torch.nn.CrossEntropyLoss() 237 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 238 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 239 | optimizer, 240 | factor=0.1, 241 | patience=2, 242 | min_lr=1e-4, 243 | ) 244 | 245 | logger.info(f'Transformer model:\n{model}') 246 | preconditioner: kfac.preconditioner.KFACPreconditioner | None = None 247 | if args.kfac: 248 | strategy = kfac.enums.DistributedStrategy[args.strategy.upper()] 249 | preconditioner = kfac.preconditioner.KFACPreconditioner( 250 | model, 251 | factor_update_steps=args.factor_update_steps, 252 | inv_update_steps=args.inv_update_steps, 253 | damping=args.damping, 254 | factor_decay=args.factor_decay, 255 | kl_clip=args.kl_clip, 256 | lr=lambda x: optimizer.param_groups[0]['lr'], 257 | grad_worker_fraction=strategy, 258 | skip_layers=args.skip_layers, 259 | loglevel=logging.INFO, 260 | ) 261 | if torch.distributed.get_rank() == 0: 262 | logger.info(f'Preconditioner config:\n{preconditioner}') 263 | 264 | start = time.perf_counter() 265 | for epoch in range(args.epochs): 266 | datasets.train.sampler.set_epoch(epoch) 267 | train( 268 | model, 269 | criterion=criterion, 270 | optimizer=optimizer, 271 | preconditioner=preconditioner, 272 | dataloader=datasets.train.loader, 273 | epoch=epoch + 1, 274 | epochs=args.epochs, 275 | ) 276 | eval_loss = evaluate( 277 | model, 278 | criterion=criterion, 279 | dataloader=datasets.val.loader, 280 | prefix='Validation', 281 | ) 282 | scheduler.step(eval_loss) 283 | end = time.perf_counter() 284 | logger.info(f'Training completed in {end - start:.2f} seconds.') 285 | 286 | evaluate( 287 | model, 288 | criterion=criterion, 289 | dataloader=datasets.test.loader, 290 | prefix='Test', 291 | ) 292 | 293 | return 0 294 | 295 | 296 | if __name__ == '__main__': 297 | raise SystemExit(main()) 298 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | """Training utilities.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Callable 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch.nn.functional import log_softmax 10 | 11 | import kfac 12 | 13 | 14 | def accuracy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 15 | """Get prediction accuracy.""" 16 | pred = output.max(1, keepdim=True)[1] 17 | return pred.eq(target.view_as(pred)).float().mean() 18 | 19 | 20 | def save_checkpoint( 21 | model: torch.nn.Module, 22 | optimizer: torch.optim.Optimizer, 23 | preconditioner: kfac.preconditioner.KFACPreconditioner | None, 24 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler | None, 25 | filepath: str, 26 | ) -> None: 27 | """Save model checkpoint.""" 28 | state = { 29 | 'model': model.state_dict(), 30 | 'optimizer': optimizer.state_dict(), 31 | 'preconditioner': preconditioner.state_dict() 32 | if preconditioner is not None 33 | else None, 34 | 'lr_scheduler': lr_scheduler.state_dict() 35 | if lr_scheduler is not None 36 | else None, 37 | } 38 | torch.save(state, filepath) 39 | 40 | 41 | class LabelSmoothLoss(torch.nn.Module): 42 | """Label smoothing loss.""" 43 | 44 | def __init__(self, smoothing: float = 0.0): 45 | """Init LabelSmoothLoss.""" 46 | super().__init__() 47 | self.smoothing = smoothing 48 | 49 | def forward( 50 | self, 51 | input_: torch.Tensor, 52 | target: torch.Tensor, 53 | ) -> torch.Tensor: 54 | """Forward pass.""" 55 | log_prob = log_softmax(input_, dim=-1) 56 | weight = ( 57 | input_.new_ones(input_.size()) 58 | * self.smoothing 59 | / (input_.size(-1) - 1.0) 60 | ) 61 | weight.scatter_(-1, target.unsqueeze(-1), (1.0 - self.smoothing)) 62 | loss = (-weight * log_prob).sum(dim=-1).mean() 63 | return loss 64 | 65 | 66 | class Metric: 67 | """Metric tracking class.""" 68 | 69 | def __init__(self, name: str): 70 | """Init Metric.""" 71 | self.name = name 72 | self.total = torch.tensor(0.0) 73 | self.n = torch.tensor(0.0) 74 | 75 | def update(self, val: torch.Tensor, n: int = 1) -> None: 76 | """Update metric. 77 | 78 | Args: 79 | val (float): new value to add. 80 | n (int): weight of new value. 81 | """ 82 | dist.all_reduce(val, async_op=False) 83 | self.total += val.cpu() / dist.get_world_size() 84 | self.n += n 85 | 86 | @property 87 | def avg(self) -> torch.Tensor: 88 | """Get average of metric.""" 89 | return self.total / self.n 90 | 91 | 92 | def create_lr_schedule( 93 | workers: int, 94 | warmup_epochs: int, 95 | decay_schedule: list[int], 96 | alpha: float = 0.1, 97 | ) -> Callable[[int], float]: 98 | """Return lr scheduler lambda.""" 99 | 100 | def lr_schedule(epoch: int) -> float: 101 | """Compute lr scale factor.""" 102 | lr_adj = 1.0 103 | if epoch < warmup_epochs: 104 | lr_adj = ( 105 | 1.0 / workers * (epoch * (workers - 1) / warmup_epochs + 1) 106 | ) 107 | else: 108 | decay_schedule.sort(reverse=True) 109 | for e in decay_schedule: 110 | if epoch >= e: 111 | lr_adj *= alpha 112 | return lr_adj 113 | 114 | return lr_schedule 115 | -------------------------------------------------------------------------------- /examples/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/examples/vision/__init__.py -------------------------------------------------------------------------------- /examples/vision/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | """Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 2 | 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | 8 | Proper ResNet-s for CIFAR10 (for fair comparison and etc.) has following 9 | number of layers and parameters: 10 | 11 | name | layers | params 12 | ResNet20 | 20 | 0.27M 13 | ResNet32 | 32 | 0.46M 14 | ResNet44 | 44 | 0.66M 15 | ResNet56 | 56 | 0.85M 16 | ResNet110 | 110 | 1.7M 17 | ResNet1202| 1202 | 19.4m 18 | 19 | which this implementation indeed has. 20 | 21 | Reference: 22 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 24 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 25 | 26 | Author: Yerlan Idelbayev 27 | Source: https://raw.githubusercontent.com/akamaster/pytorch_resnet_cifar10 28 | """ 29 | 30 | from __future__ import annotations 31 | 32 | from typing import Callable 33 | 34 | import torch 35 | import torch.nn as nn 36 | import torch.nn.init as init 37 | from torch.nn.functional import avg_pool2d 38 | from torch.nn.functional import pad 39 | from torch.nn.functional import relu 40 | 41 | __all__ = [ 42 | 'ResNet', 43 | 'resnet20', 44 | 'resnet32', 45 | 'resnet44', 46 | 'resnet56', 47 | 'resnet110', 48 | 'resnet1202', 49 | 'get_model', 50 | ] 51 | 52 | 53 | def get_model(model: str) -> torch.nn.Module: 54 | """Get PyTorch model by name.""" 55 | if model.lower() == 'resnet20': 56 | model_ = resnet20() 57 | elif model.lower() == 'resnet32': 58 | model_ = resnet32() 59 | elif model.lower() == 'resnet44': 60 | model_ = resnet44() 61 | elif model.lower() == 'resnet56': 62 | model_ = resnet56() 63 | elif model.lower() == 'resnet110': 64 | model_ = resnet110() 65 | return model_ 66 | 67 | 68 | def _weights_init(m: torch.nn.Module) -> None: 69 | """Initialize weights of linear or conv2d module.""" 70 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 71 | init.kaiming_normal_(m.weight) 72 | 73 | 74 | class LambdaLayer(nn.Module): 75 | """LambdaLayer.""" 76 | 77 | def __init__(self, lambd: Callable[[torch.Tensor], torch.Tensor]): 78 | """Init LambdaLayer.""" 79 | super().__init__() 80 | self.lambd = lambd 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | """Forward pass which applies lambda to x.""" 84 | return self.lambd(x) 85 | 86 | 87 | class BasicBlock(nn.Module): 88 | """Basic ResNet block implementation.""" 89 | 90 | expansion = 1 91 | 92 | def __init__( 93 | self, 94 | in_planes: int, 95 | planes: int, 96 | stride: int = 1, 97 | option: str = 'A', 98 | ) -> None: 99 | """Init BasicBlock.""" 100 | super().__init__() 101 | self.conv1 = nn.Conv2d( 102 | in_planes, 103 | planes, 104 | kernel_size=3, 105 | stride=stride, 106 | padding=1, 107 | bias=False, 108 | ) 109 | self.bn1 = nn.BatchNorm2d(planes) 110 | self.conv2 = nn.Conv2d( 111 | planes, 112 | planes, 113 | kernel_size=3, 114 | stride=1, 115 | padding=1, 116 | bias=False, 117 | ) 118 | self.bn2 = nn.BatchNorm2d(planes) 119 | 120 | self.shortcut: torch.nn.Module = nn.Sequential() 121 | if stride != 1 or in_planes != planes: 122 | if option == 'A': 123 | """ 124 | For CIFAR10 ResNet paper uses option A. 125 | """ 126 | self.shortcut = LambdaLayer( 127 | lambda x: pad( 128 | x[:, :, ::2, ::2], 129 | (0, 0, 0, 0, planes // 4, planes // 4), 130 | 'constant', 131 | 0, 132 | ), 133 | ) 134 | elif option == 'B': 135 | self.shortcut = nn.Sequential( 136 | nn.Conv2d( 137 | in_planes, 138 | self.expansion * planes, 139 | kernel_size=1, 140 | stride=stride, 141 | bias=False, 142 | ), 143 | nn.BatchNorm2d(self.expansion * planes), 144 | ) 145 | 146 | def forward(self, x: torch.Tensor) -> torch.Tensor: 147 | """Forward pass.""" 148 | out = relu(self.bn1(self.conv1(x))) 149 | out = self.bn2(self.conv2(out)) 150 | out += self.shortcut(x) 151 | out = relu(out) 152 | return out 153 | 154 | 155 | class ResNet(nn.Module): 156 | """ResNet model implementation.""" 157 | 158 | def __init__( 159 | self, 160 | block: type[BasicBlock], 161 | num_blocks: list[int], 162 | num_classes: int = 10, 163 | ) -> None: 164 | """Init ResNet.""" 165 | super().__init__() 166 | self.in_planes = 16 167 | 168 | self.conv1 = nn.Conv2d( 169 | 3, 170 | 16, 171 | kernel_size=3, 172 | stride=1, 173 | padding=1, 174 | bias=False, 175 | ) 176 | self.bn1 = nn.BatchNorm2d(16) 177 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 178 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 179 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 180 | self.linear = nn.Linear(64, num_classes) 181 | 182 | self.apply(_weights_init) 183 | 184 | def _make_layer( 185 | self, 186 | block: type[BasicBlock], 187 | planes: int, 188 | num_blocks: int, 189 | stride: int, 190 | ) -> torch.nn.Sequential: 191 | """Make individual layer.""" 192 | strides = [stride] + [1] * (num_blocks - 1) 193 | layers = [] 194 | for stride in strides: 195 | layers.append(block(self.in_planes, planes, stride)) 196 | self.in_planes = planes * block.expansion 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def forward(self, x: torch.Tensor) -> torch.Tensor: 201 | """Forward pass.""" 202 | out = relu(self.bn1(self.conv1(x))) 203 | out = self.layer1(out) 204 | out = self.layer2(out) 205 | out = self.layer3(out) 206 | out = avg_pool2d(out, out.size()[3]) 207 | out = out.view(out.size(0), -1) 208 | out = self.linear(out) 209 | return out 210 | 211 | 212 | def resnet20() -> ResNet: 213 | """Get ResNet20 model.""" 214 | return ResNet(BasicBlock, [3, 3, 3]) 215 | 216 | 217 | def resnet32() -> ResNet: 218 | """Get ResNet20 model.""" 219 | return ResNet(BasicBlock, [5, 5, 5]) 220 | 221 | 222 | def resnet44() -> ResNet: 223 | """Get ResNet20 model.""" 224 | return ResNet(BasicBlock, [7, 7, 7]) 225 | 226 | 227 | def resnet56() -> ResNet: 228 | """Get ResNet20 model.""" 229 | return ResNet(BasicBlock, [9, 9, 9]) 230 | 231 | 232 | def resnet110() -> ResNet: 233 | """Get ResNet20 model.""" 234 | return ResNet(BasicBlock, [18, 18, 18]) 235 | 236 | 237 | def resnet1202() -> ResNet: 238 | """Get ResNet20 model.""" 239 | return ResNet(BasicBlock, [200, 200, 200]) 240 | 241 | 242 | def test(net: ResNet) -> None: 243 | """Print sizes of all ResNet models.""" 244 | import numpy as np 245 | 246 | total_params = 0 247 | 248 | for x in filter(lambda p: p.requires_grad, net.parameters()): 249 | total_params += np.prod(x.data.numpy().shape) 250 | print('Total number of params', total_params) 251 | print( 252 | 'Total layers', 253 | len( 254 | list( 255 | filter( 256 | lambda p: p.requires_grad and len(p.data.size()) > 1, 257 | net.parameters(), 258 | ), 259 | ), 260 | ), 261 | ) 262 | 263 | 264 | if __name__ == '__main__': 265 | for net_name in __all__: 266 | if net_name.startswith('resnet'): 267 | print(net_name) 268 | test(globals()[net_name]()) 269 | print() 270 | -------------------------------------------------------------------------------- /examples/vision/datasets.py: -------------------------------------------------------------------------------- 1 | """Functions for getting computer vision datasets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from typing import Any 7 | from typing import Tuple 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data import Dataset 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torchvision import datasets 15 | from torchvision import transforms 16 | 17 | T = Tuple[torch.Tensor, torch.Tensor] 18 | 19 | 20 | def get_cifar( 21 | args: Any, 22 | ) -> tuple[ 23 | DistributedSampler[T], 24 | DataLoader[T], 25 | DistributedSampler[T], 26 | DataLoader[T], 27 | ]: 28 | """Get cifar dataset.""" 29 | transform_train = transforms.Compose( 30 | [ 31 | transforms.RandomCrop(32, padding=4), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize( 35 | (0.4914, 0.4822, 0.4465), 36 | (0.2023, 0.1994, 0.2010), 37 | ), 38 | ], 39 | ) 40 | transform_test = transforms.Compose( 41 | [ 42 | transforms.ToTensor(), 43 | transforms.Normalize( 44 | (0.4914, 0.4822, 0.4465), 45 | (0.2023, 0.1994, 0.2010), 46 | ), 47 | ], 48 | ) 49 | 50 | os.makedirs(args.data_dir, exist_ok=True) 51 | 52 | download = True if args.local_rank == 0 else False 53 | if not download: 54 | dist.barrier() 55 | train_dataset = datasets.CIFAR10( 56 | root=args.data_dir, 57 | train=True, 58 | download=download, 59 | transform=transform_train, 60 | ) 61 | test_dataset = datasets.CIFAR10( 62 | root=args.data_dir, 63 | train=False, 64 | download=download, 65 | transform=transform_test, 66 | ) 67 | if download: 68 | dist.barrier() 69 | 70 | return make_sampler_and_loader(args, train_dataset, test_dataset) 71 | 72 | 73 | def get_imagenet( 74 | args: Any, 75 | ) -> tuple[ 76 | DistributedSampler[T], 77 | DataLoader[T], 78 | DistributedSampler[T], 79 | DataLoader[T], 80 | ]: 81 | """Get imagenet dataset.""" 82 | train_dataset = datasets.ImageFolder( 83 | args.train_dir, 84 | transform=transforms.Compose( 85 | [ 86 | transforms.RandomResizedCrop(224), 87 | transforms.RandomHorizontalFlip(), 88 | transforms.ToTensor(), 89 | transforms.Normalize( 90 | mean=[0.485, 0.456, 0.406], 91 | std=[0.229, 0.224, 0.225], 92 | ), 93 | ], 94 | ), 95 | ) 96 | val_dataset = datasets.ImageFolder( 97 | args.val_dir, 98 | transform=transforms.Compose( 99 | [ 100 | transforms.Resize(256), 101 | transforms.CenterCrop(224), 102 | transforms.ToTensor(), 103 | transforms.Normalize( 104 | mean=[0.485, 0.456, 0.406], 105 | std=[0.229, 0.224, 0.225], 106 | ), 107 | ], 108 | ), 109 | ) 110 | 111 | return make_sampler_and_loader(args, train_dataset, val_dataset) 112 | 113 | 114 | def make_sampler_and_loader( 115 | args: Any, 116 | train_dataset: Dataset[T], 117 | val_dataset: Dataset[T], 118 | ) -> tuple[ 119 | DistributedSampler[T], 120 | DataLoader[T], 121 | DistributedSampler[T], 122 | DataLoader[T], 123 | ]: 124 | """Create sampler and dataloader for train and val datasets.""" 125 | torch.set_num_threads(4) 126 | kwargs: dict[str, Any] = ( 127 | {'num_workers': 4, 'pin_memory': True} if args.cuda else {} 128 | ) 129 | kwargs['prefetch_factor'] = 8 130 | kwargs['persistent_workers'] = True 131 | 132 | train_sampler: DistributedSampler[T] = DistributedSampler( 133 | train_dataset, 134 | num_replicas=dist.get_world_size(), 135 | rank=dist.get_rank(), 136 | ) 137 | train_loader: DataLoader[T] = DataLoader( 138 | train_dataset, 139 | batch_size=args.batch_size, 140 | sampler=train_sampler, 141 | **kwargs, 142 | ) 143 | val_sampler: DistributedSampler[T] = DistributedSampler( 144 | val_dataset, 145 | num_replicas=dist.get_world_size(), 146 | rank=dist.get_rank(), 147 | ) 148 | val_loader: DataLoader[T] = DataLoader( 149 | val_dataset, 150 | batch_size=args.val_batch_size, 151 | sampler=val_sampler, 152 | **kwargs, 153 | ) 154 | 155 | return train_sampler, train_loader, val_sampler, val_loader 156 | -------------------------------------------------------------------------------- /examples/vision/engine.py: -------------------------------------------------------------------------------- 1 | """Train and Eval functions for computer vision examples.""" 2 | 3 | from __future__ import annotations 4 | 5 | import argparse 6 | import math 7 | from typing import Tuple 8 | 9 | import torch 10 | from tqdm import tqdm 11 | 12 | import kfac 13 | from examples.utils import accuracy 14 | from examples.utils import Metric 15 | 16 | SampleT = Tuple[torch.Tensor, torch.Tensor] 17 | 18 | 19 | def train( 20 | epoch: int, 21 | model: torch.nn.Module, 22 | optimizer: torch.optim.Optimizer, 23 | preconditioner: kfac.preconditioner.KFACPreconditioner | None, 24 | loss_func: torch.nn.Module, 25 | train_sampler: torch.utils.data.distributed.DistributedSampler[SampleT], 26 | train_loader: torch.utils.data.DataLoader[SampleT], 27 | args: argparse.Namespace, 28 | ) -> None: 29 | """Train model.""" 30 | model.train() 31 | train_sampler.set_epoch(epoch) 32 | train_loss = Metric('train_loss') 33 | train_accuracy = Metric('train_accuracy') 34 | scaler = args.grad_scaler if 'grad_scaler' in args else None 35 | mini_step = 0 36 | step_loss = torch.tensor(0.0).to('cuda' if args.cuda else 'cpu') 37 | step_accuracy = torch.tensor(0.0).to('cuda' if args.cuda else 'cpu') 38 | 39 | with tqdm( 40 | total=math.ceil(len(train_loader) / args.batches_per_allreduce), 41 | bar_format='{l_bar}{bar:10}{r_bar}', 42 | desc=f'Epoch {epoch:3d}/{args.epochs:3d}', 43 | disable=not args.verbose, 44 | ) as t: 45 | for batch_idx, (data, target) in enumerate(train_loader): 46 | mini_step += 1 47 | if args.cuda: 48 | data, target = data.cuda(), target.cuda() 49 | 50 | if scaler is not None: 51 | with torch.cuda.amp.autocast(): 52 | output = model(data) 53 | loss = loss_func(output, target) 54 | else: 55 | output = model(data) 56 | loss = loss_func(output, target) 57 | 58 | with torch.no_grad(): 59 | step_loss += loss 60 | step_accuracy += accuracy(output, target) 61 | 62 | loss = loss / args.batches_per_allreduce 63 | 64 | if mini_step % args.batches_per_allreduce == 0 or ( 65 | batch_idx + 1 == len(train_loader) 66 | ): 67 | if scaler is not None: 68 | scaler.scale(loss).backward() 69 | else: 70 | loss.backward() 71 | else: 72 | with model.no_sync(): # type: ignore 73 | if scaler is not None: 74 | scaler.scale(loss).backward() 75 | else: 76 | loss.backward() 77 | 78 | if mini_step % args.batches_per_allreduce == 0 or ( 79 | batch_idx + 1 == len(train_loader) 80 | ): 81 | if preconditioner is not None: 82 | if scaler is not None: 83 | scaler.unscale_(optimizer) 84 | preconditioner.step() 85 | if scaler is not None: 86 | scaler.step(optimizer) 87 | scaler.update() 88 | else: 89 | optimizer.step() 90 | optimizer.zero_grad() 91 | 92 | train_loss.update(step_loss / mini_step) 93 | train_accuracy.update(step_accuracy / mini_step) 94 | step_loss.zero_() 95 | step_accuracy.zero_() 96 | 97 | t.set_postfix_str( 98 | 'loss: {:.4f}, acc: {:.2f}%, lr: {:.4f}'.format( 99 | train_loss.avg, 100 | 100 * train_accuracy.avg, 101 | optimizer.param_groups[0]['lr'], 102 | ), 103 | ) 104 | t.update(1) 105 | mini_step = 0 106 | 107 | if args.log_writer is not None: 108 | args.log_writer.add_scalar('train/loss', train_loss.avg, epoch) 109 | args.log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch) 110 | args.log_writer.add_scalar( 111 | 'train/lr', 112 | optimizer.param_groups[0]['lr'], 113 | epoch, 114 | ) 115 | 116 | 117 | def test( 118 | epoch: int, 119 | model: torch.nn.Module, 120 | loss_func: torch.nn.Module, 121 | val_loader: torch.utils.data.DataLoader[SampleT], 122 | args: argparse.Namespace, 123 | ) -> None: 124 | """Test the model.""" 125 | model.eval() 126 | val_loss = Metric('val_loss') 127 | val_accuracy = Metric('val_accuracy') 128 | 129 | with tqdm( 130 | total=len(val_loader), 131 | bar_format='{l_bar}{bar:10}|{postfix}', 132 | desc=' ', 133 | disable=not args.verbose, 134 | ) as t: 135 | with torch.no_grad(): 136 | for i, (data, target) in enumerate(val_loader): 137 | if args.cuda: 138 | data, target = data.cuda(), target.cuda() 139 | output = model(data) 140 | val_loss.update(loss_func(output, target)) 141 | val_accuracy.update(accuracy(output, target)) 142 | 143 | t.update(1) 144 | if i + 1 == len(val_loader): 145 | t.set_postfix_str( 146 | '\b\b val_loss: {:.4f}, val_acc: {:.2f}%'.format( 147 | val_loss.avg, 148 | 100 * val_accuracy.avg, 149 | ), 150 | refresh=False, 151 | ) 152 | 153 | if args.log_writer is not None: 154 | args.log_writer.add_scalar('val/loss', val_loss.avg, epoch) 155 | args.log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch) 156 | -------------------------------------------------------------------------------- /examples/vision/optimizers.py: -------------------------------------------------------------------------------- 1 | """Utilities for getting optimizers for computer vision examples.""" 2 | 3 | from __future__ import annotations 4 | 5 | import argparse 6 | from typing import Callable 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.optim as optim 11 | 12 | import kfac 13 | from examples.utils import create_lr_schedule 14 | 15 | 16 | def get_optimizer( 17 | model: torch.nn.Module, 18 | args: argparse.Namespace, 19 | ) -> tuple[ 20 | optim.Optimizer, 21 | kfac.preconditioner.KFACPreconditioner | None, 22 | tuple[ 23 | optim.lr_scheduler._LRScheduler, 24 | kfac.scheduler.LambdaParamScheduler | None, 25 | ], 26 | ]: 27 | """Get optimizer, preconditioner, and scheduler.""" 28 | use_kfac = True if args.kfac_inv_update_steps > 0 else False 29 | 30 | optimizer = optim.SGD( 31 | model.parameters(), 32 | lr=args.base_lr, 33 | momentum=args.momentum, 34 | weight_decay=args.weight_decay, 35 | ) 36 | lrs = create_lr_schedule( 37 | dist.get_world_size(), 38 | args.warmup_epochs, 39 | args.lr_decay, 40 | ) 41 | lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lrs) 42 | 43 | grad_worker_fraction: kfac.enums.DistributedStrategy | float 44 | if args.kfac_strategy == 'comm-opt': 45 | grad_worker_fraction = kfac.enums.DistributedStrategy.COMM_OPT 46 | elif args.kfac_strategy == 'mem-opt': 47 | grad_worker_fraction = kfac.enums.DistributedStrategy.MEM_OPT 48 | elif args.kfac_strategy == 'hybrid-opt': 49 | grad_worker_fraction = args.kfac_grad_worker_fraction 50 | else: 51 | raise ValueError( 52 | f'Unknown KFAC Comm Method: {args.kfac_strategy}', 53 | ) 54 | 55 | if use_kfac: 56 | preconditioner = kfac.preconditioner.KFACPreconditioner( 57 | model, 58 | factor_update_steps=args.kfac_factor_update_steps, 59 | inv_update_steps=args.kfac_inv_update_steps, 60 | damping=args.kfac_damping, 61 | factor_decay=args.kfac_factor_decay, 62 | kl_clip=args.kfac_kl_clip, 63 | lr=lambda x: optimizer.param_groups[0]['lr'], 64 | accumulation_steps=args.batches_per_allreduce, 65 | allreduce_bucket_cap_mb=25, 66 | colocate_factors=args.kfac_colocate_factors, 67 | compute_method=kfac.enums.ComputeMethod.INVERSE 68 | if args.kfac_inv_method 69 | else kfac.enums.ComputeMethod.EIGEN, 70 | grad_worker_fraction=grad_worker_fraction, 71 | grad_scaler=args.grad_scaler if 'grad_scaler' in args else None, 72 | skip_layers=args.kfac_skip_layers, 73 | ) 74 | 75 | def get_lambda( 76 | alpha: int, 77 | epochs: list[int] | None, 78 | ) -> Callable[[int], float]: 79 | """Create lambda function for param scheduler.""" 80 | if epochs is None: 81 | _epochs = [] 82 | else: 83 | _epochs = epochs 84 | 85 | def scale(epoch: int) -> float: 86 | """Compute current scale factor using epoch.""" 87 | factor = 1.0 88 | for e in _epochs: 89 | if epoch >= e: 90 | factor *= alpha 91 | return factor 92 | 93 | return scale 94 | 95 | kfac_param_scheduler = kfac.scheduler.LambdaParamScheduler( 96 | preconditioner, 97 | damping_lambda=get_lambda( 98 | args.kfac_damping_alpha, 99 | args.kfac_damping_decay, 100 | ), 101 | factor_update_steps_lambda=get_lambda( 102 | args.kfac_update_steps_alpha, 103 | args.kfac_update_steps_decay, 104 | ), 105 | inv_update_steps_lambda=get_lambda( 106 | args.kfac_update_steps_alpha, 107 | args.kfac_update_steps_decay, 108 | ), 109 | ) 110 | else: 111 | preconditioner = None 112 | kfac_param_scheduler = None 113 | 114 | return optimizer, preconditioner, (lr_scheduler, kfac_param_scheduler) 115 | -------------------------------------------------------------------------------- /kfac/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level module for K-FAC.""" 2 | 3 | from __future__ import annotations 4 | 5 | import importlib.metadata as importlib_metadata 6 | import sys 7 | 8 | import kfac.assignment as assignment 9 | import kfac.base_preconditioner as base_preconditioner 10 | import kfac.distributed as distributed 11 | import kfac.enums as enums 12 | import kfac.gpt_neox as gpt_neox 13 | import kfac.layers as layers 14 | import kfac.preconditioner as preconditioner 15 | import kfac.scheduler as scheduler 16 | import kfac.tracing as tracing 17 | import kfac.warnings as warnings 18 | 19 | __version__ = importlib_metadata.version('kfac-pytorch') 20 | -------------------------------------------------------------------------------- /kfac/enums.py: -------------------------------------------------------------------------------- 1 | """KFAC enum types.""" 2 | 3 | from __future__ import annotations 4 | 5 | from enum import Enum 6 | 7 | 8 | class AllreduceMethod(Enum): 9 | """Allreduce method.""" 10 | 11 | ALLREDUCE = 1 12 | ALLREDUCE_BUCKETED = 2 13 | 14 | 15 | class AssignmentStrategy(Enum): 16 | """KFAC Factor Distribution Method. 17 | 18 | KFAC assigns factors for second-order computation using a heuristic-based 19 | longest-processing time greedy algorithm. AssignmentStrategy.COMPUTE 20 | uses an estimation of the second-order computation time as the heuristic 21 | and AssignmentStrategy.MEMORY uses the memory requirements of storing 22 | the second-order results as the heuristic. 23 | """ 24 | 25 | COMPUTE = 1 26 | MEMORY = 2 27 | 28 | 29 | class ComputeMethod(Enum): 30 | """KFAC Second Order Computation Method. 31 | 32 | Controls if eigen decompositions or inverse of the factors will be used 33 | to precondition the gradients. 34 | """ 35 | 36 | EIGEN = 1 37 | INVERSE = 2 38 | 39 | 40 | class DistributedStrategy(Enum): 41 | """KFAC Distribution Strategy. 42 | 43 | Shortcuts for common grad_worker_fractions. 44 | - COMM_OPT: grad_worker_fraction = 1 45 | - HYBRID_OPT: grad_worker_fraction = 0.5 46 | - MEM-OPT: grad_worker_fraction = 1 / world_size 47 | 48 | See https://arxiv.org/pdf/2107.01739.pdf for more details on distribution 49 | strategies. 50 | """ 51 | 52 | COMM_OPT = 1 53 | MEM_OPT = 2 54 | HYBRID_OPT = 3 55 | -------------------------------------------------------------------------------- /kfac/gpt_neox/__init__.py: -------------------------------------------------------------------------------- 1 | """Custom KFAC support for GPT-NeoX.""" 2 | 3 | from __future__ import annotations 4 | -------------------------------------------------------------------------------- /kfac/gpt_neox/assignment.py: -------------------------------------------------------------------------------- 1 | """Custom Assignment for GPT-NeoX.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch.distributed as dist 6 | 7 | from kfac.assignment import WorkAssignment 8 | from kfac.gpt_neox.mpu import get_group_with_rank 9 | 10 | try: 11 | from deepspeed.runtime.pipe.topology import ( # type: ignore 12 | PipeModelDataParallelTopology, # type: ignore 13 | ) 14 | 15 | deepspeed_import_error = None 16 | except ImportError as e: # pragma: no cover 17 | deepspeed_import_error = e 18 | 19 | 20 | class GPTNeoXAssignment(WorkAssignment): 21 | """Pipeline parallel aware work assignment for GPT-NeoX.""" 22 | 23 | def __init__( 24 | self, 25 | work: dict[str, dict[str, float]], 26 | *, 27 | local_rank: int, 28 | topology: PipeModelDataParallelTopology, 29 | data_parallel_group: dist.ProcessGroup | None, 30 | model_parallel_group: dist.ProcessGroup | None, 31 | ) -> None: 32 | """Init GPTNeoxAssignment. 33 | 34 | Args: 35 | work (dict[str, dict[str, int]]): dictionary mapping unique layer 36 | names to sub-dictionaries where the keys are the str names for 37 | each factor associated with the layer and the values are the 38 | cost of each factor computation for load balancing. Note: that 39 | this should only be the work performed by the data parallel 40 | group. 41 | local_rank (int): local rank of this process. 42 | topology (PipeModelDataParallelTopology): topology created 43 | by DeepSpeed. 44 | data_parallel_group (ProcessGroup): DeepSpeed data parallel 45 | process group. 46 | model_parallel_group (ProcessGroup): DeepSpeed model parallel 47 | process group. 48 | """ 49 | if deepspeed_import_error is not None: # pragma: no cover 50 | raise deepspeed_import_error 51 | if not isinstance(topology, PipeModelDataParallelTopology): 52 | raise TypeError( 53 | 'Expected topology to be of type ' 54 | f'{PipeModelDataParallelTopology.__name__} but got ' 55 | f'{type(topology)} instead.', 56 | ) 57 | 58 | self.local_rank = local_rank 59 | self.data_parallel_group = data_parallel_group 60 | self.model_parallel_group = model_parallel_group 61 | 62 | # global information 63 | self.data_parallel_groups = topology.get_axis_comm_lists('data') 64 | self.model_parallel_groups = topology.get_axis_comm_lists('model') 65 | self.pipe_parallel_groups = topology.get_axis_comm_lists('pipe') 66 | 67 | self.data_parallel_peers = get_group_with_rank( 68 | self.local_rank, 69 | self.data_parallel_groups, 70 | ) 71 | self.model_parallel_peers = get_group_with_rank( 72 | self.local_rank, 73 | self.model_parallel_groups, 74 | ) 75 | self.pipe_parallel_rank = topology.get_coord(self.local_rank).pipe 76 | # List of ranks with same pipe rank as us. These are the ranks that 77 | # have the same layers as us so they are all we care about for the 78 | # purpose of assigning work 79 | self.pipe_parallel_peers = [ 80 | r 81 | for r in range(topology.world_size()) 82 | if topology.get_coord(r).pipe == self.pipe_parallel_rank 83 | ] 84 | 85 | # Reuse existing groups if possible 86 | if set(self.pipe_parallel_peers) == set(self.model_parallel_peers): 87 | self.pipe_parallel_peer_group = self.model_parallel_group 88 | elif set(self.pipe_parallel_peers) == set(self.data_parallel_peers): 89 | self.pipe_parallel_peer_group = self.data_parallel_group 90 | else: 91 | self.pipe_parallel_peer_group = dist.new_group( 92 | self.pipe_parallel_peers, 93 | ) 94 | 95 | worker_loads = [0.0 for _ in self.pipe_parallel_peers] 96 | self._inv_assignments = { 97 | layer: {factor: -1 for factor in factors} 98 | for layer, factors in work.items() 99 | } 100 | summed_work = [ 101 | (layer, sum(factors.values())) for layer, factors in work.items() 102 | ] 103 | sorted_work = sorted( 104 | summed_work, 105 | key=lambda item: (item[1], item[0]), 106 | reverse=True, 107 | ) 108 | 109 | for layer, cost in sorted_work: 110 | min_worker_index = worker_loads.index(min(worker_loads)) 111 | min_worker = self.pipe_parallel_peers[min_worker_index] 112 | for factor in self._inv_assignments[layer]: 113 | self._inv_assignments[layer][factor] = min_worker 114 | worker_loads[min_worker_index] += cost 115 | 116 | def broadcast_gradients(self) -> bool: 117 | """Return if gradients need to be broadcast. 118 | 119 | GPT-NeoX uses MEM-OPT training (grad worker fraction = 1/world_size) 120 | so gradient broadcast is necessary. 121 | """ 122 | return True 123 | 124 | def broadcast_inverses(self) -> bool: 125 | """Return if inverses need to be broadcast. 126 | 127 | GPT-NeoX uses MEM-OPT training (grad worker fraction = 1/world_size) 128 | so inverse broadcast is not necessary. 129 | """ 130 | return False 131 | 132 | def get_layers(self) -> tuple[str, ...]: 133 | """Return tuple of layers assigned.""" 134 | return tuple(self._inv_assignments.keys()) 135 | 136 | def get_factors(self, layer: str) -> tuple[str, ...]: 137 | """Return tuple of factors associated with the layer.""" 138 | return tuple(self._inv_assignments[layer].keys()) 139 | 140 | def inv_worker(self, layer: str, factor: str) -> int: 141 | """Return rank that computes inverse factor for this layer.""" 142 | return self._inv_assignments[layer][factor] 143 | 144 | def factor_worker(self, layer: str, factor: str) -> int: 145 | """Worker that gathers the factor from model parallel group peers. 146 | 147 | Also referred to as the primary worker in the layer code. 148 | """ 149 | inv_ranks = set(self._inv_assignments[layer].values()) 150 | assert len(inv_ranks) == 1 151 | inv_rank = inv_ranks.pop() 152 | 153 | data_parallel_ranks = get_group_with_rank( 154 | inv_rank, 155 | self.data_parallel_groups, 156 | ) 157 | factor_workers = set(data_parallel_ranks) & set( 158 | self.model_parallel_peers, 159 | ) 160 | assert len(factor_workers) == 1 161 | return factor_workers.pop() 162 | 163 | def is_grad_worker(self, layer: str) -> bool: 164 | """Return if this rank is a gradient worker for this layer. 165 | 166 | GPTNeoXKFACEigen.precondition_grad() requires every worker in the 167 | model parallelism group of the inv_worker to enter 168 | to decide if the grad needs to be gathered to and scatter from the 169 | true grad worker within the model parallel group, so we just return 170 | True here and let that method handle which ranks actually do work. 171 | """ 172 | return ( 173 | len( 174 | set(self._inv_assignments[layer].values()) 175 | & set(self.model_parallel_peers), 176 | ) 177 | == 1 178 | ) 179 | 180 | def src_grad_worker(self, layer: str) -> int: 181 | """Return rank that will share preconditioned gradient. 182 | 183 | If process is a gradient worker, this method should return the 184 | process rank. Otherwise, if the process is a gradient receiver, this 185 | method returns the rank that is responsible for sending the 186 | preconditioned gradient to this process. 187 | 188 | With model parallelism, the src rank is the rank that received the 189 | partial preconditioned gradient from the inv_worker. 190 | """ 191 | ranks = list(self._inv_assignments[layer].values()) 192 | assert ranks.count(ranks[0]) == len(ranks) 193 | # This is just the src rank that computes the preconditioned gradient 194 | # and then scatters it to the other ranks in its model parallel group 195 | src_rank = ranks[0] 196 | 197 | model_parallel_ranks = get_group_with_rank( 198 | src_rank, 199 | self.model_parallel_groups, 200 | ) 201 | src = set(self.data_parallel_peers) & set(model_parallel_ranks) 202 | assert len(src) == 1 203 | return src.pop() 204 | 205 | def factor_group( 206 | self, 207 | layer: str, 208 | factor: str, 209 | ) -> dist.ProcessGroup | None: 210 | """Communication group for allreducing factors. 211 | 212 | The GPTNeoXKFACEigenLayer will ignore this. 213 | """ 214 | return None 215 | 216 | def grad_worker_group(self, layer: str) -> dist.ProcessGroup | None: 217 | """Return communication group for inverse factor broadcast. 218 | 219 | This communication group is used for the broadcasts of the inverses 220 | from the inverse worker to the remaining gradient workers for the 221 | layer. 222 | """ 223 | raise NotImplementedError( 224 | 'The GPT-NeoX assignment strategy only supports MEM-OPT ' 225 | 'and therefore should not be performing inverse factor ' 226 | 'communication.', 227 | ) 228 | 229 | def grad_receiver_group(self, layer: str) -> dist.ProcessGroup | None: 230 | """Return communication group for preconditioned gradient broadcast. 231 | 232 | This communication group is used for the broadcasts of the gradients 233 | from the gradient worker to the remaining gradient receivers for the 234 | layer. 235 | """ 236 | return self.data_parallel_group 237 | -------------------------------------------------------------------------------- /kfac/gpt_neox/modules.py: -------------------------------------------------------------------------------- 1 | """Helper wrappers for supported PyTorch modules.""" 2 | 3 | from __future__ import annotations 4 | 5 | import sys 6 | 7 | if sys.version_info >= (3, 9): # pragma: >=3.9 cover 8 | from typing import Literal 9 | else: # pragma: <3.9 cover 10 | from typing_extensions import Literal 11 | 12 | import torch 13 | import torch.distributed as dist 14 | 15 | from kfac.layers.modules import LinearModuleHelper 16 | 17 | 18 | class GPTNeoXLinearModuleHelper(LinearModuleHelper): 19 | """ModuleHelper for GPTNeoX layers.""" 20 | 21 | def __init__( 22 | self, 23 | module: torch.nn.Module, 24 | model_parallel_group: dist.ProcessGroup | None, 25 | parallelism: Literal['input', 'output'], 26 | ): 27 | """Init ModuleHelper. 28 | 29 | Args: 30 | module (torch.nn.Module): module in model to wrap. 31 | model_parallel_group (ProcessGroup): model parallel distributed 32 | process group this rank belongs to. If None, it is assumed 33 | model parallelism size is 1 (i.e., there is no model 34 | parallelism). 35 | parallelism (str): "input" if the layer is split on the input or 36 | "output" if split on the output. 37 | """ 38 | self.module = module 39 | self.model_parallel_group = model_parallel_group 40 | self.model_parallel_world_size = ( 41 | 1 42 | if self.model_parallel_group is None 43 | else dist.get_world_size(self.model_parallel_group) 44 | ) 45 | self.parallelism = parallelism 46 | 47 | @property 48 | def a_factor_shape(self) -> tuple[int, int]: 49 | """Get shape of A factor.""" 50 | dim1_size = self.module.weight.size(1) # type: ignore 51 | if self.parallelism == 'input': 52 | x = (dim1_size * self.model_parallel_world_size) + int( 53 | self.has_bias(), 54 | ) 55 | else: 56 | x = dim1_size + int(self.has_bias()) 57 | return (x, x) 58 | 59 | @property 60 | def g_factor_shape(self) -> tuple[int, int]: 61 | """Get shape of G factor.""" 62 | dim0_size = self.module.weight.size(0) # type: ignore 63 | if self.parallelism == 'output': 64 | x = dim0_size * self.model_parallel_world_size 65 | else: 66 | x = dim0_size 67 | return (x, x) 68 | -------------------------------------------------------------------------------- /kfac/gpt_neox/mpu.py: -------------------------------------------------------------------------------- 1 | """Extensions of MPU functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def gather_from_model_parallel_region( 10 | tensor: torch.Tensor, 11 | dst: int, 12 | model_parallel_group: dist.ProcessGroup | None, 13 | fp32_allreduce: bool = False, 14 | dim: int = -1, 15 | ) -> torch.Tensor | None: 16 | """Gather model parallel partitions into single tensor. 17 | 18 | Note: 19 | This is a true `gather` where as mpu.gather_from_model_parallel_region 20 | is an `all gather`. 21 | 22 | Note: 23 | The concatenation is done along the last axis. I.e., this is the 24 | inverse operation of mpu.scatter_to_model_parallel_region(). 25 | 26 | Args: 27 | tensor (torch.Tensor): tensor partition to gather. 28 | dst (rank): destination rank to gather full tensor on. 29 | model_parallel_group (ProcessGroup): model parallel process group. 30 | If None, model parallel region will be assumed to have size 1. 31 | fp32_allreduce (bool): if True and tensor is bf16, the tensor will 32 | be cast to float before communication. Note: this is to match 33 | the functionality of megatron's 34 | gather_from_model_parallel_region(). 35 | dim (int): dimension along which to concatenate tensors. 36 | 37 | Returns: 38 | Gathered tensor on rank `dst` else None. 39 | """ 40 | world_size = ( 41 | 1 42 | if model_parallel_group is None 43 | else dist.get_world_size(model_parallel_group) 44 | ) 45 | # Bypass the function if we are using only 1 GPU. 46 | if world_size == 1: 47 | return tensor 48 | 49 | # Bf16 convert 50 | dt = tensor.dtype 51 | if dt == torch.bfloat16 and fp32_allreduce: 52 | tensor = tensor.float() 53 | 54 | tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] 55 | 56 | # TODO(gpauloski): PyTorch>=1.11 supports gather directly 57 | # which will be much faster 58 | torch.distributed.all_gather( 59 | tensor_list, 60 | tensor, 61 | group=model_parallel_group, 62 | ) 63 | 64 | if dist.get_rank() == dst: 65 | # Note: torch.cat already creates a contiguous tensor. 66 | output = torch.cat(tensor_list, dim=dim).contiguous() 67 | 68 | # Bf16 convert 69 | if dt == torch.bfloat16 and fp32_allreduce: 70 | output = output.bfloat16() 71 | 72 | return output 73 | else: 74 | return None 75 | 76 | 77 | def get_group_with_rank(rank: int, groups: list[list[int]]) -> list[int]: 78 | """Returns first group from list of groups containing rank. 79 | 80 | Args: 81 | rank (int): rank to search for. 82 | groups (list[list[int]]): list of groups where each group is a list 83 | of ranks. 84 | 85 | Returns: 86 | group (list of ranks) containing rank. 87 | 88 | Raises: 89 | ValueError: 90 | if a matching group is not found. 91 | """ 92 | for group in groups: 93 | if rank in group: 94 | return group 95 | raise ValueError(f'Rank {rank} was not in any of the groups.') 96 | 97 | 98 | def split_tensor_along_dim( 99 | tensor: torch.Tensor, 100 | num_partitions: int, 101 | dim: int, 102 | contiguous_split_chunks: bool = False, 103 | ) -> tuple[torch.Tensor, ...]: 104 | """Split a tensor along its last dimension. 105 | 106 | Source: https://github.com/EleutherAI/gpt-neox/blob/d7af1e7a8e3a816610b7d169456f81ca62d34ff7/megatron/mpu/utils.py 107 | 108 | Args: 109 | tensor (torch.Tensor): input tensor 110 | num_partitions (int): number of partitions to split the tensor 111 | dim (int): dimension along which to split the tensor. 112 | contiguous_split_chunks (bool): If True, make each chunk contiguous 113 | in memory. 114 | 115 | Returns: 116 | tuple of tensors 117 | """ # noqa: E501 118 | dim_size = tensor.size()[dim] 119 | 120 | if dim_size % num_partitions != 0: 121 | raise ValueError( 122 | f'Tensor dim {dim} (size={dim_size}) is not divisible ' 123 | f'into {num_partitions} parts.', 124 | ) 125 | 126 | dim_size = dim_size // num_partitions 127 | tensor_list = torch.split(tensor, dim_size, dim=dim) 128 | 129 | # Note: torch.split does not create contiguous tensors by default. 130 | if contiguous_split_chunks: 131 | return tuple(chunk.contiguous() for chunk in tensor_list) 132 | 133 | return tuple(tensor_list) 134 | -------------------------------------------------------------------------------- /kfac/hyperparams.py: -------------------------------------------------------------------------------- 1 | """Common hyperparameter schedules.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Callable 6 | 7 | 8 | def exp_decay_factor_averaging( 9 | min_value: float = 0.95, 10 | ) -> Callable[[int], float]: 11 | """Exponentially decaying factor averaging schedule. 12 | 13 | Implements the running average estimate strategy for the Kronecker factors 14 | A and G from "Optimizing Neural Networks with Kronecker-factored 15 | Approximate Curvature" (Martens et al., 2015). 16 | 17 | The running average weight e at K-FAC step k is min(1 - 1/k, min_value) 18 | where the min_value is 0.95 by default. 19 | 20 | Args: 21 | min_value (float): minimum value for the running average weight. 22 | 23 | Returns: 24 | callable that takes an integer value for the current K-FAC step and 25 | returns a float value for the running average weight. This callable 26 | can be passed as the value of `factor_decay` to instances of 27 | `kfac.base_preconditioner.BaseKFACPreconditioner`. Note: that if the 28 | current step is 0, 1 / k is undefined so k = 1 will be used, 29 | and if the current step is negative, a ValueError will be raised. 30 | 31 | Raises: 32 | ValueError: 33 | if `min_value` is less than or equal to zero. 34 | """ 35 | if min_value <= 0: 36 | raise ValueError('min_value must be greater than 0') 37 | 38 | def _factor_weight(step: int) -> float: 39 | if step < 0: 40 | raise ValueError( 41 | f'step value cannot be negative. Got step={step}.', 42 | ) 43 | if step == 0: 44 | step = 1 45 | return min(1 - (1 / step), min_value) 46 | 47 | return _factor_weight 48 | -------------------------------------------------------------------------------- /kfac/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """KFAC layer approximation module.""" 2 | 3 | from __future__ import annotations 4 | -------------------------------------------------------------------------------- /kfac/layers/inverse.py: -------------------------------------------------------------------------------- 1 | """Inverse preconditioning implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Callable 6 | from typing import cast 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | from kfac.distributed import Future 12 | from kfac.distributed import FutureType 13 | from kfac.distributed import get_rank 14 | from kfac.distributed import TorchDistributedCommunicator 15 | from kfac.enums import AllreduceMethod 16 | from kfac.layers.base import KFACBaseLayer 17 | from kfac.layers.modules import ModuleHelper 18 | 19 | 20 | class KFACInverseLayer(KFACBaseLayer): 21 | """KFAC layer that preconditions gradients with inverse factors.""" 22 | 23 | def __init__( 24 | self, 25 | module: ModuleHelper, 26 | *, 27 | tdc: TorchDistributedCommunicator, 28 | allreduce_method: AllreduceMethod = AllreduceMethod.ALLREDUCE, 29 | factor_dtype: torch.dtype | None = None, 30 | grad_scaler: ( 31 | torch.cuda.amp.GradScaler | Callable[[], float] | None 32 | ) = None, 33 | inv_dtype: torch.dtype = torch.float32, 34 | symmetry_aware: bool = False, 35 | ) -> None: 36 | """Init KFACInverseLayer. 37 | 38 | Args: 39 | module (ModuleHelper): module helper that exposes interfaces for 40 | getting the factors and gradients of a PyTorch module. 41 | tdc (TorchDistributedCommunicator): communicator object. Typically 42 | the communicator object should be shared by all KFACBaseLayers. 43 | allreduce_method (AllreduceMethod): allreduce method (default: 44 | AllreduceMethod.ALLREDUCE). 45 | factor_dtype (torch.dtype): data format to store factors in. If 46 | None, factors are stored in the format used in training 47 | (default: None). 48 | grad_scaler (optional): optional GradScaler or callable that 49 | returns the scale factor used in AMP training (default: None). 50 | inv_dtype (torch.dtype): data format to store inverses in. 51 | Inverses (or eigen decompositions) may be unstable in half- 52 | precision (default: torch.float32). 53 | symmetry_aware (bool): use symmetry aware communication method. 54 | This is typically more helpful when the factors are very 55 | large (default: False). 56 | """ 57 | super().__init__( 58 | module=module, 59 | tdc=tdc, 60 | allreduce_method=allreduce_method, 61 | factor_dtype=factor_dtype, 62 | grad_scaler=grad_scaler, 63 | inv_dtype=inv_dtype, 64 | symmetry_aware=symmetry_aware, 65 | ) 66 | 67 | # Inverse state variables 68 | # Inverse of self.a_factor 69 | self._a_inv: torch.Tensor | FutureType | None = None 70 | # Inverse of self.g_factor 71 | self._g_inv: torch.Tensor | FutureType | None = None 72 | 73 | @property 74 | def a_inv(self) -> torch.Tensor | None: 75 | """Get A inverse.""" 76 | if isinstance(self._a_inv, Future): 77 | self._a_inv = cast(torch.Tensor, self._a_inv.wait()) 78 | return self._a_inv 79 | 80 | @a_inv.setter 81 | def a_inv(self, value: torch.Tensor | FutureType | None) -> None: 82 | """Set A inverse.""" 83 | self._a_inv = value 84 | 85 | @property 86 | def g_inv(self) -> torch.Tensor | None: 87 | """Get G inverse.""" 88 | if isinstance(self._g_inv, Future): 89 | self._g_inv = cast(torch.Tensor, self._g_inv.wait()) 90 | return self._g_inv 91 | 92 | @g_inv.setter 93 | def g_inv(self, value: torch.Tensor | FutureType | None) -> None: 94 | """Set G inverse.""" 95 | self._g_inv = value 96 | 97 | def memory_usage(self) -> dict[str, int]: 98 | """Get memory usage for all variables in the layer.""" 99 | sizes = super().memory_usage() 100 | sizes['a_inverses'] = ( 101 | self.a_inv.nelement() * self.a_inv.element_size() 102 | if self.a_inv is not None 103 | else 0 104 | ) 105 | sizes['g_inverses'] = ( 106 | self.g_inv.nelement() * self.g_inv.element_size() 107 | if self.g_inv is not None 108 | else 0 109 | ) 110 | return sizes 111 | 112 | def broadcast_a_inv( 113 | self, 114 | src: int, 115 | group: dist.ProcessGroup | None = None, 116 | ) -> None: 117 | """Initiate A inv broadcast and store future to result. 118 | 119 | Note: 120 | all ranks must enter this function even if the rank is not 121 | a part of the inverse broadcast group. 122 | 123 | Args: 124 | src (int): src rank that computed A inverse. 125 | group (ProcessGroup): process group to which src should broadcast 126 | A inv. All ranks in group should enter this function. 127 | Defaults to None, the default process group. 128 | """ 129 | if self.a_inv is None: 130 | if get_rank() == src: 131 | raise RuntimeError( 132 | f'Attempt to broadcast A inv from src={src} but this rank ' 133 | 'has not computed A inv yet.', 134 | ) 135 | assert isinstance(self.a_factor, torch.Tensor) 136 | self.a_inv = torch.empty( 137 | self.a_factor.shape, 138 | device=self.a_factor.device, 139 | dtype=self.inv_dtype, 140 | ) 141 | 142 | self.a_inv = self.tdc.broadcast( # type: ignore 143 | self.a_inv, 144 | src=src, 145 | group=group, 146 | symmetric=self.symmetric_factors and self.symmetry_aware, 147 | ) 148 | 149 | def broadcast_g_inv( 150 | self, 151 | src: int, 152 | group: dist.ProcessGroup | None = None, 153 | ) -> None: 154 | """Initiate G inv broadcast and store future to result. 155 | 156 | Note: 157 | all ranks must enter this function even if the rank is not 158 | a part of the inverse broadcast group. 159 | 160 | Args: 161 | src (int): src rank that computed G inverse. 162 | group (ProcessGroup): process group to which src should broadcast 163 | G inv. All ranks in group should enter this function. 164 | Defaults to None, the default process group. 165 | """ 166 | if self.g_inv is None: 167 | if get_rank() == src: 168 | raise RuntimeError( 169 | f'Attempt to broadcast G inv from src={src} but this rank ' 170 | 'has not computed G inv yet.', 171 | ) 172 | assert isinstance(self.g_factor, torch.Tensor) 173 | self.g_inv = torch.empty( 174 | self.g_factor.shape, 175 | device=self.g_factor.device, 176 | dtype=self.inv_dtype, 177 | ) 178 | 179 | self.g_inv = self.tdc.broadcast( # type: ignore 180 | self.g_inv, 181 | src=src, 182 | group=group, 183 | symmetric=self.symmetric_factors and self.symmetry_aware, 184 | ) 185 | 186 | def compute_a_inv(self, damping: float = 0.001) -> None: 187 | """Compute A inverse on assigned rank. 188 | 189 | update_a_factor() must be called at least once before this function. 190 | 191 | Args: 192 | damping (float, optional): damping value to condition inverse 193 | (default: 0.001). 194 | """ 195 | if self.a_factor is None: 196 | raise RuntimeError('Cannot invert A before A has been computed') 197 | 198 | d = torch.diag( 199 | self.a_factor.new(self.a_factor.shape[0]).fill_(damping), 200 | ) 201 | a = self.a_factor + d 202 | self.a_inv = torch.linalg.inv(a.to(torch.float32)).to(self.inv_dtype) 203 | 204 | def compute_g_inv(self, damping: float = 0.001) -> None: 205 | """See `compute_g_inv`.""" 206 | if self.g_factor is None: 207 | raise RuntimeError('Cannot invert G before G has been computed') 208 | 209 | d = torch.diag( 210 | self.g_factor.new(self.g_factor.shape[0]).fill_(damping), 211 | ) 212 | g = self.g_factor + d 213 | self.g_inv = torch.linalg.inv(g.to(torch.float32)).to(self.inv_dtype) 214 | 215 | def preconditioned_grad(self, damping: float = 0.001) -> None: 216 | """Compute precondition gradient of each weight in module. 217 | 218 | Preconditioned gradients can be applied to the actual gradients with 219 | `update_gradient()`. Note the steps are separate in the event that 220 | intermediate steps will be applied to the preconditioned gradient. 221 | 222 | Args: 223 | damping (float, optional): damping to use if preconditioning using 224 | the eigendecomposition method (default: 0.001). 225 | """ 226 | if self.a_inv is None or self.g_inv is None: 227 | raise RuntimeError( 228 | 'Cannot precondition gradient before A and G have been ' 229 | 'inverted', 230 | ) 231 | grad = self.module.get_grad() 232 | grad_type = grad.dtype 233 | grad = grad.to(self.a_inv.dtype) 234 | self.grad = (self.g_inv @ grad @ self.a_inv).to(grad_type) 235 | -------------------------------------------------------------------------------- /kfac/layers/modules.py: -------------------------------------------------------------------------------- 1 | """Helper wrappers for supported PyTorch modules.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import cast 6 | from typing import List 7 | 8 | import torch 9 | 10 | from kfac.layers.utils import append_bias_ones 11 | from kfac.layers.utils import get_cov 12 | 13 | 14 | class ModuleHelper: 15 | """PyTorch module helper. 16 | 17 | This base class provides the interface which the KFACBaseLayer expects 18 | as input. Namely, the interface provides methods to compute the factors 19 | of a module, get the shapes of the factors, and get and set the gradients. 20 | """ 21 | 22 | def __init__(self, module: torch.nn.Module): 23 | """Init ModuleHelper. 24 | 25 | Args: 26 | module (torch.nn.Module): module in model to wrap. 27 | """ 28 | self.module = module 29 | 30 | def __repr__(self) -> str: 31 | """Representation of the ModuleHelper instance.""" 32 | return f'{self.__class__.__name__}({repr(self.module)})' 33 | 34 | @property 35 | def a_factor_shape(self) -> tuple[int, int]: 36 | """Get shape of A factor.""" 37 | raise NotImplementedError 38 | 39 | @property 40 | def g_factor_shape(self) -> tuple[int, int]: 41 | """Get shape of G factor.""" 42 | raise NotImplementedError 43 | 44 | @property 45 | def device(self) -> torch.device: 46 | """Get device that the modules parameters are on.""" 47 | return next(self.module.parameters()).device 48 | 49 | def get_a_factor(self, a: torch.Tensor) -> torch.Tensor: 50 | """Compute A factor with the input from the forward pass.""" 51 | raise NotImplementedError 52 | 53 | def get_g_factor(self, g: torch.Tensor) -> torch.Tensor: 54 | """Compute G factor with the gradient w.r.t. the output.""" 55 | raise NotImplementedError 56 | 57 | def get_grad(self) -> torch.Tensor: 58 | """Get formatted gradients (weight and bias) of module. 59 | 60 | Returns: 61 | gradient of shape If bias != None, 62 | concats bias. 63 | """ 64 | g = cast(torch.Tensor, self.module.weight.grad) 65 | if self.has_bias(): 66 | g = torch.cat( 67 | [g, self.module.bias.grad.view(-1, 1)], # type: ignore 68 | 1, 69 | ) 70 | return g 71 | 72 | def get_bias_grad(self) -> torch.Tensor: 73 | """Get the gradient of the bias.""" 74 | return cast(torch.Tensor, self.module.bias.grad) 75 | 76 | def get_weight_grad(self) -> torch.Tensor: 77 | """Get the gradient of the weight.""" 78 | return cast(torch.Tensor, self.module.weight.grad) 79 | 80 | def has_bias(self) -> bool: 81 | """Check if module has a bias parameter.""" 82 | return hasattr(self.module, 'bias') and self.module.bias is not None 83 | 84 | def has_symmetric_factors(self) -> bool: 85 | """Check if module has symmetric factors.""" 86 | return True 87 | 88 | def set_grad(self, grad: torch.Tensor) -> None: 89 | """Update the gradient of the module.""" 90 | if self.has_bias(): 91 | weight_grad = grad[:, :-1].view(self.get_weight_grad().size()) 92 | bias_grad = grad[:, -1:].view(self.get_bias_grad().size()) 93 | else: 94 | weight_grad = grad.view(self.get_weight_grad().size()) 95 | 96 | if self.has_bias(): 97 | self.module.bias.grad = bias_grad.contiguous() 98 | self.module.weight.grad = weight_grad.contiguous() 99 | 100 | 101 | class LinearModuleHelper(ModuleHelper): 102 | """ModuleHelper for torch.nn.Linear modules.""" 103 | 104 | @property 105 | def a_factor_shape(self) -> tuple[int, int]: 106 | """Get shape of A factor. 107 | 108 | A shape = (in_features + int(has_bias), in_features + int(has_bias)) 109 | """ 110 | x = self.module.weight.size(1) + int(self.has_bias()) # type: ignore 111 | return (x, x) 112 | 113 | @property 114 | def g_factor_shape(self) -> tuple[int, int]: 115 | """Get shape of G factor. 116 | 117 | G shape = (out_features, out_features) 118 | """ 119 | return ( 120 | self.module.weight.size(0), # type: ignore 121 | self.module.weight.size(0), # type: ignore 122 | ) 123 | 124 | def get_a_factor(self, a: torch.Tensor) -> torch.Tensor: 125 | """Compute A factor with the input from the forward pass. 126 | 127 | Args: 128 | a (torch.Tensor): tensor with shape batch_size * in_dim. 129 | """ 130 | a = a.view(-1, a.size(-1)) 131 | if self.has_bias(): 132 | a = append_bias_ones(a) 133 | return get_cov(a) 134 | 135 | def get_g_factor(self, g: torch.Tensor) -> torch.Tensor: 136 | """Compute G factor with the gradient w.r.t. the output. 137 | 138 | Args: 139 | g (torch.Tensor): tensor with shape batch_size * out_dim. 140 | """ 141 | g = g.reshape(-1, g.size(-1)) 142 | return get_cov(g) 143 | 144 | 145 | class Conv2dModuleHelper(ModuleHelper): 146 | """ModuleHelper for torch.nn.Conv2d layers.""" 147 | 148 | def __init__(self, module: torch.nn.Conv2d): 149 | """Init ModuleHelper. 150 | 151 | Args: 152 | module (torch.nn.Conv2d): Conv2d module in model to wrap. 153 | """ 154 | self.module = module 155 | 156 | @property 157 | def a_factor_shape(self) -> tuple[int, int]: 158 | """Get shape of A factor.""" 159 | ksize0: int = self.module.kernel_size[0] # type: ignore 160 | ksize1: int = self.module.kernel_size[1] # type: ignore 161 | in_ch: int = self.module.in_channels # type: ignore 162 | x = in_ch * ksize0 * ksize1 + int(self.has_bias()) 163 | return (x, x) 164 | 165 | @property 166 | def g_factor_shape(self) -> tuple[int, int]: 167 | """Get shape of G factor.""" 168 | out_ch: int = self.module.out_channels # type: ignore 169 | return (out_ch, out_ch) 170 | 171 | def get_a_factor(self, a: torch.Tensor) -> torch.Tensor: 172 | """Compute A factor with the input from the forward pass.""" 173 | a = self._extract_patches(a) 174 | spatial_size = a.size(1) * a.size(2) 175 | a = a.view(-1, a.size(-1)) 176 | if self.has_bias(): 177 | a = append_bias_ones(a) 178 | a = a / spatial_size 179 | return get_cov(a) 180 | 181 | def get_g_factor(self, g: torch.Tensor) -> torch.Tensor: 182 | """Compute G factor with the gradient w.r.t. the output. 183 | 184 | Args: 185 | g (torch.Tensor): tensor with shape batch_size * n_filters * 186 | out_h * out_w n_filters is actually the output dimension 187 | (analogous to Linear layer). 188 | """ 189 | spatial_size = g.size(2) * g.size(3) 190 | g = g.transpose(1, 2).transpose(2, 3) 191 | g = g.reshape(-1, g.size(-1)) 192 | g = g / spatial_size 193 | return get_cov(g) 194 | 195 | def get_grad(self) -> torch.Tensor: 196 | """Get formmated gradients (weight and bias) of module.""" 197 | grad = cast( 198 | torch.Tensor, 199 | self.module.weight.grad.view( # type: ignore 200 | self.module.weight.grad.size(0), # type: ignore 201 | -1, 202 | ), 203 | ) 204 | if self.has_bias(): 205 | grad = torch.cat( 206 | [grad, self.module.bias.grad.view(-1, 1)], # type: ignore 207 | 1, 208 | ) 209 | return grad 210 | 211 | def _extract_patches(self, x: torch.Tensor) -> torch.Tensor: 212 | """Extract patches from convolutional layer. 213 | 214 | Args: 215 | x (torch.Tensor): input feature maps with shape 216 | (batch_size, in_c, h, w). 217 | 218 | Returns: 219 | tensor of shape (batch_size, out_h, out_w, in_c*kh*kw) 220 | """ 221 | padding = cast(List[int], self.module.padding) 222 | kernel_size = cast(List[int], self.module.kernel_size) 223 | stride = cast(List[int], self.module.stride) 224 | if padding[0] + padding[1] > 0: 225 | x = torch.nn.functional.pad( 226 | x, 227 | (padding[1], padding[1], padding[0], padding[0]), 228 | ).data 229 | x = x.unfold(2, kernel_size[0], stride[0]) 230 | x = x.unfold(3, kernel_size[1], stride[1]) 231 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous() 232 | x = x.view( 233 | x.size(0), 234 | x.size(1), 235 | x.size(2), 236 | x.size(3) * x.size(4) * x.size(5), 237 | ) 238 | return x 239 | -------------------------------------------------------------------------------- /kfac/layers/register.py: -------------------------------------------------------------------------------- 1 | """Utilities for registering PyTorch modules to KFAC layers.""" 2 | 3 | from __future__ import annotations 4 | 5 | import re 6 | from typing import Any 7 | 8 | import torch 9 | 10 | from kfac.layers.base import KFACBaseLayer 11 | from kfac.layers.modules import Conv2dModuleHelper 12 | from kfac.layers.modules import LinearModuleHelper 13 | from kfac.layers.modules import ModuleHelper 14 | 15 | KNOWN_MODULES = {'linear', 'conv2d'} 16 | LINEAR_TYPES: tuple[type[torch.nn.Module], ...] = (torch.nn.Linear,) 17 | CONV2D_TYPES: tuple[type[torch.nn.Module], ...] = (torch.nn.Conv2d,) 18 | 19 | 20 | def get_flattened_modules( 21 | root: torch.nn.Module, 22 | ) -> list[tuple[str, torch.nn.Module]]: 23 | """Returns flattened view of leaves of module tree.""" 24 | return [ 25 | (name, module) 26 | for name, module in root.named_modules() 27 | if len(list(module.children())) == 0 28 | ] 29 | 30 | 31 | def requires_grad(module: torch.nn.Module) -> bool: 32 | """Return False if any module param has requires_grad=False.""" 33 | return all([p.requires_grad for p in module.parameters()]) 34 | 35 | 36 | def get_module_helper(module: torch.nn.Module) -> ModuleHelper | None: 37 | """Return KFAC module helper that wraps a PyTorch module.""" 38 | if isinstance(module, LINEAR_TYPES): 39 | return LinearModuleHelper(module) 40 | elif isinstance(module, CONV2D_TYPES): 41 | return Conv2dModuleHelper(module) # type: ignore 42 | else: 43 | return None 44 | 45 | 46 | def any_match(query: str, patterns: list[str]) -> bool: 47 | """Check if a query string matches any pattern in a list. 48 | 49 | Note: 50 | `search()` is used rather than `match()` so True will be returned 51 | if there is a match anywhere in the query string. 52 | """ 53 | regexes = [re.compile(p) for p in patterns] 54 | return any(regex.search(query) for regex in regexes) 55 | 56 | 57 | def register_modules( 58 | model: torch.nn.Module, 59 | kfac_layer_type: type[KFACBaseLayer], 60 | skip_layers: list[str], 61 | **layer_kwargs: Any, 62 | ) -> dict[torch.nn.Module, tuple[str, KFACBaseLayer]]: 63 | """Register supported modules in model with a KFACLayer. 64 | 65 | Args: 66 | model (torch.nn.Module): model to scan for modules to register. 67 | kfac_layer_type (type[KFACBaseLayer]): type of subclass of 68 | KFACBaseLayer to use. 69 | skip_layers (list[str]): regex patterns that if matched, will cause 70 | the layer to not be registered. The patterns will be applied 71 | against the layer's name and class name. 72 | **layer_kwargs (dict[str, Any]): optional keyword arguments to 73 | pass to the kfac_layer_type constructor. 74 | """ 75 | modules = get_flattened_modules(model) 76 | 77 | kfac_layers: dict[torch.nn.Module, tuple[str, KFACBaseLayer]] = {} 78 | for name, module in modules: 79 | if ( 80 | not any_match(name, skip_layers) 81 | and not any_match(module.__class__.__name__, skip_layers) 82 | and requires_grad(module) 83 | ): 84 | module_helper = get_module_helper(module) 85 | if module_helper is None: 86 | continue 87 | 88 | kfac_layer = kfac_layer_type(module_helper, **layer_kwargs) 89 | 90 | # get_flattened_modules() should never give us modules with the 91 | # same name 92 | assert module not in kfac_layers 93 | kfac_layers[module] = (name, kfac_layer) 94 | 95 | return kfac_layers 96 | -------------------------------------------------------------------------------- /kfac/layers/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for KFAC computations.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | 7 | 8 | def append_bias_ones(tensor: torch.Tensor) -> torch.Tensor: 9 | """Appends vector of ones to last dimension of tensor. 10 | 11 | For examples, if the input is of shape [4, 6], then the outputs has shape 12 | [4, 7] where the slice [:, -1] is a tensor of all ones. 13 | """ 14 | shape = list(tensor.shape[:-1]) + [1] 15 | return torch.cat([tensor, tensor.new_ones(shape)], dim=-1) 16 | 17 | 18 | def get_cov( 19 | a: torch.Tensor, 20 | b: torch.Tensor | None = None, 21 | scale: float | None = None, 22 | ) -> torch.Tensor: 23 | """Computes the empirical second moment of a 2D tensor. 24 | 25 | Reference: 26 | - https://github.com/tensorflow/kfac/blob/master/kfac/python/ops/fisher_factors.py#L220 27 | - https://arxiv.org/pdf/1602.01407.pdf#subsection.2.2 28 | 29 | Args: 30 | a (tensor): 2D tensor to compute second moment of using 31 | cov_a = a^T @ a. 32 | b (tensor, optional): optional tensor of equal shape to a such that 33 | cov_a = a^T @ b. 34 | scale (float, optional): optional tensor to divide cov_a by. Default 35 | is a.size(0). 36 | 37 | Returns: 38 | square tensor representing the second moment of a. 39 | """ # noqa: E501 40 | if len(a.shape) != 2: 41 | raise ValueError( 42 | 'Input tensor must have 2 dimensions. Got tensor with shape ' 43 | f'{a.shape}', 44 | ) 45 | if b is not None and a.shape != b.shape: 46 | raise ValueError( 47 | 'Input tensors must have same shape. Got tensors of ' 48 | 'shape {} and {}.'.format(a.shape, b.shape), 49 | ) 50 | 51 | if scale is None: 52 | scale = a.size(0) 53 | 54 | if b is None: 55 | cov_a = a.t() @ (a / scale) 56 | # TODO(gpauloski): is this redundant? 57 | return (cov_a + cov_a.t()) / 2.0 58 | else: 59 | return a.t() @ (b / scale) 60 | 61 | 62 | def reshape_data( 63 | data_list: list[torch.Tensor], 64 | batch_first: bool = True, 65 | collapse_dims: bool = False, 66 | ) -> torch.Tensor: 67 | """Concat input/output data and clear buffers. 68 | 69 | Args: 70 | data_list (list): list of tensors of equal, arbitrary shape where the 71 | batch_dim is either 0 or 1 depending on self.batch_first. 72 | batch_first (bool, optional): is batch dim first. (default: True) 73 | collapse_dims (bool, optional): if True, collapse all but the last dim 74 | together forming a 2D output tensor. 75 | 76 | Returns: 77 | single tensor with all tensors from data_list concatenated across 78 | batch_dim. Guaranteed to be 2D if collapse_dims=True. 79 | """ 80 | d = torch.cat(data_list, dim=int(not batch_first)) 81 | if collapse_dims and len(d.shape) > 2: 82 | d = d.view(-1, d.shape[-1]) 83 | return d 84 | -------------------------------------------------------------------------------- /kfac/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/kfac/py.typed -------------------------------------------------------------------------------- /kfac/scheduler.py: -------------------------------------------------------------------------------- 1 | """KFAC preconditioner parameter scheduler.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Callable 6 | 7 | from kfac.base_preconditioner import BaseKFACPreconditioner 8 | 9 | 10 | class LambdaParamScheduler: 11 | """Lambda param scheduler for KFAC preconditioner. 12 | 13 | Note: 14 | The lambda functions take as input the step value of the 15 | preconditioner. This step value is not necessarily the global number 16 | of optimization steps but rather the number of times 17 | preconditioner.step() has been called. This can be overridden by 18 | passing the step value to scheduler.step(step). 19 | 20 | Warning: 21 | KFACBasePreconditioner can take callables for the parameters instead 22 | of constant values. Passing callables for the parameters and using 23 | the LambdaParamScheduler at the same time is not possible. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | preconditioner: BaseKFACPreconditioner, 29 | *, 30 | factor_update_steps_lambda: Callable[[int], float] | None = None, 31 | inv_update_steps_lambda: Callable[[int], float] | None = None, 32 | damping_lambda: Callable[[int], float] | None = None, 33 | factor_decay_lambda: Callable[[int], float] | None = None, 34 | kl_clip_lambda: Callable[[int], float] | None = None, 35 | lr_lambda: Callable[[int], float] | None = None, 36 | ): 37 | """Init LambdaParamScheduler. 38 | 39 | Args: 40 | preconditioner (KFACBasePreconditioner): preconditioner to 41 | update parameters for. 42 | factor_update_steps_lambda (callable, optional): function which 43 | computes a multiplicative factor for the factor_update_steps 44 | given an integer value of the number of steps from the KFAC 45 | preconditioner. The result will be cast to an int 46 | (default: None). 47 | inv_update_steps_lambda (callable, optional): function which 48 | computes a multiplicative factor for the inv_update_steps 49 | given an integer value of the number of steps from the KFAC 50 | preconditioner. The result will be cast to an int 51 | (default: None). 52 | damping_lambda (callable, optional): function which 53 | computes a multiplicative factor for the damping 54 | given an integer value of the number of steps from the KFAC 55 | preconditioner (default: None). 56 | factor_decay_lambda (callable, optional): function which 57 | computes a multiplicative factor for the factor_decay 58 | given an integer value of the number of steps from the KFAC 59 | preconditioner (default: None). 60 | kl_clip_lambda (callable, optional): function which 61 | computes a multiplicative factor for the kl_clip 62 | given an integer value of the number of steps from the KFAC 63 | preconditioner (default: None). 64 | lr_lambda (callable, optional): function which 65 | computes a multiplicative factor for the lr 66 | given an integer value of the number of steps from the KFAC 67 | preconditioner (default: None). 68 | 69 | Raises: 70 | ValueError: 71 | if a lambda is passed for a parameter but the parameter in the 72 | preconditioner is already a callable. 73 | """ 74 | self._preconditioner = preconditioner 75 | self._factor_update_steps_lambda = factor_update_steps_lambda 76 | self._inv_update_steps_lambda = inv_update_steps_lambda 77 | self._damping_lambda = damping_lambda 78 | self._factor_decay_lambda = factor_decay_lambda 79 | self._kl_clip_lambda = kl_clip_lambda 80 | self._lr_lambda = lr_lambda 81 | 82 | if self._factor_update_steps_lambda is not None: 83 | if callable(self._preconditioner._factor_update_steps): 84 | raise ValueError( 85 | 'preconditioner.factor_update_steps is already a callable ' 86 | 'and cannot be updated by the lambdaparamscheduler.', 87 | ) 88 | if self._inv_update_steps_lambda is not None: 89 | if callable(self._preconditioner._inv_update_steps): 90 | raise ValueError( 91 | 'preconditioner.inv_update_steps is already a callable ' 92 | 'and cannot be updated by the lambdaparamscheduler.', 93 | ) 94 | if self._damping_lambda is not None: 95 | if callable(self._preconditioner._damping): 96 | raise ValueError( 97 | 'preconditioner.damping is already a callable ' 98 | 'and cannot be updated by the lambdaparamscheduler.', 99 | ) 100 | if self._factor_decay_lambda is not None: 101 | if callable(self._preconditioner._factor_decay): 102 | raise ValueError( 103 | 'preconditioner.factor_decay is already a callable ' 104 | 'and cannot be updated by the lambdaparamscheduler.', 105 | ) 106 | if self._kl_clip_lambda is not None: 107 | if callable(self._preconditioner._kl_clip): 108 | raise ValueError( 109 | 'preconditioner.kl_clip is already a callable ' 110 | 'and cannot be updated by the lambdaparamscheduler.', 111 | ) 112 | if self._lr_lambda is not None: 113 | if callable(self._preconditioner._lr): 114 | raise ValueError( 115 | 'preconditioner.lr is already a callable ' 116 | 'and cannot be updated by the lambdaparamscheduler.', 117 | ) 118 | 119 | def step(self, step: int | None = None) -> None: 120 | """Update KFAC preconditioner params. 121 | 122 | Note: 123 | This should be called after preconditioner.step(). 124 | 125 | Args: 126 | step (int, optional): optionally override the current step. 127 | """ 128 | if self._factor_update_steps_lambda is not None: 129 | factor = self._factor_update_steps_lambda( 130 | step if step is not None else self._preconditioner.steps, 131 | ) 132 | assert not callable(self._preconditioner._factor_update_steps) 133 | self._preconditioner._factor_update_steps = int( 134 | self._preconditioner._factor_update_steps * factor, 135 | ) 136 | if self._inv_update_steps_lambda is not None: 137 | factor = self._inv_update_steps_lambda( 138 | step if step is not None else self._preconditioner.steps, 139 | ) 140 | assert not callable(self._preconditioner._inv_update_steps) 141 | self._preconditioner._inv_update_steps = int( 142 | self._preconditioner._inv_update_steps * factor, 143 | ) 144 | if self._damping_lambda is not None: 145 | factor = self._damping_lambda( 146 | step if step is not None else self._preconditioner.steps, 147 | ) 148 | assert not callable(self._preconditioner._damping) 149 | self._preconditioner._damping *= factor 150 | if self._factor_decay_lambda is not None: 151 | factor = self._factor_decay_lambda( 152 | step if step is not None else self._preconditioner.steps, 153 | ) 154 | assert not callable(self._preconditioner._factor_decay) 155 | self._preconditioner._factor_decay *= factor 156 | if self._kl_clip_lambda is not None: 157 | factor = self._kl_clip_lambda( 158 | step if step is not None else self._preconditioner.steps, 159 | ) 160 | assert not callable(self._preconditioner._kl_clip) 161 | self._preconditioner._kl_clip *= factor 162 | if self._lr_lambda is not None: 163 | factor = self._lr_lambda( 164 | step if step is not None else self._preconditioner.steps, 165 | ) 166 | assert not callable(self._preconditioner._lr) 167 | self._preconditioner._lr *= factor 168 | -------------------------------------------------------------------------------- /kfac/tracing.py: -------------------------------------------------------------------------------- 1 | """Utilities for tracing function execution time.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | import time 7 | from typing import Any 8 | from typing import Callable 9 | from typing import TypeVar 10 | 11 | import torch 12 | 13 | RT = TypeVar('RT') 14 | 15 | _func_traces: dict[str, list[float]] = {} 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def clear_trace() -> None: 20 | """Clear recorded traces globally.""" 21 | _func_traces.clear() 22 | 23 | 24 | def get_trace( 25 | average: bool = True, 26 | max_history: int | None = None, 27 | ) -> dict[str, float]: 28 | """Get recorded traces. 29 | 30 | Args: 31 | average (bool): if true, return the average of the function 32 | execution times for each function. Otherwise, return the sum 33 | of time spent in each function (default: True). 34 | max_history (int, optional): if not None, only return statistics for 35 | the previous max_history calls. 36 | 37 | Returns: 38 | dict mapping function names to execution time. 39 | """ 40 | out = {} 41 | for fname, times in _func_traces.items(): 42 | if max_history is not None and len(times) > max_history: 43 | times = times[-max_history:] 44 | out[fname] = sum(times) 45 | if average: 46 | out[fname] /= len(times) 47 | return out 48 | 49 | 50 | def log_trace( 51 | average: bool = True, 52 | max_history: int | None = None, 53 | loglevel: int = logging.INFO, 54 | ) -> None: 55 | """Log function execution times recorded with @trace. 56 | 57 | To trace function execution times, use the @kfac.utils.trace() 58 | decorator on all functions to be traced. Then to get the average 59 | execution times, call kfac.utils.print_trace(). 60 | 61 | Args: 62 | average (bool): if true, average the times otherwise print sum of 63 | times. 64 | max_history (int, optional): most recent `max_history` times to use 65 | for average. If None, all are used. 66 | loglevel (int): logging level for trace (default: logging.INFO). 67 | """ 68 | if len(_func_traces) == 0: 69 | return 70 | for fname, times in get_trace(average, max_history).items(): 71 | logger.log(loglevel, f'{fname}: {times}') 72 | 73 | 74 | def trace( 75 | sync: bool = False, 76 | ) -> Callable[[Callable[..., RT]], Callable[..., RT]]: 77 | """Return decorator for function execution time tracing. 78 | 79 | Args: 80 | sync (bool): if true, sync distributed ranks before and after entering 81 | the decorated function. 82 | 83 | Returns: 84 | function decorator. 85 | """ 86 | 87 | def decorator(func: Callable[..., RT]) -> Callable[..., RT]: 88 | """Decorator for function execution time tracing.""" 89 | 90 | def func_timer(*args: list[Any], **kwargs: dict[str, Any]) -> Any: 91 | """Time and execute function.""" 92 | if sync: 93 | torch.distributed.barrier() 94 | t = time.time() 95 | out = func(*args, **kwargs) 96 | if sync: 97 | torch.distributed.barrier() 98 | t = time.time() - t 99 | 100 | if func.__name__ not in _func_traces: 101 | _func_traces[func.__name__] = [t] 102 | else: 103 | _func_traces[func.__name__].append(t) 104 | return out 105 | 106 | return func_timer 107 | 108 | return decorator 109 | -------------------------------------------------------------------------------- /kfac/warnings.py: -------------------------------------------------------------------------------- 1 | """KFAC Warnings.""" 2 | 3 | from __future__ import annotations 4 | 5 | 6 | class ExperimentalFeatureWarning(Warning): 7 | """Experimental features warning.""" 8 | 9 | pass 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64.0", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "kfac_pytorch" 7 | version = "0.4.2" 8 | authors = [ 9 | {name = "Greg Pauloski", email = "jgpauloski@uchicago.edu"}, 10 | ] 11 | description = "Distributed K-FAC Preconditioner for PyTorch." 12 | readme = "README.md" 13 | requires-python = ">=3.9" 14 | license = {file = "LICENSE"} 15 | classifiers = [ 16 | "License :: OSI Approved :: MIT License", 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3 :: Only", 19 | "Programming Language :: Python :: Implementation :: CPython", 20 | ] 21 | dependencies = [ 22 | "torch>=2", 23 | ] 24 | 25 | [project.urls] 26 | repository = "https://github.com/gpauloski/kfac_pytorch" 27 | 28 | [project.optional-dependencies] 29 | dev = [ 30 | "covdefaults>=2.2", 31 | "coverage", 32 | "mypy", 33 | "numpy", 34 | "pre-commit", 35 | "protobuf==3.20.2", 36 | "pytest", 37 | "pytest-cov", 38 | "ruff", 39 | "torchtext", 40 | "torchvision", 41 | "tox", 42 | "types-tqdm", 43 | "virtualenv", 44 | ] 45 | 46 | [tool.codespell] 47 | skip = """ 48 | .git, 49 | .github, 50 | __pycache__, 51 | build, 52 | dist, 53 | .*egg-info 54 | """ 55 | 56 | [tool.coverage.run] 57 | plugins = ["covdefaults"] 58 | omit = ["examples/*", "testing/*", "tests/integration/*"] 59 | concurrency = ["multiprocessing", "thread"] 60 | parallel = true 61 | 62 | [tool.mypy] 63 | python_version = "3.9" 64 | check_untyped_defs = true 65 | disallow_any_generics = true 66 | disallow_incomplete_defs = true 67 | disallow_untyped_defs = true 68 | ignore_missing_imports = true 69 | no_implicit_optional = true 70 | warn_redundant_casts = true 71 | warn_unused_configs = true 72 | warn_unused_ignores = false 73 | 74 | [[tool.mypy.overrides]] 75 | module = "testing.*" 76 | allow_incomplete_defs = true 77 | allow_untyped_defs = true 78 | 79 | [[tool.mypy.overrides]] 80 | module = "tests.*" 81 | allow_incomplete_defs = true 82 | allow_untyped_defs = true 83 | 84 | [tool.ruff] 85 | line-length = 79 86 | target-version = "py39" 87 | 88 | [tool.ruff.lint] 89 | # pycodestyle, pyflakes, flake8-builtins, flake8-bugbear, isort, pep8-naming, 90 | # pydocstyle, flake8-debugger, flake8-commas 91 | select = ["E", "F", "A", "B", "I", "N", "D", "T10", "COM"] 92 | # Ignore D202 because issue with inner function and black will fix it anyways 93 | extend-ignore = ["D202", "D401", "A005"] 94 | 95 | [tool.ruff.lint.isort] 96 | force-single-line = true 97 | known-first-party = ["kfac_pytorch", "test", "testing"] 98 | order-by-type = false 99 | required-imports = ["from __future__ import annotations"] 100 | 101 | [tool.ruff.lint.per-file-ignores] 102 | "*/__init__.py" = ["F401"] 103 | "examples/*__init__.py" = ["D104"] 104 | "tests/*__init__.py" = ["D104"] 105 | 106 | [tool.ruff.lint.pydocstyle] 107 | convention = "google" 108 | 109 | [tool.ruff.lint.flake8-quotes] 110 | inline-quotes = "single" 111 | 112 | [tool.ruff.format] 113 | indent-style = "space" 114 | quote-style = "single" 115 | 116 | [tool.setuptools.packages.find] 117 | exclude = ["tests*", "testing*"] 118 | namespaces = false 119 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Utility Scripts 2 | 3 | Example scripts for launching PyTorch distributed training locally or on multiple nodes, submitting training to Cobalt and Slurm schedulers, and more. 4 | -------------------------------------------------------------------------------- /scripts/copy_and_extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copy tar file to destination dir and extract on all nodes in environment 3 | # 4 | # Usage: 5 | # $ ./scripts/copy_and_extract.sh /path/to/source.tar /path/to/dest 6 | # 7 | # Nodes on the environment are inferred from $NODEFILE, $SLURM_NODELIST, or 8 | # $COBALT_NODEFILE in this order. If none of these variables are set, the 9 | # script just executes the copy and extract locally. 10 | 11 | SOURCE_TAR=$1 12 | DEST_DIR=$2 13 | 14 | mkdir -p $DEST_DIR 15 | 16 | FULL_CMD="cp $SOURCE_TAR $DEST_DIR ; " 17 | FULL_CMD+="cd $DEST_DIR ; " 18 | FULL_CMD+="tar -xf $SOURCE_TAR " 19 | 20 | if [[ -z "${NODEFILE}" ]]; then 21 | if [[ -n "${SLURM_NODELIST}" ]]; then 22 | NODEFILE=/tmp/imagenet_slurm_nodelist 23 | scontrol show hostnames $SLURM_NODELIST > $NODEFILE 24 | elif [[ -n "${COBALT_NODEFILE}" ]]; then 25 | NODEFILE=$COBALT_NODEFILE 26 | fi 27 | fi 28 | if [[ -z "${NODEFILE}" ]]; then 29 | RANKS=$HOSTNAME 30 | else 31 | RANKS=$(tr '\n' ' ' < $NODEFILE) 32 | fi 33 | 34 | echo "Command: $FULL_CMD" 35 | 36 | # Launch execute the command on each worker (use ssh for remote nodes) 37 | RANK=0 38 | for NODE in $RANKS; do 39 | if [[ "$NODE" == "$HOSTNAME" ]]; then 40 | echo "Launching rank $RANK on local node $NODE" 41 | eval $FULL_CMD & 42 | else 43 | echo "Launching rank $RANK on remote node $NODE" 44 | ssh $NODE "cd $PWD; $FULL_CMD" & 45 | fi 46 | RANK=$((RANK+1)) 47 | done 48 | 49 | wait 50 | -------------------------------------------------------------------------------- /scripts/kill_python_procs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Execute "pkill python" on every node in environment (as inferred from 3 | # $NODEFILE, $SLURM_NODELIST, or $COBALT_NODEFILE) or locally if no 4 | # environment can be found 5 | 6 | FULL_CMD="pkill python" 7 | 8 | if [[ -z "${NODEFILE}" ]]; then 9 | if [[ -n "${SLURM_NODELIST}" ]]; then 10 | NODEFILE=/tmp/imagenet_slurm_nodelist 11 | scontrol show hostnames $SLURM_NODELIST > $NODEFILE 12 | elif [[ -n "${COBALT_NODEFILE}" ]]; then 13 | NODEFILE=$COBALT_NODEFILE 14 | fi 15 | fi 16 | if [[ -z "${NODEFILE}" ]]; then 17 | RANKS=$HOSTNAME 18 | else 19 | RANKS=$(tr '\n' ' ' < $NODEFILE) 20 | fi 21 | 22 | echo "Command: $FULL_CMD" 23 | 24 | # Launch execute the command on each worker (use ssh for remote nodes) 25 | RANK=0 26 | for NODE in $RANKS; do 27 | if [[ "$NODE" == "$HOSTNAME" ]]; then 28 | echo "Launching rank $RANK on local node $NODE" 29 | eval $FULL_CMD & 30 | else 31 | echo "Launching rank $RANK on remote node $NODE" 32 | ssh $NODE "cd $PWD; $FULL_CMD" & 33 | fi 34 | RANK=$((RANK+1)) 35 | done 36 | 37 | wait 38 | -------------------------------------------------------------------------------- /scripts/run_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # USAGE: 3 | # 4 | # To launch pretraining with this script, first customize the PRELOAD and 5 | # CMD variables for your training configuration. 6 | # 7 | # Run locally on a compute node: 8 | # 9 | # $ ./run_imagenet.sh 10 | # 11 | # Submit as a Cobalt or Slurm job: 12 | # 13 | # $ qsub -q QUEUE -A ALLOC -n NODES -t TIME run_imagenet.sh 14 | # $ sbatch -p QUEUE -A ALLOC -N NODES -t TIME run_imagenet.sh 15 | # 16 | # Notes: 17 | # - training configuration (e.g., # nodes, # gpus / node, etc.) will be 18 | # automatically inferred from the nodelist 19 | # - additional arguments to the python script can be specified by passing 20 | # them as arguments to this script. E.g., 21 | # 22 | # $ ./run_imagenet.sh --epochs 55 --batch-size 128 23 | # 24 | 25 | PRELOAD="module load conda ; " 26 | PRELOAD="conda activate pytorch ; " 27 | PRELOAD="export OMP_NUM_THREADS=8 ; " 28 | 29 | # Arguments to the training script are passed as arguments to this script 30 | CMD="examples/torch_imagenet_resnet.py $@" 31 | 32 | # Example: copy imagenet and extract to /tmp on each worker 33 | # ./scripts/copy_and_extract.sh /path/to/imagenet.tar /tmp/imagenet 34 | 35 | # Figure out training environment 36 | if [[ -z "${NODEFILE}" ]]; then 37 | if [[ -n "${SLURM_NODELIST}" ]]; then 38 | NODEFILE=/tmp/imagenet_slurm_nodelist 39 | scontrol show hostnames $SLURM_NODELIST > $NODEFILE 40 | elif [[ -n "${COBALT_NODEFILE}" ]]; then 41 | NODEFILE=$COBALT_NODEFILE 42 | fi 43 | fi 44 | if [[ -z "${NODEFILE}" ]]; then 45 | RANKS=$HOSTNAME 46 | NNODES=1 47 | else 48 | MAIN_RANK=$(head -n 1 $NODEFILE) 49 | RANKS=$(tr '\n' ' ' < $NODEFILE) 50 | NNODES=$(< $NODEFILE wc -l) 51 | fi 52 | 53 | # Torch Distributed Launcher 54 | LAUNCHER="python -m torch.distributed.run " 55 | LAUNCHER+="--nnodes=$NNODES --nproc_per_node=auto --max_restarts 0 " 56 | if [[ "$NNODES" -eq 1 ]]; then 57 | LAUNCHER+="--standalone " 58 | else 59 | LAUNCHER+="--rdzv_backend=c10d --rdzv_endpoint=$MAIN_RANK " 60 | fi 61 | 62 | # Combine preload, launcher, and script+args into full command 63 | FULL_CMD="$PRELOAD $LAUNCHER $CMD" 64 | echo "Training command: $FULL_CMD" 65 | 66 | # Launch the pytorch processes on each worker (use ssh for remote nodes) 67 | RANK=0 68 | for NODE in $RANKS; do 69 | if [[ "$NODE" == "$HOSTNAME" ]]; then 70 | echo "Launching rank $RANK on local node $NODE" 71 | eval $FULL_CMD & 72 | else 73 | echo "Launching rank $RANK on remote node $NODE" 74 | ssh $NODE "cd $PWD; $FULL_CMD" & 75 | fi 76 | RANK=$((RANK+1)) 77 | done 78 | 79 | wait 80 | -------------------------------------------------------------------------------- /testing/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities for unit tests.""" 2 | 3 | from __future__ import annotations 4 | -------------------------------------------------------------------------------- /testing/assignment.py: -------------------------------------------------------------------------------- 1 | """Lazy WorkAssignment implementation for testing.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch.distributed as dist 6 | 7 | from kfac.assignment import WorkAssignment 8 | 9 | 10 | class LazyAssignment(WorkAssignment): 11 | """Lazy assignment where every worker is an inverse worker. 12 | 13 | Used in unit tests force a KFACPreconditioner to execute all options 14 | in the distributed control flow. 15 | """ 16 | 17 | def __init__(self, rank: int = 0, broadcast: bool = False) -> None: 18 | """Init LazyAssignment. 19 | 20 | Args: 21 | rank (int): process rank to simulate (default: 0). 22 | broadcast (bool): value to return by broadcast_gradients() and 23 | broadcast_inverses() (default: False). 24 | """ 25 | self.rank = rank 26 | self.broadcast = broadcast 27 | 28 | def broadcast_gradients(self) -> bool: 29 | """Return if gradients need to be broadcast.""" 30 | return self.broadcast 31 | 32 | def broadcast_inverses(self) -> bool: 33 | """Return if inverses need to be broadcast.""" 34 | return self.broadcast 35 | 36 | def get_layers(self) -> tuple[str, ...]: 37 | """Return tuple of layers assigned.""" 38 | return tuple() 39 | 40 | def get_factors(self, layer: str) -> tuple[str, ...]: 41 | """Return tuple of factors associated with the layer.""" 42 | return tuple() 43 | 44 | def inv_worker(self, layer: str, factor: str) -> int: 45 | """Return rank that computes inverse factor for this layer.""" 46 | return self.rank 47 | 48 | def is_grad_worker(self, layer: str) -> bool: 49 | """Return if this rank is a gradient worker for this layer.""" 50 | return True 51 | 52 | def src_grad_worker(self, layer: str) -> int: 53 | """Return rank that will share preconditioned gradient. 54 | 55 | If process is a gradient worker, this method should return the 56 | process rank. Otherwise, if the process is a gradient receiver, this 57 | method returns the rank that is responsible for sending the 58 | preconditioned gradient to this process. 59 | """ 60 | return self.rank 61 | 62 | def factor_group( 63 | self, 64 | layer: str, 65 | factor: str, 66 | ) -> dist.ProcessGroup | None: 67 | """Communication group for allreducing factors.""" 68 | return None 69 | 70 | def grad_worker_group(self, layer: str) -> dist.ProcessGroup | None: 71 | """Return communication group for inverse factor broadcast. 72 | 73 | This communication group is used for the broadcasts of the inverses 74 | from the inverse worker to the remaining gradient workers for the 75 | layer. 76 | """ 77 | return None 78 | 79 | def grad_receiver_group(self, layer: str) -> dist.ProcessGroup | None: 80 | """Return communication group for preconditioned gradient broadcast. 81 | 82 | This communication group is used for the broadcasts of the gradients 83 | from the gradient worker to the remaining gradient receivers for the 84 | layer. 85 | """ 86 | return None 87 | -------------------------------------------------------------------------------- /testing/distributed.py: -------------------------------------------------------------------------------- 1 | """Decorator for running tests in simulated distributed environments.""" 2 | 3 | from __future__ import annotations 4 | 5 | import multiprocessing 6 | import os 7 | import time 8 | from typing import Any 9 | from typing import Callable 10 | from typing import cast 11 | from typing import TypeVar 12 | 13 | import pytest 14 | import torch.distributed as dist 15 | 16 | from testing.utils import open_port 17 | 18 | # Worker timeout *after* the first worker has completed. 19 | UNIT_WORKER_TIMEOUT = 30 20 | 21 | FuncT = TypeVar('FuncT', bound=Callable[..., Any]) 22 | 23 | 24 | def distributed_test( 25 | world_size: int | list[int] = 2, 26 | ) -> Callable[[FuncT], FuncT]: 27 | """Decorator for running tests in distributed environment. 28 | 29 | A decorator for executing a function (e.g., a unit test) in adistributed 30 | manner. This decorator manages the spawning and joining of processes, 31 | initialization of torch.distributed, and catching of errors. 32 | 33 | This function is copied from: https://github.com/EleutherAI/DeeperSpeed/blob/24026e5bb37c528a222b8635c46256b1e1825d2e/tests/unit/common.py#L16 34 | 35 | Example: 36 | >>> @distributed_test(worker_size=[2,3]) 37 | >>> def my_test(): 38 | >>> rank = dist.get_rank() 39 | >>> world_size = dist.get_world_size() 40 | >>> assert(rank < world_size) 41 | 42 | Args: 43 | world_size (int, list[int]): number of ranks to spawn. Can be a list to 44 | spawn to run tests multiple times. 45 | """ # noqa: E501 46 | 47 | port = open_port() 48 | 49 | def dist_wrap(run_func: FuncT) -> FuncT: 50 | """Second-level decorator that actually wraps the func.""" 51 | 52 | def dist_init( 53 | local_rank: int, 54 | num_procs: int, 55 | *func_args: list[Any], 56 | **func_kwargs: dict[str, Any], 57 | ) -> None: 58 | """Initialize torch.distributed and execute the user function.""" 59 | os.environ['MASTER_ADDR'] = '127.0.0.1' 60 | os.environ['MASTER_PORT'] = str(port) 61 | os.environ['LOCAL_RANK'] = str(local_rank) 62 | # NOTE: unit tests don't support multi-node so 63 | # local_rank == global rank 64 | os.environ['RANK'] = str(local_rank) 65 | os.environ['WORLD_SIZE'] = str(num_procs) 66 | 67 | dist.init_process_group('gloo') 68 | 69 | run_func(*func_args, **func_kwargs) 70 | 71 | # Keep faster ranks from exiting and breaking process group 72 | dist.barrier() 73 | 74 | def dist_launcher( 75 | num_procs: int, 76 | *func_args: list[Any], 77 | **func_kwargs: dict[str, Any], 78 | ) -> None: 79 | """Launch processes and gracefully handle failures.""" 80 | # Set multiprocessing to use fork because on MacOS/Windows, the 81 | # default in Python 3.8 and later is "spawn" which cannot 82 | # pickle lambda functions. 83 | # NOTE: fork does not work with CUDA tensors but that is okay 84 | # because the test suite does not use CUDA 85 | ctx = multiprocessing.get_context('fork') 86 | 87 | # Spawn all workers on subprocesses. 88 | processes = [] 89 | for local_rank in range(num_procs): 90 | p = ctx.Process( 91 | target=dist_init, 92 | args=(local_rank, num_procs, *func_args), 93 | kwargs=func_kwargs, 94 | ) 95 | p.start() 96 | processes.append(p) 97 | 98 | # Wait for all other processes to complete 99 | for p in processes: 100 | p.join(UNIT_WORKER_TIMEOUT) 101 | 102 | failed = [ 103 | (rank, p) 104 | for rank, p in enumerate(processes) 105 | if p.exitcode != 0 106 | ] 107 | for rank, p in failed: 108 | # If it still hasn't terminated, kill it because it hung. 109 | if p.exitcode is None: 110 | p.terminate() 111 | pytest.fail(f'Worker {rank} hung.', pytrace=False) 112 | elif p.exitcode < 0: 113 | pytest.fail( 114 | f'Worker {rank} killed by signal {-p.exitcode}', 115 | pytrace=False, 116 | ) 117 | elif p.exitcode > 0: 118 | pytest.fail( 119 | f'Worker {rank} exited with code {p.exitcode}', 120 | pytrace=False, 121 | ) 122 | 123 | def run_func_decorator( 124 | *func_args: list[Any], 125 | **func_kwargs: dict[str, Any], 126 | ) -> Any: 127 | """Entry point for @distributed_test().""" 128 | if isinstance(world_size, int): 129 | dist_launcher(world_size, *func_args, **func_kwargs) 130 | elif isinstance(world_size, list): 131 | for procs in world_size: 132 | dist_launcher(procs, *func_args, **func_kwargs) 133 | time.sleep(0.5) 134 | else: 135 | raise TypeError( 136 | 'world_size must be an integer or a list of integers.', 137 | ) 138 | 139 | return cast(FuncT, run_func_decorator) 140 | 141 | return cast(Callable[[FuncT], FuncT], dist_wrap) 142 | -------------------------------------------------------------------------------- /testing/gpt_neox.py: -------------------------------------------------------------------------------- 1 | """Testing utilities for GPT NeoX code.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | from unittest import mock 7 | 8 | import torch 9 | import torch.distributed as dist 10 | from deepspeed.pipe import PipelineModule # type: ignore 11 | from deepspeed.runtime.pipe.topology import ( # type: ignore 12 | PipeModelDataParallelTopology, # type: ignore 13 | ) 14 | 15 | 16 | class ColumnParallelLinear(torch.nn.Linear): 17 | """Mock ColumnParallelLinear from Megatron.""" 18 | 19 | pass 20 | 21 | 22 | class RowParallelLinear(torch.nn.Linear): 23 | """Mock RowParallelLinear from Megatron.""" 24 | 25 | pass 26 | 27 | 28 | def get_pipeline_module(*args: Any, **kwargs: Any) -> PipelineModule: 29 | """Create pipeline module with correct topology type.""" 30 | with mock.patch.object(PipelineModule, 'to', mock.MagicMock()): 31 | m = PipelineModule(*args, **kwargs) 32 | m._topo = PipeModelDataParallelTopology( 33 | num_pp=m.num_stages, 34 | num_dp=dist.get_world_size(m.world_group) // m.num_stages, 35 | num_mp=1, 36 | ) 37 | return m 38 | 39 | 40 | def sequential_model(layers: int, hidden_dim: int) -> torch.nn.Sequential: 41 | """Returns simple sequential linear model.""" 42 | if layers <= 0: 43 | raise ValueError('Num layers must be greater than 0') 44 | 45 | ls: list[torch.nn.Module] = [] 46 | ls.append(ColumnParallelLinear(hidden_dim, hidden_dim)) 47 | layers -= 1 48 | ls.extend( 49 | [RowParallelLinear(hidden_dim, hidden_dim) for _ in range(layers)], 50 | ) 51 | 52 | return torch.nn.Sequential(*ls) 53 | -------------------------------------------------------------------------------- /testing/models.py: -------------------------------------------------------------------------------- 1 | """PyTorch Models for Testing. 2 | 3 | Examples borrowed from: 4 | https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | import torch 10 | from torch.nn import functional 11 | 12 | 13 | class TinyModel(torch.nn.Module): 14 | """Tiny model with two linear layers.""" 15 | 16 | def __init__(self): 17 | """Init TinyModel.""" 18 | super().__init__() 19 | 20 | self.linear1 = torch.nn.Linear(10, 20, bias=False) 21 | self.activation = torch.nn.ReLU() 22 | self.linear2 = torch.nn.Linear(20, 10) 23 | self.softmax = torch.nn.Softmax(dim=1) 24 | 25 | def forward(self, x): 26 | """Forward pass.""" 27 | x = self.linear1(x) 28 | x = self.activation(x) 29 | x = self.linear2(x) 30 | x = self.softmax(x) 31 | return x 32 | 33 | 34 | class LeNet(torch.nn.Module): 35 | """LeNet implementation.""" 36 | 37 | def __init__(self): 38 | """Init LeNet.""" 39 | super().__init__() 40 | # 1 input image channel (black & white), 6 output channels, 41 | # 5x5 square convolution kernel 42 | self.conv1 = torch.nn.Conv2d(1, 6, 5) 43 | self.conv2 = torch.nn.Conv2d(6, 16, 3) 44 | # an affine operation: y = Wx + b 45 | self.fc1 = torch.nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension 46 | self.fc2 = torch.nn.Linear(120, 84) 47 | self.fc3 = torch.nn.Linear(84, 10) 48 | 49 | def forward(self, x): 50 | """Forward pass.""" 51 | # Max pooling over a (2, 2) window 52 | x = functional.max_pool2d(functional.relu(self.conv1(x)), (2, 2)) 53 | # If the size is a square you can only specify a single number 54 | x = functional.max_pool2d(functional.relu(self.conv2(x)), 2) 55 | x = x.view(-1, self.num_flat_features(x)) 56 | x = functional.relu(self.fc1(x)) 57 | x = functional.relu(self.fc2(x)) 58 | x = self.fc3(x) 59 | return x 60 | 61 | def num_flat_features(self, x): 62 | """Return number of flat features in x.""" 63 | size = x.size()[1:] # all dimensions except the batch dimension 64 | num_features = 1 65 | for s in size: 66 | num_features *= s 67 | return num_features 68 | -------------------------------------------------------------------------------- /testing/utils.py: -------------------------------------------------------------------------------- 1 | """Fixtures and utilities for testing.""" 2 | 3 | from __future__ import annotations 4 | 5 | import socket 6 | 7 | _used_ports: set[int] = set() 8 | 9 | 10 | def open_port() -> int: 11 | """Return open port. 12 | 13 | Sources: 14 | https://stackoverflow.com/questions/2838244 15 | https://github.com/proxystore/proxystore/blob/598b26072784c0d38e034fd8e73ef615a19974a9/testing/utils.py 16 | """ 17 | while True: 18 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 19 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 20 | s.bind(('', 0)) 21 | s.listen(1) 22 | port = s.getsockname()[1] 23 | s.close() 24 | if port not in _used_ports: # pragma: no branch 25 | _used_ports.add(port) 26 | return port 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/tests/__init__.py -------------------------------------------------------------------------------- /tests/gpt_neox/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | 5 | import pytest 6 | 7 | # DeepSpeed currently only supports Python 3.12 and older so skip 8 | # this entire test module in Python 3.12 or later 9 | if sys.version_info >= (3, 13): # pragma: >=3.13 cover 10 | pytest.skip( 11 | 'DeepSpeed does not support Python 3.13 and later.', 12 | allow_module_level=True, 13 | ) 14 | -------------------------------------------------------------------------------- /tests/gpt_neox/gpt_assignment_test.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for kfac/gpt_neox.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | from unittest import mock 6 | 7 | import pytest 8 | from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology 9 | 10 | from kfac.gpt_neox.assignment import GPTNeoXAssignment 11 | from kfac.gpt_neox.mpu import get_group_with_rank 12 | 13 | 14 | @pytest.mark.parametrize( 15 | 'work,ranks', 16 | ( 17 | ({}, [0, 1]), 18 | ({'l1': {'A': 1, 'G': 1}, 'l2': {'A': 1, 'G': 1}}, [0]), 19 | ({'l1': {'A': 1, 'G': 1}, 'l2': {'A': 1, 'G': 1}}, [0, 1]), 20 | ({'l1': {'A': 1, 'G': 1}, 'l2': {'A': 1, 'G': 1}}, [0, 1, 2]), 21 | ), 22 | ) 23 | def test_gpt_neox_assignment( 24 | work: dict[str, dict[str, float]], 25 | ranks: list[int], 26 | ) -> None: 27 | """Test GPTNeoXAssignment.""" 28 | with pytest.raises(TypeError): 29 | GPTNeoXAssignment( 30 | work, 31 | local_rank=99999, 32 | topology=object(), 33 | data_parallel_group=None, 34 | model_parallel_group=None, 35 | ) 36 | 37 | assignments = [] 38 | topology = PipeModelDataParallelTopology(1, len(ranks), 1) 39 | for rank in ranks: 40 | assignment = GPTNeoXAssignment( 41 | work, 42 | local_rank=rank, 43 | topology=topology, 44 | data_parallel_group=None, 45 | model_parallel_group=None, 46 | ) 47 | assignments.append((rank, assignment)) 48 | 49 | for rank, assignment in assignments: 50 | # GPTNeoXAssignment uses MEM-OPT so we should always broadcast 51 | # gradients and never inverses. 52 | assert assignment.broadcast_gradients() 53 | assert not assignment.broadcast_inverses() 54 | 55 | assert set(assignment.get_layers()) == set(work.keys()) 56 | for layer, factors in work.items(): 57 | assert set(assignment.get_factors(layer)) == set(factors.keys()) 58 | 59 | for layer, factors in work.items(): 60 | inv_workers = [ 61 | assignment.inv_worker(layer, factor) for factor in factors 62 | ] 63 | # Check every factor is assigned to same inv worker 64 | assert inv_workers.count(inv_workers[0]) == len(inv_workers) 65 | assert inv_workers[0] in ranks 66 | 67 | model_parallel_peers = get_group_with_rank( 68 | rank, 69 | topology.get_axis_comm_lists('model'), 70 | ) 71 | assert assignment.is_grad_worker(layer) == ( 72 | rank in model_parallel_peers 73 | and inv_workers[0] in model_parallel_peers 74 | ) 75 | 76 | for layer in work: 77 | with pytest.raises(NotImplementedError): 78 | assignment.grad_worker_group(layer) 79 | 80 | for layer in work: 81 | src_grad_workers = [ 82 | assignment.src_grad_worker(layer) for _, assignment in assignments 83 | ] 84 | 85 | assert src_grad_workers.count(src_grad_workers[0]) == 1 86 | 87 | factor_workers = set() 88 | for factor in work[layer]: 89 | factor_workers.add(assignment.factor_worker(layer, factor)) 90 | assert len(factor_workers) == 1 91 | 92 | groups = [ 93 | assignment.factor_group(layer, 'A') 94 | for _, assignment in assignments 95 | ] 96 | groups += [ 97 | assignment.grad_receiver_group(layer) 98 | for _, assignment in assignments 99 | ] 100 | assert groups.count(groups[0]) == len(groups) 101 | 102 | 103 | @pytest.mark.parametrize( 104 | 'work,ranks,expected', 105 | ( 106 | ( 107 | {'l1': {'A': 1, 'G': 1}, 'l2': {'A': 1, 'G': 1}}, 108 | [0], 109 | {'l1': {'A': 0, 'G': 0}, 'l2': {'A': 0, 'G': 0}}, 110 | ), 111 | ( 112 | {'l1': {'A': 1, 'G': 1}, 'l2': {'A': 1, 'G': 1}}, 113 | [0, 1], 114 | {'l1': {'A': 1, 'G': 1}, 'l2': {'A': 0, 'G': 0}}, 115 | ), 116 | ( 117 | {'l1': {'A': 1, 'G': 1}, 'l2': {'A': 1, 'G': 1}}, 118 | [0, 1, 2], 119 | {'l1': {'A': 1, 'G': 1}, 'l2': {'A': 0, 'G': 0}}, 120 | ), 121 | ( 122 | { 123 | 'l1': {'A': 10, 'G': 10}, 124 | 'l2': {'A': 1, 'G': 1}, 125 | 'l3': {'A': 1, 'G': 1}, 126 | }, 127 | [0, 1], 128 | { 129 | 'l1': {'A': 0, 'G': 0}, 130 | 'l2': {'A': 1, 'G': 1}, 131 | 'l3': {'A': 1, 'G': 1}, 132 | }, 133 | ), 134 | ), 135 | ) 136 | def test_gpt_neox_assignment_load_balancing( 137 | work: dict[str, dict[str, float]], 138 | ranks: list[int], 139 | expected: dict[str, dict[str, float]], 140 | ) -> None: 141 | """Test GPTNeoXAssignment load balancing.""" 142 | topology = PipeModelDataParallelTopology(1, len(ranks), 1) 143 | for rank in ranks: 144 | assignment = GPTNeoXAssignment( 145 | work, 146 | local_rank=rank, 147 | topology=topology, 148 | data_parallel_group=None, 149 | model_parallel_group=None, 150 | ) 151 | 152 | for layer, factors in expected.items(): 153 | for factor in factors: 154 | inv_worker = assignment.inv_worker(layer, factor) 155 | assert inv_worker == factors[factor] 156 | 157 | model_parallel_peers = get_group_with_rank( 158 | rank, 159 | topology.get_axis_comm_lists('model'), 160 | ) 161 | assert assignment.is_grad_worker(layer) == ( 162 | rank in model_parallel_peers 163 | and inv_worker in model_parallel_peers 164 | ) 165 | 166 | 167 | def test_reuse_comm_groups() -> None: 168 | """Test that we reuse existing comm groups when possible.""" 169 | with mock.patch('torch.distributed.new_group', return_value=-1): 170 | topology = PipeModelDataParallelTopology(2, 1, 2) 171 | assignment = GPTNeoXAssignment( 172 | {}, 173 | local_rank=0, 174 | topology=topology, 175 | data_parallel_group=-2, # type: ignore 176 | model_parallel_group=-3, # type: ignore 177 | ) 178 | assert ( 179 | assignment.pipe_parallel_peer_group 180 | == assignment.data_parallel_group 181 | ) 182 | 183 | topology = PipeModelDataParallelTopology(2, 2, 2) 184 | assignment = GPTNeoXAssignment( 185 | {}, 186 | local_rank=0, 187 | topology=topology, 188 | data_parallel_group=-2, # type: ignore 189 | model_parallel_group=-3, # type: ignore 190 | ) 191 | assert ( 192 | assignment.pipe_parallel_peer_group 193 | != assignment.data_parallel_group 194 | != assignment.model_parallel_group 195 | ) 196 | -------------------------------------------------------------------------------- /tests/gpt_neox/gpt_modules_test.py: -------------------------------------------------------------------------------- 1 | """Test for custom GPT NeoX Module Helpers.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from kfac.gpt_neox.modules import GPTNeoXLinearModuleHelper 8 | from testing.distributed import distributed_test 9 | from testing.gpt_neox import ColumnParallelLinear 10 | from testing.gpt_neox import RowParallelLinear 11 | 12 | 13 | @pytest.mark.parametrize('world_size,bias', ((1, True), (2, False), (4, True))) 14 | def test_linear_module(world_size: int, bias: bool) -> None: 15 | """Test custom module helper for GPT NeoX.""" 16 | 17 | @distributed_test(world_size) 18 | def _test() -> None: 19 | import torch.distributed as dist 20 | 21 | in_shape = 10 22 | out_shape = 5 23 | 24 | row_linear = RowParallelLinear(in_shape, out_shape, bias=bias) 25 | helper = GPTNeoXLinearModuleHelper( 26 | row_linear, 27 | dist.new_group(), 28 | parallelism='input', 29 | ) 30 | 31 | a_dim_size = (in_shape * dist.get_world_size()) + int(bias) 32 | assert helper.a_factor_shape == (a_dim_size, a_dim_size) 33 | assert helper.g_factor_shape == (out_shape, out_shape) 34 | 35 | col_linear = ColumnParallelLinear(in_shape, out_shape, bias=bias) 36 | helper = GPTNeoXLinearModuleHelper( 37 | col_linear, 38 | dist.new_group(), 39 | parallelism='output', 40 | ) 41 | 42 | a_dim_size = in_shape + int(bias) 43 | g_dim_size = out_shape * dist.get_world_size() 44 | assert helper.a_factor_shape == (a_dim_size, a_dim_size) 45 | assert helper.g_factor_shape == (g_dim_size, g_dim_size) 46 | 47 | _test() 48 | -------------------------------------------------------------------------------- /tests/gpt_neox/gpt_mpu_test.py: -------------------------------------------------------------------------------- 1 | """Test for custom GPT NeoX Module Helpers.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | import torch 7 | 8 | from kfac.gpt_neox.mpu import gather_from_model_parallel_region 9 | from kfac.gpt_neox.mpu import get_group_with_rank 10 | from kfac.gpt_neox.mpu import split_tensor_along_dim 11 | from testing.distributed import distributed_test 12 | 13 | 14 | @pytest.mark.parametrize( 15 | 'world_size,shape,dtype,fp32_allreduce', 16 | ( 17 | (1, (1,), torch.float, False), 18 | (2, (10, 10), torch.bfloat16, True), 19 | (4, (4, 4, 4), torch.float, False), 20 | ), 21 | ) 22 | def test_gather_model_parallel( 23 | world_size: int, 24 | shape: tuple[int], 25 | dtype: torch.dtype, 26 | fp32_allreduce: bool, 27 | ) -> None: 28 | """Test gather_from_model_parallel_region.""" 29 | 30 | @distributed_test(world_size) 31 | def _test() -> None: 32 | group = torch.distributed.new_group() 33 | rank = torch.distributed.get_rank(group) 34 | world_size = torch.distributed.get_world_size(group) 35 | dst = 0 36 | 37 | partial = torch.ones(shape, dtype=dtype) * rank 38 | result = gather_from_model_parallel_region( 39 | partial, 40 | dst, 41 | group, 42 | fp32_allreduce, 43 | ) 44 | 45 | if rank != dst: 46 | assert result is None 47 | else: 48 | expected_size = list(shape) 49 | expected_size[-1] = expected_size[-1] * world_size 50 | 51 | assert isinstance(result, torch.Tensor) 52 | assert result.size() == tuple(expected_size) 53 | 54 | _test() 55 | 56 | 57 | @pytest.mark.parametrize( 58 | 'rank,groups,result,error', 59 | ( 60 | (0, [[0], [1], [2]], [0], False), 61 | (-1, [[0], [1], [2]], None, True), 62 | (0, [[0, 1], [2], [0]], [0, 1], False), 63 | (4, [[0, 1, 2, 3], [2, 3, 4]], [2, 3, 4], False), 64 | ), 65 | ) 66 | def test_get_group_with_rank( 67 | rank: int, 68 | groups: list[list[int]], 69 | result: list[int] | None, 70 | error: bool, 71 | ) -> None: 72 | """Test get_group_with_rank.""" 73 | if error: 74 | with pytest.raises(ValueError): 75 | get_group_with_rank(rank, groups) 76 | else: 77 | assert get_group_with_rank(rank, groups) == result 78 | 79 | 80 | def test_split_tensor_along_dim() -> None: 81 | """Test split_tensor_along_dim.""" 82 | x = torch.zeros([1, 4]) 83 | with pytest.raises(ValueError): 84 | split_tensor_along_dim(x, 2, 0) 85 | 86 | x = torch.zeros([2, 11]) 87 | with pytest.raises(ValueError): 88 | split_tensor_along_dim(x, 2, -1) 89 | 90 | x = torch.zeros([6, 18]) 91 | xs = split_tensor_along_dim(x, 6, -1) 92 | 93 | # Every split should be same size 94 | shapes = [t.size() for t in xs] 95 | assert len(set(shapes)) == 1 96 | assert shapes[0] == (6, 3) 97 | for t in xs: 98 | assert not t.is_contiguous() 99 | 100 | x = torch.zeros([6, 18]) 101 | xs = split_tensor_along_dim(x, 6, -1, True) 102 | for t in xs: 103 | assert t.is_contiguous() 104 | -------------------------------------------------------------------------------- /tests/gpt_neox/gpt_preconditioner_test.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for kfac/gpt_neox.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | import os 7 | import pathlib 8 | from contextlib import redirect_stderr 9 | from contextlib import redirect_stdout 10 | from typing import Any 11 | from unittest import mock 12 | 13 | import deepspeed 14 | import pytest 15 | import torch 16 | 17 | from kfac.gpt_neox.preconditioner import GPTNeoXKFACPreconditioner 18 | from testing.distributed import distributed_test 19 | from testing.gpt_neox import get_pipeline_module 20 | from testing.gpt_neox import RowParallelLinear 21 | from testing.gpt_neox import sequential_model 22 | 23 | 24 | @pytest.mark.parametrize( 25 | 'num_stages,kwargs', 26 | ( 27 | (1, {'assignment_strategy': 'memory'}), 28 | # (2, {'compute_method': 'eigen'}), 29 | (4, {'allreduce_bucket_cap_mb': 0}), 30 | ), 31 | ) 32 | def test_gpt_neox_kfac_preconditioner( 33 | num_stages: int, 34 | kwargs: dict[str, Any], 35 | ) -> None: 36 | """Test GPTNeoXKFACPreconditioner.""" 37 | 38 | @distributed_test(world_size=num_stages) 39 | def check() -> None: 40 | num_layers = 6 41 | model = sequential_model(layers=num_layers, hidden_dim=32) 42 | 43 | deepspeed.init_distributed() 44 | 45 | # This one should not be registered because it is not 46 | # a Column/RowParallelLinear 47 | model.append(torch.nn.Linear(32, 32)) 48 | # This one should not be registered because it does not require grad 49 | module = RowParallelLinear(32, 32) 50 | module.requires_grad_(False) 51 | model.append(module) 52 | 53 | # Trashing stdout/stderr because get_pipeline_module prints stuff 54 | with redirect_stdout(None), redirect_stderr(None): 55 | logging.disable(10000) 56 | model = get_pipeline_module(layers=model, num_stages=num_stages) 57 | p = GPTNeoXKFACPreconditioner(model, **kwargs) 58 | 59 | # Check only 10 layers are registered (not the linear one) 60 | layers_per_rank = [ 61 | 0 for _ in range(torch.distributed.get_world_size()) 62 | ] 63 | torch.distributed.all_gather_object( 64 | layers_per_rank, 65 | len(p._layers), 66 | ) 67 | assert sum(layers_per_rank) == num_layers 68 | 69 | check() 70 | 71 | 72 | def test_input_validation() -> None: 73 | """Test GPTNeoXKFACPreconditioner input validation.""" 74 | 75 | @distributed_test(world_size=1) 76 | def check() -> None: 77 | model = sequential_model(1, 1) 78 | 79 | deepspeed.init_distributed() 80 | 81 | # Trashing stdout/stderr because get_pipeline_module prints stuff 82 | with redirect_stdout(None), redirect_stderr(None): 83 | logging.disable(10000) 84 | model_ = get_pipeline_module(model, num_stages=1) 85 | with pytest.raises(ValueError, match='Inverse'): 86 | GPTNeoXKFACPreconditioner(model_, compute_method='inverse') 87 | 88 | with pytest.raises(ValueError, match='PipelineModule'): 89 | GPTNeoXKFACPreconditioner(model) 90 | 91 | # Trashing stdout/stderr because get_pipeline_module prints stuff 92 | with redirect_stdout(None), redirect_stderr(None): 93 | logging.disable(10000) 94 | model_ = get_pipeline_module(layers=model, num_stages=1) 95 | with pytest.raises(ValueError, match='allreduce_bucket_cap_mb'): 96 | GPTNeoXKFACPreconditioner(model_, allreduce_bucket_cap_mb=-1) 97 | 98 | check() 99 | 100 | 101 | def test_state_dict() -> None: 102 | """Test GPTNeoXKFACPreconditioner state dict.""" 103 | world_size = 2 104 | 105 | @distributed_test(world_size=world_size) 106 | def check() -> None: 107 | num_layers = 6 108 | model = sequential_model(layers=num_layers, hidden_dim=32) 109 | 110 | deepspeed.init_distributed() 111 | 112 | # Trashing stdout/stderr because get_pipeline_module prints stuff 113 | with redirect_stdout(None), redirect_stderr(None): 114 | logging.disable(10000) 115 | model = get_pipeline_module(layers=model, num_stages=1) 116 | 117 | p = GPTNeoXKFACPreconditioner(model) 118 | 119 | state_dict = p.state_dict(include_factors=False) 120 | assert 'layers' not in state_dict 121 | p.load_state_dict(state_dict) 122 | 123 | p._assignment.inv_worker = mock.MagicMock( # type: ignore 124 | return_value=1, 125 | ) 126 | 127 | for name, layer in p._layers.values(): 128 | if torch.distributed.get_rank() == p._assignment.inv_worker( 129 | name, 130 | 'A', 131 | ): 132 | layer.a_factor = torch.rand([5, 5]) 133 | layer.g_factor = torch.rand([5, 5]) 134 | 135 | state_dict = p.state_dict(include_factors=True) 136 | assert 'layers' in state_dict 137 | assert len(state_dict['layers']) == num_layers 138 | p.load_state_dict(state_dict.copy(), compute_inverses=False) 139 | 140 | for _, layer in p._layers.values(): 141 | layer.compute_a_inv = mock.MagicMock() # type: ignore 142 | layer.compute_g_inv = mock.MagicMock() # type: ignore 143 | p.load_state_dict(state_dict) 144 | for _, layer in p._layers.values(): 145 | assert layer.compute_a_inv.called # type: ignore 146 | assert layer.compute_g_inv.called # type: ignore 147 | 148 | check() 149 | 150 | 151 | def test_state_dict_save_factor_to_file_error() -> None: 152 | """Test param validation for saving factors to disk.""" 153 | 154 | @distributed_test(world_size=1) 155 | def check() -> None: 156 | model = sequential_model(layers=1, hidden_dim=32) 157 | 158 | deepspeed.init_distributed() 159 | 160 | # Trashing stdout/stderr because get_pipeline_module prints stuff 161 | with redirect_stdout(None), redirect_stderr(None): 162 | logging.disable(10000) 163 | model = get_pipeline_module(layers=model, num_stages=1) 164 | 165 | p = GPTNeoXKFACPreconditioner(model) 166 | 167 | with pytest.raises(ValueError, match='factor_checkpoint_dir'): 168 | p.save_factors_to_dir() 169 | 170 | with pytest.raises(ValueError, match='factor_checkpoint_dir'): 171 | p.load_factors_from_dir() 172 | 173 | check() 174 | 175 | 176 | def test_load_factors_from_dir_warning(tmp_path: pathlib.Path) -> None: 177 | """Test warning if checkpoint dir does not exist.""" 178 | 179 | @distributed_test(world_size=1) 180 | def check() -> None: 181 | model = sequential_model(layers=1, hidden_dim=32) 182 | 183 | deepspeed.init_distributed() 184 | 185 | # Trashing stdout/stderr because get_pipeline_module prints stuff 186 | with redirect_stdout(None), redirect_stderr(None): 187 | logging.disable(10000) 188 | model = get_pipeline_module(layers=model, num_stages=1) 189 | 190 | path = str(tmp_path / 'checkpoint') 191 | p = GPTNeoXKFACPreconditioner(model, factor_checkpoint_dir=path) 192 | 193 | with pytest.warns(UserWarning, match='not a directory'): 194 | p.load_factors_from_dir() 195 | 196 | check() 197 | 198 | 199 | def test_state_dict_save_factors_to_file(tmp_path: pathlib.Path) -> None: 200 | """Test GPTNeoXKFACPreconditioner state dict.""" 201 | world_size = 2 202 | 203 | @distributed_test(world_size=world_size) 204 | def check() -> None: 205 | num_layers = 6 206 | model = sequential_model(layers=num_layers, hidden_dim=32) 207 | 208 | deepspeed.init_distributed() 209 | 210 | # Trashing stdout/stderr because get_pipeline_module prints stuff 211 | with redirect_stdout(None), redirect_stderr(None): 212 | logging.disable(10000) 213 | model = get_pipeline_module(layers=model, num_stages=1) 214 | 215 | p = GPTNeoXKFACPreconditioner( 216 | model, 217 | factor_checkpoint_dir=str(tmp_path), 218 | ) 219 | 220 | # Force the second rank to be the inverse worker for everything 221 | # such that the first rank should not save anything 222 | p._assignment.inv_worker = mock.MagicMock( # type: ignore 223 | return_value=1, 224 | ) 225 | p._assignment.factor_worker = mock.MagicMock( # type: ignore 226 | return_value=1, 227 | ) 228 | 229 | for name, layer in p._layers.values(): 230 | if torch.distributed.get_rank() == p._assignment.inv_worker( 231 | name, 232 | 'A', 233 | ): 234 | layer.a_factor = torch.rand([5, 5]) 235 | layer.g_factor = torch.rand([5, 5]) 236 | 237 | state_dict = p.state_dict(include_factors=True) 238 | torch.distributed.barrier() 239 | p.load_state_dict(state_dict, compute_inverses=False) 240 | 241 | assert 'layers' not in state_dict 242 | assert os.path.isdir(tmp_path) 243 | files = [ 244 | f 245 | for f in os.listdir(tmp_path) 246 | if os.path.isfile(os.path.join(tmp_path, f)) 247 | ] 248 | assert len(files) == num_layers 249 | 250 | torch.distributed.barrier() 251 | 252 | if torch.distributed.get_rank() == 0: 253 | # Delete file to check we only load files that exist 254 | os.remove(os.path.join(tmp_path, files[-1])) 255 | 256 | torch.distributed.barrier() 257 | 258 | for _, layer in p._layers.values(): 259 | layer.compute_a_inv = mock.MagicMock() # type: ignore 260 | layer.compute_g_inv = mock.MagicMock() # type: ignore 261 | p.load_state_dict(state_dict) 262 | a_inv_called = 0 263 | g_inv_called = 0 264 | for _, layer in p._layers.values(): 265 | a_inv_called += int(layer.compute_a_inv.called) # type: ignore 266 | g_inv_called += int(layer.compute_g_inv.called) # type: ignore 267 | 268 | if torch.distributed.get_rank() == 1: 269 | # We remove one layer checkpoint so one less of each inverse should 270 | # be computed 271 | assert a_inv_called == num_layers - 1 272 | assert g_inv_called == num_layers - 1 273 | 274 | check() 275 | -------------------------------------------------------------------------------- /tests/hyperparams_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for kfac/hyperparams.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from kfac.hyperparams import exp_decay_factor_averaging 8 | 9 | 10 | def test_exp_decay_factor_averaging_types() -> None: 11 | """Test types and exceptions of exp_decay_factor_averaging().""" 12 | assert callable(exp_decay_factor_averaging()) 13 | assert isinstance(exp_decay_factor_averaging()(1), float) 14 | with pytest.raises(ValueError): 15 | exp_decay_factor_averaging(0) 16 | with pytest.raises(ValueError): 17 | exp_decay_factor_averaging(-1) 18 | with pytest.raises(ValueError): 19 | exp_decay_factor_averaging()(-1) 20 | 21 | 22 | def test_exp_decay_factor_averaging_non_decreasing() -> None: 23 | """Test exp_decay_factor_averaging() produces non decreasing values.""" 24 | func = exp_decay_factor_averaging() 25 | values = [func(step) for step in range(1000)] 26 | assert all(a <= b for a, b in zip(values, values[1:])) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | 'min_value,values', 31 | ( 32 | ( 33 | 0.95, 34 | [(0, 0), (1, 0), (5, 0.8), (10, 0.9), (100, 0.95), (1000, 0.95)], 35 | ), 36 | (0.1, [(1, 0), (10, 0.1), (100, 0.1), (1000, 0.1)]), 37 | (1, [(1, 0), (10, 0.9), (100, 0.99)]), 38 | ), 39 | ) 40 | def test_exp_decay_factor_averaging_values( 41 | min_value: float, 42 | values: list[tuple[int, float]], 43 | ) -> None: 44 | """Test exp_decay_factor_averaging() input/outputs.""" 45 | func = exp_decay_factor_averaging(min_value) 46 | for step, expected_value in values: 47 | assert func(step) == expected_value 48 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/mnist_integration_test.py: -------------------------------------------------------------------------------- 1 | """MNIST integration test. 2 | 3 | Source: https://github.com/pytorch/examples/blob/0cb38ebb1b6e50426464b3485435c0c6affc2b65/mnist/main.py 4 | """ # noqa: E501 5 | 6 | from __future__ import annotations 7 | 8 | import time 9 | from typing import Any 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from torch.utils.data import DataLoader 15 | from torchvision.datasets import MNIST 16 | 17 | from kfac.preconditioner import KFACPreconditioner 18 | 19 | 20 | class FastMNIST(MNIST): 21 | """Fast MNIST dataset wrapper. 22 | 23 | Source: https://github.com/y0ast/pytorch-snippets/tree/main/fast_mnist 24 | """ 25 | 26 | def __init__(self, *args: Any, **kwargs: Any) -> None: 27 | """Init FastMNIST.""" 28 | super().__init__(*args, **kwargs) 29 | 30 | # Scale data to [0,1] 31 | self.data: torch.Tensor = self.data.unsqueeze(1).float().div(255) 32 | # Normalize it with the usual MNIST mean and std 33 | self.data = self.data.sub_(0.1307).div_(0.3081) 34 | 35 | def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: 36 | """Get training image and target class.""" 37 | return self.data[index], self.targets[index] 38 | 39 | 40 | class Net(nn.Module): 41 | """MNIST Classifier Network.""" 42 | 43 | def __init__(self) -> None: 44 | """Init Net.""" 45 | super().__init__() 46 | self.conv1 = nn.Conv2d(1, 4, 3, 1) 47 | self.conv2 = nn.Conv2d(4, 4, 3, 1) 48 | self.dropout1 = nn.Dropout(0.25) 49 | self.dropout2 = nn.Dropout(0.5) 50 | self.fc1 = nn.Linear(576, 64) 51 | self.fc2 = nn.Linear(64, 10) 52 | 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: 54 | """Forward pass.""" 55 | x = self.conv1(x) 56 | x = nn.functional.relu(x) 57 | x = self.conv2(x) 58 | x = nn.functional.relu(x) 59 | x = nn.functional.max_pool2d(x, 2) 60 | x = self.dropout1(x) 61 | x = torch.flatten(x, 1) 62 | x = self.fc1(x) 63 | x = nn.functional.relu(x) 64 | x = self.dropout2(x) 65 | x = self.fc2(x) 66 | return nn.functional.log_softmax(x, dim=1) 67 | 68 | 69 | def train( 70 | model: torch.nn.Module, 71 | train_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]], 72 | optimizer: optim.Optimizer, 73 | preconditioner: KFACPreconditioner | None, 74 | ) -> None: 75 | """Train model for one epoch.""" 76 | model.train() 77 | for data, target in train_loader: 78 | for param in model.parameters(): 79 | param.grad = None 80 | output = model(data) 81 | loss = nn.functional.nll_loss(output, target) 82 | loss.backward() 83 | if preconditioner is not None: 84 | preconditioner.step() 85 | optimizer.step() 86 | 87 | 88 | def evaluate( 89 | model: torch.nn.Module, 90 | test_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]], 91 | ) -> float: 92 | """Measure accuracy on test dataset.""" 93 | model.eval() 94 | correct = 0 95 | with torch.no_grad(): 96 | for data, target in test_loader: 97 | output = model(data) 98 | pred = output.argmax(dim=1, keepdim=True) 99 | correct += pred.eq(target.view_as(pred)).sum().item() 100 | total_samples = len(test_loader.dataset) # type: ignore 101 | return 100 * (correct / total_samples) 102 | 103 | 104 | def train_and_evaluate(precondition: bool, epochs: int) -> float: 105 | """Train and test.""" 106 | torch.manual_seed(42) 107 | 108 | train_dataset = FastMNIST('/tmp/MNIST-data', train=True, download=True) 109 | test_dataset = FastMNIST('/tmp/MNIST-data', train=False, download=True) 110 | 111 | train_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = DataLoader( 112 | train_dataset, 113 | batch_size=64, 114 | shuffle=True, 115 | num_workers=0, 116 | ) 117 | test_loader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = DataLoader( 118 | test_dataset, 119 | batch_size=1000, 120 | shuffle=False, 121 | num_workers=0, 122 | ) 123 | 124 | model = Net() 125 | optimizer = optim.Adadelta(model.parameters(), lr=0.1) 126 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7) 127 | if precondition: 128 | preconditioner = KFACPreconditioner( 129 | model, 130 | factor_update_steps=10, 131 | inv_update_steps=100, 132 | lr=lambda x: optimizer.param_groups[0]['lr'], 133 | update_factors_in_hook=False, 134 | ) 135 | else: 136 | preconditioner = None 137 | 138 | accuracy = 0.0 139 | for epoch in range(1, epochs + 1): 140 | start = time.perf_counter() 141 | train(model, train_loader, optimizer, preconditioner) 142 | accuracy = evaluate(model, test_loader) 143 | scheduler.step() 144 | end = time.perf_counter() 145 | print( 146 | f'Epoch {epoch}: accuracy={accuracy:.2f}%, ' 147 | f'time={end - start:.2f} seconds', 148 | ) 149 | 150 | return accuracy 151 | 152 | 153 | def main() -> bool: 154 | """MNIST integration test runner. 155 | 156 | Returns: 157 | True if training with KFAC produces a higher final validation 158 | accuracy than without, otherwise returns False. 159 | """ 160 | start = time.perf_counter() 161 | print('Starting MNIST integration test...') 162 | print('Training without KFAC:') 163 | adadelta_acc = train_and_evaluate(False, 5) 164 | print('Training with KFAC:') 165 | kfac_acc = train_and_evaluate(True, 5) 166 | failure = kfac_acc <= adadelta_acc 167 | runtime = time.perf_counter() - start 168 | print(f'Integration test runtime: {runtime:.2f} seconds.') 169 | if failure: 170 | print( 171 | 'Failure: KFAC accuracy is worse than default. ' 172 | f'KFAC acc. = {kfac_acc} vs. default acc. = {adadelta_acc}.', 173 | ) 174 | else: 175 | print('Success.') 176 | return failure 177 | 178 | 179 | if __name__ == '__main__': 180 | raise SystemExit(main()) 181 | -------------------------------------------------------------------------------- /tests/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/tests/layers/__init__.py -------------------------------------------------------------------------------- /tests/layers/modules_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for kfac/layers/modules.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | import torch 7 | 8 | from kfac.layers.modules import Conv2dModuleHelper 9 | from kfac.layers.modules import LinearModuleHelper 10 | 11 | 12 | @pytest.mark.parametrize( 13 | 'in_ch,out_ch,kernel_size,stride,padding,batch_size,hin,win,bias', 14 | [(3, 5, 2, 1, 0, 4, 10, 10, True), (3, 5, 3, 2, 1, 4, 10, 10, False)], 15 | ) 16 | def test_conv2d_module( 17 | in_ch: int, 18 | out_ch: int, 19 | kernel_size: int, 20 | stride: int, 21 | padding: int, 22 | batch_size: int, 23 | hin: int, 24 | win: int, 25 | bias: bool, 26 | ) -> None: 27 | """Test Conv2dModuleHelper.""" 28 | hout = int((hin + 2 * padding - 1 * (kernel_size - 1) - 1) / stride) + 1 29 | wout = int((win + 2 * padding - 1 * (kernel_size - 1) - 1) / stride) + 1 30 | 31 | conv2d = torch.nn.Conv2d( 32 | in_ch, 33 | out_ch, 34 | kernel_size, 35 | stride, 36 | padding, 37 | bias=bias, 38 | ) 39 | helper = Conv2dModuleHelper(conv2d) 40 | assert isinstance(repr(helper), str) 41 | assert conv2d.weight.device == helper.device 42 | 43 | data = torch.rand((batch_size, in_ch, hin, win)) 44 | target = torch.rand((batch_size, out_ch, hout, wout)) 45 | loss = (conv2d(data) - target).sum() 46 | loss.backward() 47 | 48 | grad_shape = (out_ch, in_ch * kernel_size * kernel_size + int(bias)) 49 | assert helper.get_grad().shape == grad_shape 50 | assert helper.has_bias() == bias 51 | assert helper.has_symmetric_factors() 52 | 53 | old_weight_grad = helper.get_weight_grad() 54 | if bias: 55 | old_bias_grad = helper.get_bias_grad() 56 | merged_grad = helper.get_grad() 57 | 58 | # Test set_grad() sets weight and bias (if exists) 59 | helper.set_grad(merged_grad) 60 | assert torch.equal(old_weight_grad, helper.get_weight_grad()) 61 | if bias: 62 | assert torch.equal(old_bias_grad, helper.get_bias_grad()) 63 | 64 | a = helper.get_a_factor(data) 65 | g = helper.get_g_factor(target) 66 | assert ( 67 | a.shape 68 | == helper.a_factor_shape 69 | == ( 70 | in_ch * kernel_size * kernel_size + int(bias), 71 | in_ch * kernel_size * kernel_size + int(bias), 72 | ) 73 | ) 74 | assert g.shape == helper.g_factor_shape == (out_ch, out_ch) 75 | 76 | 77 | @pytest.mark.parametrize('bias', [True, False]) 78 | def test_linear_module(bias: bool) -> None: 79 | """Test LinearModuleHelper.""" 80 | in_shape = 5 81 | out_shape = 3 82 | batch_size = 4 83 | 84 | linear = torch.nn.Linear(in_shape, out_shape, bias=bias) 85 | helper = LinearModuleHelper(linear) 86 | assert isinstance(repr(helper), str) 87 | assert linear.weight.device == helper.device 88 | 89 | data = torch.rand(in_shape) 90 | target = torch.rand(out_shape) 91 | loss = (linear(data) - target).sum() 92 | loss.backward() 93 | 94 | grad_shape = (out_shape, in_shape + int(bias)) 95 | assert helper.get_grad().shape == grad_shape 96 | assert helper.has_bias() == bias 97 | assert helper.has_symmetric_factors() 98 | 99 | old_weight_grad = helper.get_weight_grad() 100 | if bias: 101 | old_bias_grad = helper.get_bias_grad() 102 | merged_grad = helper.get_grad() 103 | 104 | # Test set_grad() sets weight and bias (if exists) 105 | helper.set_grad(merged_grad) 106 | assert torch.equal(old_weight_grad, helper.get_weight_grad()) 107 | if bias: 108 | assert torch.equal(old_bias_grad, helper.get_bias_grad()) 109 | 110 | a = helper.get_a_factor(torch.rand([batch_size, in_shape])) 111 | g = helper.get_g_factor(torch.rand([batch_size, out_shape])) 112 | assert ( 113 | a.shape 114 | == helper.a_factor_shape 115 | == (in_shape + int(bias), in_shape + int(bias)) 116 | ) 117 | assert g.shape == helper.g_factor_shape == (out_shape, out_shape) 118 | -------------------------------------------------------------------------------- /tests/layers/register_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for kfac/layers/register.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | import torch 7 | 8 | from kfac.distributed import TorchDistributedCommunicator 9 | from kfac.enums import AllreduceMethod 10 | from kfac.layers.base import KFACBaseLayer 11 | from kfac.layers.eigen import KFACEigenLayer 12 | from kfac.layers.inverse import KFACInverseLayer 13 | from kfac.layers.modules import Conv2dModuleHelper 14 | from kfac.layers.modules import LinearModuleHelper 15 | from kfac.layers.modules import ModuleHelper 16 | from kfac.layers.register import any_match 17 | from kfac.layers.register import get_flattened_modules 18 | from kfac.layers.register import get_module_helper 19 | from kfac.layers.register import register_modules 20 | from kfac.layers.register import requires_grad 21 | from testing.models import LeNet 22 | from testing.models import TinyModel 23 | 24 | 25 | class NestedTinyModel(torch.nn.Module): 26 | """Nested model for testing recursive module discovery.""" 27 | 28 | def __init__(self) -> None: 29 | """Init NestedTinyModel.""" 30 | super().__init__() 31 | 32 | self.tiny1 = TinyModel() 33 | self.tiny2 = TinyModel() 34 | self.tiny3 = TinyModel() 35 | 36 | 37 | @pytest.mark.parametrize( 38 | 'module,expected', 39 | ( 40 | ( 41 | TinyModel(), 42 | [ 43 | ('linear1', torch.nn.Linear), 44 | ('activation', torch.nn.ReLU), 45 | ('linear2', torch.nn.Linear), 46 | ('softmax', torch.nn.Softmax), 47 | ], 48 | ), 49 | ( 50 | NestedTinyModel(), 51 | [ 52 | ('tiny1.linear1', torch.nn.Linear), 53 | ('tiny1.activation', torch.nn.ReLU), 54 | ('tiny1.linear2', torch.nn.Linear), 55 | ('tiny1.softmax', torch.nn.Softmax), 56 | ('tiny2.linear1', torch.nn.Linear), 57 | ('tiny2.activation', torch.nn.ReLU), 58 | ('tiny2.linear2', torch.nn.Linear), 59 | ('tiny2.softmax', torch.nn.Softmax), 60 | ('tiny3.linear1', torch.nn.Linear), 61 | ('tiny3.activation', torch.nn.ReLU), 62 | ('tiny3.linear2', torch.nn.Linear), 63 | ('tiny3.softmax', torch.nn.Softmax), 64 | ], 65 | ), 66 | ( 67 | LeNet(), 68 | [ 69 | ('conv1', torch.nn.Conv2d), 70 | ('conv2', torch.nn.Conv2d), 71 | ('fc1', torch.nn.Linear), 72 | ('fc2', torch.nn.Linear), 73 | ('fc3', torch.nn.Linear), 74 | ], 75 | ), 76 | ), 77 | ) 78 | def test_get_flattened_modules( 79 | module: torch.nn.Module, 80 | expected: list[tuple[str, type[torch.nn.Module]]], 81 | ) -> None: 82 | """Test get_flattened_modules.""" 83 | modules = get_flattened_modules(module) 84 | for (name, module), (exp_name, exp_type) in zip(modules, expected): 85 | assert name == exp_name 86 | assert isinstance(module, exp_type) 87 | 88 | 89 | def test_requires_grad() -> None: 90 | """Test requires_grad.""" 91 | linear = torch.nn.Linear(1, 1) 92 | assert requires_grad(linear) 93 | linear.bias.requires_grad = False 94 | assert not requires_grad(linear) 95 | 96 | 97 | @pytest.mark.parametrize( 98 | 'module,expected', 99 | ( 100 | (torch.nn.Linear(1, 1), LinearModuleHelper), 101 | (torch.nn.Conv2d(1, 1, 1), Conv2dModuleHelper), 102 | (torch.nn.Conv3d(1, 1, 1), type(None)), 103 | ), 104 | ) 105 | def test_get_module_helper( 106 | module: torch.nn.Module, 107 | expected: type[ModuleHelper | None], 108 | ) -> None: 109 | """Test get_module_helper.""" 110 | assert isinstance(get_module_helper(module), expected) 111 | 112 | 113 | @pytest.mark.parametrize( 114 | 'model,layer_type,skip_layers,expected_count', 115 | ( 116 | (TinyModel(), KFACEigenLayer, [], 2), 117 | (TinyModel(), KFACInverseLayer, [], 2), 118 | (NestedTinyModel(), KFACEigenLayer, [], 6), 119 | (NestedTinyModel(), KFACInverseLayer, [], 6), 120 | (LeNet(), KFACEigenLayer, [], 5), 121 | (LeNet(), KFACInverseLayer, [], 5), 122 | (torch.nn.Conv3d(1, 1, 1), KFACEigenLayer, [], 0), 123 | # Test skip_layers: both by name or class and case invariant 124 | (LeNet(), KFACEigenLayer, ['fc1'], 4), 125 | (LeNet(), KFACEigenLayer, ['Conv2d'], 3), 126 | (LeNet(), KFACEigenLayer, ['Conv2d', 'Linear'], 0), 127 | ), 128 | ) 129 | def test_register_modules( 130 | model: torch.nn.Module, 131 | layer_type: type[KFACBaseLayer], 132 | skip_layers: list[str], 133 | expected_count: int, 134 | ) -> None: 135 | """Test register_modules.""" 136 | kwargs = dict( 137 | allreduce_method=AllreduceMethod.ALLREDUCE, 138 | grad_scaler=None, 139 | factor_dtype=None, 140 | inv_dtype=torch.float32, 141 | symmetry_aware=False, 142 | tdc=TorchDistributedCommunicator(), 143 | ) 144 | kfac_layers = register_modules( 145 | model, 146 | layer_type, 147 | skip_layers=skip_layers, 148 | **kwargs, 149 | ) 150 | assert len(kfac_layers) == expected_count 151 | 152 | 153 | @pytest.mark.parametrize( 154 | 'query,patterns,match', 155 | ( 156 | ('mystring', [], False), 157 | ('mystring', ['yourstring'], False), 158 | ('mystring', ['mystring'], True), 159 | ('mystring', ['string'], True), 160 | ('mystring', ['^string'], False), 161 | ('mystring', ['^string', '^my'], True), 162 | ( 163 | '2.attention.query_key_value', 164 | ['attention', 'query_key_value'], 165 | True, 166 | ), 167 | ), 168 | ) 169 | def test_any_match(query: str, patterns: list[str], match: bool) -> None: 170 | """Test any_match().""" 171 | assert any_match(query, patterns) == match 172 | -------------------------------------------------------------------------------- /tests/layers/utils_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for kfac/layers/utils.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | import torch 7 | 8 | from kfac.layers.utils import append_bias_ones 9 | from kfac.layers.utils import get_cov 10 | from kfac.layers.utils import reshape_data 11 | 12 | 13 | @pytest.mark.parametrize( 14 | 'shape,out_shape', 15 | [((1,), (2,)), ((4, 6), (4, 7)), ((1, 2, 3), (1, 2, 4))], 16 | ) 17 | def test_append_bias_ones(shape: tuple[int], out_shape: tuple[int]) -> None: 18 | """Test append_bias_ones.""" 19 | x = torch.rand(shape) 20 | x_out = append_bias_ones(x) 21 | assert x_out.shape == out_shape 22 | assert x_out[..., -1].sum() == x_out[..., -1].numel() 23 | 24 | 25 | @pytest.mark.parametrize( 26 | 'a,b,scale,expected', 27 | [ 28 | (torch.ones([2, 2]), None, None, torch.ones([2, 2])), 29 | (torch.ones([2, 2]), None, 4, 0.5 * torch.ones([2, 2])), 30 | (torch.ones([2, 2]), torch.zeros([2, 2]), None, torch.zeros([2, 2])), 31 | ( 32 | torch.ones([2, 2]), 33 | 10 * torch.ones([2, 2]), 34 | 5, 35 | 4 * torch.ones([2, 2]), 36 | ), 37 | ( 38 | torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), 39 | None, 40 | None, 41 | torch.tensor( 42 | [[22.0, 26.0, 30.0], [26.0, 31.0, 36.0], [30.0, 36.0, 42.0]], 43 | ), 44 | ), 45 | ( 46 | torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), 47 | torch.tensor([[9.0, 8.0, 7.0], [6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]), 48 | 2, 49 | torch.tensor( 50 | [[27.0, 21.0, 15.0], [36.0, 28.5, 21.0], [45.0, 36.0, 27.0]], 51 | ), 52 | ), 53 | # ValueError cases: 54 | (torch.ones([2]), None, None, torch.ones([2])), 55 | (torch.ones([2, 2]), torch.ones([2]), None, torch.ones([2])), 56 | ], 57 | ) 58 | def test_get_cov( 59 | a: torch.Tensor, 60 | b: torch.Tensor | None, 61 | scale: float | None, 62 | expected: torch.Tensor, 63 | ) -> None: 64 | """Test get_cov.""" 65 | if len(a.shape) != 2 or (b is not None and a.shape != b.shape): 66 | with pytest.raises(ValueError): 67 | get_cov(a, b, scale) 68 | else: 69 | out = get_cov(a, b, scale) 70 | assert torch.equal(out, expected) 71 | if b is None: 72 | assert torch.equal(out, out.t()) 73 | 74 | 75 | @pytest.mark.parametrize( 76 | 'shapes,collapse_dims,expected', 77 | [ 78 | (((2, 2),), False, (2, 2)), 79 | (((2, 2, 2),), False, (2, 2, 2)), 80 | (((2, 2, 2),), True, (4, 2)), 81 | (((2, 2), (4, 2)), False, (6, 2)), 82 | (((2, 2, 2), (4, 2, 2)), False, (6, 2, 2)), 83 | (((2, 2, 2), (4, 2, 2)), True, (12, 2)), 84 | ], 85 | ) 86 | def test_reshape_data( 87 | shapes: tuple[tuple[int]], 88 | collapse_dims: bool, 89 | expected: tuple[int], 90 | ) -> None: 91 | """Test reshape_data.""" 92 | # TODO: this test does not check batch_first = False (which assumes the 93 | # batch is the second dimension which is a little strange). 94 | tensors = [torch.ones(shape) for shape in shapes] 95 | out = reshape_data(tensors, batch_first=True, collapse_dims=collapse_dims) 96 | assert out.shape == expected 97 | -------------------------------------------------------------------------------- /tests/preconditioner_test.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for kfac/preconditioner.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | from typing import Any 7 | 8 | import pytest 9 | 10 | from kfac.enums import AllreduceMethod 11 | from kfac.enums import AssignmentStrategy 12 | from kfac.enums import ComputeMethod 13 | from kfac.enums import DistributedStrategy 14 | from kfac.preconditioner import KFACPreconditioner 15 | from testing.distributed import distributed_test 16 | from testing.models import TinyModel 17 | 18 | 19 | def test_preconditioner_init_raises() -> None: 20 | """Test KFACPreconditioner argument validation.""" 21 | with pytest.raises(ValueError): 22 | KFACPreconditioner(TinyModel(), allreduce_bucket_cap_mb=-1) 23 | 24 | KFACPreconditioner( 25 | TinyModel(), 26 | compute_eigenvalue_outer_product=True, 27 | compute_method=ComputeMethod.INVERSE, 28 | colocate_factors=False, 29 | ) 30 | with pytest.raises(ValueError): 31 | KFACPreconditioner( 32 | TinyModel(), 33 | compute_eigenvalue_outer_product=True, 34 | compute_method=ComputeMethod.EIGEN, 35 | colocate_factors=False, 36 | ) 37 | 38 | with pytest.raises(ValueError): 39 | KFACPreconditioner(TinyModel(), grad_worker_fraction=2) 40 | 41 | with pytest.raises(ValueError): 42 | KFACPreconditioner(TinyModel(), grad_worker_fraction=-1) 43 | 44 | @distributed_test(world_size=8) 45 | def _f() -> None: 46 | with pytest.raises(ValueError): 47 | KFACPreconditioner(TinyModel(), grad_worker_fraction=0.33) 48 | 49 | _f() 50 | 51 | with pytest.warns(): 52 | KFACPreconditioner( 53 | TinyModel(), 54 | compute_method=ComputeMethod.INVERSE, 55 | colocate_factors=False, 56 | grad_worker_fraction=DistributedStrategy.MEM_OPT, 57 | ) 58 | 59 | 60 | def test_preconditioner_init() -> None: 61 | """Test KFACPreconditioner initialization.""" 62 | p1 = KFACPreconditioner(TinyModel(), assignment_strategy='memory') 63 | p2 = KFACPreconditioner( 64 | TinyModel(), 65 | assignment_strategy=AssignmentStrategy.MEMORY, 66 | ) 67 | assert p1.assignment_strategy == p2.assignment_strategy 68 | 69 | p1 = KFACPreconditioner(TinyModel(), compute_method='inverse') 70 | p2 = KFACPreconditioner(TinyModel(), compute_method=ComputeMethod.INVERSE) 71 | assert p1.compute_method == p2.compute_method 72 | 73 | @distributed_test(world_size=4) 74 | def _f() -> None: 75 | p1 = KFACPreconditioner(TinyModel(), grad_worker_fraction=1) 76 | p2 = KFACPreconditioner( 77 | TinyModel(), 78 | grad_worker_fraction=DistributedStrategy.COMM_OPT, 79 | ) 80 | assert p1.distributed_strategy == p2.distributed_strategy 81 | assert p1.grad_worker_fraction == p2.grad_worker_fraction 82 | 83 | p1 = KFACPreconditioner( 84 | TinyModel(), 85 | grad_worker_fraction=DistributedStrategy.HYBRID_OPT, 86 | ) 87 | assert p1.grad_worker_fraction == 0.5 88 | 89 | p1 = KFACPreconditioner( 90 | TinyModel(), 91 | grad_worker_fraction=DistributedStrategy.MEM_OPT, 92 | ) 93 | assert p1.grad_worker_fraction == 0.25 94 | 95 | p1 = KFACPreconditioner(TinyModel(), grad_worker_fraction=0) 96 | assert p1.grad_worker_fraction == 0.25 97 | assert p1.distributed_strategy == DistributedStrategy.MEM_OPT 98 | 99 | p1 = KFACPreconditioner( 100 | TinyModel(), 101 | grad_worker_fraction=0.5, 102 | ) 103 | assert p1.distributed_strategy == DistributedStrategy.HYBRID_OPT 104 | 105 | _f() 106 | 107 | p1 = KFACPreconditioner(TinyModel(), allreduce_bucket_cap_mb=25) 108 | assert p1.allreduce_method == AllreduceMethod.ALLREDUCE_BUCKETED 109 | 110 | p1 = KFACPreconditioner(TinyModel(), allreduce_bucket_cap_mb=0) 111 | assert p1.allreduce_method == AllreduceMethod.ALLREDUCE 112 | 113 | 114 | def test_preconditioner_logging(caplog: Any) -> None: 115 | """Test KFACPreconditioner logs relevant info.""" 116 | caplog.set_level(logging.INFO) 117 | 118 | KFACPreconditioner(TinyModel(), loglevel=logging.DEBUG) 119 | assert len(caplog.records) == 0 120 | caplog.clear() 121 | 122 | KFACPreconditioner(TinyModel(), loglevel=logging.INFO) 123 | messages = [r.getMessage() for r in caplog.records] 124 | # Should register two layers in TinyModel and have a record for each 125 | assert sum('Registered' in msg for msg in messages) == 2 126 | # Should print KAISAAssignment once 127 | assert sum('KAISAAssignment' in msg for msg in messages) == 1 128 | -------------------------------------------------------------------------------- /tests/scheduler_test.py: -------------------------------------------------------------------------------- 1 | """Unit tests for kfac/scheduler.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | from typing import Callable 7 | 8 | import pytest 9 | 10 | from kfac.base_preconditioner import BaseKFACPreconditioner 11 | from kfac.preconditioner import KFACPreconditioner 12 | from kfac.scheduler import LambdaParamScheduler 13 | from testing.models import TinyModel 14 | 15 | 16 | def factor_func(scale: int, constant: bool = True) -> Callable[..., int]: 17 | """Get function which returns scale given step.""" 18 | 19 | def factor(step: int = 1) -> int: 20 | """Scale function.""" 21 | return scale if constant else scale * step 22 | 23 | return factor 24 | 25 | 26 | @pytest.mark.parametrize( 27 | 'preconditioner_type,preconditioner_kwargs', 28 | ((KFACPreconditioner, {'model': TinyModel()}),), 29 | ) 30 | def test_input_check( 31 | preconditioner_type: type[BaseKFACPreconditioner], 32 | preconditioner_kwargs: dict[str, Any], 33 | ) -> None: 34 | """Test raises ValueError if preconditioner was already passed lambda.""" 35 | preconditioner = preconditioner_type( 36 | **preconditioner_kwargs, 37 | factor_update_steps=factor_func(1), 38 | ) 39 | with pytest.raises(ValueError): 40 | LambdaParamScheduler( 41 | preconditioner, 42 | factor_update_steps_lambda=factor_func(1), 43 | ) 44 | 45 | preconditioner = KFACPreconditioner( 46 | TinyModel(), 47 | inv_update_steps=factor_func(1), 48 | ) 49 | with pytest.raises(ValueError): 50 | LambdaParamScheduler( 51 | preconditioner, 52 | inv_update_steps_lambda=factor_func(1), 53 | ) 54 | 55 | preconditioner = KFACPreconditioner(TinyModel(), damping=factor_func(1)) 56 | with pytest.raises(ValueError): 57 | LambdaParamScheduler(preconditioner, damping_lambda=factor_func(1)) 58 | 59 | preconditioner = KFACPreconditioner( 60 | TinyModel(), 61 | factor_decay=factor_func(1), 62 | ) 63 | with pytest.raises(ValueError): 64 | LambdaParamScheduler( 65 | preconditioner, 66 | factor_decay_lambda=factor_func(1), 67 | ) 68 | 69 | preconditioner = KFACPreconditioner(TinyModel(), kl_clip=factor_func(1)) 70 | with pytest.raises(ValueError): 71 | LambdaParamScheduler(preconditioner, kl_clip_lambda=factor_func(1)) 72 | 73 | preconditioner = KFACPreconditioner(TinyModel(), lr=factor_func(1)) 74 | with pytest.raises(ValueError): 75 | LambdaParamScheduler(preconditioner, lr_lambda=factor_func(1)) 76 | 77 | 78 | @pytest.mark.parametrize( 79 | 'preconditioner_type,preconditioner_kwargs', 80 | ((KFACPreconditioner, {'model': TinyModel()}),), 81 | ) 82 | def test_scheduler( 83 | preconditioner_type: type[BaseKFACPreconditioner], 84 | preconditioner_kwargs: dict[str, Any], 85 | ) -> None: 86 | """Test param scheduler.""" 87 | preconditioner = preconditioner_type( 88 | **preconditioner_kwargs, 89 | factor_update_steps=1, 90 | inv_update_steps=1, 91 | damping=1, 92 | factor_decay=1, 93 | kl_clip=1, 94 | lr=1, 95 | ) 96 | scheduler = LambdaParamScheduler( 97 | preconditioner, 98 | factor_update_steps_lambda=factor_func(2), 99 | inv_update_steps_lambda=factor_func(3), 100 | damping_lambda=factor_func(5), 101 | factor_decay_lambda=factor_func(7), 102 | kl_clip_lambda=factor_func(9), 103 | lr_lambda=factor_func(11), 104 | ) 105 | 106 | for steps in range(1, 10): 107 | preconditioner._steps = steps 108 | scheduler.step() 109 | assert preconditioner.factor_update_steps == 2**steps 110 | assert preconditioner.inv_update_steps == 3**steps 111 | assert preconditioner.damping == 5**steps 112 | assert preconditioner.factor_decay == 7**steps 113 | assert preconditioner.kl_clip == 9**steps 114 | assert preconditioner.lr == 11**steps 115 | 116 | preconditioner = preconditioner_type( 117 | **preconditioner_kwargs, 118 | factor_update_steps=1, 119 | inv_update_steps=1, 120 | damping=1, 121 | factor_decay=1, 122 | kl_clip=1, 123 | lr=1, 124 | ) 125 | scheduler = LambdaParamScheduler( 126 | preconditioner, 127 | factor_update_steps_lambda=factor_func(2, False), 128 | inv_update_steps_lambda=factor_func(3, False), 129 | damping_lambda=factor_func(5, False), 130 | factor_decay_lambda=factor_func(7, False), 131 | kl_clip_lambda=factor_func(9, False), 132 | lr_lambda=factor_func(11, False), 133 | ) 134 | for steps in range(1, 10): 135 | preconditioner._steps = steps 136 | scheduler.step(step=0) 137 | assert preconditioner.factor_update_steps == 0 138 | assert preconditioner.inv_update_steps == 0 139 | assert preconditioner.damping == 0 140 | assert preconditioner.factor_decay == 0 141 | assert preconditioner.kl_clip == 0 142 | assert preconditioner.lr == 0 143 | 144 | scheduler = LambdaParamScheduler(preconditioner) 145 | scheduler.step() 146 | -------------------------------------------------------------------------------- /tests/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpauloski/kfac-pytorch/bd591355e52320892c9e792fd08e60fda616c548/tests/testing/__init__.py -------------------------------------------------------------------------------- /tests/testing/distributed_wrapper_test.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for testing/distributed.py. 2 | 3 | Taken from: https://github.com/EleutherAI/DeeperSpeed/blob/eb7f5cff36678625d23db8a8fe78b4a93e5d2c75/tests/unit/test_dist.py 4 | """ # noqa: E501 5 | 6 | from __future__ import annotations 7 | 8 | import pytest 9 | import torch 10 | import torch.distributed as dist 11 | 12 | from testing.distributed import distributed_test 13 | 14 | 15 | @distributed_test(world_size=3) 16 | def test_distributed_test_init() -> None: 17 | """Test distributed wrapper initialized torch.distributed.""" 18 | assert dist.is_initialized() 19 | assert dist.get_world_size() == 3 20 | assert dist.get_rank() < 3 21 | 22 | 23 | @pytest.mark.parametrize('number,color', [(1138, 'purple')]) 24 | def test_dist_args(number: int, color: str) -> None: 25 | """Outer test function with inputs from pytest.mark.parametrize().""" 26 | 27 | @distributed_test(world_size=2) 28 | def _test_dist_args_helper(x: int, color: str = 'red') -> None: 29 | """Test distributed initialized and parameters are correctly passed.""" 30 | assert dist.get_world_size() == 2 31 | assert x == 1138 32 | assert color == 'purple' 33 | 34 | # Ensure that we can parse args to distributed_test decorated functions. 35 | _test_dist_args_helper(number, color=color) 36 | 37 | 38 | @distributed_test(world_size=[1, 2, 4]) 39 | def test_dist_allreduce() -> None: 40 | """Test collective communication operations work in simulated env.""" 41 | x = torch.ones(1, 3) * (dist.get_rank() + 1) 42 | sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 43 | result = torch.ones(1, 3) * sum_of_ranks 44 | dist.all_reduce(x) 45 | assert torch.all(x == result) 46 | -------------------------------------------------------------------------------- /tests/tracing_test.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for kfac/tracing.py.""" 2 | 3 | from __future__ import annotations 4 | 5 | import time 6 | 7 | from kfac.tracing import clear_trace 8 | from kfac.tracing import get_trace 9 | from kfac.tracing import log_trace 10 | from kfac.tracing import trace 11 | from testing.distributed import distributed_test 12 | 13 | 14 | def test_trace() -> None: 15 | """Test tracing function execution times.""" 16 | 17 | @trace() 18 | def a(t: float) -> None: 19 | time.sleep(t) 20 | 21 | @trace() 22 | def b(t: float) -> None: 23 | time.sleep(t) 24 | 25 | assert len(get_trace()) == 0 26 | # Check log raises no errors... we won't bother verifying the output 27 | log_trace() 28 | 29 | a(0.01) 30 | traces = get_trace() 31 | assert len(traces) == 1 32 | assert 'a' in traces 33 | assert traces['a'] >= 0.01 34 | 35 | a(0.0) 36 | new_traces = get_trace() 37 | assert new_traces['a'] < traces['a'] 38 | 39 | b(0.01) 40 | traces = get_trace() 41 | assert len(traces) == 2 42 | assert 'b' in traces 43 | 44 | traces = get_trace(average=False) 45 | assert traces['a'] > new_traces['a'] 46 | 47 | traces = get_trace(average=False, max_history=1) 48 | assert traces['a'] < 0.01 # should only use the 0 second sleep call 49 | # Check log raises no errors... we won't bother verifying the output 50 | log_trace() 51 | 52 | clear_trace() 53 | assert len(get_trace()) == 0 54 | 55 | 56 | @distributed_test(world_size=2) 57 | def test_synced_trace() -> None: 58 | """Test syncing function executions in distributed training.""" 59 | 60 | @trace(sync=True) 61 | def a(t: float) -> None: 62 | time.sleep(t) 63 | 64 | a(0.01) 65 | -------------------------------------------------------------------------------- /tests/training_test.py: -------------------------------------------------------------------------------- 1 | """End-to-end training test for KFACPreconditoner.""" 2 | 3 | from __future__ import annotations 4 | 5 | from multiprocessing import Process 6 | 7 | import pytest 8 | import torch 9 | 10 | from kfac.preconditioner import KFACPreconditioner 11 | from testing.distributed import distributed_test 12 | from testing.models import TinyModel 13 | 14 | 15 | def train(grad_worker_frac: float) -> None: 16 | """Train TinyModel with KFAC on random data.""" 17 | batch_size = 4 18 | in_features = 10 19 | out_features = 10 20 | steps = 20 21 | 22 | # https://github.com/pytorch/pytorch/issues/41197#issuecomment-656300677 23 | torch.set_num_threads(1) 24 | 25 | x = torch.rand(batch_size, in_features) 26 | y = torch.rand(batch_size, out_features) 27 | if torch.distributed.is_initialized(): 28 | torch.distributed.all_reduce(x) 29 | torch.distributed.all_reduce(y) 30 | 31 | model: torch.nn.Module = TinyModel() 32 | if torch.distributed.is_initialized(): 33 | model = torch.nn.parallel.DistributedDataParallel(model) 34 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 35 | preconditioner = KFACPreconditioner( 36 | model, 37 | factor_update_steps=5, 38 | inv_update_steps=10, 39 | grad_worker_fraction=grad_worker_frac, 40 | allreduce_bucket_cap_mb=0, 41 | update_factors_in_hook=False, 42 | ) 43 | criterion = torch.nn.MSELoss(reduction='sum') 44 | 45 | losses = [] 46 | for _ in range(steps): 47 | y_pred = model(x) 48 | loss = criterion(y_pred, y) 49 | losses.append(loss.item()) 50 | loss.backward() 51 | preconditioner.step() 52 | optimizer.step() 53 | optimizer.zero_grad() 54 | 55 | assert losses[0] > losses[-1] 56 | 57 | 58 | @pytest.mark.parametrize( 59 | 'distributed,grad_worker_frac,world_size', 60 | ((False, 1, 1), (True, 0, 1), (True, 0.5, 2), (True, 0.5, 4)), 61 | ) 62 | def test_training( 63 | distributed: bool, 64 | grad_worker_frac: float, 65 | world_size: int, 66 | ) -> None: 67 | """Test end-to-end training with KFACPreconditioner.""" 68 | if not distributed: 69 | # Note: torch does not allow forking if autograd has been used 70 | # in the parent process. So we perform the training is a separate 71 | # process to keep this parent process "clean". See 72 | # https://github.com/pytorch/pytorch/issues/69839#issuecomment-993686048 73 | p = Process(target=train, args=(grad_worker_frac,)) 74 | p.start() 75 | p.join() 76 | else: 77 | _train = distributed_test(world_size=world_size)(train) 78 | _train(grad_worker_frac) 79 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py39, py310, py311, py312, py313, pre-commit 3 | 4 | [testenv] 5 | install_command = 6 | python -I -m pip install {opts} {packages} --extra-index-url https://download.pytorch.org/whl/cpu 7 | extras = dev 8 | deps = 9 | # Note version 0.8.2 is required for PyTorch 2.0 compatibility. 10 | py{39,310,311,312}: deepspeed>=0.8.2 11 | setenv = 12 | CUDA_VISIBLE_DEVICES = "" 13 | commands = 14 | coverage erase 15 | 16 | py{39,310,311,312}: coverage run -m pytest {posargs} 17 | py{313}: coverage run --omit "tests/gpt_neox/*,testing/*,examples/*,tests/integration/*" -m pytest {posargs} --ignore tests/gpt_neox 18 | 19 | coverage combine --quiet 20 | 21 | py{39,310,311,312}: coverage report 22 | py{313}: coverage report --omit kfac/gpt_neox/*.py,tests/gpt_neox/* 23 | 24 | [testenv:pre-commit] 25 | skip_install = true 26 | deps = pre-commit 27 | commands = pre-commit run --all-files --show-diff-on-failure 28 | --------------------------------------------------------------------------------