├── .flake8 ├── .github ├── actions │ └── setup_environment │ │ └── action.yml └── workflows │ ├── build_and_upload_wheel.yml │ ├── release_public.yml │ ├── release_test.yml │ ├── run_linting.yml │ └── run_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── ci_requirements.txt ├── dmlcloud ├── __init__.py ├── core │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── checkpoint.py │ │ ├── common.py │ │ ├── cuda.py │ │ ├── diagnostics.py │ │ ├── git.py │ │ ├── metrics.py │ │ ├── profiler.py │ │ ├── table.py │ │ ├── tensorboard.py │ │ ├── timer.py │ │ └── wandb.py │ ├── checkpoint.py │ ├── config.py │ ├── distributed.py │ ├── logging.py │ ├── metrics.py │ ├── model.py │ ├── pipeline.py │ └── stage.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── interleave.py │ ├── sharding.py │ └── xarray.py ├── git.py ├── run.py ├── slurm.py ├── util │ ├── __init__.py │ ├── argparse.py │ ├── logging.py │ ├── seed.py │ ├── tcp.py │ ├── thirdparty.py │ └── wandb.py └── version.py ├── doc ├── Makefile ├── conf.py ├── dmlcloud.data.rst ├── dmlcloud.git.rst ├── dmlcloud.rst ├── dmlcloud.slurm.rst ├── getting_started │ ├── index.rst │ └── mnist.rst ├── index.rst ├── make.bat ├── requirements.txt └── user_guide │ └── index.rst ├── examples ├── README.md ├── custom_epochs.py └── mnist.py ├── misc └── logo │ ├── dmlcloud_color.png │ ├── dmlcloud_color.svg │ ├── dmlcloud_dark.png │ ├── dmlcloud_dark.svg │ ├── dmlcloud_dark2.png │ ├── dmlcloud_dark2.svg │ ├── dmlcloud_light.png │ └── dmlcloud_light.svg ├── pyproject.toml ├── requirements.txt └── test ├── conftest.py ├── test_callback.py ├── test_config.py ├── test_csv.py ├── test_data.py ├── test_global_accessors.py ├── test_import.py ├── test_io_redirector.py ├── test_root_only.py ├── test_seed.py └── test_smoke.py /.flake8: -------------------------------------------------------------------------------- 1 | 2 | [flake8] 3 | max-line-length = 120 4 | ignore = E203, E402, E501, E741 5 | -------------------------------------------------------------------------------- /.github/actions/setup_environment/action.yml: -------------------------------------------------------------------------------- 1 | name: Setup CI Environment 2 | inputs: 3 | python-version: 4 | default: "3.11" 5 | type: string 6 | 7 | runs: 8 | using: composite 9 | steps: 10 | - name: Setup Python 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: ${{ inputs.python-version }} 14 | 15 | - name: Install CI Dependencies 16 | shell: bash 17 | run: | 18 | pip install -r ci_requirements.txt 19 | echo "/home/runner/.local/bin" >> $GITHUB_PATH 20 | -------------------------------------------------------------------------------- /.github/workflows/build_and_upload_wheel.yml: -------------------------------------------------------------------------------- 1 | name: Build and Upload Wheel 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | branch: 7 | required: true 8 | type: string 9 | do-upload: 10 | required: false 11 | default: true 12 | type: boolean 13 | real-pypi: 14 | required: false 15 | default: false 16 | type: boolean 17 | secrets: 18 | PYPI_TOKEN: 19 | required: true 20 | 21 | jobs: 22 | wheel_build: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Checkout Repository 26 | uses: actions/checkout@v4 27 | with: 28 | ref: ${{ inputs.branch }} 29 | submodules: recursive 30 | 31 | - name: Setup Environment 32 | uses: ./.github/actions/setup_environment 33 | 34 | - name: Build Wheel 35 | run: | 36 | python -m build 37 | 38 | - name: Upload Wheels to Github 39 | uses: actions/upload-artifact@v4 40 | with: 41 | name: wheels 42 | path: dist/*.whl 43 | 44 | wheel_upload: 45 | if: inputs.do-upload == true 46 | needs: [wheel_build] 47 | runs-on: ubuntu-latest 48 | outputs: 49 | upload: ${{ steps.trigger_upload.outputs.value }} 50 | steps: 51 | - name: Setup Python 52 | uses: actions/setup-python@v5 53 | with: 54 | python-version: "3.10" 55 | 56 | - name: Download Artifacts From Github 57 | continue-on-error: true 58 | uses: actions/download-artifact@v4 59 | with: 60 | name: wheels 61 | 62 | - name: Determine If Wheel Uploading Is Needed 63 | run: | 64 | upload=false 65 | for txt in *.whl; do 66 | upload=true 67 | break 68 | done 69 | echo "value=$upload" >> $GITHUB_OUTPUT 70 | id: trigger_upload 71 | 72 | - name: Display All Wheels 73 | if: steps.trigger_upload.outputs.value == 'true' 74 | run: ls -lh *.whl 75 | 76 | - name: Upload Wheels to PyPI 77 | if: | 78 | steps.trigger_upload.outputs.value == 'true' 79 | env: 80 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 81 | run: | 82 | pip install twine 83 | if [[ "${{ inputs.real-pypi }}" == true ]]; then 84 | python -m twine upload \ 85 | --username __token__ \ 86 | --password "$PYPI_TOKEN" \ 87 | *.whl 88 | else 89 | python -m twine upload \ 90 | -r testpypi \ 91 | --username __token__ \ 92 | --password "$PYPI_TOKEN" \ 93 | *.whl 94 | fi 95 | -------------------------------------------------------------------------------- /.github/workflows/release_public.yml: -------------------------------------------------------------------------------- 1 | name: Public Release 2 | 3 | on: 4 | # [ Note: Manually Trigger the Workflow ] 5 | # 1. Go to Actions under pytorch/data repo 6 | # 2. In the left sidebar, click the workflow you want to run 7 | # 3. Above the list of workflow runs, select Run workflow 8 | # 4. Use the Branch dropdown to select the release/* branch 9 | # 5. Click Run workflow 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build_and_upload_wheel: 14 | uses: ./.github/workflows/build_and_upload_wheel.yml 15 | with: 16 | branch: "" 17 | real-pypi: true 18 | secrets: 19 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 20 | -------------------------------------------------------------------------------- /.github/workflows/release_test.yml: -------------------------------------------------------------------------------- 1 | name: Test Release 2 | 3 | on: 4 | # [ Note: Manually Trigger the Workflow ] 5 | # 1. Go to Actions under pytorch/data repo 6 | # 2. In the left sidebar, click the workflow you want to run 7 | # 3. Above the list of workflow runs, select Run workflow 8 | # 4. Use the Branch dropdown to select the release/* branch 9 | # 5. Click Run workflow 10 | workflow_dispatch: 11 | 12 | jobs: 13 | build_and_upload_wheel: 14 | uses: ./.github/workflows/build_and_upload_wheel.yml 15 | with: 16 | branch: "" 17 | secrets: 18 | PYPI_TOKEN: ${{ secrets.TEST_PYPI_TOKEN }} 19 | -------------------------------------------------------------------------------- /.github/workflows/run_linting.yml: -------------------------------------------------------------------------------- 1 | name: Run Linting 2 | 3 | on: 4 | push: 5 | branches: 6 | - develop 7 | tags: # ignores pushes to tags 8 | pull_request: 9 | 10 | jobs: 11 | lint: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout Repository 15 | uses: actions/checkout@v4 16 | with: 17 | submodules: recursive 18 | 19 | - name: Setup Environment 20 | uses: ./.github/actions/setup_environment 21 | 22 | - name: Lint 23 | run: pre-commit run --all-files 24 | 25 | - name: Required modifications 26 | if: ${{ failure() }} 27 | run: git --no-pager diff 28 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | on: 3 | push: 4 | branches: 5 | - develop 6 | tags: # ignores tag pushes 7 | pull_request: 8 | 9 | jobs: 10 | test: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: 16 | - macos-latest 17 | - ubuntu-latest 18 | python-version: 19 | - "3.10" 20 | steps: 21 | - name: Checkout Repository 22 | uses: actions/checkout@v4 23 | with: 24 | submodules: recursive 25 | 26 | - name: Setup Environment 27 | uses: ./.github/actions/setup_environment 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Install Dependencies 32 | run: | 33 | pip install -r requirements.txt 34 | 35 | - name: Install Project 36 | run: | 37 | pip install . 38 | 39 | - name: Run tests & coverage 40 | run: | 41 | coverage run -m pytest --no-header -v test 42 | coverage report -m -i 43 | coverage html -i 44 | 45 | - name: Archive coverage results 46 | if: startsWith(matrix.os, 'ubuntu') 47 | uses: actions/upload-artifact@v4 48 | with: 49 | name: code-coverage-report 50 | path: htmlcov 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /slurm-*.out 3 | /wandb 4 | /doc/_build 5 | /doc/generated/ 6 | /checkpoints 7 | 8 | ############################ Auto Generated ############################ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.5.0 5 | hooks: 6 | - id: trailing-whitespace 7 | - id: end-of-file-fixer 8 | 9 | - repo: https://github.com/asottile/pyupgrade 10 | rev: v3.15.2 11 | hooks: 12 | - id: pyupgrade 13 | args: [--py37-plus] 14 | 15 | - repo: https://github.com/omnilib/ufmt 16 | rev: v2.5.1 17 | hooks: 18 | - id: ufmt 19 | additional_dependencies: 20 | - black == 23.1.0 21 | - usort == 1.1.0b2 22 | 23 | - repo: https://github.com/pycqa/flake8 24 | rev: 7.0.0 25 | hooks: 26 | - id: flake8 27 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.12" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: doc/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | # formats: 24 | # - pdf 25 | # - epub 26 | 27 | # Optional but recommended, declare the Python requirements required 28 | # to build your documentation 29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 30 | python: 31 | install: 32 | - requirements: doc/requirements.txt 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Sebastian Hoffmann 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Dmlcloud Logo](./misc/logo/dmlcloud_color.png) 2 | --------------- 3 | [![PyPI Status](https://img.shields.io/pypi/v/dmlcloud)](https://pypi.org/project/dmlcloud/) 4 | [![Documentation Status](https://readthedocs.org/projects/dmlcloud/badge/?version=latest)](https://dmlcloud.readthedocs.io/en/latest/?badge=latest) 5 | [![Test Status](https://img.shields.io/github/actions/workflow/status/sehoffmann/dmlcloud/run_tests.yml?label=tests&logo=github)](https://github.com/sehoffmann/dmlcloud/actions/workflows/run_tests.yml) 6 | 7 | A torch library for easy distributed deep learning on HPC clusters. Supports both slurm and MPI. No unnecessary abstractions and overhead. Simple, yet powerful, API. 8 | 9 | ## Highlights 10 | - Simple, yet powerful, API 11 | - Easy initialization of `torch.distributed` 12 | - Distributed metrics 13 | - Extensive logging and diagnostics 14 | - Wandb support 15 | - Tensorboard support 16 | - A wealth of useful utility functions 17 | 18 | ## Installation 19 | dmlcloud can be installed directly from PyPI: 20 | ```bash 21 | pip install dmlcloud 22 | ``` 23 | 24 | Alternatively, you can install the latest development version directly from Github: 25 | ```bash 26 | pip install git+https://github.com/sehoffmann/dmlcloud.git 27 | ``` 28 | 29 | ### Documentation 30 | 31 | You can find the official documentation at [Read the Docs](https://dmlcloud.readthedocs.io/en/latest/) 32 | 33 | ## Minimal Example 34 | See [examples/mnist.py](https://github.com/sehoffmann/dmlcloud/blob/develop/examples/mnist.py) for a minimal example on how to train MNIST with multiple GPUS. To run it with 4 GPUs, use 35 | ```bash 36 | dmlrun -n 4 python examples/mnist.py 37 | ``` 38 | `dmlrun` is a thin wrapper around `torchrun` that makes it easier to prototype on a single node. 39 | 40 | ## Slurm Support 41 | *dmlcloud* automatically looks for slurm environment variables to initialize torch.distributed. On a slurm cluster, you can hence simply use `srun` from within an sbatch script to train on multiple nodes: 42 | 43 | ```bash 44 | #!/bin/bash 45 | #SBATCH --nodes=2 46 | #SBATCH --ntasks-per-node=4 47 | #SBATCH --gpus-per-node=4 48 | #SBATCH --cpus-per-task=8 49 | #SBATCH --gpu-bind=none 50 | 51 | srun python examples/mnist.py 52 | ``` 53 | 54 | ## FAQ 55 | 56 | ### How is dmlcloud different from similar libraries like *pytorch lightning* or *fastai*? 57 | 58 | dmlcloud was designed foremost with one underlying principle: 59 | > **No unnecessary abstractions, just help with distributed training** 60 | 61 | As a consequence, dmlcloud code is almost identical to a regular pytorch training loop and only requires a few adjustments here and there. 62 | In contrast, other libraries often introduce extensive API's that can quickly feel overwhelming due to their sheer amount of options. 63 | 64 | For instance, **the constructor of `ligthning.Trainer` has 51 arguments! `dml.Pipeline` only has 2.** 65 | -------------------------------------------------------------------------------- /ci_requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | packaging 3 | wheel 4 | build 5 | pre-commit 6 | pytest 7 | sphinx 8 | sphinx-rtd-theme 9 | sphinx-autodoc-typehints 10 | coverage 11 | -------------------------------------------------------------------------------- /dmlcloud/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A torch library for easy distributed deep learning on HPC clusters. 3 | Supports both slurm and MPI. No unnecessary abstractions and overhead. 4 | Simple, yet powerful, API. 5 | """ 6 | 7 | from .version import __version__ 8 | 9 | __all__ = ['__version__'] 10 | 11 | 12 | ################################### 13 | # Sub Packages 14 | ################################### 15 | 16 | import dmlcloud.data as data 17 | import dmlcloud.git as git 18 | import dmlcloud.slurm as slurm 19 | 20 | 21 | __all__ += [ 22 | 'data', 23 | 'git', 24 | 'slurm', 25 | ] 26 | 27 | 28 | ################################### 29 | # Top-level API 30 | ################################### 31 | 32 | 33 | # Pipeline 34 | 35 | from .core.pipeline import current_pipe, current_stage, log_metric, Pipeline 36 | 37 | __all__ += [ 38 | 'Pipeline', 39 | 'current_pipe', 40 | 'current_stage', 41 | 'log_metric', 42 | ] 43 | 44 | # Stage 45 | 46 | from .core.stage import Stage 47 | 48 | __all__ += [ 49 | 'Stage', 50 | ] 51 | 52 | # Callbacks 53 | 54 | from .core.callbacks import Callback 55 | 56 | __all__ += [ 57 | 'Callback', 58 | ] 59 | 60 | # Distributed helpers 61 | 62 | from .core.distributed import ( 63 | all_gather_object, 64 | broadcast_object, 65 | deinitialize_torch_distributed, 66 | gather_object, 67 | has_environment, 68 | has_mpi, 69 | has_slurm, 70 | init, 71 | is_root, 72 | local_node, 73 | local_rank, 74 | local_world_size, 75 | rank, 76 | root_first, 77 | root_only, 78 | seed, 79 | world_size, 80 | ) 81 | 82 | __all__ += [ 83 | 'has_slurm', 84 | 'has_environment', 85 | 'has_mpi', 86 | 'is_root', 87 | 'root_only', 88 | 'root_first', 89 | 'rank', 90 | 'world_size', 91 | 'local_rank', 92 | 'local_world_size', 93 | 'local_node', 94 | 'all_gather_object', 95 | 'gather_object', 96 | 'broadcast_object', 97 | 'init', 98 | 'deinitialize_torch_distributed', 99 | 'seed', 100 | ] 101 | 102 | # Metrics 103 | 104 | from .core.metrics import Tracker, TrainingHistory 105 | 106 | __all__ += [ 107 | Tracker, 108 | TrainingHistory, 109 | ] 110 | 111 | # Logging 112 | 113 | from .core.logging import ( 114 | critical, 115 | debug, 116 | error, 117 | flush_logger, 118 | info, 119 | log, 120 | logger, 121 | print_root, 122 | print_worker, 123 | reset_logger, 124 | setup_logger, 125 | warning, 126 | ) 127 | 128 | __all__ += [ 129 | 'logger', 130 | 'setup_logger', 131 | 'reset_logger', 132 | 'flush_logger', 133 | 'print_root', 134 | 'print_worker', 135 | 'log', 136 | 'debug', 137 | 'info', 138 | 'warning', 139 | 'error', 140 | 'critical', 141 | ] 142 | 143 | # Model helper 144 | 145 | from .core.model import count_parameters, scale_lr, wrap_ddp 146 | 147 | __all__ += [ 148 | 'wrap_ddp', 149 | 'scale_lr', 150 | 'count_parameters', 151 | ] 152 | 153 | # Config helper 154 | 155 | from .core.config import factory_from_cfg, import_object, obj_from_cfg 156 | 157 | __all__ += [ 158 | 'import_object', 159 | 'factory_from_cfg', 160 | 'obj_from_cfg', 161 | ] 162 | -------------------------------------------------------------------------------- /dmlcloud/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sehoffmann/dmlcloud/9aba8f3c62e3ca52852b7d5334902e52430677ea/dmlcloud/core/__init__.py -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import CheckpointCallback 2 | from .common import Callback, CallbackList, CbPriority 3 | from .cuda import CudaCallback 4 | from .diagnostics import DiagnosticsCallback 5 | from .git import GitDiffCallback 6 | from .metrics import CsvCallback, ReduceMetricsCallback 7 | from .profiler import ProfilerCallback 8 | from .table import TableCallback 9 | from .tensorboard import TensorboardCallback 10 | from .timer import TimerCallback 11 | from .wandb import WandbInitCallback, WandbLoggerCallback 12 | 13 | 14 | __all__ = [ 15 | 'CallbackList', 16 | 'CbPriority', 17 | 'Callback', 18 | 'ProfilerCallback', 19 | 'TimerCallback', 20 | 'TableCallback', 21 | 'ReduceMetricsCallback', 22 | 'CheckpointCallback', 23 | 'CsvCallback', 24 | 'DiagnosticsCallback', 25 | 'GitDiffCallback', 26 | 'WandbInitCallback', 27 | 'WandbLoggerCallback', 28 | 'TensorboardCallback', 29 | 'CudaCallback', 30 | ] 31 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import TYPE_CHECKING, Union 4 | 5 | import dmlcloud.core.checkpoint as dml_checkpoint 6 | from dmlcloud.util.logging import IORedirector 7 | from .common import Callback 8 | 9 | if TYPE_CHECKING: 10 | from dmlcloud.core.pipeline import Pipeline 11 | 12 | 13 | class CheckpointCallback(Callback): 14 | """ 15 | Creates the checkpoint directory and optionally setups io redirection. 16 | """ 17 | 18 | def __init__(self, run_dir: Union[str, Path], redirect_io: bool = True): 19 | """ 20 | Initialize the callback with the given path. 21 | 22 | Args: 23 | run_dir: The path to the checkpoint directory. 24 | redirect_io: Whether to redirect the IO to a file. Defaults to True. 25 | """ 26 | self.run_dir = Path(run_dir) 27 | self.redirect_io = redirect_io 28 | self.io_redirector = None 29 | 30 | def pre_run(self, pipe: 'Pipeline'): 31 | if not dml_checkpoint.is_valid_checkpoint_dir(self.run_dir): 32 | dml_checkpoint.create_checkpoint_dir(self.run_dir) 33 | dml_checkpoint.save_config(pipe.config, self.run_dir) 34 | 35 | self.io_redirector = IORedirector(pipe.run_dir / 'log.txt') 36 | self.io_redirector.install() 37 | 38 | with open(pipe.run_dir / "environment.txt", 'w') as f: 39 | for k, v in os.environ.items(): 40 | f.write(f"{k}={v}\n") 41 | 42 | def cleanup(self, pipe, exc_type, exc_value, traceback): 43 | if self.io_redirector is not None: 44 | self.io_redirector.uninstall() 45 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/common.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | from typing import TYPE_CHECKING 3 | 4 | 5 | if TYPE_CHECKING: 6 | from dmlcloud.core.pipeline import Pipeline 7 | from dmlcloud.core.stage import Stage 8 | 9 | 10 | class CallbackList: 11 | """ 12 | A priority queue of callbacks. 13 | """ 14 | 15 | def __init__(self): 16 | self.callbacks = [] 17 | 18 | def append(self, callback: 'Callback', priority: int = 0): 19 | """ 20 | Append a callback to the list with the given priority. 21 | 22 | Args: 23 | callback (Callback): The callback to append. 24 | priority (int, optional): The priority of the callback. Defaults to 0. 25 | """ 26 | self.callbacks.append((priority, callback)) 27 | 28 | def __iter__(self): 29 | for _, callback in sorted(self.callbacks, key=lambda x: x[0]): 30 | yield callback 31 | 32 | def __len__(self): 33 | return len(self.callbacks) 34 | 35 | def __add__(self, other: 'CallbackList'): 36 | result = CallbackList() 37 | result.callbacks = self.callbacks + other.callbacks 38 | return result 39 | 40 | 41 | class CbPriority(IntEnum): 42 | """ 43 | Default priorities for callbacks used by the pipeline and stage classes. 44 | """ 45 | 46 | WANDB_INIT = -200 47 | CHECKPOINT = -190 48 | STAGE_TIMER = -180 49 | DIAGNOSTICS = -170 50 | CUDA = -160 51 | GIT = -150 52 | METRIC_REDUCTION = -100 53 | 54 | OBJECT_METHODS = 0 55 | 56 | PROFILER = 100 57 | WANDB_LOGGER = 110 58 | CSV = 110 59 | TENSORBOARD = 110 60 | TABLE = 120 61 | 62 | 63 | class Callback: 64 | """ 65 | A callback that can be registered to a stage or the whole pipeline to receive updates on the training progress. 66 | """ 67 | 68 | def pre_run(self, pipe: 'Pipeline'): 69 | """ 70 | Executed before the pipeline starts. 71 | """ 72 | pass 73 | 74 | def post_run(self, pipe: 'Pipeline'): 75 | """ 76 | Executed after the pipeline finishes. 77 | """ 78 | pass 79 | 80 | def cleanup(self, pipe: 'Pipeline', exc_type, exc_value, traceback): 81 | """ 82 | Executed after the pipeline finishes, even if an error occurred. 83 | E.g. to close file handles. 84 | 85 | Args: 86 | pipe (Pipeline): The pipeline that is being cleaned up. 87 | exc_type (type): The type of the exception that caused the cleanup or None if no exception occurred. 88 | exc_value (Exception): The exception that caused the cleanup or None if no exception occurred. 89 | traceback (Traceback): The traceback of the exception that caused the cleanup or None if no exception occurred. 90 | """ 91 | pass 92 | 93 | def pre_stage(self, stage: 'Stage'): 94 | """ 95 | Executed before the stage starts. 96 | """ 97 | pass 98 | 99 | def post_stage(self, stage: 'Stage'): 100 | """ 101 | Executed after the stage finishes. 102 | """ 103 | pass 104 | 105 | def pre_epoch(self, stage: 'Stage'): 106 | """ 107 | Executed before each epoch. 108 | """ 109 | pass 110 | 111 | def post_epoch(self, stage: 'Stage'): 112 | """ 113 | Executed after each epoch. 114 | """ 115 | pass 116 | 117 | def post_step(self, stage: 'Stage'): 118 | """ 119 | Executed after each step. Stage must call `finish_step` to trigger this callback. 120 | """ 121 | pass 122 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/cuda.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pynvml 4 | import torch.cuda 5 | 6 | import dmlcloud.core.logging as dml_logging 7 | from dmlcloud.core.distributed import all_gather_object, is_root 8 | from .common import Callback 9 | 10 | 11 | class CudaCallback(Callback): 12 | """ 13 | Logs various properties pertaining to CUDA devices. 14 | """ 15 | 16 | @staticmethod 17 | def _call_pynvml(method, *args, **kwargs): 18 | try: 19 | return method(*args, **kwargs) 20 | except pynvml.NVMLError: 21 | return None 22 | 23 | def pre_run(self, pipe): 24 | handle = torch.cuda._get_pynvml_handler(pipe.device) 25 | 26 | info = { 27 | 'name': self._call_pynvml(pynvml.nvmlDeviceGetName, handle), 28 | 'uuid': self._call_pynvml(pynvml.nvmlDeviceGetUUID, handle), 29 | 'serial': self._call_pynvml(pynvml.nvmlDeviceGetSerial, handle), 30 | 'torch_device': str(pipe.device), 31 | 'minor_number': self._call_pynvml(pynvml.nvmlDeviceGetMinorNumber, handle), 32 | 'architecture': self._call_pynvml(pynvml.nvmlDeviceGetArchitecture, handle), 33 | 'brand': self._call_pynvml(pynvml.nvmlDeviceGetBrand, handle), 34 | 'vbios_version': self._call_pynvml(pynvml.nvmlDeviceGetVbiosVersion, handle), 35 | 'driver_version': self._call_pynvml(pynvml.nvmlSystemGetDriverVersion), 36 | 'cuda_driver_version': self._call_pynvml(pynvml.nvmlSystemGetCudaDriverVersion_v2), 37 | 'nvml_version': self._call_pynvml(pynvml.nvmlSystemGetNVMLVersion), 38 | 'total_memory': self._call_pynvml(pynvml.nvmlDeviceGetMemoryInfo, handle, pynvml.nvmlMemory_v2).total, 39 | 'reserved_memory': self._call_pynvml(pynvml.nvmlDeviceGetMemoryInfo, handle, pynvml.nvmlMemory_v2).reserved, 40 | 'num_gpu_cores': self._call_pynvml(pynvml.nvmlDeviceGetNumGpuCores, handle), 41 | 'power_managment_limit': self._call_pynvml(pynvml.nvmlDeviceGetPowerManagementLimit, handle), 42 | 'power_managment_default_limit': self._call_pynvml(pynvml.nvmlDeviceGetPowerManagementDefaultLimit, handle), 43 | 'cuda_compute_capability': self._call_pynvml(pynvml.nvmlDeviceGetCudaComputeCapability, handle), 44 | } 45 | all_devices = all_gather_object(info) 46 | 47 | msg = '* CUDA-DEVICES:\n' 48 | info_strings = [ 49 | f'{info["torch_device"]} -> /dev/nvidia{info["minor_number"]} -> {info["name"]} (UUID: {info["uuid"]}) (VRAM: {info["total_memory"] / 1000 ** 2:.0f} MB)' 50 | for info in all_devices 51 | ] 52 | msg += '\n'.join(f' - [{i}] {info_str}' for i, info_str in enumerate(info_strings)) 53 | dml_logging.info(msg) 54 | 55 | if pipe.run_dir and is_root(): 56 | self._save(pipe.run_dir / 'cuda_devices.json', all_devices) 57 | 58 | def _save(self, path, all_devices): 59 | with open(path, 'w') as f: 60 | devices = {f'rank_{i}': device for i, device in enumerate(all_devices)} 61 | obj = {'devices': devices} 62 | json.dump(obj, f, indent=4) 63 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/diagnostics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.cuda 8 | from omegaconf import OmegaConf 9 | 10 | import dmlcloud.core.logging as dml_logging 11 | import dmlcloud.slurm as dmlcloud_slurm 12 | from dmlcloud.core.distributed import world_size 13 | from dmlcloud.git import git_hash 14 | from dmlcloud.util.thirdparty import is_imported, ML_MODULES, try_get_version 15 | from dmlcloud.version import __version__ as dmlcloud_version 16 | from .common import Callback 17 | 18 | 19 | class DiagnosticsCallback(Callback): 20 | """ 21 | A callback that logs diagnostics information at the beginning of training. 22 | """ 23 | 24 | def _experiment_header( 25 | self, 26 | name: str | None, 27 | run_dir: str | None, 28 | date: datetime, 29 | ) -> str: 30 | msg = f'............... Experiment: {name if name else "N/A"} ...............\n' 31 | msg += f'- Date: {date}\n' 32 | msg += f'- Checkpoint Dir: {run_dir if run_dir else "N/A"}\n' 33 | msg += f'- Training on {world_size()} GPUs\n' 34 | return msg 35 | 36 | def _general_diagnostics(self) -> str: 37 | msg = '* GENERAL:\n' 38 | msg += f' - argv: {sys.argv}\n' 39 | msg += f' - cwd: {Path.cwd()}\n' 40 | 41 | msg += f' - host (root): {os.environ.get("HOSTNAME")}\n' 42 | msg += f' - user: {os.environ.get("USER")}\n' 43 | msg += f' - git-hash: {git_hash()}\n' 44 | msg += f' - conda-env: {os.environ.get("CONDA_DEFAULT_ENV", "N/A")}\n' 45 | msg += f' - sys-prefix: {sys.prefix}\n' 46 | msg += f' - backend: {torch.distributed.get_backend()}\n' 47 | msg += f' - cuda: {torch.cuda.is_available()}\n' 48 | 49 | msg += '* VERSIONS:\n' 50 | msg += f' - python: {sys.version}\n' 51 | msg += f' - cuda (torch): {torch.version.cuda}\n' 52 | try: 53 | msg += ' - ' + Path('/proc/driver/nvidia/version').read_text().splitlines()[0] + '\n' 54 | except (FileNotFoundError, IndexError): 55 | pass 56 | 57 | msg += f' - dmlcloud: {dmlcloud_version}\n' 58 | 59 | for module_name in ML_MODULES: 60 | if is_imported(module_name): 61 | msg += f' - {module_name}: {try_get_version(module_name)}\n' 62 | 63 | if 'SLURM_JOB_ID' in os.environ: 64 | msg += '* SLURM:\n' 65 | msg += f' - SLURM_JOB_ID = {dmlcloud_slurm.slurm_job_id()}\n' 66 | msg += f' - SLURM_STEP_ID = {dmlcloud_slurm.slurm_step_id()}\n' 67 | msg += f' - SLURM_STEP_NODELIST = {os.environ.get("SLURM_STEP_NODELIST")}\n' 68 | msg += f' - SLURM_TASKS_PER_NODE = {os.environ.get("SLURM_TASKS_PER_NODE")}\n' 69 | msg += f' - SLURM_STEP_GPUS = {os.environ.get("SLURM_STEP_GPUS")}\n' 70 | msg += f' - SLURM_GPUS_ON_NODE = {os.environ.get("SLURM_GPUS_ON_NODE")}\n' 71 | msg += f' - SLURM_CPUS_PER_TASK = {os.environ.get("SLURM_CPUS_PER_TASK")}' 72 | 73 | return msg 74 | 75 | def pre_run(self, pipe): 76 | header = '\n' + self._experiment_header(pipe.name, pipe.run_dir, pipe.start_time) 77 | dml_logging.info(header) 78 | 79 | diagnostics = self._general_diagnostics() 80 | 81 | diagnostics += '\n* CONFIG:\n' 82 | diagnostics += '\n'.join(f' {line}' for line in OmegaConf.to_yaml(pipe.config, resolve=True).splitlines()) 83 | 84 | dml_logging.info(diagnostics) 85 | 86 | def post_stage(self, stage): 87 | if len(stage.pipe.stages) > 1: 88 | dml_logging.info(f'Finished stage in {stage.end_time - stage.start_time}') 89 | 90 | def post_run(self, pipe): 91 | dml_logging.info(f'Finished training in {pipe.stop_time - pipe.start_time} ({pipe.stop_time})') 92 | if pipe.has_checkpointing: 93 | dml_logging.info(f'Outputs have been saved to {pipe.run_dir}') 94 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/git.py: -------------------------------------------------------------------------------- 1 | import dmlcloud.core.logging as dml_logging 2 | from dmlcloud.core.callbacks import Callback 3 | from dmlcloud.core.distributed import is_root 4 | from dmlcloud.git import git_diff 5 | 6 | 7 | class GitDiffCallback(Callback): 8 | """ 9 | A callback that prints a git diff and if checkpointing is enabled, saves it to the checkpoint directory. 10 | """ 11 | 12 | def __init__(self, max_chars=2500): 13 | """ 14 | Args: 15 | max_chars: The maximum number of characters to log to console. (default: 2500) 16 | """ 17 | self.max_chars = max_chars 18 | 19 | def _log_diff(self, diff): 20 | truncate = self.max_chars and len(diff) > self.max_chars 21 | 22 | if truncate: 23 | msg = '* GIT-DIFF (truncated):\n' 24 | lines = diff[: self.max_chars].splitlines() 25 | else: 26 | msg = '* GIT-DIFF:\n' 27 | lines = diff.splitlines() 28 | 29 | msg += '\n'.join(' ' + line for line in lines) 30 | if truncate: 31 | msg += f'\n ... (truncated to {self.max_chars} characters, total: {len(diff)})' 32 | 33 | dml_logging.info(msg) 34 | 35 | def pre_run(self, pipe): 36 | diff = git_diff() 37 | if diff is None: 38 | return 39 | 40 | if pipe.run_dir and is_root(): 41 | self._save(pipe.run_dir / 'git_diff.txt', diff) 42 | 43 | self._log_diff(diff) 44 | 45 | def _save(self, path, diff): 46 | with open(path, 'w') as f: 47 | f.write(diff) 48 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from pathlib import Path 3 | from typing import TYPE_CHECKING, Union 4 | 5 | from .common import Callback 6 | 7 | 8 | if TYPE_CHECKING: 9 | from dmlcloud.core.stage import Stage 10 | 11 | 12 | class ReduceMetricsCallback(Callback): 13 | """ 14 | A callback that reduces the metrics at the end of each epoch and appends them to the history. 15 | """ 16 | 17 | def __init__(self, log_every_n_steps=50): 18 | self.log_every_n_steps = log_every_n_steps 19 | 20 | def _reduce_epoch_metrics(self, stage): 21 | metrics = stage.metrics.reduce() 22 | stage.history.append_metrics(**metrics) 23 | 24 | def _reduce_step_metrics(self, stage): 25 | metrics = stage.step_metrics.reduce() 26 | stage.step_history.append_metrics(**metrics) 27 | 28 | def post_epoch(self, stage: 'Stage'): 29 | stage.log('misc/epoch', stage.current_epoch, prefixed=False, reduction='max') 30 | self._reduce_epoch_metrics(stage) 31 | stage.step = 0 # Reset the step counter 32 | 33 | def post_step(self, stage: 'Stage'): 34 | stage.log('misc/step', stage.global_step, prefixed=False, reduction='max') 35 | 36 | if stage.global_step % self.log_every_n_steps == 0: 37 | self._reduce_step_metrics(stage) 38 | 39 | stage.step += 1 40 | stage.global_step += 1 41 | 42 | def post_stage(self, stage): 43 | has_unreduced_metrics = False 44 | for metric in stage.step_metrics.metrics.values(): 45 | if metric.update_called: 46 | has_unreduced_metrics = True 47 | break 48 | 49 | # need to check global_step > 0 to avoid reducing when finish_step() was never called once 50 | if has_unreduced_metrics and stage.global_step > 0: 51 | self._reduce_step_metrics(stage) 52 | 53 | 54 | class CsvCallback(Callback): 55 | """ 56 | Saves metrics to a CSV file at the end of each epoch. 57 | """ 58 | 59 | def __init__(self, directory: Union[str, Path]): 60 | """ 61 | Initialize the callback with the given path. 62 | 63 | Args: 64 | directory (Union[str, Path]): The path to the directory where the CSV files will be saved. 65 | """ 66 | self.directory = Path(directory) 67 | self.last_steps = {} 68 | 69 | def _build_name(self, stage: 'Stage', prefix: str): 70 | duplicate_stages = [s for s in stage.pipe.stages if s.name == stage.name] 71 | idx = duplicate_stages.index(stage) 72 | if len(duplicate_stages) > 1: 73 | return self.directory / f'{prefix}_{stage.name}_{idx + 1}.csv' 74 | else: 75 | return self.directory / f'{prefix}_{stage.name}.csv' 76 | 77 | def epoch_path(self, stage: 'Stage'): 78 | return self._build_name(stage, 'epoch_metrics') 79 | 80 | def step_path(self, stage: 'Stage'): 81 | return self._build_name(stage, 'step_metrics') 82 | 83 | def pre_stage(self, stage: 'Stage'): 84 | # If for some reason we can't write to the file or it exists already, its better to fail early 85 | with open(self.epoch_path(stage), 'x'): 86 | pass 87 | 88 | def _write_history(self, file, history, step_metric, step_name): 89 | writer = csv.writer(file) 90 | 91 | metric_names = list(history.keys()) 92 | metric_names.remove(step_metric) 93 | 94 | writer.writerow([step_name] + metric_names) # Header 95 | for row in history.rows(): 96 | csv_row = [row[step_metric]] + [row[name] for name in metric_names] 97 | writer.writerow(csv_row) 98 | 99 | def _maybe_write_step_metrics(self, stage: 'Stage'): 100 | if stage.step_history.num_steps > self.last_steps.get(stage, 0): 101 | self.last_steps[stage] = stage.step_history.num_steps 102 | with open(self.step_path(stage), 'w') as f: 103 | self._write_history(f, stage.step_history, 'misc/step', 'step') 104 | 105 | def post_epoch(self, stage: 'Stage'): 106 | with open(self.epoch_path(stage), 'w') as f: 107 | self._write_history(f, stage.history, 'misc/epoch', 'epoch') 108 | 109 | def post_step(self, stage: 'Stage'): 110 | self._maybe_write_step_metrics(stage) 111 | 112 | def post_stage(self, stage): 113 | self._maybe_write_step_metrics(stage) # edge case: last steps of training 114 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/profiler.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from torch.profiler import profile, ProfilerActivity 4 | 5 | from .common import Callback 6 | 7 | 8 | if TYPE_CHECKING: 9 | from dmlcloud.core.stage import Stage 10 | 11 | 12 | class ProfilerCallback(Callback): 13 | """ 14 | A callback that profiles the training process and saves the results to a file. 15 | """ 16 | 17 | def __init__(self, epochs=None, record_shapes=False, schedule=None): 18 | self.epochs = epochs 19 | self.record_shapes = record_shapes 20 | self.schedule = schedule 21 | 22 | self.profiler = None 23 | self._capturing = False 24 | 25 | def pre_epoch(self, stage: 'Stage'): 26 | if self.epochs and stage.current_epoch not in self.epochs: 27 | return 28 | 29 | self.profiler = profile( 30 | activities=[ 31 | ProfilerActivity.CPU, 32 | ProfilerActivity.CUDA, 33 | ], 34 | record_shapes=self.record_shapes, 35 | schedule=self.schedule, 36 | ) 37 | self.profiler.__enter__() 38 | self._capturing = True 39 | 40 | def post_epoch(self, stage): 41 | if self.epochs and (stage.current_epoch - 1) not in self.epochs: 42 | return 43 | 44 | self.profiler.__exit__(None, None, None) 45 | self._capturing = False 46 | 47 | if stage.run_dir: 48 | outfile = str(stage.run_dir / f'{stage.name}_epoch{stage.current_epoch - 1}_trace.json') 49 | self.profiler.export_chrome_trace(outfile) 50 | 51 | def cleanup(self, pipe, exc_type, exc_value, traceback): 52 | if self._capturing: 53 | self.profiler.__exit__(exc_type, exc_value, traceback) 54 | self._capturing = False 55 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/table.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Callable, Optional, TYPE_CHECKING 3 | 4 | from progress_table import ProgressTable 5 | 6 | from dmlcloud.core.distributed import is_root 7 | from dmlcloud.util.logging import DevNullIO, TimedeltaFormatter 8 | from .common import Callback 9 | 10 | if TYPE_CHECKING: 11 | from dmlcloud.core.stage import Stage 12 | 13 | 14 | class TableCallback(Callback): 15 | """ 16 | A callback that updates a table with the latest metrics from a stage. 17 | """ 18 | 19 | def __init__(self): 20 | self._table = None 21 | self.tracked_metrics = {} 22 | self.formatters = {} 23 | 24 | def get_table(self, stage: 'Stage'): 25 | if self._table is None: 26 | self._table = ProgressTable(file=sys.stdout if is_root() else DevNullIO(), interactive=0) 27 | self.track_metric(stage, 'Epoch', width=5) 28 | self.track_metric(stage, 'Took', 'misc/epoch_time', formatter=TimedeltaFormatter(), width=7) 29 | if stage._run_epoch_overridden: 30 | self.track_metric(stage, 'ETA', 'misc/eta', formatter=TimedeltaFormatter(), width=7) 31 | return self._table 32 | 33 | def set_table(self, value): 34 | self._table = value 35 | 36 | def track_metric( 37 | self, 38 | stage: 'Stage', 39 | name: str, 40 | metric: Optional[str] = None, 41 | formatter: Optional[Callable] = None, 42 | width: Optional[int] = None, 43 | color: Optional[str] = None, 44 | alignment: Optional[str] = None, 45 | ): 46 | """ 47 | Track a metric in the table. 48 | 49 | If no metric name is provided, only a column is created and the caller must update the value manually. 50 | If a formatter is provided, the metric value will be passed through the formatter before being displayed. 51 | 52 | For a detailed description of width, color, and alignment, see `ProgressTable.add_column`. 53 | 54 | Args: 55 | name (str): The name of the column. 56 | metric (str, optional): The name of the metric to track. Defaults to None. 57 | formatter (Callable, optional): A function that takes the metric value and returns a string. Defaults to None. 58 | width (int, optional): The width of the column. Defaults to None. 59 | color (str, optional): The color of the column. Defaults to None. 60 | alignment (str, optional): The alignment of the column. Defaults to None. 61 | """ 62 | if formatter and not metric: 63 | raise ValueError('Cannot provide a formatter without a metric name') 64 | 65 | table = self.get_table(stage) 66 | table.add_column(name, width=width, color=color, alignment=alignment) 67 | 68 | if metric: 69 | self.tracked_metrics[name] = metric 70 | self.formatters[name] = formatter 71 | 72 | def pre_stage(self, stage: 'Stage'): 73 | self.get_table(stage) # Ensure the table has been created at this point 74 | 75 | def post_stage(self, stage: 'Stage'): 76 | table = self.get_table(stage) 77 | table.close() 78 | 79 | def pre_epoch(self, stage: 'Stage'): 80 | table = self.get_table(stage) 81 | if 'Epoch' in self.get_table(stage).column_names: 82 | table['Epoch'] = stage.current_epoch 83 | 84 | def post_epoch(self, stage: 'Stage'): 85 | table = self.get_table(stage) 86 | metrics = stage.history.last() 87 | 88 | for column_name, metric_name in self.tracked_metrics.items(): 89 | if column_name not in table.column_names: # When does this happen? 90 | continue 91 | 92 | if metric_name in metrics: 93 | value = metrics[metric_name] 94 | formatter = self.formatters[column_name] 95 | if formatter is not None: 96 | value = formatter(value) 97 | table.update(column_name, value) 98 | else: 99 | pass # don't update -> empty cell 100 | 101 | table.next_row() 102 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import TYPE_CHECKING, Union 3 | 4 | from dmlcloud.core.callbacks import Callback 5 | 6 | if TYPE_CHECKING: 7 | from dmlcloud.core.stage import Stage 8 | 9 | 10 | class TensorboardCallback(Callback): 11 | """ 12 | A callback that logs metrics to Tensorboard. 13 | """ 14 | 15 | def __init__(self, log_dir: Union[str, Path]): 16 | self.log_dir = Path(log_dir) 17 | self.writer = None 18 | try: 19 | from torch.utils.tensorboard import SummaryWriter # noqa: F401 20 | except ImportError: 21 | raise ImportError('tensorflow is required for the TensorboardCallback') 22 | 23 | def pre_run(self, pipe): 24 | from torch.utils.tensorboard import SummaryWriter 25 | 26 | self.writer = SummaryWriter(log_dir=self.log_dir) 27 | 28 | def post_epoch(self, stage: 'Stage'): 29 | metrics = stage.history.last() 30 | for key, value in metrics.items(): 31 | self.writer.add_scalar(key, value, stage.current_epoch) 32 | 33 | def cleanup(self, pipe, exc_type, exc_value, traceback): 34 | if self.writer is not None: 35 | self.writer.close() 36 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/timer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import TYPE_CHECKING 3 | 4 | from .common import Callback 5 | 6 | 7 | if TYPE_CHECKING: 8 | from dmlcloud.core.stage import Stage 9 | 10 | 11 | class TimerCallback(Callback): 12 | """ 13 | A callback that logs the time taken for each epoch. 14 | """ 15 | 16 | def __init__(self): 17 | self.start_time = None 18 | self.end_time = None 19 | self.epoch_start_time = None 20 | self.epoch_end_time = None 21 | 22 | def pre_stage(self, stage: 'Stage'): 23 | self.start_time = datetime.now() 24 | 25 | def post_stage(self, stage: 'Stage'): 26 | self.end_time = datetime.now() 27 | 28 | def pre_epoch(self, stage: 'Stage'): 29 | self.epoch_start_time = datetime.now() 30 | 31 | def post_epoch(self, stage: 'Stage'): 32 | self.epoch_end_time = datetime.now() 33 | 34 | epoch_time = (stage.epoch_end_time - self.epoch_start_time).total_seconds() 35 | total_time = (stage.epoch_end_time - self.start_time).total_seconds() 36 | stage.log('misc/epoch_time', epoch_time, prefixed=False, log_step=False) 37 | stage.log('misc/total_time', total_time, prefixed=False, log_step=False) 38 | 39 | if stage._run_epoch_overridden: 40 | average_epoch_time = (stage.epoch_end_time - self.start_time) / (stage.current_epoch + 1) 41 | eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1) 42 | stage.log('misc/eta', eta.total_seconds(), prefixed=False, log_step=False) 43 | -------------------------------------------------------------------------------- /dmlcloud/core/callbacks/wandb.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | from omegaconf import OmegaConf 4 | 5 | from dmlcloud.core.callbacks import Callback 6 | from dmlcloud.util.wandb import wandb_is_initialized, wandb_set_startup_timeout 7 | 8 | 9 | if TYPE_CHECKING: 10 | from dmlcloud.core.pipeline import Pipeline 11 | from dmlcloud.core.stage import Stage 12 | 13 | 14 | class WandbInitCallback(Callback): 15 | """ 16 | A callback that initializes Weights & Biases and closes it at the end. 17 | This is separated from the WandbLoggerCallback to ensure it is called right at the beginning of training. 18 | """ 19 | 20 | def __init__(self, project, entity, group, tags, startup_timeout, **kwargs): 21 | try: 22 | import wandb 23 | except ImportError: 24 | raise ImportError('wandb is required for the WandbInitCallback') 25 | 26 | self.wandb = wandb 27 | self.project = project 28 | self.entity = entity 29 | self.group = group 30 | self.tags = tags 31 | self.startup_timeout = startup_timeout 32 | self.kwargs = kwargs 33 | 34 | def pre_run(self, pipe: 'Pipeline'): 35 | wandb_set_startup_timeout(self.startup_timeout) 36 | self.wandb.init( 37 | config=OmegaConf.to_container(pipe.config, resolve=True), 38 | name=pipe.name, 39 | project=self.project, 40 | entity=self.entity, 41 | group=self.group, 42 | tags=self.tags, 43 | **self.kwargs, 44 | ) 45 | 46 | def cleanup(self, pipe, exc_type, exc_value, traceback): 47 | if wandb_is_initialized(): 48 | self.wandb.finish(exit_code=0 if exc_type is None else 1) 49 | 50 | 51 | class WandbLoggerCallback(Callback): 52 | """ 53 | A callback that logs metrics to Weights & Biases. 54 | """ 55 | 56 | def __init__(self): 57 | try: 58 | import wandb 59 | except ImportError: 60 | raise ImportError('wandb is required for the WandbLoggerCallback') 61 | 62 | self.wandb = wandb 63 | 64 | def post_epoch(self, stage: 'Stage'): 65 | metrics = stage.history.last() 66 | self.wandb.log(metrics, commit=True, step=stage.current_epoch) 67 | -------------------------------------------------------------------------------- /dmlcloud/core/checkpoint.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import secrets 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | from omegaconf import OmegaConf 7 | 8 | from dmlcloud.slurm import slurm_job_id 9 | 10 | 11 | __all__ = [ 12 | 'generate_checkpoint_path', 13 | 'is_valid_checkpoint_dir', 14 | 'create_checkpoint_dir', 15 | 'find_slurm_checkpoint', 16 | 'read_slurm_id', 17 | 'save_config', 18 | 'read_config', 19 | ] 20 | 21 | 22 | def sanitize_filename(filename: str) -> str: 23 | return filename.replace('/', '_') 24 | 25 | 26 | def generate_id() -> str: 27 | s = secrets.token_urlsafe(5) 28 | return s.replace('-', 'a').replace('_', 'b') 29 | 30 | 31 | def generate_checkpoint_path( 32 | root: Path | str, name: Optional[str] = None, creation_time: Optional[datetime.datetime] = None 33 | ) -> Path: 34 | root = Path(root) 35 | 36 | if name is None: 37 | name = 'run' 38 | 39 | if creation_time is None: 40 | creation_time = datetime.datetime.now() 41 | 42 | dt = datetime.datetime.now().strftime('%Y.%m.%d-%H.%M') 43 | name = sanitize_filename(name) 44 | return root / f'{name}-{dt}-{generate_id()}' 45 | 46 | 47 | def is_valid_checkpoint_dir(path: Path) -> bool: 48 | if not path.exists() or not path.is_dir(): 49 | return False 50 | 51 | if not (path / '.dmlcloud').exists(): 52 | return False 53 | 54 | return True 55 | 56 | 57 | def create_checkpoint_dir(path: Path | str, name: Optional[str] = None) -> Path: 58 | path.mkdir(parents=True, exist_ok=True) 59 | (path / '.dmlcloud').touch() 60 | (path / 'log.txt').touch() 61 | if slurm_job_id() is not None: 62 | with open(path / '.slurm-jobid', 'w') as f: 63 | f.write(slurm_job_id()) 64 | 65 | 66 | def read_slurm_id(path: Path) -> Optional[str]: 67 | if is_valid_checkpoint_dir(path): 68 | return None 69 | 70 | if not (path / '.slurm-jobid').exists(): 71 | return None 72 | 73 | with open(path / '.slurm-jobid') as f: 74 | return f.read() 75 | 76 | 77 | def find_slurm_checkpoint(root: Path | str) -> Optional[Path]: 78 | root = Path(root) 79 | 80 | job_id = slurm_job_id() 81 | if job_id is None: 82 | return None 83 | 84 | for child in root.iterdir(): 85 | if read_slurm_id(child) == job_id: 86 | return child 87 | 88 | return None 89 | 90 | 91 | def save_config(config: OmegaConf, run_dir: Path): 92 | with open(run_dir / 'config.yaml', 'w') as f: 93 | OmegaConf.save(config, f) 94 | 95 | 96 | def read_config(run_dir: Path) -> OmegaConf: 97 | with open(run_dir / 'config.yaml') as f: 98 | return OmegaConf.load(f) 99 | -------------------------------------------------------------------------------- /dmlcloud/core/config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from functools import partial, update_wrapper 3 | from typing import Any, Callable, Mapping 4 | 5 | 6 | __all__ = [ 7 | 'import_object', 8 | 'factory_from_cfg', 9 | 'obj_from_cfg', 10 | ] 11 | 12 | 13 | def import_object(object_path: str) -> Any: 14 | """ 15 | Imports an object from a module. 16 | 17 | The object path should be in the form of "module.submodule.object". 18 | The function imports the module and returns the object. 19 | 20 | Args: 21 | object_path (str): The path to the object to import. 22 | 23 | Returns: 24 | Any: The imported object. 25 | 26 | Raises: 27 | ImportError: If the module containing the object cannot be imported or the object cannot be found. 28 | 29 | Example: 30 | >>> import_object("dmlcloud.core.stage.Stage") 31 | 32 | """ 33 | 34 | module_name, obj_name = object_path.rsplit(".", 1) 35 | module = importlib.import_module(module_name) 36 | 37 | try: 38 | return getattr(module, obj_name) 39 | except AttributeError as e: 40 | raise ImportError(f"Object '{obj_name}' not found in module '{module_name}'") from e 41 | 42 | 43 | def factory_from_cfg(config: Mapping | str, *args, **kwargs) -> Callable: 44 | """ 45 | Creates a factory function from a configuration dictionary or a string. 46 | 47 | If a string is provided, it is assumed to be the path to the factory function (or class). 48 | 49 | If a dictionary is provided, it must contain a "factory" key with the path to the factory function. 50 | Additional keys in the dictionary are passed as keyword arguments to the factory function. 51 | 52 | Args: 53 | config (Mapping | str): Configuration dictionary or string with the path to the factory function. 54 | *args: Additional positional arguments to pass to the factory function. 55 | **kwargs: Additional keyword arguments to pass to the factory function. 56 | 57 | Returns: 58 | Callable: A factory function with the provided configuration and arguments. 59 | 60 | Raises: 61 | ImportError: If the factory function cannot be imported. 62 | KeyError: If the configuration dictionary does not contain the "factory" key. 63 | 64 | Example: 65 | >>> factory = dml.factory_from_cfg('datetime.date', 2025, month=1, day=1) 66 | >>> factory 67 | 68 | >>> factory() 69 | datetime.date(2025, 1, 1) 70 | >>> factory(month=12, day=31) 71 | datetime.date(2025, 12, 31) 72 | 73 | Instead of providing a string, you can also use a configuration dictionary: 74 | 75 | >>> config = {'factory': 'datetime.date', 'year': 2025, 'month': 1, 'day': 1} 76 | >>> factory = dml.factory_from_cfg(config) 77 | >>> factory() 78 | datetime.date(2025, 1, 1) 79 | """ 80 | 81 | if isinstance(config, str): 82 | factory = import_object(config) 83 | kwargs = kwargs.copy() 84 | else: 85 | factory = import_object(config['factory']) 86 | merged_kwargs = config.copy() 87 | merged_kwargs.update(kwargs) 88 | kwargs = merged_kwargs 89 | del kwargs['factory'] 90 | 91 | wrapper = partial(factory, *args, **kwargs) 92 | return update_wrapper(wrapper, factory) 93 | 94 | 95 | def obj_from_cfg(config: Mapping | str, *args, **kwargs) -> Any: 96 | """ 97 | Creates an object from a configuration dictionary or a string. 98 | 99 | This is equivalent to calling `factory_from_cfg(config)(*args, **kwargs)`. 100 | 101 | If a string is provided, it is assumed to be the path to the object (class). 102 | If a dictionary is provided, it must contain a "factory" key with the path to the object. 103 | 104 | Additional keys in the dictionary are passed as keyword arguments to the object constructor. 105 | 106 | Args: 107 | config (Mapping | str): Configuration dictionary or string with the path to the object. 108 | *args: Additional positional arguments to pass to the object constructor. 109 | **kwargs: Additional keyword arguments to pass to the object constructor. 110 | 111 | Returns: 112 | Any: The created object. 113 | 114 | Raises: 115 | ImportError: If the object cannot be imported. 116 | KeyError: If the configuration dictionary does not contain the "factory" key. 117 | 118 | Example: 119 | >>> dml.obj_from_cfg('datetime.date', 2025, month=1, day=1) 120 | datetime.date(2025, 1, 1) 121 | """ 122 | 123 | return factory_from_cfg(config)(*args, **kwargs) 124 | -------------------------------------------------------------------------------- /dmlcloud/core/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides a simple logging interface for dmlcloud. 3 | 4 | The dmlcloud logger is setup to only log messages on the root process, with severity 'INFO' or higher. 5 | Non-root processes will only log messages with severity 'WARNING' or higher. 6 | 7 | Attributes: 8 | logger (logging.Logger): The dmlcloud logger. Only logs messages on the root process. 9 | """ 10 | 11 | import logging 12 | import sys 13 | import warnings 14 | 15 | import torch 16 | import torch.distributed 17 | 18 | from . import distributed as dml_distributed 19 | 20 | 21 | logger = logging.getLogger('dmlcloud') 22 | 23 | 24 | __all__ = [ 25 | 'logger', 26 | 'log', 27 | 'debug', 28 | 'info', 29 | 'warning', 30 | 'error', 31 | 'critical', 32 | 'setup_logger', 33 | 'reset_logger', 34 | 'flush_logger', 35 | 'print_worker', 36 | 'print_root', 37 | ] 38 | 39 | 40 | def _distributed_filter(record): 41 | if not torch.distributed.is_initialized(): 42 | return True 43 | elif torch.distributed.get_rank() == 0: 44 | return True 45 | else: 46 | return False 47 | 48 | 49 | def setup_logger(): 50 | """ 51 | Setup the dmlcloud logger. 52 | 53 | If torch.distributed is initialized, only the root-rank will log messages. Otherwise, all processes will log messages. 54 | Non-root processes will always log messages with severity 'WARNING' or higher to ensure important messages are not missed. 55 | 56 | Usually, this function is called automatically when logging a message, and should not be called manually. 57 | """ 58 | if logger.hasHandlers(): 59 | warnings.warn('Logger already setup. Ignoring call to setup_logger().') 60 | return 61 | 62 | logger.setLevel(logging.DEBUG) 63 | 64 | stdout_handler = logging.StreamHandler(sys.stdout) 65 | stdout_handler.addFilter(_distributed_filter) 66 | stdout_handler.addFilter(lambda record: record.levelno < logging.WARNING) 67 | stdout_handler.setFormatter(logging.Formatter()) 68 | stdout_handler.setLevel(logging.DEBUG) 69 | logger.addHandler(stdout_handler) 70 | 71 | stderr_handler = logging.StreamHandler() 72 | stderr_handler.addFilter(_distributed_filter) 73 | stderr_handler.setFormatter(logging.Formatter()) 74 | stderr_handler.setLevel(logging.WARNING) 75 | logger.addHandler(stderr_handler) 76 | 77 | 78 | def reset_logger(): 79 | """ 80 | Reset the dmlcloud logger to its initial state. 81 | 82 | This will remove all handlers from the logger and set its level to NOTSET. 83 | """ 84 | logger.setLevel(logging.NOTSET) 85 | to_remove = list(logger.handlers) 86 | for handler in to_remove: 87 | logger.removeHandler(handler) 88 | 89 | 90 | def flush_logger(logger: logging.Logger = None): 91 | """ 92 | Flushes all handlers of the given logger. 93 | 94 | Args: 95 | logger (logging.Logger, optional): The logger to flush. Default is the dmlcloud logger. 96 | """ 97 | if logger is None: 98 | logger = sys.modules[__name__].logger 99 | 100 | for handler in logger.handlers: 101 | handler.flush() 102 | 103 | 104 | def log(level, msg, *args, exc_info=None, stack_info=False, extra=None): 105 | """ 106 | Log 'msg % args' with severity 'level' on the dmlcloud logger. 107 | 108 | If the dmlcloud logger was not already setup, this function will setup the logger with the default configuration. 109 | """ 110 | if not logger.hasHandlers(): 111 | setup_logger() 112 | 113 | logger.log(level, msg, *args, exc_info=exc_info, stack_info=stack_info, extra=extra) 114 | 115 | 116 | def debug(msg, *args, exc_info=None, stack_info=False, extra=None): 117 | """ 118 | Log 'msg % args' with severity 'TRACE' on the dmlcloud logger. 119 | 120 | If the dmlcloud logger was not already setup, this function will setup the logger with the default configuration. 121 | """ 122 | log(logging.DEBUG, msg, *args, exc_info=exc_info, stack_info=stack_info, extra=extra) 123 | 124 | 125 | def info(msg, *args, exc_info=None, stack_info=False, extra=None): 126 | """ 127 | Log 'msg % args' with severity 'INFO' on the dmlcloud logger. 128 | 129 | If the dmlcloud logger was not already setup, this function will setup the logger with the default configuration. 130 | """ 131 | log(logging.INFO, msg, *args, exc_info=exc_info, stack_info=stack_info, extra=extra) 132 | 133 | 134 | def warning(msg, *args, exc_info=None, stack_info=False, extra=None): 135 | """ 136 | Log 'msg % args' with severity 'WARNING' on the dmlcloud logger. 137 | 138 | If the dmlcloud logger was not already setup, this function will setup the logger with the default configuration. 139 | """ 140 | log(logging.WARNING, msg, *args, exc_info=exc_info, stack_info=stack_info, extra=extra) 141 | 142 | 143 | def error(msg, *args, exc_info=None, stack_info=False, extra=None): 144 | """ 145 | Log 'msg % args' with severity 'ERROR' on the dmlcloud logger. 146 | 147 | If the dmlcloud logger was not already setup, this function will setup the logger with the default configuration. 148 | """ 149 | log(logging.ERROR, msg, *args, exc_info=exc_info, stack_info=stack_info, extra=extra) 150 | 151 | 152 | def critical(msg, *args, exc_info=None, stack_info=False, extra=None): 153 | """ 154 | Log 'msg % args' with severity 'CRITICAL' on the dmlcloud logger. 155 | 156 | If the dmlcloud logger was not already setup, this function will setup the logger with the default configuration. 157 | """ 158 | log(logging.CRITICAL, msg, *args, exc_info=exc_info, stack_info=stack_info, extra=extra) 159 | 160 | 161 | def print_worker(*values, sep=' ', end="\n", file=None, flush=True, barrier=False): 162 | """ 163 | Print the values to a stream, default sys.stdout, with additional information about the worker. 164 | 165 | Args: 166 | values (Any): The values to print. 167 | sep (str, optional): The separator between arguments. Default is a space. 168 | end (str, optional): The string to append at the end of the message. Default is a newline. 169 | file (file, optional): The file to write the message to. Default is None. 170 | flush (bool, optional): If True, the output buffer is flushed. Default is True. 171 | barrier (bool, optional): If True, a barrier is inserted before and after printing. Default is False. 172 | """ 173 | 174 | if barrier: 175 | torch.distributed.barrier() 176 | modified_values = [f'Worker {dml_distributed.rank()}'] 177 | if dml_distributed.local_node() is not None: 178 | modified_values += [f'({dml_distributed.local_node()}.{dml_distributed.local_rank()})'] 179 | modified_values.extend(values) 180 | print(*modified_values, sep=sep, end=end, file=file, flush=flush) 181 | if barrier: 182 | torch.distributed.barrier() 183 | 184 | 185 | @dml_distributed.root_only 186 | def print_root(*values, sep=' ', end="\n", file=None, flush=True): 187 | """ 188 | Print the values to a stream if the current rank is the root rank. 189 | 190 | Default is to print to the standard output stream. 191 | 192 | Args: 193 | msg (str): The message to print. 194 | sep (str, optional): The separator between arguments. Default is a space. 195 | end (str, optional): The string to append at the end of the message. Default is a newline. 196 | file (file, optional): The file to write the message to. Default is None. 197 | flush (bool, optional): If True, the output buffer is flushed. Default is True. 198 | """ 199 | 200 | print(*values, sep=sep, end=end, file=file, flush=flush) 201 | 202 | 203 | if __name__ == '__main__': 204 | from .distributed import init 205 | 206 | info("HELOOO") 207 | 208 | init() 209 | 210 | debug('[A] This is a debug message') 211 | info('[A] This is an info message') 212 | warning('[A] This is a warning message') 213 | error('[A] This is an error message') 214 | critical('[A] This is a critical message') 215 | 216 | reset_logger() 217 | torch.distributed.destroy_process_group() 218 | 219 | debug('[B] This is a debug message') 220 | info('[B] This is an info message') 221 | warning('[B] This is a warning message') 222 | error('[B] This is an error message') 223 | critical('[B] This is a critical message') 224 | -------------------------------------------------------------------------------- /dmlcloud/core/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Any, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torchmetrics 7 | from numpy.typing import ArrayLike 8 | 9 | 10 | __all__ = [ 11 | 'TrainingHistory', 12 | 'Tracker', 13 | ] 14 | 15 | 16 | class TrainingHistory: 17 | """ 18 | Stores the training history of a model. 19 | 20 | Metrics can either be ArrayLike objects or any pickleable object. 21 | 22 | Usage: 23 | history = TrainingHistory() 24 | history.append_metric('loss', 0.5) 25 | history.append_metric('accuracy', 0.99) 26 | history.next_step() 27 | 28 | for metric in history: 29 | print(f'{metric}': history[metric]) 30 | """ 31 | 32 | max_return_type = namedtuple('Max', ['value', 'step']) 33 | min_return_type = namedtuple('Min', ['value', 'step']) 34 | 35 | def __init__(self): 36 | self.num_steps = 0 37 | self._metrics = {} 38 | self._dtypes = {} 39 | 40 | def __getitem__(self, name: str): 41 | if name not in self._metrics: 42 | raise KeyError(f'Metric {name} does not exist') 43 | 44 | return np.stack(self._metrics[name], axis=0, dtype=self._dtypes[name]) 45 | 46 | def __delattr__(self, name): 47 | del self._metrics[name] 48 | 49 | def __contains__(self, name: str): 50 | return name in self._metrics 51 | 52 | def __len__(self): 53 | return len(self._metrics) 54 | 55 | def __iter__(self): 56 | return iter(self._metrics) 57 | 58 | def keys(self): 59 | return self._metrics.keys() 60 | 61 | def values(self): 62 | return [self[name] for name in self._metrics] 63 | 64 | def items(self): 65 | return [(name, self[name]) for name in self._metrics] 66 | 67 | def rows(self): 68 | for i in range(self.num_steps): 69 | yield {name: self._metrics[name][i] for name in self._metrics} 70 | 71 | def append_metric(self, name: str, value: Union[ArrayLike, Any]): 72 | """ 73 | Adds a value for a metric at the current step. 74 | 75 | Args: 76 | name (str): The name of the metric. 77 | value (ArrayLike, Any): The value of the metric. Must be a ArrayLike or pickleable object. 78 | """ 79 | if name in self._current_values: 80 | raise ValueError(f'Metric {name} already has a value for step {self.num_steps}') 81 | 82 | def append_metrics(self, **metrics): 83 | """ 84 | Adds multiple metrics at the current step. 85 | 86 | Args: 87 | **metrics: The metrics to add. 88 | """ 89 | for name, value in metrics.items(): 90 | dtype = value.dtype if type(value) == ArrayLike else object # noqa 91 | if isinstance(value, torch.Tensor) or isinstance(value, np.ndarray): 92 | value = value.item() 93 | 94 | if name not in self._metrics: 95 | self._metrics[name] = ([None] * self.num_steps) + [value] 96 | self._dtypes[name] = dtype 97 | else: 98 | self._metrics[name].append(value) 99 | 100 | self.num_steps += 1 101 | 102 | def last(self) -> dict[str, Any]: 103 | """ 104 | Returns the last value for each metric. 105 | 106 | Returns: 107 | dict[str, Any]: The last value for each metric. 108 | """ 109 | 110 | return {name: values[-1] for name, values in self._metrics.items()} 111 | 112 | def min(self) -> dict[str, min_return_type]: 113 | """ 114 | Returns a namedtuple (value, step) containing the minimum value and the corresponding step for each metric across all steps. 115 | 116 | Returns: 117 | dict[str, namedtuple]: The minimum value and the corresponding step for each metric. 118 | """ 119 | argmin = {name: np.argmin(values, axis=0) for name, values in self._metrics.items()} 120 | return {name: self.min_return_type(self._metrics[name][idx], idx) for name, idx in argmin.items()} 121 | 122 | def max(self) -> dict[str, max_return_type]: 123 | """ 124 | Returns a namedtuple (value, step) containing the maximum value and the corresponding step for each metric across all steps. 125 | 126 | Returns: 127 | dict[str, namedtuple]: The maximum value and the corresponding step for each metric. 128 | """ 129 | argmax = {name: np.argmax(values, axis=0) for name, values in self._metrics.items()} 130 | return {name: self.max_return_type(self._metrics[name][idx], idx) for name, idx in argmax.items()} 131 | 132 | 133 | class Tracker(torch.nn.Module): 134 | """ 135 | Keeps track of multiple metrics and reduces them at the end of each epoch. 136 | """ 137 | 138 | def __init__(self): 139 | super().__init__() 140 | 141 | self.metrics = torch.nn.ModuleDict() 142 | 143 | def add_metric(self, name: str, metric: torchmetrics.Metric): 144 | if name in self.metrics: 145 | raise ValueError(f'Metric {name} already exists') 146 | 147 | self.metrics[name] = metric 148 | 149 | def log(self, name: str, value: Any, reduction: str = 'mean', **kwargs): 150 | if reduction not in ['mean', 'sum', 'min', 'max', 'cat']: 151 | raise ValueError(f'Invalid reduction {reduction}. Must be one of mean, sum, min, max, cat') 152 | 153 | if not torch.is_tensor(value): 154 | value = torch.tensor(value) 155 | value = value.cpu() 156 | dtype = value.dtype 157 | 158 | if name not in self.metrics: 159 | if reduction == 'mean': 160 | metric = torchmetrics.MeanMetric(**kwargs) 161 | dtype = torch.float32 162 | elif reduction == 'sum': 163 | metric = torchmetrics.SumMetric(**kwargs) 164 | elif reduction == 'min': 165 | metric = torchmetrics.MinMetric(**kwargs) 166 | elif reduction == 'max': 167 | metric = torchmetrics.MaxMetric(**kwargs) 168 | elif reduction == 'cat': 169 | metric = torchmetrics.CatMetric(**kwargs) 170 | metric = metric.cpu().set_dtype(dtype) 171 | self.add_metric(name, metric) 172 | 173 | self.metrics[name].update(value) 174 | 175 | def reduce(self, reset: bool = True): 176 | values = {} 177 | for name, metric in self.metrics.items(): 178 | if metric.update_called: 179 | values[name] = metric.compute() 180 | if reset: 181 | metric.reset() 182 | else: 183 | values[name] = None 184 | return values 185 | 186 | def clear(self): 187 | for metric in self.metrics.values(): 188 | metric.reset() 189 | self.metrics.clear() 190 | 191 | def __getitem__(self, name: str): 192 | return self.metrics[name] 193 | 194 | def __setitem__(self, name: str, metric: torchmetrics.Metric): 195 | self.add_metric(name, metric) 196 | -------------------------------------------------------------------------------- /dmlcloud/core/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from . import distributed as dml_distributed, logging as dml_logging 5 | 6 | 7 | __all__ = [ 8 | 'count_parameters', 9 | 'wrap_ddp', 10 | 'scale_lr', 11 | ] 12 | 13 | 14 | def count_parameters(module: nn.Module) -> int: 15 | """ 16 | Returns the number of trainable parameters in a module. 17 | 18 | Args: 19 | module (nn.Module): The module to count the parameters of. 20 | 21 | Returns: 22 | int: The number of trainable parameters. 23 | """ 24 | return sum(p.numel() for p in module.parameters() if p.requires_grad) 25 | 26 | 27 | def wrap_ddp( 28 | module: nn.Module, 29 | device: torch.device, 30 | sync_bn: bool = False, 31 | find_unused_parameters: bool = False, 32 | verbose: bool = True, 33 | ) -> nn.Module: 34 | """ 35 | Wraps a module with DistributedDataParallel. 36 | 37 | Args: 38 | module (nn.Module): The module to wrap. 39 | device (torch.device): The device to use. 40 | sync_bn (bool, optional): If True, uses SyncBatchNorm. Default is False. 41 | find_unused_parameters (bool, optional): If True, finds unused parameters. Default is False. 42 | verbose (bool, optional): If True, prints information about the model. Default is True. 43 | 44 | Returns: 45 | nn.Module: The wrapped module. 46 | """ 47 | 48 | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): 49 | raise RuntimeError('DistributedDataParallel requires torch.distributed to be initialized.') 50 | 51 | module = module.to(device) # Doing it in this order is important for SyncBN 52 | if sync_bn: 53 | module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) 54 | 55 | device_ids = [device] if device.type == 'cuda' else None # Must be None for cpu devices 56 | ddp = nn.parallel.DistributedDataParallel( 57 | module, broadcast_buffers=False, device_ids=device_ids, find_unused_parameters=find_unused_parameters 58 | ) 59 | if verbose: 60 | msg = '* MODEL:\n' 61 | msg += f' - Parameters: {count_parameters(module) / 1e6:.1f} kk\n' 62 | msg += f' - {module}' 63 | dml_logging.info(msg) 64 | 65 | return ddp 66 | 67 | 68 | def scale_lr(base_lr: float, world_size: int = None) -> float: 69 | """ 70 | Scales the learning rate based on the world size. 71 | 72 | Args: 73 | base_lr (float): The base learning rate. 74 | world_size (int, optional): The number of processes. Default is the global world size. 75 | 76 | Returns: 77 | float: The scaled learning rate. 78 | 79 | See Also: 80 | - :func:`dmlcloud.` 81 | """ 82 | if world_size is None: 83 | world_size = dml_distributed.world_size() 84 | 85 | return base_lr * world_size 86 | -------------------------------------------------------------------------------- /dmlcloud/core/pipeline.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from datetime import datetime, timedelta 3 | from functools import cached_property 4 | from pathlib import Path 5 | from typing import Any, Dict, List, Optional, Union 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from omegaconf import OmegaConf 10 | 11 | from . import logging as dml_logging 12 | from .callbacks import ( 13 | Callback, 14 | CallbackList, 15 | CbPriority, 16 | CheckpointCallback, 17 | CsvCallback, 18 | CudaCallback, 19 | DiagnosticsCallback, 20 | GitDiffCallback, 21 | TensorboardCallback, 22 | WandbInitCallback, 23 | WandbLoggerCallback, 24 | ) 25 | from .checkpoint import find_slurm_checkpoint, generate_checkpoint_path, is_valid_checkpoint_dir 26 | from .distributed import broadcast_object, init, is_root, local_rank 27 | from .stage import Stage 28 | 29 | 30 | __all__ = [ 31 | 'Pipeline', 32 | 'current_pipe', 33 | 'current_stage', 34 | 'log_metric', 35 | ] 36 | 37 | _CURRENT_PIPE = None 38 | 39 | 40 | def _set_current_pipe(pipe): 41 | global _CURRENT_PIPE 42 | if pipe is None: 43 | if _CURRENT_PIPE is None: 44 | raise ValueError('Can not reset current pipe if there is none') 45 | _CURRENT_PIPE = None 46 | else: 47 | if _CURRENT_PIPE is not None: 48 | raise ValueError('Pipe already set') 49 | _CURRENT_PIPE = pipe 50 | 51 | 52 | def current_pipe() -> 'Pipeline | None': 53 | """ 54 | Returns the current running pipeline or None if no pipeline is running 55 | """ 56 | return _CURRENT_PIPE 57 | 58 | 59 | def current_stage() -> Stage | None: 60 | """ 61 | Returns the current running stage or None if no pipeline is running 62 | """ 63 | if current_pipe() is None: 64 | return None 65 | else: 66 | return current_pipe().current_stage 67 | 68 | 69 | def log_metric(name: str, value: Any, reduction: str = 'mean', prefixed: bool = True): 70 | """ 71 | Shorthand for current_stage().log 72 | """ 73 | return current_stage().log(name, value, reduction=reduction, prefixed=prefixed) 74 | 75 | 76 | class _RunGuard: 77 | """ 78 | Context manager that ensures that the pipeline is properly cleaned up in case of an exception or interruption. 79 | """ 80 | 81 | def __init__(self, pipe): 82 | self.pipe = pipe 83 | 84 | def __enter__(self): 85 | _set_current_pipe(self.pipe) 86 | return self 87 | 88 | def __exit__(self, exc_type, exc_value, traceback): 89 | _set_current_pipe(None) 90 | 91 | suppress_exception = False 92 | if exc_type is KeyboardInterrupt: 93 | dml_logging.info('------- Training interrupted by user -------') 94 | suppress_exception = True 95 | elif exc_type is not None: 96 | dml_logging.error( 97 | '------- Training failed with an exception -------', exc_info=(exc_type, exc_value, traceback) 98 | ) 99 | 100 | callbacks = [] 101 | if self.pipe.current_stage is not None: 102 | callbacks += self.pipe.current_stage.callbacks 103 | callbacks += self.pipe.callbacks 104 | 105 | for callback in reversed(callbacks): 106 | callback.cleanup(self.pipe, exc_type, exc_value, traceback) 107 | 108 | return suppress_exception 109 | 110 | 111 | class _ForwardCallback(Callback): 112 | """ 113 | Invokes the pre_run, post_run methods of the Pipeline. 114 | Stage-specific callbacks are managed by the Stage object. 115 | """ 116 | 117 | def pre_run(self, pipe): 118 | pipe.pre_run() 119 | 120 | def post_run(self, pipe): 121 | pipe.post_run() 122 | 123 | 124 | class Pipeline: 125 | """ 126 | A training pipeline that consists of multiple stages. 127 | 128 | This is the main entry point for training with dmlcloud. The pipeline manages the training process and 129 | orchestrates the execution of multiple stages. It also provides a way to add callbacks that are executed at 130 | different points during the training process. 131 | 132 | Use the `append()` method to add stages to the pipeline and `add_callback()` to add callbacks. 133 | 134 | Checkpointing can be enabled with `enable_checkpointing()` and Weights & Biases integration with `enable_wandb()`. 135 | 136 | Once the pipeline is set up, call `run()` to start the training process. 137 | """ 138 | 139 | def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None): 140 | # Auto-init torch.distributed if not already initialized 141 | if not dist.is_initialized(): 142 | init() 143 | 144 | if config is None: 145 | self.config = OmegaConf.create() 146 | elif not isinstance(config, OmegaConf): 147 | self.config = OmegaConf.create(config) 148 | else: 149 | self.config = config 150 | 151 | self.name = name 152 | 153 | self.run_dir: Path | None = None 154 | self.resumed = None 155 | self.start_time = None 156 | self.stop_time = None 157 | self.current_stage = None 158 | 159 | self._wandb = False 160 | 161 | self.stages = [] 162 | self.callbacks = CallbackList() 163 | 164 | self.add_callback(DiagnosticsCallback(), CbPriority.DIAGNOSTICS) 165 | self.add_callback(GitDiffCallback(), CbPriority.GIT) 166 | self.add_callback(_ForwardCallback(), CbPriority.OBJECT_METHODS) # methods have priority 0 167 | if self.device.type == 'cuda': 168 | self.add_callback(CudaCallback(), CbPriority.CUDA) 169 | 170 | if dist.is_gloo_available(): 171 | self.gloo_group = dist.new_group(backend='gloo') 172 | else: 173 | warnings.warn('Gloo backend not available. Barriers will not use custom timeouts.') 174 | 175 | @property 176 | def has_checkpointing(self): 177 | return self.run_dir is not None 178 | 179 | @property 180 | def has_wandb(self): 181 | return self._wandb 182 | 183 | @property 184 | def has_tensorboard(self): 185 | return self.has_checkpointing 186 | 187 | def add_callback(self, callback: Callback, priority: int = 1): 188 | """ 189 | Adds a callback to this pipeline. 190 | 191 | Callbacks added to the pipeline and not to individual stages are executed for all stages in the pipeline. 192 | Callbacks are executed based on their priority, with lower values being executed first. 193 | Callbacks with the same priority are executed in the order they were added. 194 | 195 | Methods of the stage and pipeline objects, e.g. pre_run(), have priority 0. 196 | 197 | Args: 198 | callback (StageCallback): The callback to add. 199 | priority (int, optional): The priority of the callback. Defaults to 1. 200 | """ 201 | self.callbacks.append(callback, priority) 202 | 203 | def append(self, stage: Stage): 204 | if not isinstance(stage, Stage): 205 | raise ValueError('stage must be a Stage object') 206 | 207 | stage.pipe = self 208 | self.stages.append(stage) 209 | 210 | def enable_checkpointing( 211 | self, 212 | root: str, 213 | resume: bool = False, 214 | ): 215 | if self.has_checkpointing: 216 | raise ValueError('Checkpointing already enabled') 217 | 218 | if resume and is_valid_checkpoint_dir(root): 219 | self.run_dir = root 220 | self.resumed = True 221 | elif resume and find_slurm_checkpoint(root): 222 | self.run_dir = find_slurm_checkpoint(root) 223 | self.resumed = True 224 | 225 | if self.run_dir is None: # no need for a barrier here, dir creation happens in _pre_run() 226 | path = generate_checkpoint_path(root=root, name=self.name, creation_time=self.start_time) 227 | self.run_dir = broadcast_object(path) 228 | self.resumed = False 229 | 230 | if is_root(): 231 | self.add_callback(CheckpointCallback(self.run_dir), CbPriority.CHECKPOINT) 232 | self.add_callback(CsvCallback(self.run_dir), CbPriority.CSV) 233 | self.add_callback(TensorboardCallback(self.run_dir), CbPriority.TENSORBOARD) 234 | 235 | def enable_wandb( 236 | self, 237 | project: str | None = None, 238 | entity: str | None = None, 239 | group: str | None = None, 240 | tags: List[str] | None = None, 241 | startup_timeout: int = 360, 242 | **kwargs, 243 | ): 244 | if self._wandb: 245 | raise ValueError('Wandb already enabled') 246 | 247 | import wandb # import now to avoid potential long import times later on # noqa 248 | 249 | if is_root(): 250 | init_callback = WandbInitCallback( 251 | project=project, 252 | entity=entity, 253 | group=group, 254 | tags=tags, 255 | startup_timeout=startup_timeout, 256 | **kwargs, 257 | ) 258 | self.add_callback(init_callback, CbPriority.WANDB_INIT) 259 | self.add_callback(WandbLoggerCallback(), CbPriority.WANDB_LOGGER) 260 | 261 | self._wandb = True 262 | 263 | def barrier(self, timeout=None): 264 | if self.gloo_group is None: 265 | dist.barrier() 266 | else: 267 | timeout = timedelta(seconds=timeout) if timeout is not None else None 268 | dist.monitored_barrier(self.gloo_group, timeout=timeout, wait_all_ranks=True) 269 | 270 | def run(self): 271 | """ 272 | Starts the training and runs all registered stages. 273 | """ 274 | if len(self.stages) == 0: 275 | raise ValueError('No stages defined. Use append() to add stages to the pipeline.') 276 | 277 | # make sure everything is set up before starting the run 278 | # important to prevent checkpoint dir creation before all processes searched for it 279 | self.barrier(timeout=10 * 60) 280 | 281 | with _RunGuard(self): 282 | self._pre_run() 283 | for stage in self.stages: 284 | self.current_stage = stage 285 | stage._run() 286 | self._post_run() 287 | 288 | def pre_run(self): 289 | pass 290 | 291 | def post_run(self): 292 | pass 293 | 294 | def resume_run(self): 295 | pass 296 | 297 | @cached_property 298 | def device(self): 299 | if torch.cuda.is_available(): 300 | if local_rank() is None: 301 | warnings.warn( 302 | 'CUDA is available but no local rank found. Make sure to set CUDA_VISIBLE_DEVICES manually for each rank.' 303 | ) 304 | return torch.device('cuda') 305 | else: 306 | return torch.device('cuda', local_rank()) 307 | else: 308 | warnings.warn('CUDA is not available. Running on CPU.') 309 | return torch.device('cpu') 310 | 311 | def _pre_run(self): 312 | self.start_time = datetime.now() 313 | 314 | if self.resumed: 315 | self._resume_run() 316 | 317 | for callback in self.callbacks: 318 | callback.pre_run(self) 319 | 320 | def _resume_run(self): 321 | dml_logging.info(f'Resuming training from checkpoint: {self.run_dir}') 322 | self.resume_run() 323 | 324 | def _post_run(self): 325 | self.stop_time = datetime.now() 326 | 327 | for callback in self.callbacks: 328 | callback.post_run(self) 329 | -------------------------------------------------------------------------------- /dmlcloud/core/stage.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import torch 4 | 5 | from . import logging as dml_logging 6 | from .callbacks import ( 7 | Callback, 8 | CallbackList, 9 | CbPriority, 10 | ProfilerCallback, 11 | ReduceMetricsCallback, 12 | TableCallback, 13 | TimerCallback, 14 | ) 15 | from .distributed import is_root 16 | from .metrics import Tracker, TrainingHistory 17 | 18 | 19 | __all__ = [ 20 | 'Stage', 21 | ] 22 | 23 | 24 | class _ForwardCallback(Callback): 25 | """ 26 | Invokes the pre_stage, post_stage, pre_epoch, and post_epoch methods of the Stage. 27 | """ 28 | 29 | def pre_stage(self, stage): 30 | stage.pre_stage() 31 | 32 | def post_stage(self, stage): 33 | stage.post_stage() 34 | 35 | def pre_epoch(self, stage): 36 | stage.pre_epoch() 37 | 38 | def post_epoch(self, stage): 39 | stage.post_epoch() 40 | 41 | def post_step(self, stage): 42 | stage.post_step() 43 | 44 | 45 | class Stage: 46 | """ 47 | Hook Points: 48 | - pre_stage() 49 | - post_stage() 50 | - pre_epoch() 51 | - post_epoch() 52 | """ 53 | 54 | def __init__(self, name: str = None, epochs: int | None = 1): 55 | self.name = name or self.__class__.__name__ 56 | self.max_epochs = epochs 57 | 58 | self.callbacks = CallbackList() 59 | 60 | self.pipe = None # set by the pipeline 61 | 62 | self.history = TrainingHistory() 63 | self.step_history = TrainingHistory() 64 | self.metrics = Tracker() 65 | self.step_metrics = Tracker() 66 | 67 | self.step = 0 68 | self.global_step = 0 69 | 70 | self.metric_prefix = None 71 | self.barrier_timeout = None 72 | 73 | self._timer = TimerCallback() 74 | self._table_callback = TableCallback() 75 | self._reduce_metrics_callback = ReduceMetricsCallback() 76 | self._forward_callback = _ForwardCallback() 77 | self._profiler_callback = None 78 | self.add_callback(self._timer, CbPriority.STAGE_TIMER) 79 | self.add_callback(self._reduce_metrics_callback, CbPriority.METRIC_REDUCTION) 80 | self.add_callback(self._table_callback, CbPriority.TABLE) 81 | self.add_callback(self._forward_callback, CbPriority.OBJECT_METHODS) # methods have priority 0 82 | 83 | @property 84 | def device(self): 85 | """ 86 | Same as :attr:`Pipeline.device`. 87 | """ 88 | return self.pipe.device 89 | 90 | @property 91 | def config(self): 92 | """ 93 | Same as :attr:`Pipeline.config`. 94 | """ 95 | return self.pipe.config 96 | 97 | @property 98 | def run_dir(self): 99 | """ 100 | Same as :attr:`Pipeline.run_dir`. 101 | """ 102 | return self.pipe.run_dir 103 | 104 | @property 105 | def current_epoch(self): 106 | return self.history.num_steps 107 | 108 | @property 109 | def start_time(self): 110 | return self._timer.start_time 111 | 112 | @property 113 | def end_time(self): 114 | return self._timer.end_time 115 | 116 | @property 117 | def epoch_start_time(self): 118 | return self._timer.epoch_start_time 119 | 120 | @property 121 | def epoch_end_time(self): 122 | return self._timer.epoch_end_time 123 | 124 | @property 125 | def table(self): 126 | return self._table_callback.get_table(self) 127 | 128 | @property 129 | def has_profiler(self) -> bool: 130 | """ 131 | Returns True if the profiler is enabled for this stage, otherwise False. 132 | """ 133 | return self._profiler_callback is not None 134 | 135 | @property 136 | def profiler(self) -> torch.profiler.profile | None: 137 | """ 138 | If enabled, returns the profiler object associated with this stage, otherwise None. 139 | 140 | Returns: 141 | torch.profiler.profile or None: The profiler object. 142 | """ 143 | if not self.has_profiler: 144 | return None 145 | return self._profiler_callback.profiler 146 | 147 | @property 148 | def _run_overridden(self): 149 | return type(self).run != Stage.run 150 | 151 | @property 152 | def _run_epoch_overridden(self): 153 | return type(self).run_epoch != Stage.run_epoch 154 | 155 | def add_callback(self, callback: 'Callback', priority: int = 1): 156 | """ 157 | Adds a callback to this stage. 158 | 159 | Callbacks are executed based on their priority, with lower values being executed first. 160 | Callbacks with the same priority are executed in the order they were added. 161 | 162 | The pre_stage, post_stage, pre_epoch, and post_epoch methods are treated as callbacks with priority 0. 163 | 164 | Args: 165 | callback (StageCallback): The callback to add. 166 | priority (int, optional): The priority of the callback. Defaults to 1. 167 | """ 168 | self.callbacks.append(callback, priority) 169 | 170 | def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True, log_step: bool = True): 171 | if prefixed and self.metric_prefix: 172 | name = f'{self.metric_prefix}/{name}' 173 | self.metrics.log(name, value, reduction) 174 | if log_step: 175 | self.step_metrics.log(name, value, reduction) 176 | 177 | def add_metric(self, name, metric): 178 | metric = metric.to(self.device) 179 | self.metrics.add_metric(name, metric) 180 | return metric 181 | 182 | def add_column( 183 | self, 184 | name: str, 185 | metric: str | None = None, 186 | formatter: Callable | None = None, 187 | width: int | None = None, 188 | color: str | None = None, 189 | alignment: str | None = None, 190 | ): 191 | """ 192 | Adds a column to the table. 193 | 194 | If metric is provided, the column will be updated with the latest value of the metric. 195 | Otherwise,the caller must update the value manually using `table.update`. 196 | 197 | If a formatter is provided, the metric value will be passed through the formatter before being displayed. 198 | 199 | For a detailed description of width, color, and alignment, see `ProgressTable.add_column`. 200 | 201 | Args: 202 | name (str): The name of the column. 203 | metric (str, optional): The name of the metric to track. Defaults to None. 204 | formatter (Callable, optional): A function that takes the metric value and returns a string. Defaults to None. 205 | width (int, optional): The width of the column. Defaults to None. 206 | color (str, optional): The color of the column. Defaults to None. 207 | alignment (str, optional): The alignment of the column. Defaults to None. 208 | """ 209 | self._table_callback.track_metric( 210 | self, name, metric=metric, formatter=formatter, width=width, color=color, alignment=alignment 211 | ) 212 | 213 | def enable_profiler(self, epochs: list | None = [0], schedule=None): 214 | """ 215 | Enables the profiler for this stage. 216 | 217 | If the `schedule` argument is not provided, the following default schedule is used: 218 | ``` 219 | schedule = torch.profiler.schedule( 220 | wait=5, 221 | warmup=10, 222 | active=5, 223 | repeat=1, 224 | ) 225 | ``` 226 | 227 | The user must call `self.profiler.step()` on the root rank at the end of each iteration to advance the profiler. 228 | 229 | Args: 230 | epochs (list, optional): The epochs to profile. Defaults to [0]. If None, the profiler is enabled for all epochs. 231 | schedule: The schedule for the profiler, i.e. the object returned by torch.profiler.schedule(). If None, a default schedule is used. Defaults to None. 232 | """ 233 | if not is_root(): 234 | return 235 | 236 | if self.has_profiler: 237 | raise ValueError('Profiler is already enabled for this stage.') 238 | 239 | if schedule is None: 240 | schedule = torch.profiler.schedule( 241 | wait=10, 242 | warmup=10, 243 | active=5, 244 | repeat=1, 245 | ) 246 | 247 | self._profiler_callback = ProfilerCallback(epochs=epochs, schedule=schedule) 248 | self.add_callback(self._profiler_callback, CbPriority.PROFILER) 249 | 250 | def pre_stage(self): 251 | """ 252 | Executed before the stage starts. 253 | Use this method to setup aby stage-specific data sets or models. 254 | """ 255 | pass 256 | 257 | def post_stage(self): 258 | """ 259 | Executed after the stage finishes. 260 | Use this method to clean up any stage-specific resources or to save any intermediate results/artifacts. 261 | """ 262 | pass 263 | 264 | def pre_epoch(self): 265 | """ 266 | Executed before each epoch. 267 | """ 268 | pass 269 | 270 | def post_epoch(self): 271 | """ 272 | Executed after each epoch. 273 | """ 274 | pass 275 | 276 | def post_step(self): 277 | """ 278 | Executed after each step. Stage must call `finish_step()` at the end of each step. 279 | """ 280 | pass 281 | 282 | def run(): 283 | """ 284 | Override this method to implement the main logic of the stage and do manual epoch management. 285 | 286 | Either this method or :meth:`run_epoch` must be implemented by subclasses. 287 | Unlike :meth:`run_epoch`, this method is called only once per stage, and the implementation is responsible for 288 | managing the epochs and calling :meth:`next_epoch` when appropriate. 289 | """ 290 | raise NotImplementedError() 291 | 292 | def next_epoch(self): 293 | """ 294 | Advances the stage to the next epoch. 295 | 296 | This method must only be called by the implementation of :meth:`run` when the stage finishes an epoch. 297 | """ 298 | if self._run_epoch_overridden: 299 | raise ValueError('next_epoch() must not be called when run_epoch() is implemented.') 300 | 301 | self._post_epoch() 302 | self._pre_epoch() 303 | 304 | def finish_step(self): 305 | self._post_step() 306 | 307 | def run_epoch(self): 308 | """ 309 | Override this method to implement the main logic of the stage for a single epoch. 310 | 311 | Either this method or :meth:`run` must be implemented by subclasses. 312 | Unlike :meth:`run`, this method is called automatically by the stage and does not need to manage the epochs. 313 | """ 314 | raise NotImplementedError() 315 | 316 | def _run(self): 317 | """ 318 | Runs this stage. Either until max_epochs are reached, or until stop_stage() is called. 319 | """ 320 | if self._run_overridden and self._run_epoch_overridden: 321 | raise ValueError('Only one of run() or run_epoch() must be implemented.') 322 | elif not self._run_overridden and not self._run_epoch_overridden: 323 | raise ValueError('Either run() or run_epoch() must be implemented.') 324 | elif self._run_epoch_overridden: 325 | self._pre_stage() 326 | while self.max_epochs is None or self.current_epoch < self.max_epochs: 327 | self._pre_epoch() 328 | self.run_epoch() 329 | self._post_epoch() 330 | self._post_stage() 331 | else: 332 | self._pre_stage() 333 | self._pre_epoch() 334 | self.run() 335 | self._post_epoch() 336 | self._post_stage() 337 | 338 | def _pre_stage(self): 339 | if len(self.pipe.stages) > 1: 340 | dml_logging.info(f'\n========== STAGE: {self.name} ==========') 341 | 342 | callbacks = self.callbacks + self.pipe.callbacks 343 | for callback in callbacks: 344 | callback.pre_stage(self) 345 | 346 | dml_logging.flush_logger() 347 | self.pipe.barrier(self.barrier_timeout) 348 | 349 | def _post_stage(self): 350 | callbacks = self.callbacks + self.pipe.callbacks 351 | for callback in callbacks: 352 | callback.post_stage(self) 353 | self.pipe.barrier(self.barrier_timeout) 354 | 355 | def _pre_epoch(self): 356 | callbacks = self.callbacks + self.pipe.callbacks 357 | for callback in callbacks: 358 | callback.pre_epoch(self) 359 | 360 | def _post_epoch(self): 361 | callbacks = self.callbacks + self.pipe.callbacks 362 | for callback in callbacks: 363 | callback.post_epoch(self) 364 | 365 | def _post_step(self): 366 | callbacks = self.callbacks + self.pipe.callbacks 367 | for callback in callbacks: 368 | callback.post_step(self) 369 | -------------------------------------------------------------------------------- /dmlcloud/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains helpers for distributed data processing and loading.""" 2 | 3 | __all__ = [] 4 | 5 | # Sharding 6 | 7 | from .sharding import chunk_and_shard_indices, shard_indices, shard_sequence 8 | 9 | __all__ += [ 10 | 'shard_indices', 11 | 'shard_sequence', 12 | 'chunk_and_shard_indices', 13 | ] 14 | 15 | # Dataset 16 | 17 | from .dataset import BatchDataset, DownstreamDataset, PrefetchDataset, ShardedSequenceDataset 18 | 19 | __all__ += [ 20 | 'ShardedSequenceDataset', 21 | 'DownstreamDataset', 22 | 'PrefetchDataset', 23 | 'BatchDataset', 24 | ] 25 | 26 | # Interleave 27 | 28 | from .interleave import interleave_batches, interleave_dict_batches 29 | 30 | __all__ += [ 31 | 'interleave_batches', 32 | 'interleave_dict_batches', 33 | ] 34 | 35 | 36 | # Xarray 37 | 38 | from .xarray import sharded_xr_dataset, ShardedXrDataset 39 | 40 | __all__ += [ 41 | 'sharded_xr_dataset', 42 | 'ShardedXrDataset', 43 | ] 44 | -------------------------------------------------------------------------------- /dmlcloud/data/dataset.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from typing import Iterable, Sequence 3 | 4 | import torch.distributed as dist 5 | from torch.utils.data import get_worker_info, IterableDataset 6 | 7 | from .sharding import shard_sequence 8 | 9 | 10 | __all__ = [ 11 | 'ShardedSequenceDataset', 12 | 'DownstreamDataset', 13 | 'PrefetchDataset', 14 | 'BatchDataset', 15 | ] 16 | 17 | 18 | class ShardedSequenceDataset(IterableDataset): 19 | def __init__( 20 | self, 21 | sequence: Sequence, 22 | shuffle: bool = False, 23 | even_shards: bool = True, 24 | seed: int = 0, 25 | rank: int | None = None, 26 | world_size: int | None = None, 27 | ): 28 | self.sequence = sequence 29 | self.shuffle = shuffle 30 | self.even_shards = even_shards 31 | self.seed = seed 32 | self.rank = rank if rank is not None else dist.get_rank() 33 | self.world_size = world_size if world_size is not None else dist.get_world_size() 34 | self.epoch = 0 35 | 36 | def set_epoch(self, epoch: int): 37 | self.epoch = epoch 38 | 39 | def __iter__(self): 40 | worker_info = get_worker_info() 41 | if worker_info is None: 42 | rank = self.rank 43 | world_size = self.world_size 44 | else: 45 | rank = self.rank * worker_info.num_workers + worker_info.id 46 | world_size = self.world_size * worker_info.num_workers 47 | shards = shard_sequence( 48 | self.sequence, 49 | rank, 50 | world_size, 51 | shuffle=self.shuffle, 52 | even_shards=self.even_shards, 53 | seed=self.seed + self.epoch, 54 | ) 55 | return iter(shards) 56 | 57 | 58 | class DownstreamDataset(IterableDataset): 59 | def __init__(self, source_ds: Iterable): 60 | self.source_ds = source_ds 61 | 62 | def set_epoch(self, epoch: int): 63 | if hasattr(self.source_ds, 'set_epoch'): 64 | self.source_ds.set_epoch(epoch) 65 | 66 | def __len__(self): 67 | return len(self.source_ds) 68 | 69 | 70 | class PrefetchDataset(DownstreamDataset): 71 | def __init__(self, source_ds: Iterable, num_elements: int): 72 | super().__init__(source_ds) 73 | self.num_elements = num_elements 74 | 75 | def __iter__(self): 76 | pool = ThreadPoolExecutor(max_workers=1) 77 | iter_ = iter(self.source_ds) 78 | 79 | with pool: 80 | futures = [pool.submit(next, iter_) for _ in range(self.num_elements)] 81 | while True: 82 | future = futures.pop(0) 83 | try: 84 | element = future.result() 85 | except StopIteration: 86 | return 87 | futures += [pool.submit(next, iter_)] 88 | yield element 89 | 90 | 91 | class BatchDataset(DownstreamDataset): 92 | def __init__(self, source_ds: Iterable, batch_size: int, drop_remainder: bool = False): 93 | super().__init__(source_ds) 94 | self.batch_size = batch_size 95 | self.drop_remainder = drop_remainder 96 | 97 | def __len__(self): 98 | if self.drop_remainder: 99 | return len(self.source_ds) // self.batch_size 100 | else: 101 | return (len(self.source_ds) + self.batch_size - 1) // self.batch_size 102 | 103 | def __iter__(self): 104 | batch = [] 105 | for element in self.source_ds: 106 | batch.append(element) 107 | if len(batch) == self.batch_size: 108 | yield batch 109 | batch = [] 110 | if batch and not self.drop_remainder: 111 | yield batch 112 | -------------------------------------------------------------------------------- /dmlcloud/data/interleave.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch 4 | 5 | __all__ = [ 6 | 'interleave_batches', 7 | 'interleave_dict_batches', 8 | ] 9 | 10 | 11 | def interleave_batches( 12 | iterable: Iterable[torch.Tensor], num_batches: int, pin_memory: bool = False 13 | ) -> Iterable[torch.Tensor]: 14 | """ 15 | Interleaves batches from an iterable of batches. 16 | Important: Returned batches must be used immediately or copied to avoid overwriting. 17 | """ 18 | if num_batches < 1: 19 | raise ValueError('num_batches must be greater than 0') 20 | 21 | if num_batches == 1: 22 | yield from iterable 23 | 24 | batches = [] 25 | memory = None 26 | batch_size = None 27 | slice_size = None 28 | for batch in iterable: 29 | if memory is None: 30 | batch_size = batch.shape[0] 31 | slice_size = batch_size // num_batches 32 | if batch_size % num_batches != 0: 33 | raise ValueError(f'Batch dimension ({batch_size}) must be divisible by num_batches={num_batches}') 34 | memory = torch.empty( 35 | (num_batches, *batch.shape), dtype=batch.dtype, device=batch.device, pin_memory=pin_memory 36 | ) 37 | 38 | batches.append(batch) 39 | 40 | if len(batches) == num_batches: 41 | for i in range(num_batches): 42 | for j in range(num_batches): 43 | memory[i, j * slice_size : (j + 1) * slice_size] = batches[j][i * slice_size : (i + 1) * slice_size] 44 | batches = [] 45 | for i in range(num_batches): 46 | yield memory[i] 47 | 48 | 49 | def interleave_dict_batches( 50 | iterable: Iterable[torch.Tensor], num_batches: int, pin_memory: bool = False 51 | ) -> Iterable[torch.Tensor]: 52 | """ 53 | Interleaves batches from an iterable of batches. 54 | Important: Returned batches must be used immediately or copied to avoid overwriting. 55 | """ 56 | if num_batches < 1: 57 | raise ValueError('num_batches must be greater than 0') 58 | 59 | if num_batches == 1: 60 | yield from iterable 61 | 62 | batches = [] 63 | memory = {} 64 | slice_size = {} 65 | for batch in iterable: 66 | if not memory: 67 | for k, tensor in batch.items(): 68 | batch_size = tensor.shape[0] 69 | if batch_size % num_batches != 0: 70 | raise ValueError(f'Batch dimension ({batch_size}) must be divisible by num_batches={num_batches}') 71 | slice_size[k] = batch_size // num_batches 72 | memory[k] = torch.empty( 73 | (num_batches, *tensor.shape), dtype=tensor.dtype, device=tensor.device, pin_memory=pin_memory 74 | ) 75 | 76 | batches.append(batch) 77 | 78 | if len(batches) == num_batches: 79 | for k in memory: 80 | for i in range(num_batches): 81 | for j in range(num_batches): 82 | source = batches[j][k][i * slice_size[k] : (i + 1) * slice_size[k]] 83 | memory[k][i, j * slice_size[k] : (j + 1) * slice_size[k]] = source 84 | batches = [] 85 | for i in range(num_batches): 86 | yield {k: memory[k][i] for k in memory.keys()} 87 | -------------------------------------------------------------------------------- /dmlcloud/data/sharding.py: -------------------------------------------------------------------------------- 1 | """Utilities for sharding data across multiple workers.""" 2 | 3 | 4 | from typing import Sequence 5 | 6 | import numpy as np 7 | 8 | __all__ = [ 9 | 'shard_indices', 10 | 'shard_sequence', 11 | 'chunk_and_shard_indices', 12 | ] 13 | 14 | 15 | def shard_indices( 16 | num_elements: int, 17 | rank: int, 18 | world_size: int, 19 | shuffle: bool = False, 20 | even_shards: bool = True, 21 | seed: int = 0, 22 | ) -> list[int]: 23 | """ 24 | even_shards: If True, every worker receives the same number of shards, and the last shards are dropped. 25 | """ 26 | indices = np.arange(num_elements) 27 | 28 | if shuffle: 29 | np.random.Generator(np.random.MT19937(seed)).shuffle(indices) 30 | 31 | if even_shards: 32 | indices = indices[: num_elements - num_elements % world_size] 33 | 34 | return indices[rank::world_size].tolist() # this also converts np.int64 to python's int 35 | 36 | 37 | def shard_sequence( 38 | sequence: Sequence, 39 | rank: int, 40 | world_size: int, 41 | shuffle: bool = False, 42 | even_shards: bool = True, 43 | seed: int = 0, 44 | ): 45 | indices = shard_indices(len(sequence), rank, world_size, shuffle=shuffle, even_shards=even_shards, seed=seed) 46 | return [sequence[i] for i in indices] 47 | 48 | 49 | def chunk_and_shard_indices( 50 | num_elements: int, 51 | chunk_size: int, 52 | rank: int, 53 | world_size: int, 54 | chunk_overlap: int = 0, 55 | even_shards: bool = True, 56 | equal_chunks: bool = True, 57 | shuffle: bool = False, 58 | seed: int = 0, 59 | ): 60 | if equal_chunks: 61 | num_chunks = num_elements // chunk_size 62 | else: 63 | num_chunks = (num_elements + chunk_size - 1) // chunk_size 64 | 65 | chunk_indices = shard_indices(num_chunks, rank, world_size, shuffle=shuffle, even_shards=even_shards, seed=seed) 66 | chunks = [] 67 | for chunk_idx in chunk_indices: 68 | start = chunk_idx * chunk_size 69 | end = start + chunk_size + chunk_overlap 70 | chunks.append((start, end)) 71 | return chunks 72 | -------------------------------------------------------------------------------- /dmlcloud/data/xarray.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch.distributed as dist 4 | import xarray as xr 5 | from torch.utils.data import get_worker_info, IterableDataset 6 | 7 | from .sharding import chunk_and_shard_indices 8 | 9 | 10 | __all__ = [ 11 | 'sharded_xr_dataset', 12 | 'ShardedXrDataset', 13 | ] 14 | 15 | 16 | def sharded_xr_dataset( 17 | ds: xr.Dataset | xr.DataArray, 18 | dim: str, 19 | chunk_size: int, 20 | chunk_overlap: int = 0, 21 | even_shards: bool = True, 22 | equal_chunks: bool = True, 23 | shuffle: bool = False, 24 | seed: int = 0, 25 | rank: int | None = None, 26 | world_size: int | None = None, 27 | process_group: dist.ProcessGroup | None = None, 28 | load: bool = False, 29 | load_kwargs: dict | None = None, 30 | ) -> Iterable[xr.Dataset | xr.DataArray]: 31 | if rank is None: 32 | rank = dist.get_rank(process_group) 33 | if world_size is None: 34 | world_size = dist.get_world_size(process_group) 35 | 36 | num_elements = len(ds[dim]) 37 | chunks = chunk_and_shard_indices( 38 | num_elements, 39 | chunk_size, 40 | rank, 41 | world_size, 42 | chunk_overlap=chunk_overlap, 43 | even_shards=even_shards, 44 | equal_chunks=equal_chunks, 45 | shuffle=shuffle, 46 | seed=seed, 47 | ) 48 | for start, end in chunks: 49 | chunk = ds.isel({dim: slice(start, end)}) 50 | if load: 51 | kwargs = load_kwargs or {} 52 | chunk.load(**kwargs) 53 | yield chunk 54 | 55 | 56 | class ShardedXrDataset(IterableDataset): 57 | def __init__( 58 | self, 59 | ds: xr.Dataset | xr.DataArray, 60 | dim: str, 61 | chunk_size: int, 62 | chunk_overlap: int = 0, 63 | even_shards: bool = True, 64 | equal_chunks: bool = True, 65 | shuffle: bool = False, 66 | seed: int = 0, 67 | rank: int | None = None, 68 | world_size: int | None = None, 69 | process_group: dist.ProcessGroup | None = None, 70 | load: bool = False, 71 | load_kwargs: dict | None = None, 72 | ): 73 | self.ds = ds 74 | self.dim = dim 75 | self.chunk_size = chunk_size 76 | self.chunk_overlap = chunk_overlap 77 | self.even_shards = even_shards 78 | self.equal_chunks = equal_chunks 79 | self.shuffle = shuffle 80 | self.seed = seed 81 | self.load = load 82 | self.load_kwargs = load_kwargs 83 | 84 | self.rank = rank if rank is not None else dist.get_rank(process_group) 85 | self.world_size = world_size if world_size is not None else dist.get_world_size(process_group) 86 | self._num_iters = 0 87 | 88 | def set_epoch(self, epoch: int): 89 | self._num_iters = epoch 90 | 91 | def __iter__(self): 92 | worker_info = get_worker_info() 93 | if worker_info is None: 94 | rank = self.rank 95 | world_size = self.world_size 96 | else: 97 | rank = self.rank * worker_info.num_workers + worker_info.id 98 | world_size = self.world_size * worker_info.num_workers 99 | 100 | return sharded_xr_dataset( 101 | self.ds, 102 | self.dim, 103 | self.chunk_size, 104 | chunk_overlap=self.chunk_overlap, 105 | even_shards=self.even_shards, 106 | equal_chunks=self.equal_chunks, 107 | shuffle=self.shuffle, 108 | seed=self.seed + self._num_iters, 109 | rank=rank, 110 | world_size=world_size, 111 | load=self.load, 112 | load_kwargs=self.load_kwargs, 113 | ) 114 | -------------------------------------------------------------------------------- /dmlcloud/git.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides functions to interact with git 3 | """ 4 | 5 | import subprocess 6 | import sys 7 | import traceback 8 | from pathlib import Path 9 | 10 | 11 | def is_setuptools_cli_script(module): 12 | """ 13 | Heuristically checks if the given module is a cli script generated by setuptools. 14 | """ 15 | if not hasattr(module, '__file__'): 16 | return False 17 | try: 18 | with open(module.__file__) as f: 19 | lines = f.readlines(4089) 20 | except OSError: 21 | return False 22 | 23 | if len(lines) != 8: 24 | return False 25 | if not lines[0].startswith('#!'): 26 | return False 27 | if lines[2] != 'import re\n': 28 | return False 29 | if lines[3] != 'import sys\n': 30 | return False 31 | if lines[5] != "if __name__ == '__main__':\n": 32 | return False 33 | if 'sys.exit(' not in lines[7]: 34 | return False 35 | return True 36 | 37 | 38 | def script_path() -> Path | None: 39 | """ 40 | Returns the path to the script or module that was executed. 41 | 42 | Returns None if python runs in interactive mode, or if "-c" command line option was used. 43 | 44 | Returns: 45 | Path to the script or module that was executed or None if not available. 46 | """ 47 | main = sys.modules['__main__'] 48 | if not hasattr(main, '__file__'): 49 | return None # interactive mode 50 | 51 | if is_setuptools_cli_script(main): 52 | stack = traceback.extract_stack() 53 | if len(stack) < 2: 54 | return Path(main.__file__).resolve() 55 | return Path(stack[1].filename).resolve() 56 | 57 | else: 58 | return Path(main.__file__).resolve() 59 | 60 | 61 | def script_directory() -> Path | None: 62 | """ 63 | Returns the directory containing the script or module that was executed. 64 | 65 | Returns None if python runs in interactive mode, or if "-c" command line option was used. 66 | 67 | Returns: 68 | Directory containing the script or module that was executed or None if not available. 69 | """ 70 | file = script_path() 71 | if file is None: 72 | return None 73 | else: 74 | return file.parent 75 | 76 | 77 | def project_directory() -> Path | None: 78 | """ 79 | Returns the top-level directory containing the script or module that was executed. 80 | 81 | Returns None if python runs in interactive mode, or if "-c" command line option was used. 82 | 83 | Returns: 84 | Top-level directory containing the script or module that was executed or None if not available. 85 | """ 86 | cur_dir = script_directory() 87 | if cur_dir is None: 88 | return None 89 | 90 | while (cur_dir / '__init__.py').exists(): 91 | cur_dir = cur_dir.parent 92 | return cur_dir 93 | 94 | 95 | def run_in_project(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) -> subprocess.CompletedProcess: 96 | """ 97 | Runs a command in the project directory and returns the output. 98 | 99 | Raises: 100 | RuntimeError: If the project directory could not be determined. 101 | """ 102 | cwd = project_directory() 103 | if cwd is None: 104 | raise RuntimeError('Could not determine project directory') 105 | return subprocess.run(cmd, cwd=cwd, stdout=stdout, stderr=stderr, **kwargs) 106 | 107 | 108 | def git_hash(short=False) -> str | None: 109 | """ 110 | Returns the git hash of the current commit. 111 | 112 | If git is not available or the project is not a git repository, None is returned. 113 | 114 | Args: 115 | short: If True, the short hash is returned. 116 | 117 | Returns: 118 | The git hash of the current commit or None if not available. 119 | """ 120 | try: 121 | if short: 122 | process = run_in_project(['git', 'rev-parse', '--short', 'HEAD']) 123 | else: 124 | process = run_in_project(['git', 'rev-parse', 'HEAD']) 125 | return process.stdout.decode('utf-8').strip() 126 | except RuntimeError: 127 | return None 128 | 129 | 130 | def git_diff() -> str | None: 131 | """ 132 | Returns the output of `git diff -U0 --no-color HEAD` 133 | 134 | If git is not available or the project is not a git repository, None is returned. 135 | 136 | Returns: 137 | The output of `git diff -U0 --no-color HEAD` or None if not available. 138 | """ 139 | 140 | try: 141 | process = run_in_project(['git', 'diff', '-U0', '--no-color', 'HEAD']) 142 | return process.stdout.decode('utf-8').strip() 143 | except RuntimeError: 144 | return None 145 | -------------------------------------------------------------------------------- /dmlcloud/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | usage: dmlrun [-h] [--gpus GPUS] [--nprocs NPROCS] script ... 3 | 4 | dmlrun is a thin wrapper around torch.distributed.launch that provides a more user-friendly interface. 5 | 6 | While torchrun is a powerful tool, it can be a bit clunky to use for testing and debugging. dmlrun aims to make it easier to launch distributed training jobs on a single node.For serious mulit-node training, we recommend using srun or torchrun directly. 7 | 8 | positional arguments: 9 | script Path to the script to run. 10 | args Arguments to pass to the script. 11 | 12 | options: 13 | -h, --help show this help message and exit 14 | --gpus GPUS, -g GPUS Comma-seperated list of GPU IDs to use for training. Overrides CUDA_VISIBLE_DEVICES. 15 | --nprocs NPROCS, -n NPROCS 16 | Number of GPUs to use for training. 17 | 18 | Example: 19 | dmlrun --gpus 3,7 train.py 20 | dmlrun --num-gpus 2 train.py --batch-size 64 21 | """ 22 | 23 | import argparse 24 | import os 25 | 26 | 27 | def main(): 28 | description = ( 29 | 'dmlrun is a thin wrapper around torch.distributed.launch that provides a more user-friendly interface.\n\n' 30 | 'While torchrun is a powerful tool, it can be a bit clunky to use for testing and debugging. dmlrun aims to make it easier to launch distributed training jobs on a single node.' 31 | 'For serious mulit-node training, we recommend using srun or torchrun directly.' 32 | ) 33 | epilog = 'Example:\n' ' dmlrun --gpus 3,7 train.py\n' ' dmlrun --num-gpus 2 train.py --batch-size 64' 34 | parser = argparse.ArgumentParser( 35 | prog='dmlrun', description=description, epilog=epilog, formatter_class=argparse.RawDescriptionHelpFormatter 36 | ) 37 | parser.add_argument( 38 | '--gpus', '-g', help='Comma-seperated list of GPU IDs to use for training. Overrides CUDA_VISIBLE_DEVICES.' 39 | ) 40 | parser.add_argument('--nprocs', '-n', type=int, help='Number of GPUs to use for training.') 41 | parser.add_argument('script', type=str, help='Path to the script to run.') 42 | parser.add_argument('args', nargs=argparse.REMAINDER, help='Arguments to pass to the script.') 43 | 44 | args = parser.parse_args() 45 | 46 | if args.gpus and args.nprocs: 47 | raise ValueError('Only one of --gpus or --num-gpus can be specified.') 48 | 49 | if args.gpus: 50 | ids = args.gpus.split(',') 51 | if not all(id.isdigit() for id in ids): 52 | raise ValueError('GPU IDs must be integers.') 53 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 54 | nprocs = len(ids) 55 | elif args.nprocs: 56 | nprocs = args.nprocs 57 | else: 58 | nprocs = 1 59 | 60 | import torch.distributed.run 61 | 62 | cmdline = [ 63 | '--standalone', 64 | '--nproc_per_node', 65 | f'{nprocs}', 66 | '--no-python', 67 | ] 68 | 69 | cmdline += [args.script] + args.args 70 | print('Executing: torchrun', ' '.join(cmdline), flush=True) 71 | torch.distributed.run.main(cmdline) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /dmlcloud/slurm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides functions to interact with slurm 3 | """ 4 | 5 | 6 | import os 7 | 8 | 9 | def slurm_job_id(): 10 | """ """ 11 | return os.environ.get('SLURM_JOB_ID') 12 | 13 | 14 | def slurm_step_id(): 15 | return os.environ.get('SLURM_STEP_ID') 16 | 17 | 18 | def slurm_available(): 19 | return slurm_job_id() is not None 20 | -------------------------------------------------------------------------------- /dmlcloud/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sehoffmann/dmlcloud/9aba8f3c62e3ca52852b7d5334902e52430677ea/dmlcloud/util/__init__.py -------------------------------------------------------------------------------- /dmlcloud/util/argparse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import enum 3 | 4 | 5 | class EnumAction(argparse.Action): 6 | """ 7 | Argparse action for handling Enums 8 | From https://stackoverflow.com/a/60750535/4546885 9 | """ 10 | 11 | def __init__(self, **kwargs): 12 | # Pop off the type value 13 | enum_type = kwargs.pop("type", None) 14 | 15 | # Ensure an Enum subclass is provided 16 | if enum_type is None: 17 | raise ValueError("type must be assigned an Enum when using EnumAction") 18 | if not issubclass(enum_type, enum.Enum): 19 | raise TypeError("type must be an Enum when using EnumAction") 20 | 21 | # Generate choices from the Enum 22 | kwargs.setdefault("choices", tuple(e.value for e in enum_type)) 23 | 24 | super().__init__(**kwargs) 25 | 26 | self._enum = enum_type 27 | 28 | def __call__(self, parser, namespace, values, option_string=None): 29 | # Convert value back into an Enum 30 | value = self._enum(values) 31 | setattr(namespace, self.dest, value) 32 | -------------------------------------------------------------------------------- /dmlcloud/util/logging.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | from datetime import timedelta 4 | from pathlib import Path 5 | 6 | 7 | class TimedeltaFormatter: 8 | """ 9 | A formatter that converts a number of seconds to a human-readable string. 10 | """ 11 | 12 | def __init__(self, microseconds=False): 13 | self.microseconds = microseconds 14 | 15 | def __call__(self, seconds: float) -> str: 16 | delta = timedelta(seconds=seconds) 17 | if not self.microseconds: 18 | delta -= timedelta(microseconds=delta.microseconds) 19 | return str(delta) 20 | 21 | 22 | class IORedirector: 23 | """ 24 | Context manager to redirect stdout and stderr to a file. 25 | Data is written to the file and the original streams. 26 | """ 27 | 28 | # Caveats: 29 | # * We always need to forward the current stdout/stderr. People can change them. 30 | # * Even after uninstall, people can still hold reference to the redirected streams. 31 | # Hence, we must be fault tolorant and not crash if the file is closed or the streams are changed. 32 | 33 | class RedirectedStream: 34 | def __init__(self, io_redirector, stream_name): 35 | self.__io_redirector = io_redirector 36 | self.__stream_name = stream_name 37 | self.__org_stream = None 38 | 39 | @property 40 | def __file(self): 41 | return self.__io_redirector.file 42 | 43 | @property 44 | def __current_stream(self): 45 | return getattr(sys, self.__stream_name) 46 | 47 | def install(self): 48 | self.__org_stream = getattr(sys, self.__stream_name) 49 | setattr(sys, self.__stream_name, self) 50 | 51 | def uninstall(self): 52 | setattr(sys, self.__stream_name, self.__org_stream) 53 | self.__org_stream = None 54 | 55 | def write(self, data): 56 | if self.__file is not None: 57 | self.__file.write(data) 58 | 59 | if self.__current_stream is self: # Avoid infinite recursion 60 | self.__org_stream.write(data) 61 | else: 62 | self.__current_stream.write(data) 63 | 64 | def flush(self): 65 | if self.__file is not None: 66 | self.__file.flush() 67 | 68 | if self.__current_stream is self: # Avoid infinite recursion 69 | self.__org_stream.flush() 70 | else: 71 | self.__current_stream.flush() 72 | 73 | def __getattr__(self, name): 74 | if self.__current_stream is self: 75 | return getattr(self.__org_stream, name) 76 | else: 77 | raise AttributeError(obj=self, name=name) 78 | 79 | 80 | def __init__(self, log_file: Path): 81 | self.path = log_file 82 | self.file = None 83 | self.stdout = None 84 | self.stderr = None 85 | 86 | def install(self): 87 | if self.file is not None: 88 | return 89 | 90 | self.file = self.path.open('a', encoding='utf-8', errors='replace') 91 | 92 | self.stdout = self.RedirectedStream(self, 'stdout') 93 | self.stdout.install() 94 | 95 | self.stderr = self.RedirectedStream(self, 'stderr') 96 | self.stderr.install() 97 | 98 | def uninstall(self): 99 | if self.file is None: 100 | raise ValueError('IORedirector is not installed') 101 | 102 | self.stdout.uninstall() 103 | self.stderr.uninstall() 104 | 105 | file = self.file 106 | self.file = None # Prevent further writes 107 | file.close() 108 | 109 | self.stdout = None 110 | self.stderr = None 111 | 112 | def __enter__(self): 113 | self.install() 114 | return self 115 | 116 | def __exit__(self, exc_type, exc_value, traceback): 117 | self.uninstall() 118 | 119 | 120 | class DevNullIO(io.TextIOBase): 121 | """ 122 | Dummy TextIOBase that will simply ignore anything written to it similar to /dev/null 123 | """ 124 | 125 | def write(self, msg): 126 | pass 127 | -------------------------------------------------------------------------------- /dmlcloud/util/seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def seed_all(seed: int): 8 | torch.manual_seed(seed) 9 | np.random.seed(seed) 10 | random.seed(seed) 11 | 12 | 13 | def enable_determinism(): 14 | torch.backends.cudnn.benchmark = False 15 | torch.use_deterministic_algorithms(True) 16 | -------------------------------------------------------------------------------- /dmlcloud/util/tcp.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import subprocess 3 | 4 | 5 | def find_free_port(): 6 | """ 7 | Returns a free port on the local machine. 8 | """ 9 | with socket.socket() as s: 10 | s.bind(('', 0)) 11 | return s.getsockname()[1] 12 | 13 | 14 | def get_local_ips(use_hostname=True): 15 | """ 16 | Returns the IP addresses of the local machine. 17 | """ 18 | if use_hostname: 19 | proc = subprocess.run(['hostname', '-I'], capture_output=True, text=True) 20 | if proc.returncode == 0: 21 | return proc.stdout.strip().split(' ') 22 | else: 23 | err = proc.stderr.strip() 24 | raise RuntimeError(err) 25 | else: 26 | hostname = socket.gethostname() 27 | return socket.gethostbyname_ex(hostname)[2] 28 | -------------------------------------------------------------------------------- /dmlcloud/util/thirdparty.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | from types import ModuleType 4 | from typing import Optional 5 | 6 | 7 | ML_MODULES = [ 8 | 'torch', 9 | 'torchvision', 10 | 'torchtext', 11 | 'torchaudio', 12 | 'einops', 13 | 'numpy', 14 | 'pandas', 15 | 'xarray', 16 | 'sklearn', 17 | ] 18 | 19 | 20 | def is_imported(name: str) -> bool: 21 | return name in sys.modules 22 | 23 | 24 | def try_import(name: str) -> Optional[ModuleType]: 25 | try: 26 | return importlib.import_module(name) 27 | except ImportError: 28 | return None 29 | 30 | 31 | def try_get_version(name: str) -> Optional[str]: 32 | module = try_import(name) 33 | if module is not None: 34 | return str(module.__version__) 35 | else: 36 | return None 37 | -------------------------------------------------------------------------------- /dmlcloud/util/wandb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | class WandbModuleWrapper: 6 | def __getattr__(self, name): 7 | import wandb 8 | 9 | return getattr(wandb, name) 10 | 11 | def __setattr__(self, name, value): 12 | import wandb 13 | 14 | setattr(wandb, name, value) 15 | 16 | 17 | wandb = WandbModuleWrapper() 18 | 19 | 20 | def wandb_set_startup_timeout(seconds: int): 21 | assert isinstance(seconds, int) 22 | os.environ['WANDB__SERVICE_WAIT'] = f'{seconds}' 23 | 24 | 25 | def wandb_is_imported(): 26 | return 'wandb' in sys.modules 27 | 28 | 29 | def wandb_is_initialized(): 30 | return wandb.run is not None 31 | -------------------------------------------------------------------------------- /dmlcloud/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4' 2 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'dmlcloud' 10 | copyright = '2024, Sebastian Hoffmann' 11 | author = 'Sebastian Hoffmann' 12 | release = 'v0.3.3' 13 | 14 | # -- General configuration --------------------------------------------------- 15 | extensions = [ 16 | 'sphinx.ext.autodoc', 17 | 'sphinx.ext.napoleon', 18 | 'sphinx_autodoc_typehints', 19 | 'sphinx.ext.duration', 20 | 'sphinx.ext.autosummary', 21 | 'sphinx.ext.coverage', 22 | 'sphinx.ext.intersphinx', 23 | ] 24 | 25 | templates_path = ['_templates'] 26 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 27 | 28 | 29 | # -- autodoc ----------------------------------------------------------------- 30 | # autodoc_typehints = "description" 31 | 32 | # -- Napoleon ---------------------------------------------------------------- 33 | # napoleon_use_param = False 34 | # napoleon_use_rtype = False 35 | # napoleon_preprocess_types = True 36 | 37 | # -- External documentation (intersphinx) ------------------------------------ 38 | intersphinx_mapping = { 39 | 'python': ('https://docs.python.org/3', None), 40 | 'torch': ('https://pytorch.org/docs/stable/', None), 41 | } 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | html_theme = 'sphinx_rtd_theme' 46 | html_static_path = ['_static'] 47 | html_logo = '../misc/logo/dmlcloud_light.png' 48 | -------------------------------------------------------------------------------- /doc/dmlcloud.data.rst: -------------------------------------------------------------------------------- 1 | dmlcloud.data 2 | ============= 3 | 4 | .. automodule:: dmlcloud.data 5 | -------------------------------------------------------------------------------- /doc/dmlcloud.git.rst: -------------------------------------------------------------------------------- 1 | dmlcloud.git 2 | ============ 3 | 4 | .. automodule:: dmlcloud.git 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | git_diff 12 | git_hash 13 | is_setuptools_cli_script 14 | script_path 15 | script_directory 16 | project_directory 17 | run_in_project 18 | -------------------------------------------------------------------------------- /doc/dmlcloud.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: dmlcloud 2 | 3 | dmlcloud 4 | ======== 5 | 6 | This the API reference for the dmlcloud package. 7 | 8 | .. autosummary:: 9 | :toctree: generated 10 | 11 | Pipeline 12 | Stage 13 | current_pipe 14 | current_stage 15 | log_metric 16 | 17 | 18 | torch.distributed Helpers 19 | ------------------------- 20 | dmlcloud provides a set of helper functions to simplify the use of torch.distributed. 21 | 22 | .. autosummary:: 23 | :toctree: generated 24 | 25 | init 26 | seed 27 | deinitialize_torch_distributed 28 | 29 | is_root 30 | root_only 31 | root_first 32 | 33 | rank 34 | world_size 35 | local_rank 36 | local_world_size 37 | local_node 38 | 39 | all_gather_object 40 | gather_object 41 | broadcast_object 42 | 43 | has_slurm 44 | has_environment 45 | has_mpi 46 | 47 | 48 | Logging 49 | ------- 50 | dmlcloud provides a set of logging utilities to simplify logging in a distributed environment. 51 | In particular, it lazily setups a logger ('dmlcloud') that only logs on the root process. 52 | Users are encouraged to use the provided log functions instead of print statements to prevent duplicated logs. 53 | 54 | .. autosummary:: 55 | :toctree: generated 56 | 57 | logger 58 | log 59 | debug 60 | info 61 | warning 62 | error 63 | critical 64 | print_worker 65 | print_root 66 | setup_logger 67 | reset_logger 68 | 69 | 70 | 71 | 72 | Metric Tracking 73 | --------------- 74 | .. autosummary:: 75 | :toctree: generated 76 | 77 | TrainingHistory 78 | Tracker 79 | 80 | 81 | Model Creation 82 | -------------- 83 | .. autosummary:: 84 | :toctree: generated 85 | 86 | scale_lr 87 | wrap_ddp 88 | count_parameters 89 | 90 | 91 | Config Helpers 92 | --------------- 93 | These functions can be used to create objects from configuration files. 94 | 95 | .. autosummary:: 96 | :toctree: generated 97 | 98 | import_object 99 | factory_from_cfg 100 | obj_from_cfg 101 | -------------------------------------------------------------------------------- /doc/dmlcloud.slurm.rst: -------------------------------------------------------------------------------- 1 | dmlcloud.slurm 2 | ============== 3 | 4 | .. automodule:: dmlcloud.slurm 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | slurm_available 12 | slurm_job_id 13 | slurm_step_id 14 | -------------------------------------------------------------------------------- /doc/getting_started/index.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | Installation 5 | ------------ 6 | Dmlcloud can be installed directly from PyPI:: 7 | 8 | pip install dmlcloud 9 | 10 | The latest development version can be installed from GitHub:: 11 | 12 | pip install git+https://github.com/sehoffmann/dmlcloud.git 13 | 14 | 15 | Optional Dependencies 16 | --------------------- 17 | * wandb 18 | * xarray 19 | * mpi4py 20 | 21 | .. toctree:: 22 | :maxdepth: 2 23 | :hidden: 24 | 25 | mnist 26 | -------------------------------------------------------------------------------- /doc/getting_started/mnist.rst: -------------------------------------------------------------------------------- 1 | Training on MNIST 2 | ================= 3 | 4 | A simple example on MNIST. 5 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | dmlcloud documentation 2 | ====================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | Getting Started 8 | User Guide 9 | 10 | 11 | .. toctree:: 12 | :glob: 13 | :maxdepth: 2 14 | :caption: Python API 15 | 16 | dmlcloud 17 | dmlcloud.data 18 | dmlcloud.git 19 | dmlcloud.slurm 20 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | -r ../ci_requirements.txt 2 | -e ./ 3 | -r ../requirements.txt 4 | -------------------------------------------------------------------------------- /doc/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | User Guide 2 | ========== 3 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains multiple examples that demonstrate the usage of dmlcloud and its features. 4 | The `mnist` example is a good starting point for beginners. It demonstrates how to train a simple neural network 5 | on the MNIST dataset using dmlcloud. 6 | 7 | | Example | Description | 8 | | --- | --- | 9 | | [mnist.py](mnist.py) | Minimal example that demonstrates how to train a simple neural network on the MNIST dataset using dmlcloud. | 10 | | [custom_epochs.py](custom_epochs.py) | Demonstrates how to fully control when "epochs" start and end, e.g. for reinforcement learning or LLM training. | 11 | -------------------------------------------------------------------------------- /examples/custom_epochs.py: -------------------------------------------------------------------------------- 1 | import dmlcloud as dml 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets, transforms 6 | 7 | 8 | class CustomEpochStage(dml.Stage): 9 | def pre_stage(self): 10 | with dml.root_first(): 11 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 12 | train_dataset = datasets.MNIST(root='data', train=True, download=dml.is_root(), transform=transform) 13 | 14 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 15 | self.train_loader = DataLoader(train_dataset, batch_size=32, sampler=self.train_sampler) 16 | 17 | model = nn.Sequential( 18 | nn.Conv2d(1, 16, 3, padding=1), 19 | nn.ReLU(), 20 | nn.MaxPool2d(2), 21 | nn.Conv2d(16, 16, 3, padding=1), 22 | nn.ReLU(), 23 | nn.MaxPool2d(2), 24 | nn.Flatten(), 25 | nn.Linear(784, 10), 26 | ) 27 | self.model = dml.wrap_ddp(model, self.device) 28 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=dml.scale_lr(1e-3)) 29 | 30 | self.loss = nn.CrossEntropyLoss() 31 | 32 | # Finally, we add columns to the table to track the loss and accuracy 33 | self.add_column('# Steps', 'misc/steps') 34 | self.add_column('# Samples', 'misc/total_samples') 35 | 36 | self.add_column('Loss', 'train/loss', color='green') 37 | 38 | def run(self): 39 | MAX_STEPS = 5000 40 | LOG_PERIOD = 250 41 | 42 | num_steps = 0 43 | total_samples = 0 44 | while num_steps < MAX_STEPS: 45 | self.train_sampler.set_epoch(self.current_epoch) 46 | 47 | for img, target in self.train_loader: 48 | img, target = img.to(self.device), target.to(self.device) 49 | 50 | self.optimizer.zero_grad() 51 | output = self.model(img) 52 | loss = self.loss(output, target) 53 | loss.backward() 54 | self.optimizer.step() 55 | 56 | self.log('train/loss', loss) 57 | self.log('misc/samples', len(img), reduction='sum') 58 | 59 | num_steps += 1 60 | if num_steps % LOG_PERIOD == 0: 61 | total_samples += self.metrics['misc/samples'].compute() 62 | self.log('misc/total_samples', total_samples) 63 | self.log('misc/steps', num_steps) 64 | if num_steps < MAX_STEPS: 65 | self.next_epoch() 66 | else: 67 | break 68 | 69 | 70 | def main(): 71 | pipe = dml.Pipeline(name='custom-epochs') 72 | pipe.append(CustomEpochStage()) 73 | pipe.enable_checkpointing('checkpoints') 74 | pipe.enable_wandb() 75 | pipe.run() 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import dmlcloud as dml 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets, transforms 8 | 9 | 10 | class MNISTStage(dml.Stage): 11 | # The pre_stage method is called before the first epoch 12 | # It's a good place to load the dataset, create the model, and set up the optimizer 13 | def pre_stage(self): 14 | # Load the MNIST dataset 15 | # The root_first context manager ensures the root process downloads the dataset before the other processes 16 | with dml.root_first(): 17 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 18 | train_dataset = datasets.MNIST(root='data', train=True, download=dml.is_root(), transform=transform) 19 | val_dataset = datasets.MNIST(root='data', train=False, download=dml.is_root(), transform=transform) 20 | 21 | # For distributed training, we need to shard our dataset across all processes 22 | # Here we use the DistributedSampler to do this 23 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 24 | self.train_loader = DataLoader(train_dataset, batch_size=32, sampler=self.train_sampler) 25 | 26 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) 27 | self.val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler) 28 | 29 | # We create our model regularly... 30 | model = nn.Sequential( 31 | nn.Conv2d(1, 16, 3, padding=1), 32 | nn.ReLU(), 33 | nn.MaxPool2d(2), 34 | nn.Conv2d(16, 16, 3, padding=1), 35 | nn.ReLU(), 36 | nn.MaxPool2d(2), 37 | nn.Flatten(), 38 | nn.Linear(784, 10), 39 | ) 40 | 41 | # ...and then wrap it with dml.wrap_ddp to enable distributed training 42 | self.model = dml.wrap_ddp(model, self.device) 43 | 44 | # It's also important to scale the learning rate based on the number of GPUs, dml.scale_lr does this for us 45 | # Otherwise, we wouldn't profit from the increased batch size 46 | # In practice, you would likely want to combine this with a linear lr rampup during the very first steps as well 47 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=dml.scale_lr(1e-3)) 48 | 49 | self.loss = nn.CrossEntropyLoss() 50 | 51 | # Finally, we add columns to the table to track the loss and accuracy 52 | self.add_column('[Train] Loss', 'train/loss', color='green') 53 | self.add_column('[Train] Acc.', 'train/accuracy', formatter=lambda acc: f'{100 * acc:.2f}%', color='green') 54 | self.add_column('[Val] Loss', 'val/loss', color='cyan') 55 | self.add_column('[Val] Acc.', 'val/accuracy', formatter=lambda acc: f'{100 * acc:.2f}%', color='cyan') 56 | 57 | # The run_epoch method is called once per epoch 58 | def run_epoch(self): 59 | self._train_epoch() 60 | self._val_epoch() 61 | 62 | def _train_epoch(self): 63 | self.model.train() 64 | self.metric_prefix = 'train' # This is used to prefix the metrics in the table 65 | self.train_sampler.set_epoch(self.current_epoch) 66 | 67 | for img, target in self.train_loader: 68 | img, target = img.to(self.device), target.to(self.device) 69 | 70 | self.optimizer.zero_grad() 71 | output = self.model(img) 72 | loss = self.loss(output, target) 73 | loss.backward() 74 | self.optimizer.step() 75 | 76 | self.log('loss', loss) 77 | self.log('accuracy', (output.argmax(1) == target).float().mean()) 78 | 79 | self.finish_step() # optional, but useful to get step-wise metrics 80 | 81 | @torch.no_grad() 82 | def _val_epoch(self): 83 | self.model.eval() 84 | self.metric_prefix = 'val' 85 | 86 | for img, target in self.val_loader: 87 | img, target = img.to(self.device), target.to(self.device) 88 | 89 | output = self.model(img) 90 | loss = self.loss(output, target) 91 | 92 | self.log('loss', loss) 93 | self.log('accuracy', (output.argmax(1) == target).float().mean()) 94 | 95 | 96 | def main(): 97 | dml.init() 98 | 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--epochs', type=int, default=4) 101 | parser.add_argument('--seed', type=int) 102 | args = parser.parse_args() 103 | 104 | seed = dml.seed(args.seed) # This is a helper function to set the seed for all devices 105 | config = { 106 | 'seed': seed, 107 | 'epochs': args.epochs, 108 | } 109 | 110 | pipe = dml.Pipeline(config, name='MNIST') 111 | pipe.append(MNISTStage(epochs=args.epochs)) 112 | pipe.enable_checkpointing('checkpoints') 113 | pipe.enable_wandb() 114 | pipe.run() 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /misc/logo/dmlcloud_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sehoffmann/dmlcloud/9aba8f3c62e3ca52852b7d5334902e52430677ea/misc/logo/dmlcloud_color.png -------------------------------------------------------------------------------- /misc/logo/dmlcloud_color.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 35 | 37 | 43 | 49 | 50 | 55 | dmlcloud 68 | 69 | 74 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /misc/logo/dmlcloud_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sehoffmann/dmlcloud/9aba8f3c62e3ca52852b7d5334902e52430677ea/misc/logo/dmlcloud_dark.png -------------------------------------------------------------------------------- /misc/logo/dmlcloud_dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 35 | 37 | 43 | 49 | 50 | 55 | dmlcloud 68 | 69 | 74 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /misc/logo/dmlcloud_dark2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sehoffmann/dmlcloud/9aba8f3c62e3ca52852b7d5334902e52430677ea/misc/logo/dmlcloud_dark2.png -------------------------------------------------------------------------------- /misc/logo/dmlcloud_dark2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 35 | 37 | 43 | 49 | 50 | 55 | dmlcloud 68 | 69 | 74 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /misc/logo/dmlcloud_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sehoffmann/dmlcloud/9aba8f3c62e3ca52852b7d5334902e52430677ea/misc/logo/dmlcloud_light.png -------------------------------------------------------------------------------- /misc/logo/dmlcloud_light.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 12 | 14 | 20 | 26 | 27 | 30 | dmlcloud 42 | 43 | 46 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "packaging", "wheel", "build", "pre-commit", "pytest"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dmlcloud" 7 | authors = [ 8 | {name = "Sebastian Hoffmann"} 9 | ] 10 | description = "Distributed torch training using horovod and slurm" 11 | requires-python = ">=3.10" 12 | license = {file = "LICENSE"} 13 | keywords = ["pytorch", "torch.distributed", "slurm", "distributed training", "deep learning"] 14 | classifiers = [ 15 | "Development Status :: 3 - Alpha", 16 | "License :: OSI Approved :: BSD License", 17 | "Operating System :: MacOS", 18 | "Operating System :: POSIX :: Linux", 19 | "Programming Language :: Python :: 3.10", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | ] 22 | dynamic = ["version", "readme", "dependencies"] 23 | 24 | [project.urls] 25 | Repository = "https://github.com/sehoffmann/dmlcloud" 26 | 27 | [project.scripts] 28 | dmlrun = "dmlcloud.run:main" 29 | 30 | [tool.setuptools.packages.find] 31 | include = ["dmlcloud*"] 32 | namespaces = false 33 | 34 | [tool.setuptools.dynamic] 35 | version = {attr = "dmlcloud.__version__"} 36 | readme = {file = ["README.md"], content-type = "text/markdown"} 37 | dependencies = {file = ["requirements.txt"]} 38 | 39 | [tool.black] 40 | skip-string-normalization = true 41 | line-length = 120 42 | target-version = ["py310"] 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | xarray 4 | progress_table>=2.2.0 5 | omegaconf 6 | torchmetrics 7 | nvidia-ml-py 8 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | import pytest 4 | import torch 5 | import torch.distributed 6 | from dmlcloud.core.distributed import init 7 | 8 | 9 | @pytest.fixture 10 | def torch_distributed(): 11 | init(kind='dummy') 12 | yield 13 | torch.distributed.destroy_process_group() 14 | 15 | 16 | class DistributedEnvironment: 17 | @staticmethod 18 | def bind(tmpdir): 19 | def init(world_size, timeout=5 * 60, daemon=True): 20 | return DistributedEnvironment(world_size, timeout, daemon, str(tmpdir / 'filestore')) 21 | 22 | return init 23 | 24 | @staticmethod # important to be staticmethod, otherwise pickle will fail 25 | def _run(rank, world_size, file, conn, func, *args, **kwargs): 26 | store = torch.distributed.FileStore(file, world_size) 27 | torch.distributed.init_process_group(backend='gloo', world_size=world_size, rank=rank, store=store) 28 | 29 | torch.distributed.barrier() 30 | ret = func(*args, **kwargs) # TODO: need to handle exceptions 31 | torch.distributed.barrier() 32 | 33 | conn.send(ret) 34 | 35 | torch.distributed.destroy_process_group() 36 | 37 | def __init__(self, world_size: int, timeout: int = 5 * 60, daemon: bool = True, file: str = None): 38 | self.world_size = world_size 39 | self.timeout = timeout 40 | self.daemon = daemon 41 | self.file = str(file) 42 | 43 | def start(self, func, *args, **kwargs): 44 | ctx = mp.get_context('spawn') 45 | 46 | self.processes = [] 47 | self.conns = [] 48 | for rank in range(self.world_size): 49 | recv_conn, send_conn = ctx.Pipe() 50 | process_args = (rank, self.world_size, self.file, send_conn, func) + args 51 | process_kwargs = dict(kwargs) 52 | process = ctx.Process( 53 | target=DistributedEnvironment._run, args=process_args, kwargs=process_kwargs, daemon=self.daemon 54 | ) 55 | self.conns.append(recv_conn) 56 | self.processes.append(process) 57 | 58 | for process in self.processes: 59 | process.start() 60 | 61 | return_values = [] 62 | for process, conn in zip(self.processes, self.conns): # TODO: should probably be a context manager 63 | ret = conn.recv() 64 | return_values.append(ret) 65 | process.join(self.timeout) 66 | 67 | return return_values 68 | 69 | 70 | @pytest.fixture 71 | def distributed_environment(tmp_path): 72 | return DistributedEnvironment.bind(tmp_path) 73 | -------------------------------------------------------------------------------- /test/test_callback.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | import dmlcloud as dml 5 | import pytest 6 | from dmlcloud.core.callbacks import CallbackList 7 | 8 | 9 | class DummyCallback(dml.Callback): 10 | def __init__(self, idx): 11 | super().__init__() 12 | self.idx = idx 13 | 14 | self.t_pre_run = [] 15 | self.t_post_run = [] 16 | self.t_pre_stage = [] 17 | self.t_post_stage = [] 18 | self.t_cleanup = [] 19 | self.t_pre_epoch = [] 20 | self.t_post_epoch = [] 21 | 22 | def pre_run(self, pipe): 23 | self.t_pre_run.append(time.monotonic_ns()) 24 | 25 | def post_run(self, pipe): 26 | self.t_post_run.append(time.monotonic_ns()) 27 | 28 | def pre_stage(self, stage): 29 | self.t_pre_stage.append(time.monotonic_ns()) 30 | 31 | def post_stage(self, stage): 32 | self.t_post_stage.append(time.monotonic_ns()) 33 | 34 | def cleanup(self, pipe, exc_type, exc_value, traceback): 35 | self.t_cleanup.append(time.monotonic_ns()) 36 | 37 | def pre_epoch(self, stage): 38 | self.t_pre_epoch.append(time.monotonic_ns()) 39 | 40 | def post_epoch(self, stage): 41 | self.t_post_epoch.append(time.monotonic_ns()) 42 | 43 | 44 | class DummyStage(dml.Stage): 45 | def __init__(self, name, epochs): 46 | super().__init__(name, epochs) 47 | self.t_pre_stage = [] 48 | self.t_post_stage = [] 49 | self.t_pre_epoch = [] 50 | self.t_post_epoch = [] 51 | 52 | def pre_stage(self): 53 | self.t_pre_stage.append(time.monotonic_ns()) 54 | 55 | def post_stage(self): 56 | self.t_post_stage.append(time.monotonic_ns()) 57 | 58 | def pre_epoch(self): 59 | self.t_pre_epoch.append(time.monotonic_ns()) 60 | 61 | def post_epoch(self): 62 | self.t_post_epoch.append(time.monotonic_ns()) 63 | 64 | def run_epoch(self): 65 | pass 66 | 67 | 68 | class TestCallbackList: 69 | def test_priorities(self): 70 | cb_list = CallbackList() 71 | cb_list.append(DummyCallback(0), 100) 72 | cb_list.append(DummyCallback(1), 50) 73 | cb_list.append(DummyCallback(2), 200) 74 | cb_list.append(DummyCallback(3), -100) 75 | cb_list.append(DummyCallback(4), 100) 76 | 77 | indices = [cb.idx for cb in cb_list] 78 | assert indices == [3, 1, 0, 4, 2] 79 | 80 | def test_combining(self): 81 | cb_list1 = CallbackList() 82 | cb_list1.append(DummyCallback(0), 100) 83 | cb_list1.append(DummyCallback(1), 50) 84 | 85 | cb_list2 = CallbackList() 86 | cb_list2.append(DummyCallback(2), 200) 87 | cb_list2.append(DummyCallback(3), -100) 88 | cb_list2.append(DummyCallback(4), 50) 89 | 90 | combined1 = cb_list1 + cb_list2 91 | indices = [cb.idx for cb in combined1] 92 | assert indices == [3, 1, 4, 0, 2] 93 | 94 | # Order for same-priority depends on the order of the operands 95 | combined2 = cb_list2 + cb_list1 96 | indices = [cb.idx for cb in combined2] 97 | assert indices == [3, 4, 1, 0, 2] 98 | 99 | def test_len(self): 100 | cb_list = CallbackList() 101 | assert len(cb_list) == 0 102 | 103 | cb_list.append(DummyCallback(0), 100) 104 | assert len(cb_list) == 1 105 | 106 | cb_list.append(DummyCallback(1), 50) 107 | assert len(cb_list) == 2 108 | 109 | cb_list.append(DummyCallback(2), 200) 110 | assert len(cb_list) == 3 111 | 112 | 113 | class TestCallback: 114 | def test_stage_methods(self, torch_distributed): 115 | pipe = dml.Pipeline() 116 | stage1 = DummyStage('stage1', 2) 117 | pipe.append(stage1) 118 | pipe.run() 119 | 120 | assert len(stage1.t_pre_stage) == 1 121 | assert len(stage1.t_post_stage) == 1 122 | assert len(stage1.t_pre_epoch) == 2 123 | assert len(stage1.t_post_epoch) == 2 124 | 125 | assert stage1.t_pre_stage[0] <= stage1.t_pre_epoch[0] 126 | assert stage1.t_pre_epoch[0] <= stage1.t_post_epoch[0] 127 | assert stage1.t_post_epoch[0] <= stage1.t_pre_epoch[1] 128 | assert stage1.t_pre_epoch[1] <= stage1.t_post_epoch[1] 129 | assert stage1.t_post_epoch[1] <= stage1.t_post_stage[0] 130 | 131 | def test_stage_callback(self, torch_distributed): 132 | pipe = dml.Pipeline() 133 | stage1 = DummyStage('stage1', 1) 134 | stage2 = DummyStage('stage2', 1) 135 | cb = DummyCallback(0) 136 | 137 | pipe.append(stage1) 138 | pipe.append(stage2) 139 | 140 | stage1.add_callback(cb) 141 | 142 | pipe.run() 143 | 144 | assert len(cb.t_pre_stage) == 1 145 | assert len(cb.t_post_stage) == 1 146 | assert len(cb.t_pre_epoch) == 1 147 | assert len(cb.t_post_epoch) == 1 148 | assert len(cb.t_pre_run) == 0 149 | assert len(cb.t_post_run) == 0 150 | 151 | assert stage1.t_pre_stage[0] <= cb.t_pre_stage[0] 152 | assert stage1.t_post_stage[0] <= cb.t_post_stage[0] 153 | 154 | def test_stage_callback_priority(self, torch_distributed): 155 | pipe = dml.Pipeline() 156 | stage1 = DummyStage('stage1', 1) 157 | stage2 = DummyStage('stage2', 1) 158 | cb = DummyCallback(0) 159 | 160 | pipe.append(stage1) 161 | pipe.append(stage2) 162 | 163 | stage1.add_callback(cb, priority=-1) 164 | 165 | pipe.run() 166 | 167 | assert len(cb.t_pre_stage) == 1 168 | assert len(cb.t_post_stage) == 1 169 | assert len(cb.t_pre_epoch) == 1 170 | assert len(cb.t_post_epoch) == 1 171 | assert len(cb.t_pre_run) == 0 172 | assert len(cb.t_post_run) == 0 173 | 174 | assert cb.t_pre_stage[0] <= stage1.t_pre_stage[0] 175 | assert cb.t_post_stage[0] <= stage1.t_post_stage[0] 176 | 177 | def test_pipeline_callback(self, torch_distributed): 178 | pipe = dml.Pipeline() 179 | stage1 = DummyStage('stage1', 1) 180 | stage2 = DummyStage('stage2', 1) 181 | cb = DummyCallback(0) 182 | 183 | pipe.append(stage1) 184 | pipe.append(stage2) 185 | pipe.add_callback(cb) 186 | 187 | pipe.run() 188 | 189 | assert len(cb.t_pre_run) == 1 190 | assert len(cb.t_post_run) == 1 191 | assert len(cb.t_cleanup) == 1 192 | assert len(cb.t_pre_stage) == 2 193 | assert len(cb.t_post_stage) == 2 194 | assert len(cb.t_pre_epoch) == 2 195 | assert len(cb.t_post_epoch) == 2 196 | 197 | assert cb.t_pre_run[0] <= cb.t_pre_stage[0] 198 | assert cb.t_post_stage[0] <= cb.t_post_run[0] 199 | assert cb.t_post_run[0] <= cb.t_cleanup[0] 200 | 201 | 202 | if __name__ == '__main__': 203 | sys.exit(pytest.main([__file__])) 204 | -------------------------------------------------------------------------------- /test/test_config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from datetime import date 3 | 4 | import dmlcloud as dml 5 | import pytest 6 | 7 | 8 | class TestConfig: 9 | def test_import_object(self): 10 | obj = dml.import_object('datetime.date') 11 | assert obj is date 12 | 13 | def test_factory_from_cfg(self): 14 | factory = dml.factory_from_cfg('datetime.date') 15 | assert factory(2025, 1, day=1) == date(2025, 1, 1) 16 | 17 | factory2 = dml.factory_from_cfg('datetime.date', 2025, 1, 1) 18 | assert factory2() == date(2025, 1, 1) 19 | 20 | factory3 = dml.factory_from_cfg('datetime.date', 2025, day=31) 21 | assert factory3(12) == date(2025, 12, 31) 22 | assert factory3(month=12) == date(2025, 12, 31) 23 | 24 | def test_factory_from_cfg_mapping(self): 25 | config = {'factory': 'datetime.date', 'year': 2025, 'month': 1, 'day': 1} 26 | factory = dml.factory_from_cfg(config) 27 | assert factory() == date(2025, 1, 1) 28 | 29 | config = {'factory': 'datetime.date', 'month': 1, 'day': 1} 30 | factory = dml.factory_from_cfg(config) 31 | assert factory(1990) == date(1990, 1, 1) 32 | 33 | def test_obj_from_cfg(self): 34 | assert dml.obj_from_cfg('datetime.date', 2025, 1, 1) == date(2025, 1, 1) 35 | assert dml.obj_from_cfg('datetime.date', year=2025, month=1, day=1) == date(2025, 1, 1) 36 | assert dml.obj_from_cfg('datetime.date', 2025, month=1, day=1) == date(2025, 1, 1) 37 | 38 | def test_obj_from_cfg_mapping(self): 39 | config = {'factory': 'datetime.date', 'year': 2025, 'month': 1, 'day': 1} 40 | assert dml.obj_from_cfg(config) == date(2025, 1, 1) 41 | 42 | config = {'factory': 'datetime.date', 'month': 1, 'day': 1} 43 | assert dml.obj_from_cfg(config, 1990) == date(1990, 1, 1) 44 | 45 | 46 | if __name__ == '__main__': 47 | sys.exit(pytest.main([__file__])) 48 | -------------------------------------------------------------------------------- /test/test_csv.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | 4 | import dmlcloud as dml 5 | import pytest 6 | from dmlcloud.core.callbacks import CsvCallback 7 | 8 | 9 | class DummyStage(dml.Stage): 10 | def run_epoch(self): 11 | self.log('train/loss', 10 - self.current_epoch) 12 | self.log('train/acc', 90 + self.current_epoch) 13 | 14 | 15 | class TestCsvCallback: 16 | def test_basic_metrics(self, torch_distributed, tmp_path): 17 | metrics_file = tmp_path / 'epoch_metrics_DummyStage.csv' 18 | 19 | pipe = dml.Pipeline() 20 | pipe.append(DummyStage(epochs=5)) 21 | pipe.add_callback(CsvCallback(tmp_path)) 22 | pipe.run() 23 | 24 | assert metrics_file.exists() 25 | 26 | with open(metrics_file) as f: 27 | reader = csv.reader(f) 28 | rows = list(reader) 29 | 30 | assert len(rows) == 6 31 | assert rows[0][:3] == ['epoch', 'train/loss', 'train/acc'] 32 | assert rows[1][:3] == ['0', '10.0', '90.0'] 33 | assert rows[2][:3] == ['1', '9.0', '91.0'] 34 | assert rows[3][:3] == ['2', '8.0', '92.0'] 35 | assert rows[4][:3] == ['3', '7.0', '93.0'] 36 | assert rows[5][:3] == ['4', '6.0', '94.0'] 37 | 38 | # misc metrics 39 | assert 'misc/epoch_time' in rows[0] 40 | assert 'misc/total_time' in rows[0] 41 | assert 'misc/eta' in rows[0] 42 | 43 | def test_stage_name(self, torch_distributed, tmp_path): 44 | pipe = dml.Pipeline() 45 | pipe.append(DummyStage(epochs=5)) 46 | pipe.add_callback(CsvCallback(tmp_path)) 47 | pipe.run() 48 | 49 | assert (tmp_path / 'epoch_metrics_DummyStage.csv').exists() 50 | 51 | def test_duplicate_names(self, torch_distributed, tmp_path): 52 | pipe = dml.Pipeline() 53 | pipe.append(DummyStage(epochs=5)) 54 | pipe.append(DummyStage(epochs=5)) 55 | pipe.add_callback(CsvCallback(tmp_path)) 56 | pipe.run() 57 | 58 | assert (tmp_path / 'epoch_metrics_DummyStage_1.csv').exists() 59 | assert (tmp_path / 'epoch_metrics_DummyStage_2.csv').exists() 60 | 61 | 62 | if __name__ == '__main__': 63 | sys.exit(pytest.main([__file__])) 64 | -------------------------------------------------------------------------------- /test/test_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from functools import partial 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | import xarray as xr 8 | from dmlcloud.data import interleave_batches, shard_indices, sharded_xr_dataset, ShardedXrDataset 9 | from numpy.testing import assert_array_equal 10 | from torch.utils.data import DataLoader, IterableDataset 11 | 12 | 13 | class _Unzip(IterableDataset): 14 | def __init__(self, ds): 15 | self.ds = ds 16 | 17 | def __iter__(self): 18 | for chunk in self.ds: 19 | arr = chunk.to_array().values[0] 20 | yield from arr 21 | 22 | 23 | class TestSharding: 24 | def test_types(self): 25 | indices = shard_indices(10, 0, 2, shuffle=False, even_shards=False) 26 | assert isinstance(indices, list) 27 | assert all(isinstance(i, int) for i in indices) 28 | 29 | def test_even(self): 30 | assert shard_indices(10, 0, 2, shuffle=False, even_shards=False) == [0, 2, 4, 6, 8] 31 | assert shard_indices(10, 1, 2, shuffle=False, even_shards=False) == [1, 3, 5, 7, 9] 32 | 33 | def test_uneven(self): 34 | assert shard_indices(10, 0, 3, shuffle=False, even_shards=False) == [0, 3, 6, 9] 35 | assert shard_indices(10, 1, 3, shuffle=False, even_shards=False) == [1, 4, 7] 36 | assert shard_indices(10, 2, 3, shuffle=False, even_shards=False) == [2, 5, 8] 37 | 38 | assert shard_indices(11, 0, 2, shuffle=False, even_shards=False) == [0, 2, 4, 6, 8, 10] 39 | assert shard_indices(11, 1, 2, shuffle=False, even_shards=False) == [1, 3, 5, 7, 9] 40 | 41 | def test_dropping(self): 42 | assert shard_indices(10, 0, 3, shuffle=False, even_shards=True) == [0, 3, 6] 43 | assert shard_indices(10, 1, 3, shuffle=False, even_shards=True) == [1, 4, 7] 44 | assert shard_indices(10, 2, 3, shuffle=False, even_shards=True) == [2, 5, 8] 45 | 46 | assert shard_indices(11, 0, 2, shuffle=False, even_shards=True) == [0, 2, 4, 6, 8] 47 | assert shard_indices(11, 1, 2, shuffle=False, even_shards=True) == [1, 3, 5, 7, 9] 48 | 49 | def test_shuffling(self): 50 | indices = shard_indices(10, 0, 2, shuffle=True, even_shards=False, seed=0) 51 | assert len(indices) == 5 52 | assert len(np.unique(indices)) == 5 53 | assert indices != list(sorted(indices)) 54 | assert (np.array(indices) >= 0).all() and (np.array(indices) <= 9).all() 55 | 56 | 57 | class TestShardedXr: 58 | def test_basic(self): 59 | ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() 60 | world_size = 3 61 | chunk_size = 15 62 | 63 | shard = partial(sharded_xr_dataset, ds, 'x', chunk_size, world_size=world_size, shuffle=False) 64 | chunks_1 = list(shard(rank=0)) 65 | chunks_2 = list(shard(rank=1)) 66 | chunks_3 = list(shard(rank=2)) 67 | 68 | assert len(chunks_1) == 2 69 | assert len(chunks_2) == 2 70 | assert len(chunks_3) == 2 71 | 72 | assert isinstance(chunks_1[0], xr.Dataset) 73 | 74 | assert chunks_1[0].x.size == 15 75 | assert chunks_1[1].x.size == 15 76 | assert chunks_2[0].x.size == 15 77 | assert chunks_2[1].x.size == 15 78 | assert chunks_3[0].x.size == 15 79 | assert chunks_3[1].x.size == 15 80 | 81 | assert_array_equal(chunks_1[0]['var'], np.arange(0, 15)) 82 | assert_array_equal(chunks_2[0]['var'], np.arange(15, 30)) 83 | assert_array_equal(chunks_3[0]['var'], np.arange(30, 45)) 84 | assert_array_equal(chunks_1[1]['var'], np.arange(45, 60)) 85 | assert_array_equal(chunks_2[1]['var'], np.arange(60, 75)) 86 | assert_array_equal(chunks_3[1]['var'], np.arange(75, 90)) 87 | 88 | def test_uneven(self): 89 | ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() 90 | world_size = 3 91 | chunk_size = 20 92 | 93 | shard = partial( 94 | sharded_xr_dataset, ds, 'x', chunk_size, even_shards=False, world_size=world_size, shuffle=False 95 | ) 96 | chunks_1 = list(shard(rank=0)) 97 | chunks_2 = list(shard(rank=1)) 98 | chunks_3 = list(shard(rank=2)) 99 | 100 | assert len(chunks_1) == 2 101 | assert len(chunks_2) == 2 102 | assert len(chunks_3) == 1 103 | 104 | assert isinstance(chunks_1[0], xr.Dataset) 105 | 106 | assert chunks_1[0].x.size == 20 107 | assert chunks_1[1].x.size == 20 108 | assert chunks_2[0].x.size == 20 109 | assert chunks_2[1].x.size == 20 110 | assert chunks_3[0].x.size == 20 111 | 112 | assert_array_equal(chunks_1[0]['var'], np.arange(0, 20)) 113 | assert_array_equal(chunks_2[0]['var'], np.arange(20, 40)) 114 | assert_array_equal(chunks_3[0]['var'], np.arange(40, 60)) 115 | assert_array_equal(chunks_1[1]['var'], np.arange(60, 80)) 116 | assert_array_equal(chunks_2[1]['var'], np.arange(80, 100)) 117 | 118 | def test_unequal(self): 119 | ds = xr.DataArray(np.arange(110), dims=['x'], name='var').to_dataset() 120 | world_size = 3 121 | chunk_size = 20 122 | 123 | shard = partial( 124 | sharded_xr_dataset, ds, 'x', chunk_size, equal_chunks=False, world_size=world_size, shuffle=False 125 | ) 126 | chunks_1 = list(shard(rank=0)) 127 | chunks_2 = list(shard(rank=1)) 128 | chunks_3 = list(shard(rank=2)) 129 | 130 | assert len(chunks_1) == 2 131 | assert len(chunks_2) == 2 132 | assert len(chunks_3) == 2 133 | 134 | assert isinstance(chunks_1[0], xr.Dataset) 135 | 136 | assert chunks_1[0].x.size == 20 137 | assert chunks_1[1].x.size == 20 138 | assert chunks_2[0].x.size == 20 139 | assert chunks_2[1].x.size == 20 140 | assert chunks_3[0].x.size == 20 141 | assert chunks_3[1].x.size == 10 142 | 143 | assert_array_equal(chunks_1[0]['var'], np.arange(0, 20)) 144 | assert_array_equal(chunks_2[0]['var'], np.arange(20, 40)) 145 | assert_array_equal(chunks_3[0]['var'], np.arange(40, 60)) 146 | assert_array_equal(chunks_1[1]['var'], np.arange(60, 80)) 147 | assert_array_equal(chunks_2[1]['var'], np.arange(80, 100)) 148 | assert_array_equal(chunks_3[1]['var'], np.arange(100, 110)) 149 | 150 | def test_shuffled(self): 151 | ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() 152 | world_size = 3 153 | chunk_size = 15 154 | 155 | shard = partial(sharded_xr_dataset, ds, 'x', chunk_size, world_size=world_size, shuffle=True, seed=0) 156 | chunks_1 = list(shard(rank=0)) 157 | chunks_2 = list(shard(rank=1)) 158 | chunks_3 = list(shard(rank=2)) 159 | 160 | assert len(chunks_1) == 2 161 | assert len(chunks_2) == 2 162 | assert len(chunks_3) == 2 163 | 164 | catted = xr.concat(chunks_1 + chunks_2 + chunks_3, dim='x')['var'].values 165 | assert catted.tolist() != list(range(90)) 166 | assert list(sorted(catted.tolist())) == list(range(90)) 167 | 168 | chunk = chunks_1[0]['var'].values 169 | assert chunk.tolist() == list(range(chunk[0], chunk[-1] + 1)) 170 | 171 | def test_XrShardedDataset_multiprocessing(self): 172 | xr_ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() 173 | 174 | # Simple case: 2 workers, world_size=1 175 | # Workers act as additional processes and we expect interleaved chunks 176 | torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=1, rank=0, shuffle=False) 177 | torch_ds = _Unzip(torch_ds) 178 | dataloader = DataLoader( 179 | torch_ds, 180 | num_workers=2, 181 | batch_size=1, 182 | prefetch_factor=1, 183 | ) 184 | results = list(batch.item() for batch in dataloader) 185 | assert results == [ 186 | 0, 187 | 15, 188 | 1, 189 | 16, 190 | 2, 191 | 17, 192 | 3, 193 | 18, 194 | 4, 195 | 19, 196 | 5, 197 | 20, 198 | 6, 199 | 21, 200 | 7, 201 | 22, 202 | 8, 203 | 23, 204 | 9, 205 | 24, 206 | 10, 207 | 25, 208 | 11, 209 | 26, 210 | 12, 211 | 27, 212 | 13, 213 | 28, 214 | 14, 215 | 29, 216 | 30, 217 | 45, 218 | 31, 219 | 46, 220 | 32, 221 | 47, 222 | 33, 223 | 48, 224 | 34, 225 | 49, 226 | 35, 227 | 50, 228 | 36, 229 | 51, 230 | 37, 231 | 52, 232 | 38, 233 | 53, 234 | 39, 235 | 54, 236 | 40, 237 | 55, 238 | 41, 239 | 56, 240 | 42, 241 | 57, 242 | 43, 243 | 58, 244 | 44, 245 | 59, 246 | 60, 247 | 75, 248 | 61, 249 | 76, 250 | 62, 251 | 77, 252 | 63, 253 | 78, 254 | 64, 255 | 79, 256 | 65, 257 | 80, 258 | 66, 259 | 81, 260 | 67, 261 | 82, 262 | 68, 263 | 83, 264 | 69, 265 | 84, 266 | 70, 267 | 85, 268 | 71, 269 | 86, 270 | 72, 271 | 87, 272 | 73, 273 | 88, 274 | 74, 275 | 89, 276 | ] 277 | 278 | # Advanced case: 2 workers, world_size=2 279 | # Each rank gets consecutive chunks and splits them between workers (which interleave again) 280 | # Since the effective world size is now 4, and the dataset has 6 chunks in total, we will only get 4 chunks (up to 60) 281 | torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=2, rank=0, shuffle=False) 282 | torch_ds = _Unzip(torch_ds) 283 | dataloader = DataLoader( 284 | torch_ds, 285 | num_workers=2, 286 | batch_size=1, 287 | prefetch_factor=1, 288 | ) 289 | results = list(batch.item() for batch in dataloader) 290 | assert results == [ 291 | 0, 292 | 15, 293 | 1, 294 | 16, 295 | 2, 296 | 17, 297 | 3, 298 | 18, 299 | 4, 300 | 19, 301 | 5, 302 | 20, 303 | 6, 304 | 21, 305 | 7, 306 | 22, 307 | 8, 308 | 23, 309 | 9, 310 | 24, 311 | 10, 312 | 25, 313 | 11, 314 | 26, 315 | 12, 316 | 27, 317 | 13, 318 | 28, 319 | 14, 320 | 29, 321 | ] 322 | 323 | torch_ds = ShardedXrDataset(xr_ds, chunk_size=15, dim='x', world_size=2, rank=1, shuffle=False) 324 | torch_ds = _Unzip(torch_ds) 325 | dataloader = DataLoader( 326 | torch_ds, 327 | num_workers=2, 328 | batch_size=1, 329 | prefetch_factor=1, 330 | ) 331 | results = list(batch.item() for batch in dataloader) 332 | assert results == [ 333 | 30, 334 | 45, 335 | 31, 336 | 46, 337 | 32, 338 | 47, 339 | 33, 340 | 48, 341 | 34, 342 | 49, 343 | 35, 344 | 50, 345 | 36, 346 | 51, 347 | 37, 348 | 52, 349 | 38, 350 | 53, 351 | 39, 352 | 54, 353 | 40, 354 | 55, 355 | 41, 356 | 56, 357 | 42, 358 | 57, 359 | 43, 360 | 58, 361 | 44, 362 | 59, 363 | ] 364 | 365 | def test_overlap(self): 366 | ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() 367 | world_size = 3 368 | chunk_size = 15 369 | overlap = 5 370 | 371 | shard = partial( 372 | sharded_xr_dataset, ds, 'x', chunk_size, chunk_overlap=overlap, world_size=world_size, shuffle=False 373 | ) 374 | 375 | chunks_1 = list(shard(rank=0)) 376 | chunks_2 = list(shard(rank=1)) 377 | chunks_3 = list(shard(rank=2)) 378 | 379 | assert len(chunks_1) == 2 380 | assert len(chunks_2) == 2 381 | assert len(chunks_3) == 2 382 | 383 | assert isinstance(chunks_1[0], xr.Dataset) 384 | 385 | assert chunks_1[0].x.size == 20 386 | assert chunks_1[1].x.size == 20 387 | assert chunks_2[0].x.size == 20 388 | assert chunks_2[1].x.size == 20 389 | assert chunks_3[0].x.size == 20 390 | assert chunks_3[1].x.size == 20 391 | 392 | assert_array_equal(chunks_1[0]['var'], np.arange(0, 20)) 393 | assert_array_equal(chunks_2[0]['var'], np.arange(15, 35)) 394 | assert_array_equal(chunks_3[0]['var'], np.arange(30, 50)) 395 | assert_array_equal(chunks_1[1]['var'], np.arange(45, 65)) 396 | assert_array_equal(chunks_2[1]['var'], np.arange(60, 80)) 397 | assert_array_equal(chunks_3[1]['var'], np.arange(75, 95)) 398 | 399 | def test_overlap_unequal_uneven(self): 400 | ds = xr.DataArray(np.arange(100), dims=['x'], name='var').to_dataset() 401 | world_size = 3 402 | chunk_size = 15 403 | overlap = 5 404 | 405 | shard = partial( 406 | sharded_xr_dataset, 407 | ds, 408 | 'x', 409 | chunk_size, 410 | chunk_overlap=overlap, 411 | even_shards=False, 412 | equal_chunks=False, 413 | world_size=world_size, 414 | shuffle=False, 415 | ) 416 | 417 | chunks_1 = list(shard(rank=0)) 418 | chunks_2 = list(shard(rank=1)) 419 | chunks_3 = list(shard(rank=2)) 420 | 421 | assert len(chunks_1) == 3 422 | assert len(chunks_2) == 2 423 | assert len(chunks_3) == 2 424 | 425 | assert isinstance(chunks_1[0], xr.Dataset) 426 | 427 | assert chunks_1[0].x.size == 20 428 | assert chunks_1[1].x.size == 20 429 | assert chunks_2[0].x.size == 20 430 | assert chunks_2[1].x.size == 20 431 | assert chunks_3[0].x.size == 20 432 | assert chunks_3[1].x.size == 20 433 | assert chunks_1[2].x.size == 10 434 | 435 | assert_array_equal(chunks_1[0]['var'], np.arange(0, 20)) 436 | assert_array_equal(chunks_2[0]['var'], np.arange(15, 35)) 437 | assert_array_equal(chunks_3[0]['var'], np.arange(30, 50)) 438 | assert_array_equal(chunks_1[1]['var'], np.arange(45, 65)) 439 | assert_array_equal(chunks_2[1]['var'], np.arange(60, 80)) 440 | assert_array_equal(chunks_3[1]['var'], np.arange(75, 95)) 441 | assert_array_equal(chunks_1[2]['var'], np.arange(90, 100)) 442 | 443 | 444 | class TestInterleaveBatches: 445 | def test_basic(self): 446 | batches = [ 447 | torch.arange(0, 8), 448 | torch.arange(8, 16), 449 | torch.arange(16, 24), 450 | torch.arange(24, 32), 451 | ] 452 | interleaved_batches = list(t.clone() for t in interleave_batches(batches, num_batches=2)) 453 | assert len(interleaved_batches) == 4 454 | assert {t.item() for t in interleaved_batches[0]} == {0, 1, 2, 3, 8, 9, 10, 11} 455 | assert {t.item() for t in interleaved_batches[1]} == {4, 5, 6, 7, 12, 13, 14, 15} 456 | assert {t.item() for t in interleaved_batches[2]} == {16, 17, 18, 19, 24, 25, 26, 27} 457 | assert {t.item() for t in interleaved_batches[3]} == {20, 21, 22, 23, 28, 29, 30, 31} 458 | 459 | 460 | if __name__ == '__main__': 461 | sys.exit(pytest.main([__file__])) 462 | -------------------------------------------------------------------------------- /test/test_global_accessors.py: -------------------------------------------------------------------------------- 1 | import dmlcloud as dml 2 | 3 | 4 | class DummyStage(dml.Stage): 5 | def run_epoch(self): 6 | pass 7 | 8 | 9 | class ProbingCallback(dml.Callback): 10 | def __init__(self, pipe=None, stage=None): 11 | self.pipe = pipe 12 | self.stage = stage 13 | self.pipe_test = False 14 | self.stage_test = False 15 | 16 | def pre_run(self, pipe): 17 | self.pipe_test = dml.current_pipe() is self.pipe 18 | 19 | def pre_stage(self, stage): 20 | self.stage_test = dml.current_stage() is self.stage 21 | 22 | 23 | class LogCallback(dml.Callback): 24 | def __init__(self): 25 | self.i = 0 26 | 27 | def pre_epoch(self, stage): 28 | dml.log_metric('test', self.i) 29 | self.i += 1 30 | 31 | 32 | class TestGlobalAccessors: 33 | def test_accessors(self, torch_distributed): 34 | pipe = dml.Pipeline() 35 | stage1 = DummyStage() 36 | stage2 = DummyStage() 37 | pipe.append(stage1) 38 | pipe.append(stage2) 39 | 40 | cb1 = ProbingCallback(pipe) 41 | cb2 = ProbingCallback(stage=stage1) 42 | cb3 = ProbingCallback(stage=stage2) 43 | 44 | pipe.add_callback(cb1) 45 | stage1.add_callback(cb2) 46 | stage2.add_callback(cb3) 47 | 48 | assert dml.current_pipe() is None 49 | assert dml.current_stage() is None 50 | 51 | pipe.run() 52 | assert cb1.pipe_test 53 | assert cb2.stage_test 54 | assert cb3.stage_test 55 | 56 | assert dml.current_pipe() is None 57 | assert dml.current_stage() is None 58 | 59 | def test_logging(self, torch_distributed): 60 | pipe = dml.Pipeline() 61 | stage1 = DummyStage(epochs=3) 62 | stage2 = DummyStage(epochs=1) 63 | pipe.append(stage1) 64 | pipe.append(stage2) 65 | 66 | pipe.add_callback(LogCallback()) 67 | 68 | pipe.run() 69 | 70 | assert 'test' in stage1.history 71 | assert list(stage1.history['test']) == [0, 1, 2] 72 | 73 | assert 'test' in stage2.history 74 | assert list(stage2.history['test']) == [3] 75 | -------------------------------------------------------------------------------- /test/test_import.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | 5 | 6 | class TestImport: 7 | def test_import(self): 8 | import dmlcloud # noqa: F401 9 | 10 | def test_version(self): 11 | import dmlcloud 12 | 13 | version = dmlcloud.__version__ 14 | assert isinstance(version, str) 15 | assert version > '0.0.0' 16 | 17 | 18 | if __name__ == '__main__': 19 | sys.exit(pytest.main([__file__])) 20 | -------------------------------------------------------------------------------- /test/test_io_redirector.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | from dmlcloud.util.logging import IORedirector 5 | 6 | 7 | class DummyStream: 8 | def __init__(self, stdout=True): 9 | self.data = '' 10 | self.is_stdout = stdout 11 | 12 | def write(self, data): 13 | self.data += data 14 | 15 | def flush(self): 16 | pass 17 | 18 | def __enter__(self): 19 | self._org_stream = getattr(sys, 'stdout' if self.is_stdout else 'stderr') 20 | setattr(sys, 'stdout' if self.is_stdout else 'stderr', self) 21 | return self 22 | 23 | def __exit__(self, exc_type, exc_value, traceback): 24 | setattr(sys, 'stdout' if self.is_stdout else 'stderr', self._org_stream) 25 | 26 | 27 | class TestIORedirector: 28 | def test_context_manager(self, tmp_path): 29 | org_stdout = sys.stdout 30 | org_stderr = sys.stderr 31 | 32 | with IORedirector(tmp_path / 'log.txt'): 33 | assert sys.stdout is not org_stdout 34 | assert sys.stderr is not org_stderr 35 | 36 | assert sys.stdout is org_stdout 37 | assert sys.stderr is org_stderr 38 | 39 | def test_file_creation(self, tmp_path): 40 | with IORedirector(tmp_path / 'log.txt'): 41 | pass 42 | assert (tmp_path / 'log.txt').exists() 43 | assert (tmp_path / 'log.txt').read_text() == '' 44 | 45 | def test_basic_write(self, tmp_path): 46 | with DummyStream() as out, DummyStream(stdout=False) as err: 47 | with IORedirector(tmp_path / 'log.txt'): 48 | print('Hello, world!') 49 | print('Error message', file=sys.stderr) 50 | 51 | file_content = (tmp_path / 'log.txt').read_text() 52 | assert file_content == 'Hello, world!\nError message\n' 53 | 54 | assert out.data == 'Hello, world!\n' 55 | assert err.data == 'Error message\n' 56 | 57 | assert sys.stdout is out._org_stream 58 | assert sys.stderr is err._org_stream 59 | 60 | def test_writes_after_exit(self, tmp_path): 61 | with DummyStream() as out, DummyStream(stdout=False) as err: 62 | with IORedirector(tmp_path / 'log.txt'): 63 | saved_out = sys.stdout 64 | saved_err = sys.stderr 65 | 66 | print('Test', file=saved_out) 67 | print('Error', file=saved_err) 68 | assert out.data == 'Test\n' 69 | assert err.data == 'Error\n' 70 | 71 | file_content = (tmp_path / 'log.txt').read_text() 72 | assert file_content == '' 73 | 74 | # Now we reset and replace sys.stdout and sys.stderr again, writes should go to the new streams 75 | with DummyStream() as out2, DummyStream(stdout=False) as err2: 76 | print('Test', file=saved_out) 77 | print('Error', file=saved_err) 78 | 79 | assert out.data == 'Test\n' 80 | assert err.data == 'Error\n' 81 | 82 | assert out2.data == 'Test\n' 83 | assert err2.data == 'Error\n' 84 | 85 | 86 | if __name__ == '__main__': 87 | sys.exit(pytest.main([__file__])) 88 | -------------------------------------------------------------------------------- /test/test_root_only.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import dmlcloud as dml 4 | import pytest 5 | 6 | 7 | @dml.root_only 8 | def return_root_rank(): 9 | """TEST_DOC_STRING""" 10 | return dml.rank() 11 | 12 | 13 | @dml.root_only 14 | class RootOnlyStage(dml.Stage): 15 | """TEST_DOC_STRING""" 16 | 17 | def __init__(self, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.cb_executed = { 20 | 'pre_stage': False, 21 | 'post_stage': False, 22 | 'pre_epoch': False, 23 | 'post_epoch': False, 24 | 'run_epoch': False, 25 | } 26 | 27 | def pre_stage(self): 28 | """TEST_DOC_STRING""" 29 | self.cb_executed['pre_stage'] = True 30 | 31 | def post_stage(self): 32 | """TEST_DOC_STRING""" 33 | self.cb_executed['post_stage'] = True 34 | 35 | def pre_epoch(self): 36 | """TEST_DOC_STRING""" 37 | self.cb_executed['pre_epoch'] = True 38 | 39 | def post_epoch(self): 40 | """TEST_DOC_STRING""" 41 | self.cb_executed['post_epoch'] = True 42 | 43 | def run_epoch(self): 44 | """TEST_DOC_STRING""" 45 | self.cb_executed['run_epoch'] = True 46 | 47 | 48 | class PartialRootOnlyStage(dml.Stage): 49 | def __init__(self, *args, **kwargs): 50 | super().__init__(*args, **kwargs) 51 | self.cb_executed = { 52 | 'pre_stage': False, 53 | 'post_stage': False, 54 | 'pre_epoch': False, 55 | 'post_epoch': False, 56 | 'run_epoch': False, 57 | } 58 | 59 | def pre_stage(self): 60 | self.cb_executed['pre_stage'] = True 61 | 62 | @dml.root_only 63 | def post_stage(self): 64 | """TEST_DOC_STRING""" 65 | self.cb_executed['post_stage'] = True 66 | 67 | @dml.root_only 68 | def pre_epoch(self): 69 | """TEST_DOC_STRING""" 70 | self.cb_executed['pre_epoch'] = True 71 | 72 | def post_epoch(self): 73 | self.cb_executed['post_epoch'] = True 74 | 75 | @dml.root_only 76 | def run_epoch(self): 77 | self.cb_executed['run_epoch'] = True 78 | 79 | 80 | @dml.root_only 81 | class RootOnlyPipeline(dml.Pipeline): 82 | """TEST_DOC_STRING""" 83 | 84 | def __init__(self, *args, **kwargs): 85 | super().__init__(*args, **kwargs) 86 | self.cb_executed = { 87 | 'pre_run': False, 88 | 'post_run': False, 89 | } 90 | 91 | def pre_run(self): 92 | """TEST_DOC_STRING""" 93 | self.cb_executed['pre_run'] = True 94 | 95 | def post_run(self): 96 | """TEST_DOC_STRING""" 97 | self.cb_executed['post_run'] = True 98 | 99 | 100 | class PartialRootOnlyPipeline(dml.Pipeline): 101 | def __init__(self, *args, **kwargs): 102 | super().__init__(*args, **kwargs) 103 | self.cb_executed = { 104 | 'pre_run': False, 105 | 'post_run': False, 106 | } 107 | 108 | @dml.root_only 109 | def pre_run(self): 110 | """TEST_DOC_STRING""" 111 | self.cb_executed['pre_run'] = True 112 | 113 | def post_run(self): 114 | self.cb_executed['post_run'] = True 115 | 116 | 117 | class TestRootOnly: 118 | def test_function(self, distributed_environment): 119 | ranks = distributed_environment(4).start(return_root_rank) 120 | assert ranks == [0, None, None, None] 121 | assert return_root_rank.__name__ == 'return_root_rank' 122 | assert return_root_rank.__doc__ == 'TEST_DOC_STRING' 123 | 124 | @staticmethod 125 | def _test_stage_run(): 126 | stage = RootOnlyStage(epochs=1) 127 | pipe = dml.Pipeline() 128 | pipe.append(stage) 129 | pipe.run() 130 | return stage.cb_executed 131 | 132 | def test_stage(self, distributed_environment): 133 | results = distributed_environment(3).start(TestRootOnly._test_stage_run) 134 | 135 | assert [r['pre_stage'] for r in results] == [True, False, False] 136 | assert [r['post_stage'] for r in results] == [True, False, False] 137 | assert [r['pre_epoch'] for r in results] == [True, False, False] 138 | assert [r['post_epoch'] for r in results] == [True, False, False] 139 | assert [r['run_epoch'] for r in results] == [True, False, False] 140 | 141 | assert RootOnlyStage.__name__ == 'RootOnlyStage' 142 | assert RootOnlyStage.__doc__ == 'TEST_DOC_STRING' 143 | 144 | assert RootOnlyStage.pre_stage.__name__ == 'pre_stage' 145 | assert RootOnlyStage.pre_stage.__doc__ == 'TEST_DOC_STRING' 146 | 147 | assert RootOnlyStage.post_stage.__name__ == 'post_stage' 148 | assert RootOnlyStage.post_stage.__doc__ == 'TEST_DOC_STRING' 149 | 150 | assert RootOnlyStage.pre_epoch.__name__ == 'pre_epoch' 151 | assert RootOnlyStage.pre_epoch.__doc__ == 'TEST_DOC_STRING' 152 | 153 | assert RootOnlyStage.post_epoch.__name__ == 'post_epoch' 154 | assert RootOnlyStage.post_epoch.__doc__ == 'TEST_DOC_STRING' 155 | 156 | assert RootOnlyStage.run_epoch.__name__ == 'run_epoch' 157 | assert RootOnlyStage.run_epoch.__doc__ == 'TEST_DOC_STRING' 158 | 159 | @staticmethod 160 | def _test_partial_stage_run(): 161 | stage = PartialRootOnlyStage(epochs=1) 162 | pipe = dml.Pipeline() 163 | pipe.append(stage) 164 | pipe.run() 165 | return stage.cb_executed 166 | 167 | def test_partial_stage(self, distributed_environment): 168 | results = distributed_environment(3).start(TestRootOnly._test_partial_stage_run) 169 | 170 | assert [r['pre_stage'] for r in results] == [True, True, True] 171 | assert [r['post_stage'] for r in results] == [True, False, False] 172 | assert [r['pre_epoch'] for r in results] == [True, False, False] 173 | assert [r['post_epoch'] for r in results] == [True, True, True] 174 | assert [r['run_epoch'] for r in results] == [True, False, False] 175 | 176 | assert PartialRootOnlyStage.post_stage.__name__ == 'post_stage' 177 | assert PartialRootOnlyStage.post_stage.__doc__ == 'TEST_DOC_STRING' 178 | 179 | assert PartialRootOnlyStage.pre_epoch.__name__ == 'pre_epoch' 180 | assert PartialRootOnlyStage.pre_epoch.__doc__ == 'TEST_DOC_STRING' 181 | 182 | @staticmethod 183 | def _test_pipeline_run(): 184 | pipe = RootOnlyPipeline() 185 | pipe.append(RootOnlyStage(epochs=1)) 186 | pipe.run() 187 | return pipe.cb_executed 188 | 189 | def test_pipeline(self, distributed_environment): 190 | results = distributed_environment(3).start(TestRootOnly._test_pipeline_run) 191 | 192 | assert [r['pre_run'] for r in results] == [True, False, False] 193 | assert [r['post_run'] for r in results] == [True, False, False] 194 | 195 | assert RootOnlyPipeline.__name__ == 'RootOnlyPipeline' 196 | assert RootOnlyPipeline.__doc__ == 'TEST_DOC_STRING' 197 | 198 | assert RootOnlyPipeline.pre_run.__name__ == 'pre_run' 199 | assert RootOnlyPipeline.pre_run.__doc__ == 'TEST_DOC_STRING' 200 | 201 | assert RootOnlyPipeline.post_run.__name__ == 'post_run' 202 | assert RootOnlyPipeline.post_run.__doc__ == 'TEST_DOC_STRING' 203 | 204 | @staticmethod 205 | def _test_partial_pipeline_run(): 206 | pipe = PartialRootOnlyPipeline() 207 | pipe.append(RootOnlyStage(epochs=1)) 208 | pipe.run() 209 | return pipe.cb_executed 210 | 211 | def test_partial_pipeline(self, distributed_environment): 212 | results = distributed_environment(3).start(TestRootOnly._test_partial_pipeline_run) 213 | 214 | assert [r['pre_run'] for r in results] == [True, False, False] 215 | assert [r['post_run'] for r in results] == [True, True, True] 216 | 217 | assert PartialRootOnlyPipeline.pre_run.__name__ == 'pre_run' 218 | assert PartialRootOnlyPipeline.pre_run.__doc__ == 'TEST_DOC_STRING' 219 | 220 | 221 | if __name__ == '__main__': 222 | sys.exit(pytest.main([__file__])) 223 | -------------------------------------------------------------------------------- /test/test_seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | 4 | import dmlcloud as dml 5 | import numpy as np 6 | import pytest 7 | import torch 8 | 9 | 10 | def seed(seed=None): 11 | seed = dml.seed(seed) 12 | state = dict( 13 | seed=seed, 14 | torch_state=np.array(torch.get_rng_state()), 15 | numpy_state=np.random.get_state()[1], 16 | random_state=np.array(random.getstate()[1]), 17 | ) 18 | return state 19 | 20 | 21 | class TestSeed: 22 | def test_single_worker_deterministic(self, torch_distributed): 23 | prev_torch_state = np.array(torch.get_rng_state()) 24 | prev_numpy_state = np.random.get_state()[1] 25 | prev_random_state = np.array(random.getstate()[1]) 26 | 27 | states = seed(42) 28 | assert states['seed'] == 42 29 | assert (states['torch_state'] != prev_torch_state).any() 30 | assert (states['numpy_state'] != prev_numpy_state).any() 31 | assert (states['random_state'] != prev_random_state).any() 32 | 33 | # advance the RNG 34 | torch.randint(0, 10, (1,)) 35 | np.random.randint(0, 10) 36 | 37 | # reseeding should reset the RNG 38 | new_states = seed(42) 39 | assert new_states['seed'] == 42 40 | assert (new_states['torch_state'] == states['torch_state']).all() 41 | assert (new_states['numpy_state'] == states['numpy_state']).all() 42 | assert (new_states['random_state'] == states['random_state']).all() 43 | 44 | def test_input_validation(self, torch_distributed): 45 | with pytest.raises(RuntimeError): 46 | dml.seed(2**80) 47 | assert dml.seed(2**64 - 1) == 2**64 - 1 48 | 49 | def test_single_worker_random(self, torch_distributed): 50 | prev_torch_state = np.array(torch.get_rng_state()) 51 | prev_numpy_state = np.random.get_state()[1] 52 | prev_random_state = np.array(random.getstate()[1]) 53 | 54 | states = seed() 55 | assert type(states['seed']) is int 56 | assert (states['torch_state'] != prev_torch_state).any() 57 | assert (states['numpy_state'] != prev_numpy_state).any() 58 | assert (states['random_state'] != prev_random_state).any() 59 | 60 | # reseeding should yield different states 61 | new_states = seed() 62 | assert new_states['seed'] != states['seed'] 63 | assert (new_states['torch_state'] != states['torch_state']).any() 64 | assert (new_states['numpy_state'] != states['numpy_state']).any() 65 | assert (new_states['random_state'] != states['random_state']).any() 66 | 67 | def test_multi_worker_deterministic(self, distributed_environment): 68 | states = distributed_environment(4).start(seed, 42) 69 | assert [s['seed'] for s in states] == [42, 42, 42, 42] 70 | 71 | # workers should have different states 72 | assert all((s['torch_state'] != states[0]['torch_state']).any() for s in states[1:]) 73 | assert all((s['numpy_state'] != states[0]['numpy_state']).any() for s in states[1:]) 74 | assert all((s['random_state'] != states[0]['random_state']).any() for s in states[1:]) 75 | 76 | # same seed should yield same states 77 | new_states = distributed_environment(4).start(seed, 42) 78 | assert [s['seed'] for s in new_states] == [42, 42, 42, 42] 79 | assert all((s1['torch_state'] == s2['torch_state']).all() for s1, s2 in zip(states, new_states)) 80 | assert all((s1['numpy_state'] == s2['numpy_state']).all() for s1, s2 in zip(states, new_states)) 81 | assert all((s1['random_state'] == s2['random_state']).all() for s1, s2 in zip(states, new_states)) 82 | 83 | # different seed should yield different states 84 | new_states = distributed_environment(4).start(seed, 11) 85 | assert [s['seed'] for s in new_states] == [11, 11, 11, 11] 86 | assert all((s1['torch_state'] != s2['torch_state']).any() for s1, s2 in zip(states, new_states)) 87 | assert all((s1['numpy_state'] != s2['numpy_state']).any() for s1, s2 in zip(states, new_states)) 88 | assert all((s1['random_state'] != s2['random_state']).any() for s1, s2 in zip(states, new_states)) 89 | 90 | def test_multi_worker_random(self, distributed_environment): 91 | # all workers should have same seeds 92 | states = distributed_environment(4).start(seed) 93 | assert [s['seed'] for s in states] == [states[0]['seed']] * 4 94 | 95 | # workers should have different states 96 | assert all((s['torch_state'] != states[0]['torch_state']).any() for s in states[1:]) 97 | assert all((s['numpy_state'] != states[0]['numpy_state']).any() for s in states[1:]) 98 | assert all((s['random_state'] != states[0]['random_state']).any() for s in states[1:]) 99 | 100 | # reseeding should yield different states and seeds 101 | new_states = distributed_environment(4).start(seed) 102 | assert [s['seed'] for s in new_states] != [s['seed'] for s in states] 103 | assert all((s1['torch_state'] != s2['torch_state']).any() for s1, s2 in zip(states, new_states)) 104 | assert all((s1['numpy_state'] != s2['numpy_state']).any() for s1, s2 in zip(states, new_states)) 105 | assert all((s1['random_state'] != s2['random_state']).any() for s1, s2 in zip(states, new_states)) 106 | 107 | 108 | if __name__ == '__main__': 109 | sys.exit(pytest.main([__file__])) 110 | -------------------------------------------------------------------------------- /test/test_smoke.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import dmlcloud as dml 4 | import pytest 5 | import torch 6 | 7 | 8 | class DummyDataset(torch.utils.data.Dataset): 9 | def __len__(self): 10 | return 256 11 | 12 | def __getitem__(self, idx): 13 | x = torch.randn(10) 14 | y = x.sum() * 0.1 15 | return x, y 16 | 17 | 18 | class DummyStage(dml.Stage): 19 | def pre_stage(self): 20 | self.train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=32) 21 | 22 | model = torch.nn.Sequential( 23 | torch.nn.Linear(10, 32), 24 | torch.nn.Linear(32, 1), 25 | ) 26 | self.model = dml.wrap_ddp(model, self.device) 27 | self.optim = torch.optim.Adam(self.model.parameters(), lr=dml.scale_lr(1e-2)) 28 | self.loss = torch.nn.L1Loss() 29 | 30 | def run_epoch(self): 31 | for x, y in self.train_dl: 32 | self.optim.zero_grad() 33 | 34 | x, y = x.to(self.device), y.to(self.device) 35 | output = self.model(x) 36 | loss = self.loss(output[:, 0], y) 37 | loss.backward() 38 | 39 | self.optim.step() 40 | 41 | self.log('train/loss', loss) 42 | 43 | 44 | class TestSmoke: 45 | def test_smoke(self, torch_distributed): 46 | pipe = dml.Pipeline() 47 | stage = DummyStage(epochs=3) 48 | pipe.append(stage) 49 | pipe.run() 50 | 51 | assert stage.current_epoch == 3 52 | assert 'train/loss' in stage.history 53 | assert stage.history.last()['train/loss'] < 0.1 54 | 55 | 56 | if __name__ == '__main__': 57 | sys.exit(pytest.main([__file__])) 58 | --------------------------------------------------------------------------------