├── .flake8 ├── .github └── workflows │ └── pre-commit.yaml ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.rst ├── pyproject.toml ├── requirements-dev.txt ├── requirements-notebook.txt ├── requirements-train.txt ├── requirements.txt ├── setup.py ├── slogs └── .placeholder ├── slurm ├── notebook.slrm ├── train.slrm └── utils │ ├── report_env_config.sh │ ├── report_repo.sh │ └── report_slurm_config.sh └── template_experiment ├── __init__.py ├── __meta__.py ├── data_transformations.py ├── datasets.py ├── encoders.py ├── evaluation.py ├── io.py ├── train.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | # E203: whitespace before ":". Sometimes violated by black. 4 | # E402: Module level import not at top of file. Violated by lazy imports. 5 | # D100-D107: Missing docstrings 6 | # D200: One-line docstring should fit on one line with quotes. 7 | extend-ignore = E203,E402,D100,D101,D102,D103,D104,D105,D106,D107,D200 8 | docstring-convention = numpy 9 | # F401: Module imported but unused. 10 | # Ignore missing docstrings within unit testing functions. 11 | per-file-ignores = **/__init__.py:F401 **/tests/:D100,D101,D102,D103,D104,D105,D106,D107 12 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | # pre-commit workflow 2 | # 3 | # Ensures the codebase passes the pre-commit stack. 4 | # We run this on GHA to catch issues in commits from contributors who haven't 5 | # set up pre-commit. 6 | 7 | name: pre-commit 8 | 9 | on: [push, pull_request] 10 | 11 | jobs: 12 | pre-commit: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.10" 19 | cache: pip 20 | - uses: pre-commit/action@v3.0.0 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/isort 3 | rev: 6.0.0 4 | hooks: 5 | - id: isort 6 | name: isort 7 | args: ["--profile=black"] 8 | - id: isort 9 | name: isort (cython) 10 | types: [cython] 11 | args: ["--profile=black"] 12 | 13 | - repo: https://github.com/psf/black 14 | rev: 25.1.0 15 | hooks: 16 | - id: black 17 | types: [python] 18 | 19 | - repo: https://github.com/asottile/blacken-docs 20 | rev: 1.19.1 21 | hooks: 22 | - id: blacken-docs 23 | additional_dependencies: ["black==25.1.0"] 24 | 25 | - repo: https://github.com/kynan/nbstripout 26 | rev: 0.8.1 27 | hooks: 28 | - id: nbstripout 29 | 30 | - repo: https://github.com/nbQA-dev/nbQA 31 | rev: 1.9.1 32 | hooks: 33 | - id: nbqa-isort 34 | args: ["--profile=black"] 35 | - id: nbqa-black 36 | additional_dependencies: ["black==25.1.0"] 37 | - id: nbqa-flake8 38 | 39 | - repo: https://github.com/pre-commit/pygrep-hooks 40 | rev: v1.10.0 41 | hooks: 42 | - id: python-check-blanket-noqa 43 | - id: python-check-blanket-type-ignore 44 | - id: python-check-mock-methods 45 | - id: python-no-log-warn 46 | - id: rst-backticks 47 | - id: rst-directive-colons 48 | types: [text] 49 | - id: rst-inline-touching-normal 50 | types: [text] 51 | 52 | - repo: https://github.com/pre-commit/pre-commit-hooks 53 | rev: v5.0.0 54 | hooks: 55 | - id: check-added-large-files 56 | - id: check-ast 57 | - id: check-builtin-literals 58 | - id: check-case-conflict 59 | - id: check-docstring-first 60 | - id: check-shebang-scripts-are-executable 61 | - id: check-merge-conflict 62 | - id: check-json 63 | - id: check-toml 64 | - id: check-xml 65 | - id: check-yaml 66 | - id: debug-statements 67 | - id: destroyed-symlinks 68 | - id: detect-private-key 69 | - id: end-of-file-fixer 70 | exclude: ^LICENSE|\.(html|csv|txt|svg|py)$ 71 | - id: pretty-format-json 72 | args: ["--autofix", "--no-ensure-ascii", "--no-sort-keys"] 73 | exclude_types: [jupyter] 74 | - id: requirements-txt-fixer 75 | - id: trailing-whitespace 76 | args: [--markdown-linebreak-ext=md] 77 | exclude: \.(html|svg)$ 78 | 79 | - repo: https://github.com/asottile/setup-cfg-fmt 80 | rev: v2.7.0 81 | hooks: 82 | - id: setup-cfg-fmt 83 | 84 | - repo: https://github.com/PyCQA/flake8 85 | rev: 7.1.2 86 | hooks: 87 | - id: flake8 88 | additional_dependencies: 89 | - flake8-2020 90 | - flake8-bugbear 91 | - flake8-comprehensions 92 | - flake8-implicit-str-concat 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICEN[CS]E* 2 | include pyproject.toml 3 | include README.rst 4 | include requirements*.txt 5 | 6 | global-exclude *.py[co] 7 | global-exclude __pycache__ 8 | global-exclude *~ 9 | global-exclude *.ipynb_checkpoints/* 10 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | |SLURM| |preempt| |PyTorch| |wandb| |pre-commit| |black| 2 | 3 | PyTorch Experiment Template 4 | =========================== 5 | 6 | This repository gives a fully-featured template or skeleton for new PyTorch 7 | experiments for use on the Vector Institute cluster. 8 | It supports: 9 | 10 | - multi-node, multi-GPU jobs using DistributedDataParallel (DDP), 11 | - preemption handling which gracefully stops and resumes the job, 12 | - logging experiments to Weights & Biases, 13 | - configuration seemlessly scales up as the amount of resources increases. 14 | 15 | If you want to run the example experiment, you can clone the repository as-is 16 | without modifying it. Then see the `Installation`_ and 17 | `Executing the example experiment`_ sections below. 18 | 19 | If you want to create a new repository from this template, you should follow 20 | the `Creating a git repository using this template`_ instructions first. 21 | 22 | 23 | Creating a git repository using this template 24 | --------------------------------------------- 25 | 26 | When creating a new repository from this template, these are the steps to follow: 27 | 28 | #. *Don't click the fork button.* 29 | The fork button is for making a new template based in this one, not for using the template to make a new repository. 30 | 31 | #. Create repository. 32 | 33 | #. **New GitHub repository**. 34 | 35 | You can create a new repository on GitHub from this template by clicking the `Use this template `_ button. 36 | 37 | Then clone your new repository to your local system [pseudocode]: 38 | 39 | .. code-block:: bash 40 | 41 | git clone git@github.com:your_org/your_repo_name.git 42 | cd your_repo_name 43 | 44 | #. **New repository not on GitHub**. 45 | 46 | Alternatively, if your new repository is not going to be on GitHub, you can download `this repo as a zip `_ and work from there. 47 | 48 | Note that this zip does not include the .gitignore and .gitattributes files (because GitHub automatically omits them, which is usually helpful but is not for our purposes). 49 | Thus you will also need to download the `.gitignore `__ and `.gitattributes `__ files. 50 | 51 | #. Delete the LICENSE file and replace it with a LICENSE file of your own choosing. 52 | If the code is intended to be freely available for anyone to use, use an `open source license`_, such as `MIT License`_ or `GPLv3`_. 53 | If you don't want your code to be used by anyone else, add a LICENSE file which just says: 54 | 55 | .. code-block:: none 56 | 57 | Copyright (c) CURRENT_YEAR, YOUR_NAME 58 | 59 | All rights reserved. 60 | 61 | Note that if you don't include a LICENSE file, you will still have copyright over your own code (this copyright is automatically granted), and your code will be private source (technically nobody else will be permitted to use it, even if you make your code publicly available). 62 | 63 | #. Edit the file ``template_experiment/__meta__.py`` to contain your author and repo details. 64 | 65 | name 66 | The name as it would be on PyPI (users will do ``pip install new_name_here``). 67 | It is `recommended `__ to use a name all lowercase, runtogetherwords but if separators are needed hyphens are preferred over underscores. 68 | 69 | path 70 | The path to the package. What you will rename the directory ``template_experiment``. 71 | `Should be `__ the same as ``name``, but now hyphens are disallowed and should be swapped for underscores. 72 | By default, this is automatically inferred from ``name``. 73 | 74 | license 75 | Should be the name of the license you just picked and put in the LICENSE file (e.g. ``MIT`` or ``GPLv3``). 76 | 77 | Other fields to enter should be self-explanatory. 78 | 79 | #. Rename the directory ``template_experiment`` to be the ``path`` variable you just added to ``__meta__.py``: 80 | 81 | .. code-block:: bash 82 | 83 | # Define PROJ_HYPH as your actual project name (use hyphens instead of underscores or spaces) 84 | PROJ_HYPH=your-actual-project-name-with-hyphens-for-spaces 85 | 86 | # Automatically convert hyphens to underscores to get the directory name 87 | PROJ_DIRN="${PROJ_HYPH//-/_}" 88 | # Rename the directory 89 | mv template_experiment "$PROJ_DIRN" 90 | 91 | #. Change references to ``template_experiment`` and ``template-experiment`` 92 | to your path variable. 93 | 94 | This can be done with the sed command: 95 | 96 | .. code-block:: bash 97 | 98 | sed -i "s/template_experiment/$PROJ_DIRN/" "$PROJ_DIRN"/*.py setup.py pyproject.toml slurm/*.slrm .pre-commit-config.yaml 99 | sed -i "s/template-experiment/$PROJ_HYPH/" "$PROJ_DIRN"/*.py setup.py pyproject.toml slurm/*.slrm .pre-commit-config.yaml 100 | 101 | Which will make changes in the following places. 102 | 103 | - In ``setup.py``, `L51 `__:: 104 | 105 | exec(read("template_experiment/__meta__.py"), meta) 106 | 107 | - In ``__meta__.py``, `L2,4 `__:: 108 | 109 | name = "template-experiment" 110 | 111 | - In ``train.py``, `L18-19 `__:: 112 | 113 | from template_experiment import data_transformations, datasets, encoders, utils 114 | from template_experiment.evaluation import evaluate 115 | 116 | - In ``train.py``, `L1260 `__:: 117 | 118 | group.add_argument( 119 | "--wandb-project", 120 | type=str, 121 | default="template-experiment", 122 | help="Name of project on wandb, where these runs will be saved.", 123 | ) 124 | 125 | - In ``slurm/train.slrm``, `L19 `__:: 126 | 127 | #SBATCH --job-name=template-experiment # Set this to be a shorthand for your project's name. 128 | 129 | - In ``slurm/train.slrm``, `L23 `__:: 130 | 131 | PROJECT_NAME="template-experiment" 132 | 133 | - In ``slurm/notebook.slrm``, `L15 `__:: 134 | 135 | PROJECT_NAME="template-experiment" 136 | 137 | #. Swap out the contents of ``README.rst`` with an initial description of your project. 138 | If you prefer, you can use markdown (``README.md``) instead of rST: 139 | 140 | .. code-block:: bash 141 | 142 | git rm README.rst 143 | # touch README.rst 144 | touch README.md && sed -i "s/.rst/.md/" MANIFEST.in 145 | 146 | #. Add your changes to the repo's initial commit and force-push your changes: 147 | 148 | .. code-block:: bash 149 | 150 | git add . 151 | git commit --amend -m "Initial commit" 152 | git push --force 153 | 154 | .. _PEP-8: https://www.python.org/dev/peps/pep-0008/ 155 | .. _open source license: https://choosealicense.com/ 156 | .. _MIT License: https://choosealicense.com/licenses/mit/ 157 | .. _GPLv3: https://choosealicense.com/licenses/gpl-3.0/ 158 | 159 | 160 | Installation 161 | ------------ 162 | 163 | I recommend using miniconda to create an environment for your project. 164 | By using one virtual environment dedicated to each project, you are ensured 165 | stability - if you upgrade a package for one project, it won't affect the 166 | environments you already have established for the others. 167 | 168 | Vector one-time set-up 169 | ~~~~~~~~~~~~~~~~~~~~~~ 170 | 171 | Run this code block to install miniconda before you make your first environment 172 | (you don't need to re-run this every time you start a new project): 173 | 174 | .. code-block:: bash 175 | 176 | # Login to Vector 177 | ssh USERNAME@v.vectorinstitute.ai 178 | # Enter your password and 2FA code to login. 179 | # Run the rest of this code block on the gateway node of the cluster that 180 | # you get to after establishing the ssh connection. 181 | 182 | # Make a screen session for us to work in 183 | screen; 184 | 185 | # Download miniconda to your ~/Downloads directory 186 | mkdir -p $HOME/Downloads; 187 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 188 | -O "$HOME/Downloads/miniconda.sh"; 189 | # Install miniconda to the home directory, if it isn't there already. 190 | if [ ! -d "$HOME/miniconda/bin" ]; then 191 | if [ -d "$HOME/miniconda" ]; then rm -r "$HOME/miniconda"; fi; 192 | bash $HOME/Downloads/miniconda.sh -b -p "$HOME/miniconda"; 193 | fi; 194 | 195 | # Add conda to the PATH environment variable 196 | export PATH="$HOME/miniconda/bin:$PATH"; 197 | 198 | # Automatically say yes to any check from conda (optional) 199 | conda config --set always_yes yes 200 | 201 | # Set the command prompt prefix to be the name of the current venv 202 | conda config --set env_prompt '({name}) ' 203 | 204 | # Add conda setup to your ~/.bashrc file 205 | conda init; 206 | 207 | # Now exit this screen session (you have to exit the current terminal 208 | # session after conda init, and exiting the screen session achieves that 209 | # without closing the ssh connection) 210 | exit; 211 | 212 | Follow this next step if you want to use `Weights and Biases`_ to log your experiments. 213 | Weights and Biases is an online service for tracking your experiments which is 214 | free for academic usage. 215 | To set this up, you need to install the wandb pip package, and you'll need to 216 | `create a Weights and Biases account `_ if you don't already have one: 217 | 218 | .. code-block:: bash 219 | 220 | # (On v.vectorinstitute.ai) 221 | # You need to run the conda setup instructions that miniconda added to 222 | # your ~/.bashrc file so that conda is on your PATH and you can run it. 223 | # Either create a new screen session - when you launch a new screen session, 224 | # bash automatically runs source ~/.bashrc 225 | screen; 226 | # Or stay in your current window and explicitly yourself run 227 | source ~/.bashrc 228 | # Either way, you'll now see "(miniconda)" at the left of your command prompt, 229 | # indicating miniconda is on your PATH and using your default conda environment. 230 | 231 | # Install wandb 232 | pip install wandb 233 | 234 | # Log in to wandb at the command prompt 235 | wandb login 236 | # wandb asks you for your username, then password 237 | # Then wandb creates a file in ~/.netrc which it uses to automatically login in the future 238 | 239 | .. _Weights and Biases: https://wandb.ai/ 240 | .. _wandb-signup: https://wandb.ai/login?signup=true 241 | 242 | 243 | Project one-time set-up 244 | ~~~~~~~~~~~~~~~~~~~~~~~ 245 | 246 | Run this code block once every time you start a new project from this template. 247 | Change ENVNAME to equal the name of your project. This code will then create a 248 | new virtual environment to use for the project: 249 | 250 | .. code-block:: bash 251 | 252 | # (On v.vectorinstitute.ai) 253 | # You need to run the conda setup instructions that miniconda added to 254 | # your ~/.bashrc file so that conda is on your PATH and you can run it. 255 | # Either create a new screen session - when you launch a new screen session, 256 | # bash automatically runs source ~/.bashrc 257 | screen; 258 | # Or stay in your current window and explicitly yourself run 259 | source ~/.bashrc 260 | # Either way, you'll now see "(miniconda)" at the left of your command prompt, 261 | # indicating miniconda is on your PATH and using your default conda environment. 262 | 263 | # Now run the following one-time setup per virtual environment (i.e. once per project) 264 | 265 | # Pick a name for the new environment. 266 | # It should correspond to the name of your project (hyphen separated, no spaces) 267 | ENVNAME=template-experiment 268 | 269 | # Create a python3.x conda environment, with pip installed, with this name. 270 | conda create -y --name "$ENVNAME" -q python=3 pip 271 | 272 | # Activate the environment 273 | conda activate "$ENVNAME" 274 | # The command prompt should now have your environment at the left of it, e.g. 275 | # (template-experiment) slowe@v3:~$ 276 | 277 | 278 | Resuming work on an existing project 279 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 280 | 281 | Run this code block when you want to resume work on an existing project: 282 | 283 | .. code-block:: bash 284 | 285 | # (On v.vectorinstitute.ai) 286 | # Run conda setup in ~/.bashrc if you it hasn't already been run in this 287 | # terminal session 288 | source ~/.bashrc 289 | # The command prompt should now say (miniconda) at the left of it. 290 | 291 | # Activate the environment 292 | conda activate template-experiment 293 | # The command prompt should now have your environment at the left of it, e.g. 294 | # (template-experiment) slowe@v3:~$ 295 | 296 | 297 | Executing the example experiment 298 | -------------------------------- 299 | 300 | The following commands describe how to setup and run the example repository 301 | in its unmodified state. 302 | 303 | To run the code in a repository you have 304 | `created from this template `_, 305 | replace ``template-experiment`` with the name of your package and 306 | ``template_experiment`` with the name of your package directory, etc. 307 | 308 | Set-up 309 | ~~~~~~ 310 | 311 | #. If you haven't already, then follow the `Vector one-time set-up`_ 312 | instructions. 313 | 314 | #. Then clone the repository: 315 | 316 | .. code-block:: bash 317 | 318 | git clone git@github.com:scottclowe/pytorch-experiment-template.git 319 | cd pytorch-experiment-template 320 | 321 | #. Run the `Project one-time set-up`_ (using ``template-experiment`` as 322 | the environment name). 323 | 324 | #. With the project's conda environment activated, install the package and its 325 | training dependencies: 326 | 327 | .. code-block:: bash 328 | 329 | pip install --editable .[train] 330 | 331 | This step will typically take 5-10 minutes to run. 332 | 333 | #. Check the installation by running the help command: 334 | 335 | .. code-block:: bash 336 | 337 | python template_experiment/train.py -h 338 | 339 | This should print the help message for the training script. 340 | 341 | 342 | Example commands 343 | ~~~~~~~~~~~~~~~~ 344 | 345 | - To run the default training command locally: 346 | 347 | .. code-block:: bash 348 | 349 | python template_experiment/train.py 350 | 351 | or alternatively:: 352 | 353 | template-experiment-train 354 | 355 | - Run the default training command with on the cluster with SLURM. 356 | First, ssh into the cluster and cd to the project repository. 357 | You don't need to activate the project's conda environment. 358 | Then use sbatch to add your SLURM job to the queue: 359 | 360 | .. code-block:: bash 361 | 362 | sbatch slurm/train.slrm 363 | 364 | - You can supply arguments to sbatch by including them before the path to the 365 | SLURM script. 366 | Arguments set on the command prompt like this will override the arguments in 367 | ``slurm/train.slrm``. 368 | This is useful for customizing the job name, for example: 369 | 370 | .. code-block:: bash 371 | 372 | sbatch --job-name=exp_cf10_rn18 slurm/train.slrm 373 | 374 | I recommend you should pretty much always customize the name of your job. 375 | The custom job name will be visible in the output of ``squeue -u "$USER"`` 376 | when browsing your active jobs (helpful if you have multiple jobs running 377 | and need to check on their status or cancel one of them). 378 | When using this codebase, the custom job name is also used in the path to the 379 | checkpoint, the path to the SLURM log file, and the name of the job on wandb. 380 | 381 | - Any arguments you include after ``slurm/train.slrm`` will be passed through to train.py. 382 | 383 | For example, you can specify to use a pretrained model: 384 | 385 | .. code-block:: bash 386 | 387 | sbatch --job-name=exp_cf10_rn18-pt slurm/train.slrm --dataset=cifar10 --pretrained 388 | 389 | change the architecture and dataset: 390 | 391 | .. code-block:: bash 392 | 393 | sbatch --job-name=exp_cf100_vit-pt \ 394 | slurm/train.slrm --dataset=cifar100 --model=vit_small_patch16_224 --pretrained 395 | 396 | or change the learning rate of the encoder: 397 | 398 | .. code-block:: bash 399 | 400 | sbatch --job-name=exp_cf10_rn18-pt_enc-lr-0.01 \ 401 | slurm/train.slrm --dataset=cifar10 --pretrained --lr-encoder-mult=0.01 402 | 403 | - You can trivially scale up the job to run across multiple GPUs, either by 404 | changing the gres argument to use more of the GPUs on the node (up to 8 GPUs 405 | per node on the t4v2 partition, 4 GPUs per node otherwise): 406 | 407 | .. code-block:: bash 408 | 409 | sbatch --job-name=exp_cf10_rn18-pt_4gpu --gres=gpu:4 slurm/train.slrm --pretrained 410 | 411 | or increasing the number of nodes being requested: 412 | 413 | .. code-block:: bash 414 | 415 | sbatch --job-name=exp_cf10_rn18-pt_2x1gpu --nodes=2 slurm/train.slrm --pretrained 416 | 417 | or both: 418 | 419 | .. code-block:: bash 420 | 421 | sbatch --job-name=exp_cf10_rn18-pt_2x4gpu --nodes=2 --gres=gpu:4 slurm/train.slrm --pretrained 422 | 423 | In each case, the amount of memory and CPUs requested in the SLURM job will 424 | automatically be scaled up with the number of GPUs requested. 425 | The total batch size will be scaled up by the number of GPUs requested too. 426 | 427 | As you run these commands, you can see the results logged on wandb at 428 | https://wandb.ai/your-username/template-experiment 429 | 430 | 431 | Jupyter notebook 432 | ~~~~~~~~~~~~~~~~ 433 | 434 | You can use the script ``slurm/notebook.slrm`` to launch a Jupyter notebook 435 | server on one of the interactive compute nodes. 436 | This uses the methodology of https://support.vectorinstitute.ai/jupyter_notebook 437 | 438 | You'll need to install jupyter into your conda environment to launch the notebook. 439 | After activating the environment for this project, run: 440 | 441 | .. code-block:: bash 442 | 443 | pip install -r requirements-notebook.txt 444 | 445 | To launch a notebook server and connect to it on your local machine, perform 446 | the following steps. 447 | 448 | #. Run the notebook SLURM script to launch the jupyter notebook: 449 | 450 | .. code-block:: bash 451 | 452 | sbatch slurm/notebook.slrm 453 | 454 | The job will launch on one of the interactive nodes, and will acquire a 455 | random port on that node to serve the notebook on. 456 | 457 | #. Wait for the job to start running. You can monitor it with: 458 | 459 | .. code-block:: bash 460 | 461 | squeue --me 462 | 463 | Note the job id of the notebook job. e.g.: 464 | 465 | .. code-block:: none 466 | 467 | (template-experiment) slowe@v2:~/pytorch-experiment-template$ squeue --me 468 | JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON) 469 | 10618891 interacti jnb slowe R 1:07 1 gpu026 470 | 471 | Here we can see our JOBID is 10618891, and it is running on node gpu026. 472 | 473 | #. Inspect the output of the job with: 474 | 475 | .. code-block:: bash 476 | 477 | cat jnb_JOBID.out 478 | 479 | e.g.: 480 | 481 | .. code-block:: bash 482 | 483 | cat jnb_10618891.out 484 | 485 | The output will contain the port number that the notebook server is using, 486 | and the token as follows: 487 | 488 | .. code-block:: none 489 | 490 | To access the server, open this file in a browser: 491 | file:///ssd005/home/slowe/.local/share/jupyter/runtime/jpserver-7885-open.html 492 | Or copy and paste one of these URLs: 493 | http://gpu026:47201/tree?token=f54c10f52e3dad08e19101149a54985d1561dca7eec96b29 494 | http://127.0.0.1:47201/tree?token=f54c10f52e3dad08e19101149a54985d1561dca7eec96b29 495 | 496 | Here we can see the job is on node gpu026 and the notebook is being served 497 | on port 47201. 498 | We will need to use the token f54c10f52e3dad08e19101149a54985d1561dca7eec96b29 499 | to log in to the notebook. 500 | 501 | #. On your local machine, use ssh to forward the port from the compute node to 502 | your local machine: 503 | 504 | .. code-block:: bash 505 | 506 | ssh USERNAME@v.vectorinstitute.ai -N -L 8887:gpu026:47201 507 | 508 | You need to replace USERNAME with your Vector username, gpu026 with the node 509 | your job is running on, and 47201 with the port number from the previous 510 | step. 511 | In this example, the local port which the notebook is being forwarded to is 512 | port 8887. 513 | 514 | #. Open a browser on your local machine and navigate to http://localhost:8887 515 | (or whatever port you chose in the previous step): 516 | 517 | .. code-block:: bash 518 | 519 | sensible-browser http://localhost:8887 520 | 521 | You should see the Jupyter notebook interface. 522 | Copy the token from the URL shown in the log file and paste it into the 523 | ``Password or token: [ ] Log in`` box. 524 | You should now have access to the remote notebook server on your local 525 | machine. 526 | 527 | #. Once you are done working in your notebooks (and have saved your changes), 528 | make sure to end the job running the notebook with: 529 | 530 | .. code-block:: bash 531 | 532 | scancel JOBID 533 | 534 | e.g.: 535 | 536 | .. code-block:: bash 537 | 538 | scancel 10618891 539 | 540 | This will free up the interactive GPU node for other users to use. 541 | 542 | Note that you can skip the need to copy the access token if you 543 | `set up Jupyter notebook to use a password `_ instead. 544 | 545 | .. _jnb-password: https://saturncloud.io/blog/how-to-autoconfigure-jupyter-password-from-command-line/ 546 | 547 | 548 | Features 549 | -------- 550 | 551 | This template includes the following features. 552 | 553 | 554 | Scalable training script 555 | ~~~~~~~~~~~~~~~~~~~~~~~~ 556 | 557 | The SLURM training script ``slurm/train.slrm`` will interface with the python 558 | training script ``template_experiment/train.py`` to train a model on multiple 559 | GPUs across, multiple nodes, using DistributedDataParallel_ (DDP). 560 | 561 | The SLURM script is configured to scale up the amount of RAM and CPUs requested 562 | with the GPUs requested. 563 | 564 | The arguments to the python script control the batch size per GPU, and the 565 | learning rate for a fixed batch size of 128 samples. 566 | The total batch size will automatically scale up when deployed on more GPUs, 567 | and the learning rate will automatically scale up linearly with the total batch 568 | size. (This is the linear scaling rule from `Training ImageNet in 1 Hour`_.) 569 | 570 | .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html 571 | .. _Training ImageNet in 1 Hour: https://arxiv.org/abs/1706.02677 572 | 573 | 574 | Preemptable 575 | ~~~~~~~~~~~ 576 | 577 | Everything is set up to resume correctly if the job is interrupted by 578 | preemption. 579 | 580 | 581 | Checkpoints 582 | ~~~~~~~~~~~ 583 | 584 | The training script will save a checkpoint every epoch, and will resume from 585 | this if the job is interrupted by preemption. 586 | 587 | The checkpoint for a job will be saved to the directory 588 | ``/checkpoint/USERNAME/PROJECT__JOBNAME__JOBID`` (with double-underscores 589 | between each category) along with a record of the conda environment and 590 | frozen pip requirements used to run the job in ``environment.yml`` and 591 | ``frozen-requirements.txt``. 592 | 593 | 594 | Log messages 595 | ~~~~~~~~~~~~ 596 | 597 | Any print statements and error messages from the training script will be saved 598 | to the file ``slogs/JOBNAME__JOBID_ARRAYID.out``. 599 | Only the output from the rank 0 worker (the worker which saves the 600 | checkpoints and sends logs to wandb) will be saved to this file. 601 | When using multiple nodes, the output from each node will be saved to a 602 | separate file: ``slogs-inner/JOBNAME__JOBID_ARRAYID-NODERANK.out``. 603 | 604 | You can monitor the progress of a job that is currently running by monitoring 605 | the contents of its log file. For example: 606 | 607 | .. code-block:: bash 608 | 609 | tail -n 50 -f slogs/JOBNAME__JOBID_ARRAYID.out 610 | 611 | 612 | Weights and Biases 613 | ~~~~~~~~~~~~~~~~~~ 614 | 615 | `Weights and Biases`_ (wandb) is an online service for tracking your 616 | experiments which is free for academic usage. 617 | 618 | This template repository is set up to automatically log your experiments, using 619 | the same job label across both SLURM and wandb. 620 | 621 | If the job is preempted, the wandb logging will resume to the same wandb job 622 | ID instead of spawning a new one. 623 | 624 | 625 | Random Number Generator (RNG) state 626 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 627 | 628 | All RNG states are configured based on the overall seed that is set with the 629 | ``--seed`` argument to ``train.py``. 630 | 631 | When running ``train.py`` directly, the seed is **not** set by default, so 632 | behaviour will not be reproducible. 633 | You will need to include the argument ``--seed=0`` (for example), to make sure 634 | your experiments are reproducible. 635 | 636 | When running on SLURM with slurm/train.slrm, the seed **is** set by default. 637 | The seed used is equal the `array ID `_ of the job. 638 | This configuration lets you easily run the same job with multiple seeds in one 639 | sbatch command. 640 | Our default job array in ``slurm/train.slrm`` is ``--array=0``, so only one job 641 | will be launched, and that job will use the default seed of ``0``. 642 | 643 | To launch the same job 5 times, each with a different seed (1, 2, 3, 4, and 5): 644 | 645 | .. code-block:: bash 646 | 647 | sbatch --array=1-5 slurm/train.slrm 648 | 649 | or to use seeds 42 and 888: 650 | 651 | .. code-block:: bash 652 | 653 | sbatch --array=42,888 slurm/train.slrm 654 | 655 | or to use a randomly selected seed: 656 | 657 | .. code-block:: bash 658 | 659 | sbatch --array="$RANDOM" slurm/train.slrm 660 | 661 | The seed is used to set the following RNG states: 662 | 663 | - Each epoch gets its own RNG seed (derived from the overall seed and the epoch 664 | number). 665 | The RNG state is set with this seed at the start of each epoch. This makes it 666 | possible to resume from preemption without needing to save all the RNG states 667 | to the model checkpoint and restore them on resume. 668 | 669 | - Each GPU gets its own RNG seed, so any random operations such as dropout 670 | or random masking in the training script itself will be different on each 671 | GPU, but deterministically so. 672 | 673 | - The dataloader workers each have distinct seeds from each other for torch, 674 | numpy and python's random module, so randomly selected augmentations won't be 675 | replicated across workers. 676 | (Pytorch only sets up its own worker seeds correctly, leaving numpy and 677 | random mirrored across all workers.) 678 | 679 | **Caution:** To get *exactly* the same model produced when training with the 680 | same seed, you will need to run the training script with the ``--deterministic`` 681 | flag to disable cuDNN's non-deterministic operations *and* use precisely the 682 | same number of GPU devices and CPU workers on each attempt. 683 | Without these steps, the model will be *almost* the same (because the initial 684 | seed for the model parameters was the same, and the training trajectory was 685 | very similar), but not *exactly* the same, due to (a) non-deterministic cuDNN 686 | operations (b) the batch size increasing with the number of devices 687 | (c) any randomized augmentation operations depending on the identity of the CPU 688 | worker, which will each have an offset seed. 689 | 690 | .. _slurm-job-array: https://slurm.schedmd.com/job_array.html 691 | 692 | 693 | Prototyping mode, with distinct val/test sets 694 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 695 | 696 | Initial experiments and hyperparameter searches should be performed without 697 | seeing the final test performance. They should be run only on a validation set. 698 | Unfortunately, many datasets do not come with a validation set, and it is easy 699 | to accidentally use the test set as a validation set, which can lead to 700 | overfitting the model selection on the test set. 701 | 702 | The image datasets implemented in ``template_experiment/datasets.py`` come with 703 | support for creating a validation set from the training set, which is separate 704 | from the test set. You should use this (with flag ``--prototyping``) during the 705 | initial model development steps and for any hyperparameter searches. 706 | 707 | Your final models should be trained without ``--prototyping`` enabled, so that 708 | the full training set is used for training and the best model is produced. 709 | 710 | 711 | Optional extra package dependencies 712 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 713 | 714 | There are several requirements files in the root directory of the repository. 715 | The idea is the requirements.txt file contains the minimal set of packages 716 | that are needed to use the models in the package. 717 | The other requirements files are for optional extra packages. 718 | 719 | requirements-dev.txt 720 | Extra packages needed for code development (i.e. writing the codebase) 721 | 722 | requirements-notebook.txt 723 | Extra packages needed for running the notebooks. 724 | 725 | requirements-train.txt 726 | Extra packages needed for training the models. 727 | 728 | The setup.py file will automatically parse any requirements files in the 729 | root directory of the repository which are named like ``requirements-*.txt`` 730 | and make them available to ``pip`` as extras. 731 | 732 | For example, to install the repository to your virtual environment with the 733 | extra packages needed for training:: 734 | 735 | pip install --editable .[train] 736 | 737 | You can also install all the extras at once:: 738 | 739 | pip install --editable .[all] 740 | 741 | Or you can install the extras directly from the requirements files:: 742 | 743 | pip install -r requirements-train.txt 744 | 745 | As a developer of the repository, you will need to pip install the package 746 | with the ``--editable`` flag so the installed copy is updated automatically 747 | when you make changes to the codebase. 748 | 749 | 750 | Automated code checking and formatting 751 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 752 | 753 | The template repository comes with a pre-commit_ stack. 754 | This is a set of git hooks which are executed every time you make a commit. 755 | The hooks catch errors as they occur, and will automatically fix some of these errors. 756 | 757 | To set up the pre-commit hooks, run the following code from within the repo directory:: 758 | 759 | pip install -r requirements-dev.txt 760 | pre-commit install 761 | 762 | Whenever you try to commit code which is flagged by the pre-commit hooks, 763 | *the commit will not go through*. Some of the pre-commit hooks 764 | (such as black_, isort_) will automatically modify your code to fix formatting 765 | issues. When this happens, you'll have to stage the changes made by the commit 766 | hooks and then try your commit again. Other pre-commit hooks, such as flake8_, 767 | will not modify your code and will just tell you about issues in what you tried 768 | to commit (e.g. a variable was declared and never used), and you'll then have 769 | to manually fix these yourself before staging the corrected version. 770 | 771 | After installing it, the pre-commit stack will run every time you try to make 772 | a commit to this repository on that machine. 773 | You can also manually run the pre-commit stack on all the files at any time: 774 | 775 | .. code-block:: bash 776 | 777 | pre-commit run --all-files 778 | 779 | To force a commit to go through without passing the pre-commit hooks use the ``--no-verify`` flag: 780 | 781 | .. code-block:: bash 782 | 783 | git commit --no-verify 784 | 785 | The pre-commit stack which comes with the template is highly opinionated, and 786 | includes the following operations: 787 | 788 | - All **outputs in Jupyter notebooks are cleared** using nbstripout_. 789 | 790 | - Code is reformatted to use the black_ style. 791 | Any code inside docstrings will be formatted to black using blackendocs_. 792 | All code cells in Jupyter notebooks are also formatted to black using black_nbconvert_. 793 | 794 | - Imports are automatically sorted using isort_. 795 | 796 | - Entries in requirements.txt files are automatically sorted alphabetically. 797 | 798 | - Several `hooks from pre-commit `_ are used to screen for 799 | non-language specific git issues, such as incomplete git merges, overly large 800 | files being commited to the repo, bugged JSON and YAML files. 801 | 802 | - JSON files are also prettified automatically to have standardised indentation. 803 | 804 | The pre-commit stack will also run on github with one of the action workflows, 805 | which ensures the code that is pushed is validated without relying on every 806 | contributor installing pre-commit locally. 807 | 808 | This development practice of using pre-commit_, and standardizing the 809 | code-style using black_, is popular among leading open-source python projects 810 | including numpy, scipy, sklearn, Pillow, and many others. 811 | 812 | If you want to use pre-commit, but **want to commit outputs in Jupyter notebooks** 813 | instead of stripping them, simply remove the nbstripout_ hook from the 814 | `.pre-commit-config.yaml file `__ 815 | and commit that change. 816 | 817 | If you don't want to use pre-commit at all, you can uninstall it: 818 | 819 | .. code-block:: bash 820 | 821 | pre-commit uninstall 822 | 823 | and purge it (along with black and flake8) from the repository: 824 | 825 | .. code-block:: bash 826 | 827 | git rm .pre-commit-config.yaml .flake8 .github/workflows/pre-commit.yaml 828 | git commit -m "DEV: Remove pre-commit hooks" 829 | 830 | .. _black: https://github.com/psf/black 831 | .. _black_nbconvert: https://github.com/dfm/black_nbconvert 832 | .. _blackendocs: https://github.com/asottile/blacken-docs 833 | .. _flake8: https://gitlab.com/pycqa/flake8 834 | .. _isort: https://github.com/timothycrosley/isort 835 | .. _nbstripout: https://github.com/kynan/nbstripout 836 | .. _pre-commit: https://pre-commit.com/ 837 | .. _pre-commit-hooks: https://github.com/pre-commit/pre-commit-hooks 838 | .. _pre-commit-py-hooks: https://github.com/pre-commit/pygrep-hooks 839 | 840 | 841 | Additional features 842 | ------------------- 843 | 844 | This template was forked from a more general `python template repository`_. 845 | 846 | For more information on the features of the python template repository, see 847 | `here `_. 848 | 849 | .. _`python template repository`: https://github.com/scottclowe/python-template-repo 850 | .. _`python-template-repository-features`: https://github.com/scottclowe/python-template-repo#features 851 | 852 | 853 | Contributing 854 | ------------ 855 | 856 | Contributions are welcome! If you can see a way to improve this template: 857 | 858 | - Clone this repo 859 | - Create a feature branch 860 | - Make your changes in the feature branch 861 | - Push your branch and make a pull request 862 | 863 | Or to report a bug or request something new, make an issue. 864 | 865 | 866 | .. |SLURM| image:: https://img.shields.io/badge/scheduler-SLURM-40B1EC 867 | :target: https://slurm.schedmd.com/ 868 | :alt: SLURM 869 | .. |preempt| image:: https://img.shields.io/badge/preemption-supported-brightgreen 870 | :alt: preemption 871 | .. |PyTorch| image:: https://img.shields.io/badge/PyTorch-DDP-EE4C2C?logo=pytorch&logoColor=EE4C2C 872 | :target: https://pytorch.org/ 873 | :alt: pytorch 874 | .. |wandb| image:: https://img.shields.io/badge/Weights_%26_Biases-enabled-FFCC33?logo=WeightsAndBiases&logoColor=FFCC33 875 | :target: https://wandb.ai 876 | :alt: Weights&Biases 877 | .. |pre-commit| image:: https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white 878 | :target: https://github.com/pre-commit/pre-commit 879 | :alt: pre-commit 880 | .. |black| image:: https://img.shields.io/badge/code%20style-black-000000.svg 881 | :target: https://github.com/psf/black 882 | :alt: black 883 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 120 7 | target-version = ["py38"] 8 | 9 | [tool.isort] 10 | src_paths = ["template_experiment"] 11 | known_first_party = ["template_experiment"] 12 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black==25.1.0 2 | identify>=1.4.20 3 | pre-commit 4 | -------------------------------------------------------------------------------- /requirements-notebook.txt: -------------------------------------------------------------------------------- 1 | jupyter[notebook] 2 | -------------------------------------------------------------------------------- /requirements-train.txt: -------------------------------------------------------------------------------- 1 | wandb 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn>=1.2.1 3 | timm>=0.6.12 4 | torch>=1.12 5 | torchvision>=0.13 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | from setuptools import find_packages, setup 7 | 8 | 9 | def read(fname): 10 | """ 11 | Read the contents of a file. 12 | 13 | Parameters 14 | ---------- 15 | fname : str 16 | Path to file. 17 | 18 | Returns 19 | ------- 20 | str 21 | File contents. 22 | """ 23 | with open(os.path.join(os.path.dirname(__file__), fname)) as f: 24 | return f.read() 25 | 26 | 27 | install_requires = read("requirements.txt").splitlines() 28 | 29 | # Dynamically determine extra dependencies 30 | extras_require = {} 31 | extra_req_files = glob.glob("requirements-*.txt") 32 | for extra_req_file in extra_req_files: 33 | name = os.path.splitext(extra_req_file)[0].replace("requirements-", "", 1) 34 | extras_require[name] = read(extra_req_file).splitlines() 35 | 36 | # If there are any extras, add a catch-all case that includes everything. 37 | # This assumes that entries in extras_require are lists (not single strings), 38 | # and that there are no duplicated packages across the extras. 39 | if extras_require: 40 | extras_require["all"] = sorted({x for v in extras_require.values() for x in v}) 41 | 42 | 43 | # Import meta data from __meta__.py 44 | # 45 | # We use exec for this because __meta__.py runs its __init__.py first, 46 | # __init__.py may assume the requirements are already present, but this code 47 | # is being run during the `python setup.py install` step, before requirements 48 | # are installed. 49 | # https://packaging.python.org/guides/single-sourcing-package-version/ 50 | meta = {} 51 | exec(read("template_experiment/__meta__.py"), meta) 52 | 53 | 54 | # Import the README and use it as the long-description. 55 | # If your readme path is different, add it here. 56 | possible_readme_names = ["README.rst", "README.md", "README.txt", "README"] 57 | 58 | # Handle turning a README file into long_description 59 | long_description = meta["description"] 60 | readme_fname = "" 61 | for fname in possible_readme_names: 62 | try: 63 | long_description = read(fname) 64 | except IOError: 65 | # doesn't exist 66 | continue 67 | else: 68 | # exists 69 | readme_fname = fname 70 | break 71 | 72 | # Infer the content type of the README file from its extension. 73 | # If the contents of your README do not match its extension, manually assign 74 | # long_description_content_type to the appropriate value. 75 | readme_ext = os.path.splitext(readme_fname)[1] 76 | if readme_ext.lower() == ".rst": 77 | long_description_content_type = "text/x-rst" 78 | elif readme_ext.lower() == ".md": 79 | long_description_content_type = "text/markdown" 80 | else: 81 | long_description_content_type = "text/plain" 82 | 83 | 84 | setup( 85 | # Essential details on the package and its dependencies 86 | name=meta["name"], 87 | version=meta["version"], 88 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 89 | package_dir={meta["name"]: os.path.join(".", meta["path"])}, 90 | # If any package contains *.txt or *.rst files, include them: 91 | # package_data={"": ["*.txt", "*.rst"],} 92 | python_requires=">=3.8", 93 | install_requires=install_requires, 94 | extras_require=extras_require, 95 | # Metadata to display on PyPI 96 | author=meta["author"], 97 | author_email=meta["author_email"], 98 | description=meta["description"], 99 | long_description=long_description, 100 | long_description_content_type=long_description_content_type, 101 | license=meta["license"], 102 | url=meta["url"], 103 | classifiers=[ 104 | # Trove classifiers 105 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 106 | "Natural Language :: English", 107 | "Intended Audience :: Science/Research", 108 | "Programming Language :: Python", 109 | "Programming Language :: Python :: 3", 110 | "Programming Language :: Python :: 3.8", 111 | "Programming Language :: Python :: 3.9", 112 | "Programming Language :: Python :: 3.10", 113 | "Programming Language :: Python :: 3.11", 114 | "Programming Language :: Python :: 3.12", 115 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 116 | ], 117 | entry_points={ 118 | "console_scripts": [ 119 | "template-experiment-train=template_experiment.train:cli", 120 | ], 121 | }, 122 | ) 123 | -------------------------------------------------------------------------------- /slogs/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scottclowe/pytorch-experiment-template/138c3bdec82bf4fb026d458032e05ae405de1fdb/slogs/.placeholder -------------------------------------------------------------------------------- /slurm/notebook.slrm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=t4v1,t4v2 # Which node partitions to use. Use a comma-separated list if you don't mind which partition: t4v1,t4v2,rtx6000,a40 3 | #SBATCH --nodes=1 # Number of nodes to request. Usually use 1 for interactive jobs. 4 | #SBATCH --tasks-per-node=1 # Number of processes to spawn per node. Should always be set to 1, regardless of number of GPUs! 5 | #SBATCH --gres=gpu:1 # Number of GPUs per node to request 6 | #SBATCH --cpus-per-gpu=4 # Number of CPUs to request per GPU (soft maximum of 4 per GPU requested) 7 | #SBATCH --mem-per-gpu=10G # RAM per GPU 8 | #SBATCH --time=16:00:00 # You must specify a maximum run-time if you want to run for more than 2h 9 | #SBATCH --output=jnb_%j.out # You'll need to inspect this log file to find out how to connect to the notebook 10 | #SBATCH --job-name=jnb 11 | 12 | 13 | # Manually define the project name. 14 | # This should also be the name of your conda environment used for this project. 15 | PROJECT_NAME="template-experiment" 16 | 17 | # Exit if any command hits an error 18 | set -e 19 | 20 | # Store the time at which the script was launched 21 | start_time="$SECONDS" 22 | 23 | echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) begins on $(hostname), submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)" 24 | 25 | # Print slurm config report 26 | echo "Running slurm/utils/report_slurm_config.sh" 27 | source "slurm/utils/report_slurm_config.sh" 28 | 29 | echo "" 30 | echo "-------- Activating environment ----------------------------------------" 31 | date 32 | echo "" 33 | echo "Running ~/.bashrc" 34 | source ~/.bashrc 35 | 36 | # Activate virtual environment 37 | ENVNAME="$PROJECT_NAME" 38 | echo "Activating environment $ENVNAME" 39 | conda activate "$ENVNAME" 40 | echo "" 41 | 42 | # We don't want to use a fixed port number (e.g. 8888) because that would lead 43 | # to collisions between jobs on the same node as each other. 44 | # Try to use a port number automatically selected from the job id number. 45 | JOB_SOCKET=$(( $SLURM_JOB_ID % 16384 + 49152 )) 46 | if ss -tulpn | grep -q ":$JOB_SOCKET "; 47 | then 48 | # The port we selected is in use, so we'll get a random available port instead. 49 | JOB_SOCKET="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')"; 50 | fi 51 | echo "Will use port $JOB_SOCKET to serve the notebook" 52 | echo "" 53 | echo "-------- Connection instructions ---------------------------------------" 54 | echo "" 55 | echo "You should be able to access the notebook locally with the following commands:" 56 | echo "" 57 | echo "ssh ${USER}@v.vectorinstitute.ai -N -L 8887:$(hostname):${JOB_SOCKET}" 58 | echo "sensible-browser http://localhost:8887" 59 | echo "# Then enter the token printed below, or your password if you set one" 60 | echo "" 61 | echo "-------- Starting jupyter notebook -------------------------------------" 62 | date 63 | export XDG_RUNTIME_DIR="" 64 | python -m jupyter notebook --ip 0.0.0.0 --port "$JOB_SOCKET" 65 | -------------------------------------------------------------------------------- /slurm/train.slrm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=t4v1,t4v2 # Which node partitions to use. Use a comma-separated list if you don't mind which partition: t4v1,t4v2,rtx6000,a40 3 | #SBATCH --nodes=1 # Number of nodes to request. Can increase to --nodes=2, etc, for more GPUs (spread out over different nodes). 4 | #SBATCH --tasks-per-node=1 # Number of processes to spawn per node. Should always be set to 1, regardless of number of GPUs! 5 | #SBATCH --gres=gpu:1 # Number of GPUs per node. Can increase to --gres=gpu:2, etc, for more GPUs (together on the same node). 6 | #SBATCH --cpus-per-gpu=4 # Number of CPUs per GPU. Soft maximum of 4 per GPU requested on t4, 8 otherwise. Hard maximum of 32 per node. 7 | #SBATCH --mem-per-gpu=10G # RAM per GPU. Soft maximum of 20G per GPU requested on t4v2, 41G otherwise. Hard maximum of 167G per node. 8 | #SBATCH --time=08:00:00 # You must specify a maximum run-time if you want to run for more than 2h 9 | #SBATCH --signal=B:USR1@120 # Send signal SIGUSR1 120 seconds before the job hits the time limit 10 | #SBATCH --output=slogs/%x__%A_%a.out 11 | # %x=job-name, %A=job ID, %a=array value, %n=node rank, %t=task rank, %N=hostname 12 | # Note: You must create output directory "slogs" before launching job, otherwise it will immediately 13 | # fail without an error message. 14 | # Note: If you specify --output and not --error, then both STDOUT and STDERR will both be sent to the 15 | # file specified by --output. 16 | #SBATCH --array=0 # Use array to run multiple jobs that are identical except for $SLURM_ARRAY_TASK_ID. 17 | # In this example, we use this to set the seed. You can run multiple seeds with --array=0-4, for example. 18 | #SBATCH --open-mode=append # Use append mode otherwise preemption resets the checkpoint file. 19 | #SBATCH --job-name=template-experiment # Set this to be a shorthand for your project's name. 20 | 21 | # Manually define the project name. 22 | # This must also be the name of your conda environment used for this project. 23 | PROJECT_NAME="template-experiment" 24 | # Automatically convert hyphens to underscores, to get the name of the project directory. 25 | PROJECT_DIRN="${PROJECT_NAME//-/_}" 26 | 27 | # Exit the script if any command hits an error 28 | set -e 29 | 30 | # Set up a handler to requeue the job if it hits the time-limit without terminating 31 | function term_handler() 32 | { 33 | echo "** Job $SLURM_JOB_NAME ($SLURM_JOB_ID) received SIGUSR1 at $(date) **" 34 | echo "** Requeuing job $SLURM_JOB_ID so it can run for longer **" 35 | scontrol requeue "${SLURM_JOB_ID}" 36 | } 37 | # Call this term_hnadler function when the job recieves the SIGUSR1 signal 38 | trap term_handler SIGUSR1 39 | 40 | # sbatch script for Vector 41 | # Inspired by: 42 | # https://github.com/VectorInstitute/TechAndEngineering/blob/master/benchmarks/resnet_torch/sample_script/script.sh 43 | # https://github.com/VectorInstitute/TechAndEngineering/blob/master/checkpoint_examples/PyTorch/launch_job.slrm 44 | # https://github.com/VectorInstitute/TechAndEngineering/blob/master/checkpoint_examples/PyTorch/run_train.sh 45 | # https://github.com/PrincetonUniversity/multi_gpu_training/tree/main/02_pytorch_ddp 46 | # https://pytorch.org/docs/stable/elastic/run.html 47 | # https://pytorch.org/tutorials/intermediate/ddp_tutorial.html 48 | # https://unix.stackexchange.com/a/146770/154576 49 | 50 | # Store the time at which the script was launched, so we can measure how long has elapsed. 51 | start_time="$SECONDS" 52 | 53 | echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) begins on $(hostname), submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)" 54 | echo "" 55 | # Print slurm config report (SLURM environment variables, some of which we use later in the script) 56 | # By sourcing the script, we execute it as if its code were here in the script 57 | # N.B. This script only prints things out, it doesn't assign any environment variables. 58 | echo "Running slurm/utils/report_slurm_config.sh" 59 | source "slurm/utils/report_slurm_config.sh" 60 | # Print repo status report (current branch, commit ref, where any uncommitted changes are located) 61 | # N.B. This script only prints things out, it doesn't assign any environment variables. 62 | echo "Running slurm/utils/report_repo.sh" 63 | source "slurm/utils/report_repo.sh" 64 | echo "" 65 | if false; then 66 | # Print disk usage report, to catch errors due to lack of file space. 67 | # This is disabled by default to prevent confusing new users with too 68 | # much output. 69 | echo "------------------------------------" 70 | echo "df -h:" 71 | # Print header, then sort the rows alphabetically by mount point 72 | df -h --output=target,pcent,size,used,avail,source | head -n 1 73 | df -h --output=target,pcent,size,used,avail,source | tail -n +2 | sort -h 74 | echo "" 75 | fi 76 | echo "-------- Input handling ------------------------------------------------" 77 | date 78 | echo "" 79 | # Use the SLURM job array to select the seed for the experiment 80 | SEED="$SLURM_ARRAY_TASK_ID" 81 | if [[ "$SEED" == "" ]]; 82 | then 83 | SEED=0 84 | fi 85 | echo "SEED = $SEED" 86 | 87 | # Check if the first argument is a path to the python script to run 88 | if [[ "$1" == *.py ]]; 89 | then 90 | # If it is, we'll run this python script and remove it from the list of 91 | # arguments to pass on to the script. 92 | SCRIPT_PATH="$1" 93 | shift 94 | else 95 | # Otherwise, use our default python training script. 96 | SCRIPT_PATH="$PROJECT_DIRN/train.py" 97 | fi 98 | echo "SCRIPT_PATH = $SCRIPT_PATH" 99 | 100 | # Any arguments provided to sbatch after the name of the slurm script will be 101 | # passed through to the main script later. 102 | # (The pass-through works like *args or **kwargs in python.) 103 | echo "Pass-through args: ${@}" 104 | echo "" 105 | echo "-------- Activating environment ----------------------------------------" 106 | date 107 | echo "" 108 | echo "Running ~/.bashrc" 109 | source ~/.bashrc 110 | 111 | # Activate virtual environment 112 | ENVNAME="$PROJECT_NAME" 113 | echo "Activating conda environment $ENVNAME" 114 | conda activate "$ENVNAME" 115 | echo "" 116 | # Print env status (which packages you have installed - useful for diagnostics) 117 | # N.B. This script only prints things out, it doesn't assign any environment variables. 118 | echo "Running slurm/utils/report_env_config.sh" 119 | source "slurm/utils/report_env_config.sh" 120 | 121 | # Set the JOB_LABEL environment variable 122 | echo "-------- Setting JOB_LABEL ---------------------------------------------" 123 | echo "" 124 | # Decide the name of the paths to use for saving this job 125 | if [ "$SLURM_ARRAY_TASK_COUNT" != "" ] && [ "$SLURM_ARRAY_TASK_COUNT" -gt 1 ]; 126 | then 127 | JOB_ID="${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}"; 128 | else 129 | JOB_ID="${SLURM_JOB_ID}"; 130 | fi 131 | # Decide the name of the paths to use for saving this job 132 | JOB_LABEL="${SLURM_JOB_NAME}__${JOB_ID}"; 133 | echo "JOB_ID = $JOB_ID" 134 | echo "JOB_LABEL = $JOB_LABEL" 135 | echo "" 136 | 137 | # Set checkpoint directory ($CKPT_DIR) environment variables 138 | echo "-------- Setting checkpoint and output path variables ------------------" 139 | echo "" 140 | # Vector provides a fast parallel filesystem local to the GPU nodes, dedicated 141 | # for checkpointing. It is mounted under /checkpoint. It is strongly 142 | # recommended that you keep your intermediary checkpoints under this directory 143 | CKPT_DIR="/checkpoint/${SLURM_JOB_USER}/${SLURM_JOB_ID}" 144 | echo "CKPT_DIR = $CKPT_DIR" 145 | CKPT_PTH="$CKPT_DIR/checkpoint_latest.pt" 146 | echo "CKPT_PTH = $CKPT_PTH" 147 | echo "" 148 | # Ensure the checkpoint dir exists 149 | mkdir -p "$CKPT_DIR" 150 | echo "Current contents of ${CKPT_DIR}:" 151 | ls -lh "${CKPT_DIR}" 152 | echo "" 153 | # Create a symlink to the job's checkpoint directory within a subfolder of the 154 | # current directory (repository directory) named "checkpoint_working". 155 | mkdir -p "checkpoints_working" 156 | ln -sfn "$CKPT_DIR" "$PWD/checkpoints_working/$SLURM_JOB_NAME" 157 | # Specify an output directory to place checkpoints for long term storage once 158 | # the job is finished. 159 | if [[ -d "/scratch/hdd001/home/$SLURM_JOB_USER" ]]; 160 | then 161 | OUTPUT_DIR="/scratch/hdd001/home/$SLURM_JOB_USER" 162 | elif [[ -d "/scratch/ssd004/scratch/$SLURM_JOB_USER" ]]; 163 | then 164 | OUTPUT_DIR="/scratch/ssd004/scratch/$SLURM_JOB_USER" 165 | else 166 | OUTPUT_DIR="" 167 | fi 168 | if [[ "$OUTPUT_DIR" != "" ]]; 169 | then 170 | # Directory OUTPUT_DIR will contain all completed jobs for this project. 171 | OUTPUT_DIR="$OUTPUT_DIR/checkpoints/$PROJECT_NAME" 172 | # Subdirectory JOB_OUTPUT_DIR will contain the outputs from this job. 173 | JOB_OUTPUT_DIR="$OUTPUT_DIR/$JOB_LABEL" 174 | echo "JOB_OUTPUT_DIR = $JOB_OUTPUT_DIR" 175 | if [[ -d "$JOB_OUTPUT_DIR" ]]; 176 | then 177 | echo "Current contents of ${JOB_OUTPUT_DIR}" 178 | ls -lh "${JOB_OUTPUT_DIR}" 179 | fi 180 | echo "" 181 | fi 182 | 183 | # Save a list of installed packages and their versions to a file in the output directory 184 | conda env export > "$CKPT_DIR/environment.yml" 185 | pip freeze > "$CKPT_DIR/frozen-requirements.txt" 186 | 187 | if [[ "$SLURM_RESTART_COUNT" > 0 && ! -f "$CKPT_PTH" ]]; 188 | then 189 | echo "" 190 | echo "=====================================================================" 191 | echo "SLURM SCRIPT ERROR:" 192 | echo " Resuming after pre-emption (SLURM_RESTART_COUNT=$SLURM_RESTART_COUNT)" 193 | echo " but there is no checkpoint file at $CKPT_PTH" 194 | echo "=====================================================================" 195 | exit 1; 196 | fi; 197 | 198 | echo "" 199 | echo "------------------------------------" 200 | elapsed=$(( SECONDS - start_time )) 201 | eval "echo Running total elapsed time for restart $SLURM_RESTART_COUNT: $(date -ud "@$elapsed" +'$((%s/3600/24)) days %H hr %M min %S sec')" 202 | echo "" 203 | echo "-------- Begin main script ---------------------------------------------" 204 | date 205 | echo "" 206 | # Store the master node's IP address in the MASTER_ADDR environment variable, 207 | # which torch.distributed will use to initialize DDP. 208 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 209 | echo "Rank 0 node is at $MASTER_ADDR" 210 | 211 | # Use a port number automatically selected from the job id number. 212 | # This will let us use the same port for every task in the job without having 213 | # to create a file to store the port number. 214 | # We don't want to use a fixed port number because that would lead to 215 | # collisions between jobs when they are scheduled on the same node. 216 | # We only use ports in the range 49152-65535 (inclusive), which are the 217 | # Dynamic Ports, also known as Private Ports. 218 | MASTER_PORT="$(( $SLURM_JOB_ID % 16384 + 49152 ))" 219 | if ss -tulpn | grep -q ":$MASTER_PORT "; 220 | then 221 | # The port we selected is in use, so we'll get a random available port instead. 222 | echo "Finding a free port to use for $SLURM_NNODES node training" 223 | MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')"; 224 | fi 225 | export MASTER_PORT; 226 | echo "Will use port $MASTER_PORT for c10d communication" 227 | 228 | export WORLD_SIZE="$(($SLURM_NNODES * $SLURM_GPUS_ON_NODE))" 229 | echo "WORLD_SIZE = $WORLD_SIZE" 230 | 231 | # NCCL options ---------------------------------------------------------------- 232 | 233 | # This is needed to print debug info from NCCL, can be removed if all goes well 234 | # export NCCL_DEBUG=INFO 235 | 236 | # This is needed to avoid NCCL to use ifiniband, which the cluster does not have 237 | export NCCL_IB_DISABLE=1 238 | 239 | # This is to tell NCCL to use bond interface for network communication 240 | if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; 241 | then 242 | echo "Using NCCL_SOCKET_IFNAME=bond0 on ${SLURM_JOB_PARTITION}" 243 | export NCCL_SOCKET_IFNAME=bond0 244 | fi 245 | 246 | # Set this when using the NCCL backend for inter-GPU communication. 247 | export TORCH_NCCL_BLOCKING_WAIT=1 248 | # ----------------------------------------------------------------------------- 249 | 250 | # Multi-GPU configuration 251 | echo "" 252 | echo "Main script begins via torchrun with host tcp://${MASTER_ADDR}:$MASTER_PORT with backend NCCL" 253 | if [[ "$SLURM_JOB_NUM_NODES" == "1" ]]; 254 | then 255 | echo "Single ($SLURM_JOB_NUM_NODES) node training ($SLURM_GPUS_ON_NODE GPUs)" 256 | else 257 | echo "Multiple ($SLURM_JOB_NUM_NODES) node training (x$SLURM_GPUS_ON_NODE GPUs per node)" 258 | fi 259 | echo "" 260 | 261 | # We use the torchrun command to launch our main python script. 262 | # It will automatically set up the necessary environment variables for DDP, 263 | # and will launch the script once for each GPU on each node. 264 | # 265 | # We pass the CKPT_DIR environment variable on as the output path for our 266 | # python script, and also try to resume from a checkpoint in this directory 267 | # in case of pre-emption. The python script should run from scratch if there 268 | # is no checkpoint at this path to resume from. 269 | # 270 | # We pass on to train.py an arary of arbitrary extra arguments given to this 271 | # slurm script contained in the `$@` magic variable. 272 | # 273 | # We execute the srun command in the background with `&` (and then check its 274 | # process ID and wait for it to finish before continuing) so the main process 275 | # can handle the SIGUSR1 signal. Otherwise if a child process is running, the 276 | # signal will be ignored. 277 | srun -N "$SLURM_NNODES" --ntasks-per-node=1 \ 278 | torchrun \ 279 | --nnodes="$SLURM_JOB_NUM_NODES" \ 280 | --nproc_per_node="$SLURM_GPUS_ON_NODE" \ 281 | --rdzv_id="$SLURM_JOB_ID" \ 282 | --rdzv_backend=c10d \ 283 | --rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \ 284 | "$SCRIPT_PATH" \ 285 | --cpu-workers="$SLURM_CPUS_PER_GPU" \ 286 | --seed="$SEED" \ 287 | --checkpoint="$CKPT_PTH" \ 288 | --log-wandb \ 289 | --run-name="$SLURM_JOB_NAME" \ 290 | --run-id="$JOB_ID" \ 291 | "${@}" & 292 | child="$!" 293 | wait "$child" 294 | 295 | echo "" 296 | echo "------------------------------------" 297 | elapsed=$(( SECONDS - start_time )) 298 | eval "echo Running total elapsed time for restart $SLURM_RESTART_COUNT: $(date -ud "@$elapsed" +'$((%s/3600/24)) days %H hr %M min %S sec')" 299 | echo "" 300 | # Now the job is finished, remove the symlink to the job's checkpoint directory 301 | # from checkpoints_working 302 | rm "$PWD/checkpoints_working/$SLURM_JOB_NAME" 303 | # By overriding the JOB_OUTPUT_DIR environment variable, we disable saving 304 | # checkpoints to long-term storage. This is disabled by default to preserve 305 | # disk space. When you are sure your job config is correct and you are sure 306 | # you need to save your checkpoints for posterity, comment out this line. 307 | JOB_OUTPUT_DIR="" 308 | # 309 | if [[ "$CKPT_DIR" == "" ]]; 310 | then 311 | # This shouldn't ever happen, but we have a check for just in case. 312 | # If $CKPT_DIR were somehow not set, we would mistakenly try to copy far 313 | # too much data to $JOB_OUTPUT_DIR. 314 | echo "CKPT_DIR is unset. Will not copy outputs to $JOB_OUTPUT_DIR." 315 | elif [[ "$JOB_OUTPUT_DIR" == "" ]]; 316 | then 317 | echo "JOB_OUTPUT_DIR is unset. Will not copy outputs from $CKPT_DIR." 318 | else 319 | echo "-------- Saving outputs for long term storage --------------------------" 320 | date 321 | echo "" 322 | echo "Copying outputs from $CKPT_DIR to $JOB_OUTPUT_DIR" 323 | mkdir -p "$JOB_OUTPUT_DIR" 324 | rsync -rutlzv "$CKPT_DIR/" "$JOB_OUTPUT_DIR/" 325 | echo "" 326 | echo "Output contents of ${JOB_OUTPUT_DIR}:" 327 | ls -lh "$JOB_OUTPUT_DIR" 328 | # Set up a symlink to the long term storage directory 329 | ln -sfn "$OUTPUT_DIR" "checkpoints_finished" 330 | fi 331 | echo "" 332 | echo "------------------------------------------------------------------------" 333 | echo "" 334 | echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) finished, submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)" 335 | date 336 | echo "------------------------------------" 337 | elapsed=$(( SECONDS - start_time )) 338 | eval "echo Total elapsed time for restart $SLURM_RESTART_COUNT: $(date -ud "@$elapsed" +'$((%s/3600/24)) days %H hr %M min %S sec')" 339 | echo "========================================================================" 340 | -------------------------------------------------------------------------------- /slurm/utils/report_env_config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "========================================================================" 4 | echo "-------- Reporting environment configuration ---------------------------" 5 | date 6 | echo "" 7 | echo "pwd:" 8 | pwd 9 | echo "" 10 | echo "which python:" 11 | which python 12 | echo "" 13 | echo "python version:" 14 | python --version 15 | echo "" 16 | echo "which conda:" 17 | which conda 18 | echo "" 19 | echo "conda info:" 20 | conda info 21 | echo "" 22 | echo "which pip:" 23 | which pip 24 | echo "" 25 | ## Don't bother looking at system nvcc, as we have a conda installation 26 | # echo "which nvcc:" 27 | # which nvcc || echo "No nvcc" 28 | # echo "" 29 | # echo "nvcc version:" 30 | # nvcc --version || echo "No nvcc" 31 | # echo "" 32 | echo "nvidia-smi:" 33 | nvidia-smi || echo "No nvidia-smi" 34 | echo "" 35 | echo "torch info:" 36 | python -c "import torch; print(f'pytorch={torch.__version__}, cuda={torch.cuda.is_available()}, gpus={torch.cuda.device_count()}')" 37 | python -c "import torch; print(str(torch.ones(1, device=torch.device('cuda')))); print('able to use cuda')" 38 | echo "" 39 | echo "========================================================================" 40 | echo "" 41 | -------------------------------------------------------------------------------- /slurm/utils/report_repo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "========================================================================" 4 | echo "-------- Reporting git repo configuration ------------------------------" 5 | date 6 | echo "" 7 | echo "pwd: $(pwd)" 8 | echo "commit ref: $(git rev-parse HEAD)" 9 | echo "" 10 | git status 11 | echo "========================================================================" 12 | echo "" 13 | -------------------------------------------------------------------------------- /slurm/utils/report_slurm_config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "========================================================================" 4 | echo "-------- Reporting SLURM configuration ---------------------------------" 5 | date 6 | echo "" 7 | echo "SLURM_CLUSTER_NAME = $SLURM_CLUSTER_NAME" # Name of the cluster on which the job is executing. 8 | echo "SLURM_JOB_QOS = $SLURM_JOB_QOS" # Quality Of Service (QOS) of the job allocation. 9 | echo "SLURM_JOB_ID = $SLURM_JOB_ID" # The ID of the job allocation. 10 | echo "SLURM_RESTART_COUNT = $SLURM_RESTART_COUNT" # The number of times the job has been restarted. 11 | if [ "$SLURM_ARRAY_TASK_COUNT" != "" ]; then 12 | echo "" 13 | echo "SLURM_ARRAY_JOB_ID = $SLURM_ARRAY_JOB_ID" # Job array's master job ID number. 14 | echo "SLURM_ARRAY_TASK_COUNT = $SLURM_ARRAY_TASK_COUNT" # Total number of tasks in a job array. 15 | echo "SLURM_ARRAY_TASK_ID = $SLURM_ARRAY_TASK_ID" # Job array ID (index) number. 16 | echo "SLURM_ARRAY_TASK_MAX = $SLURM_ARRAY_TASK_MAX" # Job array's maximum ID (index) number. 17 | echo "SLURM_ARRAY_TASK_STEP = $SLURM_ARRAY_TASK_STEP" # Job array's index step size. 18 | fi; 19 | echo "" 20 | echo "SLURM_JOB_NUM_NODES = $SLURM_JOB_NUM_NODES" # Total number of nodes in the job's resource allocation. 21 | echo "SLURM_JOB_NODELIST = $SLURM_JOB_NODELIST" # List of nodes allocated to the job. 22 | echo "SLURM_TASKS_PER_NODE = $SLURM_TASKS_PER_NODE" # Number of tasks to be initiated on each node. 23 | echo "SLURM_NTASKS = $SLURM_NTASKS" # Number of tasks to spawn. 24 | echo "SLURM_PROCID = $SLURM_PROCID" # The MPI rank (or relative process ID) of the current process 25 | echo "" 26 | echo "SLURM_GPUS_ON_NODE = $SLURM_GPUS_ON_NODE" # Number of allocated GPUs per node. 27 | echo "SLURM_CPUS_ON_NODE = $SLURM_CPUS_ON_NODE" # Number of allocated CPUs per node. 28 | echo "SLURM_CPUS_PER_GPU = $SLURM_CPUS_PER_GPU" # Number of CPUs requested per GPU. Only set if the --cpus-per-gpu option is specified. 29 | echo "SLURM_MEM_PER_GPU = $SLURM_MEM_PER_GPU" # Memory per allocated GPU. Only set if the --mem-per-gpu option is specified. 30 | echo "" 31 | if [[ "$SLURM_TMPDIR" != "" ]]; 32 | then 33 | echo "------------------------------------" 34 | echo "" 35 | echo "SLURM_TMPDIR = $SLURM_TMPDIR" 36 | echo "" 37 | echo "Contents of $SLURM_TMPDIR" 38 | ls -lh "$SLURM_TMPDIR" 39 | echo "" 40 | fi; 41 | echo "========================================================================" 42 | echo "" 43 | -------------------------------------------------------------------------------- /template_experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from . import __meta__ 2 | 3 | __version__ = __meta__.version 4 | -------------------------------------------------------------------------------- /template_experiment/__meta__.py: -------------------------------------------------------------------------------- 1 | # `name` is the name of the package as used for `pip install package` (use hyphens) 2 | name = "template-experiment" 3 | # `path` is the name of the package for `import package` (use underscores) 4 | path = name.lower().replace("-", "_").replace(" ", "_") 5 | # Your version number should follow https://python.org/dev/peps/pep-0440 and 6 | # https://semver.org 7 | version = "0.1.dev0" 8 | author = "Author Name" 9 | author_email = "" 10 | description = "" # One-liner 11 | url = "" # your project homepage 12 | license = "Unlicense" # See https://choosealicense.com 13 | -------------------------------------------------------------------------------- /template_experiment/data_transformations.py: -------------------------------------------------------------------------------- 1 | import timm.data 2 | import torch 3 | from torchvision import transforms 4 | 5 | NORMALIZATION = { 6 | "imagenet": [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)], 7 | "mnist": [(0.1307,), (0.3081,)], 8 | "cifar": [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)], 9 | } 10 | 11 | VALID_TRANSFORMS = ["imagenet", "cifar", "mnist"] 12 | 13 | 14 | def get_transform(transform_type="noaug", image_size=32, args=None): 15 | if args is None: 16 | args = {} 17 | mean, std = NORMALIZATION[args.get("normalization", "imagenet")] 18 | if "mean" in args: 19 | mean = args["mean"] 20 | if "std" in args: 21 | std = args["std"] 22 | 23 | if transform_type == "noaug": 24 | # No augmentations, just resize and normalize. 25 | # N.B. If the raw training image isn't square, there is a small 26 | # "augmentation" as we will randomly crop a square (of length equal to 27 | # the shortest side) from it. We do this because we assume inputs to 28 | # the network must be square. 29 | train_transform = transforms.Compose( 30 | [ 31 | transforms.Resize(image_size), # Resize shortest side to image_size 32 | transforms.RandomCrop(image_size), # If it is not square, *random* crop 33 | transforms.ToTensor(), 34 | transforms.Normalize(mean=mean, std=std), 35 | ] 36 | ) 37 | test_transform = transforms.Compose( 38 | [ 39 | transforms.Resize(image_size), # Resize shortest side to image_size 40 | transforms.CenterCrop(image_size), # If it is not square, center crop 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=mean, std=std), 43 | ] 44 | ) 45 | 46 | elif transform_type == "imagenet": 47 | # Appropriate for really large natual images, as in ImageNet. 48 | # For training: 49 | # - Zoom in randomly with scale (big range of how much to zoom in) 50 | # - Stretch with random aspect ratio 51 | # - Flip horizontally 52 | # - Randomly adjust brightness/contrast/saturation 53 | # - (No rotation or skew) 54 | # - Interpolation is randomly either bicubic or bilinear 55 | train_transform = timm.data.create_transform( 56 | input_size=image_size, 57 | is_training=True, 58 | scale=(0.08, 1.0), # default imagenet scale range 59 | ratio=(3.0 / 4.0, 4.0 / 3.0), # default imagenet ratio range 60 | hflip=0.5, 61 | vflip=0.0, 62 | color_jitter=0.4, 63 | interpolation="random", 64 | mean=mean, 65 | std=std, 66 | ) 67 | # For testing: 68 | # - Zoom in 87.5% 69 | # - Center crop 70 | # - Interpolation is bilinear 71 | test_transform = transforms.Compose( 72 | [ 73 | transforms.Resize(int(image_size / 0.875)), 74 | transforms.CenterCrop(image_size), 75 | transforms.ToTensor(), 76 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 77 | ] 78 | ) 79 | 80 | elif transform_type == "cifar": 81 | # Appropriate for smaller natural images, as in CIFAR-10/100. 82 | # For training: 83 | # - Zoom in randomly with scale (small range of how much to zoom in by) 84 | # - Stretch with random aspect ratio 85 | # - Flip horizontally 86 | # - Randomly adjust brightness/contrast/saturation 87 | # - (No rotation or skew) 88 | train_transform = timm.data.create_transform( 89 | input_size=image_size, 90 | is_training=True, 91 | scale=(0.7, 1.0), # reduced scale range 92 | ratio=(3.0 / 4.0, 4.0 / 3.0), # default imagenet ratio range 93 | hflip=0.5, 94 | vflip=0.0, 95 | color_jitter=0.4, # default imagenet color jitter 96 | interpolation="random", 97 | mean=mean, 98 | std=std, 99 | ) 100 | # For testing: 101 | # - Resize to desired size only, with a center crop step included in 102 | # case the raw image was not square. 103 | test_transform = transforms.Compose( 104 | [ 105 | transforms.Resize(image_size), 106 | transforms.CenterCrop(image_size), 107 | transforms.ToTensor(), 108 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 109 | ] 110 | ) 111 | 112 | elif transform_type == "digits": 113 | # Appropriate for smaller images containing digits, as in MNIST. 114 | # - Zoom in randomly with scale (small range of how much to zoom in by) 115 | # - Stretch with random aspect ratio 116 | # - Don't flip the images (that would change the digit) 117 | # - Randomly adjust brightness/contrast/saturation 118 | # - (No rotation or skew) 119 | train_transform = timm.data.create_transform( 120 | input_size=image_size, 121 | is_training=True, 122 | scale=(0.7, 1.0), # reduced scale range 123 | ratio=(3.0 / 4.0, 4.0 / 3.0), # default imagenet ratio range 124 | hflip=0.0, 125 | vflip=0.0, 126 | color_jitter=0.4, # default imagenet color jitter 127 | interpolation="random", 128 | mean=mean, 129 | std=std, 130 | ) 131 | # For testing: 132 | # - Resize to desired size only, with a center crop step included in 133 | # case the raw image was not square. 134 | test_transform = transforms.Compose( 135 | [ 136 | transforms.Resize(image_size), 137 | transforms.CenterCrop(image_size), 138 | transforms.ToTensor(), 139 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 140 | ] 141 | ) 142 | 143 | elif transform_type == "autoaugment-imagenet": 144 | # Augmentation policy learnt by AutoAugment, described in 145 | # https://arxiv.org/abs/1805.09501 146 | # The policies mostly concern changing the colours of the image, 147 | # but there is a little rotation and shear too. We need to include 148 | # our own random cropping, stretching, and flipping. 149 | train_transform = transforms.Compose( 150 | [ 151 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), 152 | transforms.AutoAugment( 153 | policy=transforms.AutoAugmentPolicy.IMAGENET, 154 | interpolation=transforms.InterpolationMode.BILINEAR, 155 | ), 156 | transforms.RandomHorizontalFlip(0.5), 157 | transforms.ToTensor(), 158 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 159 | ] 160 | ) 161 | # For testing: 162 | # - Zoom in 87.5% 163 | # - Center crop 164 | test_transform = transforms.Compose( 165 | [ 166 | transforms.Resize(int(image_size / 0.875)), 167 | transforms.CenterCrop(image_size), 168 | transforms.ToTensor(), 169 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 170 | ] 171 | ) 172 | 173 | elif transform_type == "autoaugment-cifar": 174 | # Augmentation policy learnt by AutoAugment, described in 175 | # https://arxiv.org/abs/1805.09501 176 | # The policies mostly concern changing the colours of the image, 177 | # but there is a little rotation and shear too. We need to include 178 | # our own random cropping, stretching, and flipping. 179 | train_transform = transforms.Compose( 180 | [ 181 | transforms.AutoAugment( 182 | policy=transforms.AutoAugmentPolicy.CIFAR10, 183 | interpolation=transforms.InterpolationMode.BILINEAR, 184 | ), 185 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), 186 | transforms.RandomHorizontalFlip(0.5), 187 | transforms.ToTensor(), 188 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 189 | ] 190 | ) 191 | # For testing: 192 | # - Resize to desired size only, with a center crop step included in 193 | # case the raw image was not square. 194 | test_transform = transforms.Compose( 195 | [ 196 | transforms.Resize(image_size), 197 | transforms.CenterCrop(image_size), 198 | transforms.ToTensor(), 199 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 200 | ] 201 | ) 202 | 203 | elif transform_type == "randaugment-imagenet": 204 | # Augmentation policy learnt by RandAugment, described in 205 | # https://arxiv.org/abs/1909.13719 206 | train_transform = transforms.Compose( 207 | [ 208 | transforms.RandAugment( 209 | interpolation=transforms.InterpolationMode.BILINEAR, 210 | ), 211 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), 212 | transforms.RandomHorizontalFlip(0.5), 213 | transforms.ToTensor(), 214 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 215 | ] 216 | ) 217 | # For testing: 218 | # - Zoom in 87.5% 219 | # - Center crop 220 | test_transform = transforms.Compose( 221 | [ 222 | transforms.Resize(int(image_size / 0.875)), 223 | transforms.CenterCrop(image_size), 224 | transforms.ToTensor(), 225 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 226 | ] 227 | ) 228 | 229 | elif transform_type == "randaugment-cifar": 230 | # Augmentation policy learnt by RandAugment, described in 231 | # https://arxiv.org/abs/1909.13719 232 | train_transform = transforms.Compose( 233 | [ 234 | transforms.RandAugment( 235 | interpolation=transforms.InterpolationMode.BILINEAR, 236 | ), 237 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), 238 | transforms.RandomHorizontalFlip(0.5), 239 | transforms.ToTensor(), 240 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 241 | ] 242 | ) 243 | # For testing: 244 | # - Resize to desired size only, with a center crop step included in 245 | # case the raw image was not square. 246 | test_transform = transforms.Compose( 247 | [ 248 | transforms.Resize(image_size), 249 | transforms.CenterCrop(image_size), 250 | transforms.ToTensor(), 251 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 252 | ] 253 | ) 254 | 255 | elif transform_type == "trivialaugment-imagenet": 256 | # Trivial augmentation policy, described in https://arxiv.org/abs/2103.10158 257 | train_transform = transforms.Compose( 258 | [ 259 | transforms.TrivialAugmentWide( 260 | interpolation=transforms.InterpolationMode.BILINEAR, 261 | ), 262 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), 263 | transforms.RandomHorizontalFlip(0.5), 264 | transforms.ToTensor(), 265 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 266 | ] 267 | ) 268 | # For testing: 269 | # - Zoom in 87.5% 270 | # - Center crop 271 | test_transform = transforms.Compose( 272 | [ 273 | transforms.Resize(int(image_size / 0.875)), 274 | transforms.CenterCrop(image_size), 275 | transforms.ToTensor(), 276 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 277 | ] 278 | ) 279 | 280 | elif transform_type == "trivialaugment-cifar": 281 | # Trivial augmentation policy, described in https://arxiv.org/abs/2103.10158 282 | train_transform = transforms.Compose( 283 | [ 284 | transforms.TrivialAugmentWide( 285 | interpolation=transforms.InterpolationMode.BILINEAR, 286 | ), 287 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), 288 | transforms.RandomHorizontalFlip(0.5), 289 | transforms.ToTensor(), 290 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 291 | ] 292 | ) 293 | # For testing: 294 | # - Resize to desired size only, with a center crop step included in 295 | # case the raw image was not square. 296 | test_transform = transforms.Compose( 297 | [ 298 | transforms.Resize(image_size), 299 | transforms.CenterCrop(image_size), 300 | transforms.ToTensor(), 301 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 302 | ] 303 | ) 304 | 305 | elif transform_type == "randomerasing-imagenet": 306 | train_transform = transforms.Compose( 307 | [ 308 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), 309 | transforms.RandomHorizontalFlip(0.5), 310 | transforms.ToTensor(), 311 | transforms.RandomErasing(), 312 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 313 | ] 314 | ) 315 | # For testing: 316 | # - Zoom in 87.5% 317 | # - Center crop 318 | test_transform = transforms.Compose( 319 | [ 320 | transforms.Resize(int(image_size / 0.875)), 321 | transforms.CenterCrop(image_size), 322 | transforms.ToTensor(), 323 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 324 | ] 325 | ) 326 | 327 | elif transform_type == "randomerasing-cifar": 328 | train_transform = transforms.Compose( 329 | [ 330 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), 331 | transforms.RandomHorizontalFlip(0.5), 332 | transforms.ToTensor(), 333 | transforms.RandomErasing(), 334 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 335 | ] 336 | ) 337 | # For testing: 338 | # - Resize to desired size only, with a center crop step included in 339 | # case the raw image was not square. 340 | test_transform = transforms.Compose( 341 | [ 342 | transforms.Resize(image_size), 343 | transforms.CenterCrop(image_size), 344 | transforms.ToTensor(), 345 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), 346 | ] 347 | ) 348 | 349 | else: 350 | raise NotImplementedError 351 | 352 | return (train_transform, test_transform) 353 | -------------------------------------------------------------------------------- /template_experiment/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handlers for various image datasets. 3 | """ 4 | 5 | import os 6 | import socket 7 | import warnings 8 | 9 | import numpy as np 10 | import sklearn.model_selection 11 | import torch 12 | import torchvision.datasets 13 | 14 | 15 | def determine_host(): 16 | r""" 17 | Determine which compute server we are on. 18 | 19 | Returns 20 | ------- 21 | host : str, one of {"vaughan", "mars"} 22 | An identifier for the host compute system. 23 | """ 24 | hostname = socket.gethostname() 25 | slurm_submit_host = os.environ.get("SLURM_SUBMIT_HOST") 26 | slurm_cluster_name = os.environ.get("SLURM_CLUSTER_NAME") 27 | 28 | if slurm_cluster_name and slurm_cluster_name.startswith("vaughan"): 29 | return "vaughan" 30 | if slurm_submit_host and slurm_submit_host in ["q.vector.local", "m.vector.local"]: 31 | return "mars" 32 | if hostname and hostname in ["q.vector.local", "m.vector.local"]: 33 | return "mars" 34 | if hostname and hostname.startswith("v"): 35 | return "vaughan" 36 | if slurm_submit_host and slurm_submit_host.startswith("v"): 37 | return "vaughan" 38 | return "" 39 | 40 | 41 | def image_dataset_sizes(dataset): 42 | r""" 43 | Get the image size and number of classes for a dataset. 44 | 45 | Parameters 46 | ---------- 47 | dataset : str 48 | Name of the dataset. 49 | 50 | Returns 51 | ------- 52 | num_classes : int 53 | Number of classes in the dataset. 54 | img_size : int or None 55 | Size of the images in the dataset, or None if the images are not all 56 | the same size. Images are assumed to be square. 57 | num_channels : int 58 | Number of colour channels in the images in the dataset. This will be 59 | 1 for greyscale images, and 3 for colour images. 60 | """ 61 | dataset = dataset.lower().replace("-", "").replace("_", "").replace(" ", "") 62 | 63 | if dataset == "cifar10": 64 | num_classes = 10 65 | img_size = 32 66 | num_channels = 3 67 | 68 | elif dataset == "cifar100": 69 | num_classes = 100 70 | img_size = 32 71 | num_channels = 3 72 | 73 | elif dataset in ["imagenet", "imagenet1k", "ilsvrc2012"]: 74 | num_classes = 1000 75 | img_size = None 76 | num_channels = 3 77 | 78 | elif dataset.startswith("imagenette"): 79 | num_classes = 10 80 | img_size = None 81 | num_channels = 3 82 | 83 | elif dataset.startswith("imagewoof"): 84 | num_classes = 10 85 | img_size = None 86 | num_channels = 3 87 | 88 | elif dataset == "mnist": 89 | num_classes = 10 90 | img_size = 28 91 | num_channels = 1 92 | 93 | elif dataset == "svhn": 94 | num_classes = 10 95 | img_size = 32 96 | num_channels = 3 97 | 98 | else: 99 | raise ValueError("Unrecognised dataset: {}".format(dataset)) 100 | 101 | return num_classes, img_size, num_channels 102 | 103 | 104 | def fetch_image_dataset( 105 | dataset, 106 | root=None, 107 | transform_train=None, 108 | transform_eval=None, 109 | download=False, 110 | ): 111 | r""" 112 | Fetch a train and test dataset object for a given image dataset name. 113 | 114 | Parameters 115 | ---------- 116 | dataset : str 117 | Name of dataset. 118 | root : str, optional 119 | Path to root directory containing the dataset. 120 | transform_train : callable, optional 121 | A function/transform that takes in an PIL image and returns a 122 | transformed version, to be applied to the training dataset. 123 | transform_eval : callable, optional 124 | A function/transform that takes in an PIL image and returns a 125 | transformed version, to be applied to the evaluation dataset. 126 | download : bool, optional 127 | Whether to download the dataset to the expected directory if it is not 128 | there. Only supported by some datasets. Default is ``False``. 129 | """ 130 | dataset = dataset.lower().replace("-", "").replace("_", "").replace(" ", "") 131 | host = determine_host() 132 | 133 | if dataset == "cifar10": 134 | if root: 135 | pass 136 | elif host == "vaughan": 137 | root = "/scratch/ssd002/datasets/" 138 | elif host == "mars": 139 | root = "/scratch/gobi1/datasets/" 140 | else: 141 | root = "~/Datasets" 142 | dataset_train = torchvision.datasets.CIFAR10( 143 | os.path.join(root, dataset), 144 | train=True, 145 | transform=transform_train, 146 | download=download, 147 | ) 148 | dataset_val = None 149 | dataset_test = torchvision.datasets.CIFAR10( 150 | os.path.join(root, dataset), 151 | train=False, 152 | transform=transform_eval, 153 | download=download, 154 | ) 155 | 156 | elif dataset == "cifar100": 157 | if root: 158 | pass 159 | elif host == "vaughan": 160 | root = "/scratch/ssd002/datasets/" 161 | else: 162 | root = "~/Datasets" 163 | dataset_train = torchvision.datasets.CIFAR100( 164 | os.path.join(root, dataset), 165 | train=True, 166 | transform=transform_train, 167 | download=download, 168 | ) 169 | dataset_val = None 170 | dataset_test = torchvision.datasets.CIFAR100( 171 | os.path.join(root, dataset), 172 | train=False, 173 | transform=transform_eval, 174 | download=download, 175 | ) 176 | 177 | elif dataset in ["imagenet", "imagenet1k", "ilsvrc2012"]: 178 | if root: 179 | pass 180 | elif host == "vaughan": 181 | root = "/scratch/ssd004/datasets/" 182 | elif host == "mars": 183 | root = "/scratch/gobi1/datasets/" 184 | else: 185 | root = "~/Datasets" 186 | dataset_train = torchvision.datasets.ImageFolder( 187 | os.path.join(root, "imagenet", "train"), 188 | transform=transform_train, 189 | ) 190 | dataset_val = None 191 | dataset_test = torchvision.datasets.ImageFolder( 192 | os.path.join(root, "imagenet", "val"), 193 | transform=transform_eval, 194 | ) 195 | 196 | elif dataset == "imagenette": 197 | if root: 198 | root = os.path.join(root, "imagenette") 199 | elif host == "vaughan": 200 | root = "/scratch/ssd004/datasets/imagenette2/full/" 201 | else: 202 | root = "~/Datasets/imagenette/" 203 | dataset_train = torchvision.datasets.ImageFolder( 204 | os.path.join(root, "train"), 205 | transform=transform_train, 206 | ) 207 | dataset_val = None 208 | dataset_test = torchvision.datasets.ImageFolder( 209 | os.path.join(root, "val"), 210 | transform=transform_eval, 211 | ) 212 | 213 | elif dataset == "imagewoof": 214 | if root: 215 | root = os.path.join(root, "imagewoof") 216 | elif host == "vaughan": 217 | root = "/scratch/ssd004/datasets/imagewoof2/full/" 218 | else: 219 | root = "~/Datasets/imagewoof/" 220 | dataset_train = torchvision.datasets.ImageFolder( 221 | os.path.join(root, "train"), 222 | transform=transform_train, 223 | ) 224 | dataset_val = None 225 | dataset_test = torchvision.datasets.ImageFolder( 226 | os.path.join(root, "val"), 227 | transform=transform_eval, 228 | ) 229 | 230 | elif dataset == "mnist": 231 | if root: 232 | pass 233 | elif host == "vaughan": 234 | root = "/scratch/ssd004/datasets/" 235 | else: 236 | root = "~/Datasets" 237 | # Will read from [root]/MNIST/processed 238 | dataset_train = torchvision.datasets.MNIST( 239 | root, 240 | train=True, 241 | transform=transform_train, 242 | download=download, 243 | ) 244 | dataset_val = None 245 | dataset_test = torchvision.datasets.MNIST( 246 | root, 247 | train=False, 248 | transform=transform_eval, 249 | download=download, 250 | ) 251 | 252 | elif dataset == "svhn": 253 | # SVHN has: 254 | # 73,257 digits for training, 255 | # 26,032 digits for testing, and 256 | # 531,131 additional, less difficult, samples to use as extra training data 257 | # We don't use the extra split here, only train. There are original 258 | # images which are large and have bounding boxes, but the pytorch class 259 | # just uses the 32px cropped individual digits. 260 | if root: 261 | pass 262 | elif host == "vaughan": 263 | root = "/scratch/ssd004/datasets/" 264 | elif host == "mars": 265 | root = "/scratch/gobi1/datasets/" 266 | else: 267 | root = "~/Datasets" 268 | dataset_train = torchvision.datasets.SVHN( 269 | os.path.join(root, dataset), 270 | split="train", 271 | transform=transform_train, 272 | download=download, 273 | ) 274 | dataset_val = None 275 | dataset_test = torchvision.datasets.SVHN( 276 | os.path.join(root, dataset), 277 | split="test", 278 | transform=transform_eval, 279 | download=download, 280 | ) 281 | 282 | else: 283 | raise ValueError("Unrecognised dataset: {}".format(dataset)) 284 | 285 | return dataset_train, dataset_val, dataset_test 286 | 287 | 288 | def fetch_dataset( 289 | dataset, 290 | root=None, 291 | prototyping=False, 292 | transform_train=None, 293 | transform_eval=None, 294 | protoval_split_rate=0.1, 295 | protoval_split_id=0, 296 | download=False, 297 | ): 298 | r""" 299 | Fetch a train and test dataset object for a given dataset name. 300 | 301 | Parameters 302 | ---------- 303 | dataset : str 304 | Name of dataset. 305 | root : str, optional 306 | Path to root directory containing the dataset. 307 | prototyping : bool, default=False 308 | Whether to return a validation split distinct from the test split. 309 | If ``False``, the validation split will be the same as the test split 310 | for datasets which don't intrincally have a separate val and test 311 | partition. 312 | If ``True``, the validation partition is carved out of the train 313 | partition (resulting in a smaller training set) when there is no 314 | distinct validation partition available. 315 | transform_train : callable, optional 316 | A function/transform that takes in an PIL image and returns a 317 | transformed version, to be applied to the training dataset. 318 | transform_eval : callable, optional 319 | A function/transform that takes in an PIL image and returns a 320 | transformed version, to be applied to the evaluation dataset. 321 | protoval_split_rate : float or str, default=0.1 322 | The fraction of the train data to use for validating when in 323 | prototyping mode. If this is set to "auto", the split rate will be 324 | chosen such that the validation partition is the same size as the test 325 | partition. 326 | protoval_split_id : int, default=0 327 | The identity of the random split used for the train/val partitioning. 328 | This controls the seed of the folds used for the split, and which 329 | fold to use for the validation set. 330 | The seed is equal to ``int(protoval_split_id * protoval_split_rate)`` 331 | and the fold is equal to ``protoval_split_id % (1 / protoval_split_rate)``. 332 | download : bool, optional 333 | Whether to download the dataset to the expected directory if it is not 334 | there. Only supported by some datasets. Default is ``False``. 335 | 336 | Returns 337 | ------- 338 | dataset_train : torch.utils.data.Dataset 339 | The training dataset. 340 | dataset_val : torch.utils.data.Dataset 341 | The validation dataset. 342 | dataset_test : torch.utils.data.Dataset 343 | The test dataset. 344 | distinct_val_test : bool 345 | Whether the validation and test partitions are distinct (True) or 346 | identical (False). 347 | """ 348 | dataset_train, dataset_val, dataset_test = fetch_image_dataset( 349 | dataset=dataset, 350 | root=root, 351 | transform_train=transform_train, 352 | transform_eval=transform_eval, 353 | download=download, 354 | ) 355 | 356 | # Handle the validation partition 357 | if dataset_val is not None: 358 | distinct_val_test = True 359 | elif not prototyping: 360 | dataset_val = dataset_test 361 | distinct_val_test = False 362 | else: 363 | # Create our own train/val split. 364 | # 365 | # For the validation part, we need a copy of dataset_train with the 366 | # evaluation transform instead. 367 | # The transform argument is *probably* going to be set to an attribute 368 | # on the dataset object called transform and called from there. But we 369 | # can't be completely sure, so to be completely agnostic about the 370 | # internals of the dataset class let's instantiate the dataset again! 371 | dataset_val = fetch_dataset( 372 | dataset, 373 | root=root, 374 | prototyping=False, 375 | transform_train=transform_eval, 376 | )[0] 377 | # dataset_val is a copy of the full training set, but with the transform 378 | # changed to transform_eval 379 | # Handle automatic validation partition sizing option. 380 | if not isinstance(protoval_split_rate, str): 381 | pass 382 | elif protoval_split_rate == "auto": 383 | # We want the validation set to be the same size as the test set. 384 | # This is the same as having a split rate of 1 - test_size. 385 | protoval_split_rate = len(dataset_test) / len(dataset_train) 386 | else: 387 | raise ValueError(f"Unsupported protoval_split_rate: {protoval_split_rate}") 388 | # Create the train/val split using these dataset objects. 389 | dataset_train, dataset_val = create_train_val_split( 390 | dataset_train, 391 | dataset_val, 392 | split_rate=protoval_split_rate, 393 | split_id=protoval_split_id, 394 | ) 395 | distinct_val_test = True 396 | 397 | return ( 398 | dataset_train, 399 | dataset_val, 400 | dataset_test, 401 | distinct_val_test, 402 | ) 403 | 404 | 405 | def create_train_val_split( 406 | dataset_train, 407 | dataset_val=None, 408 | split_rate=0.1, 409 | split_id=0, 410 | ): 411 | r""" 412 | Create a train/val split of a dataset. 413 | 414 | Parameters 415 | ---------- 416 | dataset_train : torch.utils.data.Dataset 417 | The full training dataset with training transforms. 418 | dataset_val : torch.utils.data.Dataset, optional 419 | The full training dataset with evaluation transforms. 420 | If this is not given, the source for the validation set will be 421 | ``dataset_test`` (with the same transforms as the training partition). 422 | Note that ``dataset_val`` must have the same samples as 423 | ``dataset_train``, and the samples must be in the same order. 424 | split_rate : float, default=0.1 425 | The fraction of the train data to use for the validation split. 426 | split_id : int, default=0 427 | The identity of the split to use. 428 | This controls the seed of the folds used for the split, and which 429 | fold to use for the validation set. 430 | The seed is equal to ``int(split_id * split_rate)`` 431 | and the fold is equal to ``split_id % (1 / split_rate)``. 432 | 433 | Returns 434 | ------- 435 | dataset_train : torch.utils.data.Dataset 436 | The training subset of the dataset. 437 | dataset_val : torch.utils.data.Dataset 438 | The validation subset of the dataset. 439 | """ 440 | if dataset_val is None: 441 | dataset_val = dataset_train 442 | # Now we need to reduce it down to just a subset of the training set. 443 | # Let's use K-folds so subsequent prototype split IDs will have 444 | # non-overlapping validation sets. With split_rate = 0.1, 445 | # there will be 10 folds. 446 | n_splits = round(1.0 / split_rate) 447 | if (1.0 / n_splits) != split_rate: 448 | warnings.warn( 449 | "The requested train/val split rate is not possible when using" 450 | " dataset into K folds. The actual split rate will be" 451 | f" {1.0 / n_splits} instead of {split_rate}.", 452 | UserWarning, 453 | stacklevel=2, 454 | ) 455 | split_seed = int(split_id * split_rate) 456 | fold_id = split_id % n_splits 457 | print( 458 | f"Creating prototyping train/val split #{split_id}." 459 | f" Using fold {fold_id} of {n_splits} folds, generated with seed" 460 | f" {split_seed}." 461 | ) 462 | # Try to do a stratified split. 463 | classes = get_dataset_labels(dataset_train) 464 | if classes is None: 465 | warnings.warn( 466 | "Creating prototyping splits without stratification.", 467 | UserWarning, 468 | stacklevel=2, 469 | ) 470 | splitter_ftry = sklearn.model_selection.KFold 471 | else: 472 | splitter_ftry = sklearn.model_selection.StratifiedKFold 473 | 474 | # Create our splits. Assuming the dataset objects are always loaded 475 | # the same way, since a given split ID will always be the same 476 | # fold from the same seeded KFold splitter, it will yield the same 477 | # train/val split on each run. 478 | splitter = splitter_ftry(n_splits=n_splits, shuffle=True, random_state=split_seed) 479 | splits = splitter.split(np.arange(len(dataset_train)), classes) 480 | # splits is an iterable and we want to take the n-th fold from it. 481 | for i, (train_indices, val_indices) in enumerate(splits): # noqa: B007 482 | if i == fold_id: 483 | break 484 | dataset_train = torch.utils.data.Subset(dataset_train, train_indices) 485 | dataset_val = torch.utils.data.Subset(dataset_val, val_indices) 486 | return dataset_train, dataset_val 487 | 488 | 489 | def get_dataset_labels(dataset): 490 | r""" 491 | Get the class labels within a :class:`torch.utils.data.Dataset` object. 492 | 493 | Parameters 494 | ---------- 495 | dataset : torch.utils.data.Dataset 496 | The dataset object. 497 | 498 | Returns 499 | ------- 500 | array_like or None 501 | The class labels for each sample. 502 | """ 503 | if isinstance(dataset, torch.utils.data.Subset): 504 | # For a dataset subset, we need to get the full set of labels from the 505 | # interior subset and then reduce them down to just the labels we have 506 | # in the subset. 507 | labels = get_dataset_labels(dataset.dataset) 508 | if labels is None: 509 | return labels 510 | return np.array(labels)[dataset.indices] 511 | 512 | labels = None 513 | if hasattr(dataset, "targets"): 514 | # MNIST, CIFAR, ImageFolder, etc 515 | labels = dataset.targets 516 | elif hasattr(dataset, "labels"): 517 | # STL10, SVHN 518 | labels = dataset.labels 519 | elif hasattr(dataset, "_labels"): 520 | # Flowers102 521 | labels = dataset._labels 522 | 523 | return labels 524 | -------------------------------------------------------------------------------- /template_experiment/encoders.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import timm 4 | import torch 5 | from timm.data import resolve_data_config 6 | 7 | 8 | def get_timm_encoder(model_name, pretrained=False, in_chans=3): 9 | r""" 10 | Get the encoder model and its configuration from timm. 11 | 12 | Parameters 13 | ---------- 14 | model_name : str 15 | Name of the model to load. 16 | pretrained : bool, default=False 17 | Whether to load the model with pretrained weights. 18 | in_chans : int, default=3 19 | Number of input channels. 20 | 21 | Returns 22 | ------- 23 | encoder : torch.nn.Module 24 | The encoder model (with pretrained weights loaded if requested). 25 | encoder_config : dict 26 | The data configuration of the encoder model. 27 | """ 28 | if len(timm.list_models(model_name)) == 0: 29 | warnings.warn( 30 | f"Unrecognized model '{model_name}'. Trying to fetch it from the hugging-face hub.", 31 | UserWarning, 32 | stacklevel=2, 33 | ) 34 | model_name = "hf-hub:timm/" + model_name 35 | 36 | # We request the model without the classification head (num_classes=0) 37 | # to get it is an encoder-only model 38 | encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0, in_chans=in_chans) 39 | encoder_config = resolve_data_config({}, model=encoder) 40 | 41 | # Send a dummy input through the encoder to find out the shape of its output 42 | encoder.eval() 43 | dummy_output = encoder(torch.zeros((1, *encoder_config["input_size"]))) 44 | encoder_config["n_feature"] = dummy_output.shape[1] 45 | encoder.train() 46 | 47 | encoder_config["in_channels"] = encoder_config["input_size"][0] 48 | 49 | return encoder, encoder_config 50 | -------------------------------------------------------------------------------- /template_experiment/evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation routines. 3 | """ 4 | 5 | import numpy as np 6 | import sklearn.metrics 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from . import utils 11 | 12 | 13 | def evaluate( 14 | dataloader, 15 | model, 16 | device, 17 | partition_name="Val", 18 | verbosity=1, 19 | is_distributed=False, 20 | ): 21 | r""" 22 | Evaluate model performance on a dataset. 23 | 24 | Parameters 25 | ---------- 26 | dataloader : torch.utils.data.DataLoader 27 | Dataloader for the dataset to evaluate on. 28 | model : torch.nn.Module 29 | Model to evaluate. 30 | device : torch.device 31 | Device to run the model on. 32 | partition_name : str, default="Val" 33 | Name of the partition being evaluated. 34 | verbosity : int, default=1 35 | Verbosity level. 36 | is_distributed : bool, default=False 37 | Whether the model is distributed across multiple GPUs. 38 | 39 | Returns 40 | ------- 41 | results : dict 42 | Dictionary of evaluation results. 43 | """ 44 | model.eval() 45 | 46 | y_true_all = [] 47 | y_pred_all = [] 48 | xent_all = [] 49 | 50 | for stimuli, y_true in dataloader: 51 | stimuli = stimuli.to(device) 52 | y_true = y_true.to(device) 53 | with torch.no_grad(): 54 | logits = model(stimuli) 55 | xent = F.cross_entropy(logits, y_true, reduction="none") 56 | y_pred = torch.argmax(logits, dim=-1) 57 | 58 | if is_distributed: 59 | # Fetch results from other GPUs 60 | xent = utils.concat_all_gather_ragged(xent) 61 | y_true = utils.concat_all_gather_ragged(y_true) 62 | y_pred = utils.concat_all_gather_ragged(y_pred) 63 | 64 | xent_all.append(xent.cpu().numpy()) 65 | y_true_all.append(y_true.cpu().numpy()) 66 | y_pred_all.append(y_pred.cpu().numpy()) 67 | 68 | # Concatenate the targets and predictions from each batch 69 | xent = np.concatenate(xent_all) 70 | y_true = np.concatenate(y_true_all) 71 | y_pred = np.concatenate(y_pred_all) 72 | # If the dataset size was not evenly divisible by the world size, 73 | # DistributedSampler will pad the end of the list of samples 74 | # with some repetitions. We need to trim these off. 75 | n_samples = len(dataloader.dataset) 76 | xent = xent[:n_samples] 77 | y_true = y_true[:n_samples] 78 | y_pred = y_pred[:n_samples] 79 | # Create results dictionary 80 | results = {} 81 | results["count"] = len(y_true) 82 | results["cross-entropy"] = np.mean(xent) 83 | # Note that these evaluation metrics have all been converted to percentages 84 | results["accuracy"] = 100.0 * sklearn.metrics.accuracy_score(y_true, y_pred) 85 | results["accuracy-balanced"] = 100.0 * sklearn.metrics.balanced_accuracy_score(y_true, y_pred) 86 | results["f1-micro"] = 100.0 * sklearn.metrics.f1_score(y_true, y_pred, average="micro") 87 | results["f1-macro"] = 100.0 * sklearn.metrics.f1_score(y_true, y_pred, average="macro") 88 | results["f1-support"] = 100.0 * sklearn.metrics.f1_score(y_true, y_pred, average="weighted") 89 | # Could expand to other metrics too 90 | 91 | if verbosity >= 1: 92 | print(f"\n{partition_name} evaluation results:") 93 | for k, v in results.items(): 94 | if k == "count": 95 | print(f" {k + ' ':.<21s}{v:7d}") 96 | elif "entropy" in k: 97 | print(f" {k + ' ':.<24s} {v:9.5f} nat") 98 | else: 99 | print(f" {k + ' ':.<24s} {v:6.2f} %") 100 | 101 | return results 102 | -------------------------------------------------------------------------------- /template_experiment/io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Input/output utilities. 3 | """ 4 | 5 | import os 6 | from inspect import getsourcefile 7 | 8 | import torch 9 | 10 | PACKAGE_DIR = os.path.dirname(os.path.abspath(getsourcefile(lambda: 0))) 11 | 12 | 13 | def get_project_root() -> str: 14 | return os.path.dirname(PACKAGE_DIR) 15 | 16 | 17 | def safe_save_model(modules, checkpoint_path=None, config=None, **kwargs): 18 | """ 19 | Save a model to a checkpoint file, along with any additional data. 20 | 21 | Parameters 22 | ---------- 23 | modules : dict 24 | A dictionary of modules to save. The keys are the names of the modules 25 | and the values are the modules themselves. 26 | checkpoint_path : str, optional 27 | Path to the checkpoint file. If not provided, the path will be taken 28 | from the config object. 29 | config : :class:`argparse.Namespace`, optional 30 | A configuration object containing the checkpoint path. 31 | **kwargs 32 | Additional data to save to the checkpoint file. 33 | """ 34 | if checkpoint_path is not None: 35 | pass 36 | elif config is not None and hasattr(config, "checkpoint_path"): 37 | checkpoint_path = config.checkpoint_path 38 | else: 39 | raise ValueError("No checkpoint path provided") 40 | print(f"\nSaving model to {checkpoint_path}") 41 | # Create the directory if it doesn't already exist 42 | os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) 43 | # Save to a temporary file first, then move the temporary file to the target 44 | # destination. This is to prevent clobbering the checkpoint with a partially 45 | # saved file, in the event that the saving process is interrupted. Saving 46 | # the checkpoint takes a little while and can be disrupted by preemption, 47 | # whereas moving the file is an atomic operation. 48 | tmp_a, tmp_b = os.path.split(checkpoint_path) 49 | tmp_fname = os.path.join(tmp_a, ".tmp." + tmp_b) 50 | data = {k: v.state_dict() for k, v in modules.items()} 51 | data.update(kwargs) 52 | if config is not None: 53 | data["config"] = config 54 | 55 | torch.save(data, tmp_fname) 56 | os.rename(tmp_fname, checkpoint_path) 57 | print(f"Saved model to {checkpoint_path}") 58 | -------------------------------------------------------------------------------- /template_experiment/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import builtins 4 | import copy 5 | import os 6 | import shutil 7 | import time 8 | import warnings 9 | from contextlib import nullcontext 10 | from datetime import datetime 11 | from socket import gethostname 12 | 13 | import torch 14 | import torch.optim 15 | from torch import nn 16 | from torch.utils.data.distributed import DistributedSampler 17 | 18 | from template_experiment import data_transformations, datasets, encoders, utils 19 | from template_experiment.evaluation import evaluate 20 | from template_experiment.io import safe_save_model 21 | 22 | BASE_BATCH_SIZE = 128 23 | 24 | 25 | def check_is_distributed(): 26 | r""" 27 | Check if the current job is running in distributed mode. 28 | 29 | Returns 30 | ------- 31 | bool 32 | Whether the job is running in distributed mode. 33 | """ 34 | return ( 35 | "WORLD_SIZE" in os.environ 36 | and "RANK" in os.environ 37 | and "LOCAL_RANK" in os.environ 38 | and "MASTER_ADDR" in os.environ 39 | and "MASTER_PORT" in os.environ 40 | ) 41 | 42 | 43 | def setup_slurm_distributed(): 44 | r""" 45 | Use SLURM environment variables to set up environment variables needed for DDP. 46 | 47 | Note: This is not used when using torchrun, as that sets RANK etc. for us, 48 | but is useful if you're using srun without torchrun (i.e. using srun within 49 | the sbatch file to lauching one task per GPU). 50 | """ 51 | if "WORLD_SIZE" in os.environ: 52 | pass 53 | elif "SLURM_NNODES" in os.environ and "SLURM_GPUS_ON_NODE" in os.environ: 54 | os.environ["WORLD_SIZE"] = str(int(os.environ["SLURM_NNODES"]) * int(os.environ["SLURM_GPUS_ON_NODE"])) 55 | elif "SLURM_NPROCS" in os.environ: 56 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"] 57 | if "RANK" not in os.environ and "SLURM_PROCID" in os.environ: 58 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 59 | if int(os.environ["RANK"]) > 0 and "WORLD_SIZE" not in os.environ: 60 | raise EnvironmentError( 61 | f"SLURM_PROCID is {os.environ['SLURM_PROCID']}, implying" 62 | " distributed training, but WORLD_SIZE could not be determined." 63 | ) 64 | if "LOCAL_RANK" not in os.environ and "SLURM_LOCALID" in os.environ: 65 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 66 | if "MASTER_ADDR" not in os.environ and "SLURM_NODELIST" in os.environ: 67 | os.environ["MASTER_ADDR"] = os.environ["SLURM_NODELIST"].split("-")[0] 68 | if "MASTER_PORT" not in os.environ and "SLURM_JOB_ID" in os.environ: 69 | os.environ["MASTER_PORT"] = str(49152 + int(os.environ["SLURM_JOB_ID"]) % 16384) 70 | 71 | 72 | def run(config): 73 | r""" 74 | Run training job (one worker if using distributed training). 75 | 76 | Parameters 77 | ---------- 78 | config : argparse.Namespace or OmegaConf 79 | The configuration for this experiment. 80 | """ 81 | if config.log_wandb: 82 | # Lazy import of wandb, since logging to wandb is optional 83 | import wandb 84 | 85 | if config.seed is not None: 86 | utils.set_rng_seeds_fixed(config.seed) 87 | 88 | if config.deterministic: 89 | print("Running in deterministic cuDNN mode. Performance may be slower, but more reproducible.") 90 | torch.backends.cudnn.deterministic = True 91 | torch.backends.cudnn.benchmark = False 92 | 93 | # DISTRIBUTION ============================================================ 94 | # Setup for distributed training 95 | setup_slurm_distributed() 96 | config.world_size = int(os.environ.get("WORLD_SIZE", 1)) 97 | config.distributed = check_is_distributed() 98 | if config.world_size > 1 and not config.distributed: 99 | raise EnvironmentError( 100 | f"WORLD_SIZE is {config.world_size}, but not all other required" 101 | " environment variables for distributed training are set." 102 | ) 103 | # Work out the total batch size depending on the number of GPUs we are using 104 | config.batch_size = config.batch_size_per_gpu * config.world_size 105 | 106 | if config.distributed: 107 | # For multiprocessing distributed training, gpu rank needs to be 108 | # set to the global rank among all the processes. 109 | config.global_rank = int(os.environ["RANK"]) 110 | config.local_rank = int(os.environ["LOCAL_RANK"]) 111 | print( 112 | f"Rank {config.global_rank} of {config.world_size} on {gethostname()}" 113 | f" (local GPU {config.local_rank} of {torch.cuda.device_count()})." 114 | f" Communicating with master at {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" 115 | ) 116 | torch.distributed.init_process_group(backend="nccl") 117 | else: 118 | config.global_rank = 0 119 | 120 | # Suppress printing if this is not the master process for the node 121 | if config.distributed and config.global_rank != 0: 122 | 123 | def print_pass(*args, **kwargs): 124 | pass 125 | 126 | builtins.print = print_pass 127 | 128 | print() 129 | print("Configuration:") 130 | print() 131 | print(config) 132 | print() 133 | print(f"Found {torch.cuda.device_count()} GPUs and {utils.get_num_cpu_available()} CPUs.") 134 | 135 | # Check which device to use 136 | use_cuda = not config.no_cuda and torch.cuda.is_available() 137 | 138 | if config.distributed and not use_cuda: 139 | raise EnvironmentError("Distributed training with NCCL requires CUDA.") 140 | if not use_cuda: 141 | device = torch.device("cpu") 142 | elif config.local_rank is not None: 143 | device = f"cuda:{config.local_rank}" 144 | else: 145 | device = "cuda" 146 | 147 | print(f"Using device {device}") 148 | 149 | # RESTORE OMITTED CONFIG FROM RESUMPTION CHECKPOINT ======================= 150 | checkpoint = None 151 | config.model_output_dir = None 152 | if config.checkpoint_path: 153 | config.model_output_dir = os.path.dirname(config.checkpoint_path) 154 | if not config.checkpoint_path: 155 | # Not trying to resume from a checkpoint 156 | pass 157 | elif not os.path.isfile(config.checkpoint_path): 158 | # Looks like we're trying to resume from the checkpoint that this job 159 | # will itself create. Let's assume this is to let the job resume upon 160 | # preemption, and it just hasn't been preempted yet. 161 | print(f"Skipping premature resumption from preemption: no checkpoint file found at '{config.checkpoint_path}'") 162 | else: 163 | print(f"Loading resumption checkpoint '{config.checkpoint_path}'") 164 | # Map model parameters to be load to the specified gpu. 165 | checkpoint = torch.load(config.checkpoint_path, map_location=device) 166 | keys = vars(get_parser().parse_args("")).keys() 167 | keys = set(keys).difference(["resume", "gpu", "global_rank", "local_rank", "cpu_workers"]) 168 | for key in keys: 169 | if getattr(checkpoint["config"], key, None) is None: 170 | continue 171 | if getattr(config, key) is None: 172 | print(f" Restoring config value for {key} from checkpoint: {getattr(checkpoint['config'], key)}") 173 | setattr(config, key, getattr(checkpoint["config"], key, None)) 174 | elif getattr(config, key) != getattr(checkpoint["config"], key): 175 | print( 176 | f" Warning: config value for {key} differs from checkpoint:" 177 | f" {getattr(config, key)} (ours) vs {getattr(checkpoint['config'], key)} (checkpoint)" 178 | ) 179 | 180 | if checkpoint is None: 181 | # Our epochs go from 1 to n_epoch, inclusive 182 | start_epoch = 1 183 | else: 184 | # Continue from where we left off 185 | start_epoch = checkpoint["epoch"] + 1 186 | if config.seed is not None: 187 | # Make sure we don't get the same behaviour as we did on the 188 | # first epoch repeated on this resumed epoch. 189 | utils.set_rng_seeds_fixed(config.seed + start_epoch, all_gpu=False) 190 | 191 | # MODEL =================================================================== 192 | 193 | # Encoder ----------------------------------------------------------------- 194 | # Build our Encoder. 195 | # We have to build the encoder before we load the dataset because it will 196 | # inform us about what size images we should produce in the preprocessing pipeline. 197 | n_class, raw_img_size, img_channels = datasets.image_dataset_sizes(config.dataset_name) 198 | if img_channels > 3 and config.freeze_encoder: 199 | raise ValueError( 200 | "Using a dataset with more than 3 image channels will require retraining the encoder" 201 | ", but a frozen encoder was requested." 202 | ) 203 | if config.arch_framework == "timm": 204 | encoder, encoder_config = encoders.get_timm_encoder(config.arch, config.pretrained, in_chans=img_channels) 205 | elif config.arch_framework == "torchvision": 206 | # It's trickier to implement this for torchvision models, because they 207 | # don't have the same naming conventions for model names as in timm; 208 | # need us to specify the name of the weights when loading a pretrained 209 | # model; and don't support changing the number of input channels. 210 | raise NotImplementedError(f"Unsupported architecture framework: {config.arch_framework}") 211 | else: 212 | raise ValueError(f"Unknown architecture framework: {config.arch_framework}") 213 | 214 | if config.freeze_encoder and not config.pretrained: 215 | warnings.warn( 216 | "A frozen encoder was requested, but the encoder is not pretrained.", 217 | UserWarning, 218 | stacklevel=2, 219 | ) 220 | 221 | if config.image_size is None: 222 | if "input_size" in encoder_config: 223 | config.image_size = encoder_config["input_size"][-1] 224 | print(f"Setting model input image size to encoder's expected input size: {config.image_size}") 225 | else: 226 | config.image_size = 224 227 | print(f"Setting model input image size to default: {config.image_size}") 228 | if raw_img_size: 229 | warnings.warn( 230 | "Be aware that we are using a different input image size" 231 | f" ({config.image_size}px) to the raw image size in the" 232 | f" dataset ({raw_img_size}px).", 233 | UserWarning, 234 | stacklevel=2, 235 | ) 236 | elif "input_size" in encoder_config and config.pretrained and encoder_config["input_size"][-1] != config.image_size: 237 | warnings.warn( 238 | f"A different image size {config.image_size} than what the model was" 239 | f" pretrained with {encoder_config['input_size'][-1]} was suplied", 240 | UserWarning, 241 | stacklevel=2, 242 | ) 243 | 244 | # Classifier ------------------------------------------------------------- 245 | # Build our classifier head 246 | classifier = nn.Linear(encoder_config["n_feature"], n_class) 247 | 248 | # Configure model for distributed training -------------------------------- 249 | print("\nEncoder architecture:") 250 | print(encoder) 251 | print("\nClassifier architecture:") 252 | print(classifier) 253 | print() 254 | 255 | if config.cpu_workers is None: 256 | config.cpu_workers = utils.get_num_cpu_available() 257 | 258 | if not use_cuda: 259 | print("Using CPU (this will be slow)") 260 | elif config.distributed: 261 | # Convert batchnorm into SyncBN, using stats computed from all GPUs 262 | encoder = nn.SyncBatchNorm.convert_sync_batchnorm(encoder) 263 | classifier = nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 264 | # For multiprocessing distributed, the DistributedDataParallel 265 | # constructor should always set a single device scope, otherwise 266 | # DistributedDataParallel will use all available devices. 267 | encoder.to(device) 268 | classifier.to(device) 269 | torch.cuda.set_device(device) 270 | encoder = nn.parallel.DistributedDataParallel( 271 | encoder, device_ids=[config.local_rank], output_device=config.local_rank 272 | ) 273 | classifier = nn.parallel.DistributedDataParallel( 274 | classifier, device_ids=[config.local_rank], output_device=config.local_rank 275 | ) 276 | else: 277 | if config.local_rank is not None: 278 | torch.cuda.set_device(config.local_rank) 279 | encoder = encoder.to(device) 280 | classifier = classifier.to(device) 281 | 282 | # DATASET ================================================================= 283 | # Transforms -------------------------------------------------------------- 284 | transform_args = {} 285 | if config.dataset_name in data_transformations.VALID_TRANSFORMS: 286 | transform_args["normalization"] = config.dataset_name 287 | 288 | if "mean" in encoder_config: 289 | transform_args["mean"] = encoder_config["mean"] 290 | if "std" in encoder_config: 291 | transform_args["std"] = encoder_config["std"] 292 | 293 | transform_train, transform_eval = data_transformations.get_transform( 294 | config.transform_type, config.image_size, transform_args 295 | ) 296 | 297 | # Dataset ----------------------------------------------------------------- 298 | dataset_args = { 299 | "dataset": config.dataset_name, 300 | "root": config.data_dir, 301 | "prototyping": config.prototyping, 302 | "download": config.allow_download_dataset, 303 | } 304 | if config.protoval_split_id is not None: 305 | dataset_args["protoval_split_id"] = config.protoval_split_id 306 | ( 307 | dataset_train, 308 | dataset_val, 309 | dataset_test, 310 | distinct_val_test, 311 | ) = datasets.fetch_dataset( 312 | **dataset_args, 313 | transform_train=transform_train, 314 | transform_eval=transform_eval, 315 | ) 316 | eval_set = "Val" if distinct_val_test else "Test" 317 | 318 | # Dataloader -------------------------------------------------------------- 319 | dl_train_kwargs = { 320 | "batch_size": config.batch_size_per_gpu, 321 | "drop_last": True, 322 | "sampler": None, 323 | "shuffle": True, 324 | "worker_init_fn": utils.worker_seed_fn, 325 | } 326 | dl_test_kwargs = { 327 | "batch_size": config.batch_size_per_gpu, 328 | "drop_last": False, 329 | "sampler": None, 330 | "shuffle": False, 331 | "worker_init_fn": utils.worker_seed_fn, 332 | } 333 | if use_cuda: 334 | cuda_kwargs = {"num_workers": config.cpu_workers, "pin_memory": True} 335 | dl_train_kwargs.update(cuda_kwargs) 336 | dl_test_kwargs.update(cuda_kwargs) 337 | 338 | dl_val_kwargs = copy.deepcopy(dl_test_kwargs) 339 | 340 | if config.distributed: 341 | # The DistributedSampler breaks up the dataset across the GPUs 342 | dl_train_kwargs["sampler"] = DistributedSampler( 343 | dataset_train, 344 | shuffle=True, 345 | seed=config.seed if config.seed is not None else 0, 346 | drop_last=False, 347 | ) 348 | dl_train_kwargs["shuffle"] = None 349 | dl_val_kwargs["sampler"] = DistributedSampler( 350 | dataset_val, 351 | shuffle=False, 352 | drop_last=False, 353 | ) 354 | dl_val_kwargs["shuffle"] = None 355 | dl_test_kwargs["sampler"] = DistributedSampler( 356 | dataset_test, 357 | shuffle=False, 358 | drop_last=False, 359 | ) 360 | dl_test_kwargs["shuffle"] = None 361 | 362 | dataloader_train = torch.utils.data.DataLoader(dataset_train, **dl_train_kwargs) 363 | dataloader_val = torch.utils.data.DataLoader(dataset_val, **dl_val_kwargs) 364 | dataloader_test = torch.utils.data.DataLoader(dataset_test, **dl_test_kwargs) 365 | 366 | # OPTIMIZATION ============================================================ 367 | # Optimizer --------------------------------------------------------------- 368 | # Set up the optimizer 369 | 370 | # Bigger batch sizes mean better estimates of the gradient, so we can use a 371 | # bigger learning rate. See https://arxiv.org/abs/1706.02677 372 | # Hence we scale the learning rate linearly with the total batch size. 373 | config.lr = config.lr_relative * config.batch_size / BASE_BATCH_SIZE 374 | 375 | # Freeze the encoder, if requested 376 | if config.freeze_encoder: 377 | for m in encoder.parameters(): 378 | m.requires_grad = False 379 | 380 | # Set up a parameter group for each component of the model, allowing 381 | # them to have different learning rates (for fine-tuning encoder). 382 | params = [] 383 | if not config.freeze_encoder: 384 | params.append( 385 | { 386 | "params": encoder.parameters(), 387 | "lr": config.lr * config.lr_encoder_mult, 388 | "name": "encoder", 389 | } 390 | ) 391 | params.append( 392 | { 393 | "params": classifier.parameters(), 394 | "lr": config.lr * config.lr_classifier_mult, 395 | "name": "classifier", 396 | } 397 | ) 398 | 399 | # Fetch the constructor of the appropriate optimizer from torch.optim 400 | optimizer = getattr(torch.optim, config.optimizer)(params, lr=config.lr, weight_decay=config.weight_decay) 401 | 402 | # Scheduler --------------------------------------------------------------- 403 | # Set up the learning rate scheduler 404 | if config.scheduler.lower() == "onecycle": 405 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 406 | optimizer, 407 | [p["lr"] for p in optimizer.param_groups], 408 | epochs=config.epochs, 409 | steps_per_epoch=len(dataloader_train), 410 | ) 411 | else: 412 | raise NotImplementedError(f"Scheduler {config.scheduler} not supported.") 413 | 414 | # Loss function ----------------------------------------------------------- 415 | # Set up loss function 416 | criterion = nn.CrossEntropyLoss() 417 | 418 | # LOGGING ================================================================= 419 | # Setup logging and saving 420 | 421 | # If we're using wandb, initialize the run, or resume it if the job was preempted. 422 | if config.log_wandb and config.global_rank == 0: 423 | wandb_run_name = config.run_name 424 | if wandb_run_name is not None and config.run_id is not None: 425 | wandb_run_name = f"{wandb_run_name}__{config.run_id}" 426 | EXCLUDED_WANDB_CONFIG_KEYS = [ 427 | "log_wandb", 428 | "wandb_entity", 429 | "wandb_project", 430 | "global_rank", 431 | "local_rank", 432 | "run_name", 433 | "run_id", 434 | "model_output_dir", 435 | ] 436 | utils.init_or_resume_wandb_run( 437 | config.model_output_dir, 438 | name=wandb_run_name, 439 | id=config.run_id, 440 | entity=config.wandb_entity, 441 | project=config.wandb_project, 442 | config=wandb.helper.parse_config(config, exclude=EXCLUDED_WANDB_CONFIG_KEYS), 443 | job_type="train", 444 | tags=["prototype" if config.prototyping else "final"], 445 | ) 446 | # If a run_id was not supplied at the command prompt, wandb will 447 | # generate a name. Let's use that as the run_name. 448 | if config.run_name is None: 449 | config.run_name = wandb.run.name 450 | if config.run_id is None: 451 | config.run_id = wandb.run.id 452 | 453 | # If we still don't have a run name, generate one from the current time. 454 | if config.run_name is None: 455 | config.run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 456 | if config.run_id is None: 457 | config.run_id = utils.generate_id() 458 | 459 | # If no checkpoint path was supplied, but models_dir was, we will automatically 460 | # determine the path to which we will save the model checkpoint. 461 | # If both are empty, we won't save the model. 462 | if not config.checkpoint_path and config.models_dir: 463 | config.model_output_dir = os.path.join( 464 | config.models_dir, 465 | config.dataset_name, 466 | f"{config.run_name}__{config.run_id}", 467 | ) 468 | config.checkpoint_path = os.path.join(config.model_output_dir, "checkpoint_latest.pt") 469 | if config.log_wandb and config.global_rank == 0: 470 | wandb.config.update({"checkpoint_path": config.checkpoint_path}, allow_val_change=True) 471 | 472 | if config.checkpoint_path is None: 473 | print("Model will not be saved.") 474 | else: 475 | print(f"Model will be saved to '{config.checkpoint_path}'") 476 | 477 | # RESUME ================================================================== 478 | # Now that everything is set up, we can load the state of the model, 479 | # optimizer, and scheduler from a checkpoint, if supplied. 480 | 481 | # Initialize step related variables as if we're starting from scratch. 482 | # Their values will be overridden by the checkpoint if we're resuming. 483 | total_step = 0 484 | n_samples_seen = 0 485 | 486 | best_stats = {"max_accuracy": 0, "best_epoch": 0} 487 | 488 | if checkpoint is not None: 489 | print(f"Loading state from checkpoint (epoch {checkpoint['epoch']})") 490 | # Map model to be loaded to specified single gpu. 491 | total_step = checkpoint["total_step"] 492 | n_samples_seen = checkpoint["n_samples_seen"] 493 | encoder.load_state_dict(checkpoint["encoder"]) 494 | classifier.load_state_dict(checkpoint["classifier"]) 495 | optimizer.load_state_dict(checkpoint["optimizer"]) 496 | scheduler.load_state_dict(checkpoint["scheduler"]) 497 | best_stats["max_accuracy"] = checkpoint.get("max_accuracy", 0) 498 | best_stats["best_epoch"] = checkpoint.get("best_epoch", 0) 499 | 500 | # TRAIN =================================================================== 501 | print() 502 | print("Configuration:") 503 | print() 504 | print(config) 505 | print() 506 | 507 | # Ensure modules are on the correct device 508 | encoder = encoder.to(device) 509 | classifier = classifier.to(device) 510 | 511 | # Stack the encoder and classifier together to create an overall model. 512 | # At inference time, we don't need to make a distinction between modules 513 | # within this stack. 514 | model = nn.Sequential(encoder, classifier) 515 | 516 | timing_stats = {} 517 | t_end_epoch = time.time() 518 | for epoch in range(start_epoch, config.epochs + 1): 519 | t_start_epoch = time.time() 520 | if config.seed is not None: 521 | # If the job is resumed from preemption, our RNG state is currently set the 522 | # same as it was at the start of the first epoch, not where it was when we 523 | # stopped training. This is not good as it means jobs which are resumed 524 | # don't do the same thing as they would be if they'd run uninterrupted 525 | # (making preempted jobs non-reproducible). 526 | # To address this, we reset the seed at the start of every epoch. Since jobs 527 | # can only save at the end of and resume at the start of an epoch, this 528 | # makes the training process reproducible. But we shouldn't use the same 529 | # RNG state for each epoch - instead we use the original seed to define the 530 | # series of seeds that we will use at the start of each epoch. 531 | epoch_seed = utils.determine_epoch_seed(config.seed, epoch=epoch) 532 | # We want each GPU to have a different seed to the others to avoid 533 | # correlated randomness between the workers on the same batch. 534 | # We offset the seed for this epoch by the GPU rank, so every GPU will get a 535 | # unique seed for the epoch. This means the job is only precisely 536 | # reproducible if it is rerun with the same number of GPUs (and the same 537 | # number of CPU workers for the dataloader). 538 | utils.set_rng_seeds_fixed(epoch_seed + config.global_rank, all_gpu=False) 539 | if isinstance(getattr(dataloader_train, "generator", None), torch.Generator): 540 | # Finesse the dataloader's RNG state, if it is not using the global state. 541 | dataloader_train.generator.manual_seed(epoch_seed + config.global_rank) 542 | if isinstance(getattr(dataloader_train.sampler, "generator", None), torch.Generator): 543 | # Finesse the sampler's RNG state, if it is not using the global RNG state. 544 | dataloader_train.sampler.generator.manual_seed(config.seed + epoch + 10000 * config.global_rank) 545 | 546 | if hasattr(dataloader_train.sampler, "set_epoch"): 547 | # Handling for DistributedSampler. 548 | # Set the epoch for the sampler so that it can shuffle the data 549 | # differently for each epoch, but synchronized across all GPUs. 550 | dataloader_train.sampler.set_epoch(epoch) 551 | 552 | # Train ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 553 | # Note the number of samples seen before this epoch started, so we can 554 | # calculate the number of samples seen in this epoch. 555 | n_samples_seen_before = n_samples_seen 556 | # Run one epoch of training 557 | train_stats, total_step, n_samples_seen = train_one_epoch( 558 | config=config, 559 | encoder=encoder, 560 | classifier=classifier, 561 | optimizer=optimizer, 562 | scheduler=scheduler, 563 | criterion=criterion, 564 | dataloader=dataloader_train, 565 | device=device, 566 | epoch=epoch, 567 | n_epoch=config.epochs, 568 | total_step=total_step, 569 | n_samples_seen=n_samples_seen, 570 | ) 571 | t_end_train = time.time() 572 | 573 | timing_stats["train"] = t_end_train - t_start_epoch 574 | n_epoch_samples = n_samples_seen - n_samples_seen_before 575 | train_stats["throughput"] = n_epoch_samples / timing_stats["train"] 576 | 577 | print(f"Training epoch {epoch}/{config.epochs} summary:") 578 | print(f" Steps ..............{len(dataloader_train):8d}") 579 | print(f" Samples ............{n_epoch_samples:8d}") 580 | if timing_stats["train"] > 172800: 581 | print(f" Duration ...........{timing_stats['train']/86400:11.2f} days") 582 | elif timing_stats["train"] > 5400: 583 | print(f" Duration ...........{timing_stats['train']/3600:11.2f} hours") 584 | elif timing_stats["train"] > 120: 585 | print(f" Duration ...........{timing_stats['train']/60:11.2f} minutes") 586 | else: 587 | print(f" Duration ...........{timing_stats['train']:11.2f} seconds") 588 | print(f" Throughput .........{train_stats['throughput']:11.2f} samples/sec") 589 | print(f" Loss ...............{train_stats['loss']:14.5f}") 590 | print(f" Accuracy ...........{train_stats['accuracy']:11.2f} %") 591 | 592 | # Validate ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 593 | # Evaluate on validation set 594 | t_start_val = time.time() 595 | 596 | eval_stats = evaluate( 597 | dataloader=dataloader_val, 598 | model=model, 599 | device=device, 600 | partition_name=eval_set, 601 | is_distributed=config.distributed, 602 | ) 603 | t_end_val = time.time() 604 | timing_stats["val"] = t_end_val - t_start_val 605 | eval_stats["throughput"] = len(dataloader_val.dataset) / timing_stats["val"] 606 | 607 | # Check if this is the new best model 608 | if eval_stats["accuracy"] >= best_stats["max_accuracy"]: 609 | best_stats["max_accuracy"] = eval_stats["accuracy"] 610 | best_stats["best_epoch"] = epoch 611 | 612 | print(f"Evaluating epoch {epoch}/{config.epochs} summary:") 613 | if timing_stats["val"] > 172800: 614 | print(f" Duration ...........{timing_stats['val']/86400:11.2f} days") 615 | elif timing_stats["val"] > 5400: 616 | print(f" Duration ...........{timing_stats['val']/3600:11.2f} hours") 617 | elif timing_stats["val"] > 120: 618 | print(f" Duration ...........{timing_stats['val']/60:11.2f} minutes") 619 | else: 620 | print(f" Duration ...........{timing_stats['val']:11.2f} seconds") 621 | print(f" Throughput .........{eval_stats['throughput']:11.2f} samples/sec") 622 | print(f" Cross-entropy ......{eval_stats['cross-entropy']:14.5f}") 623 | print(f" Accuracy ...........{eval_stats['accuracy']:11.2f} %") 624 | 625 | # Save model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 626 | t_start_save = time.time() 627 | if config.model_output_dir and (not config.distributed or config.global_rank == 0): 628 | safe_save_model( 629 | { 630 | "encoder": encoder, 631 | "classifier": classifier, 632 | "optimizer": optimizer, 633 | "scheduler": scheduler, 634 | }, 635 | config.checkpoint_path, 636 | config=config, 637 | epoch=epoch, 638 | total_step=total_step, 639 | n_samples_seen=n_samples_seen, 640 | encoder_config=encoder_config, 641 | transform_args=transform_args, 642 | **best_stats, 643 | ) 644 | if config.save_best_model and best_stats["best_epoch"] == epoch: 645 | ckpt_path_best = os.path.join(config.model_output_dir, "best_model.pt") 646 | print(f"Copying model to {ckpt_path_best}") 647 | shutil.copyfile(config.checkpoint_path, ckpt_path_best) 648 | 649 | t_end_save = time.time() 650 | timing_stats["saving"] = t_end_save - t_start_save 651 | 652 | # Log to wandb ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 653 | # Overall time won't include uploading to wandb, but there's nothing 654 | # we can do about that. 655 | timing_stats["overall"] = time.time() - t_end_epoch 656 | t_end_epoch = time.time() 657 | 658 | # Send training and eval stats for this epoch to wandb 659 | if config.log_wandb and config.global_rank == 0: 660 | pre = "Training/epochwise" 661 | wandb.log( 662 | { 663 | "Training/stepwise/epoch": epoch, 664 | "Training/stepwise/epoch_progress": epoch, 665 | "Training/stepwise/n_samples_seen": n_samples_seen, 666 | f"{pre}/epoch": epoch, 667 | **{f"{pre}/Train/{k}": v for k, v in train_stats.items()}, 668 | **{f"{pre}/{eval_set}/{k}": v for k, v in eval_stats.items()}, 669 | **{f"{pre}/duration/{k}": v for k, v in timing_stats.items()}, 670 | }, 671 | step=total_step, 672 | ) 673 | # Record the wandb time as contributing to the next epoch 674 | timing_stats = {"wandb": time.time() - t_end_epoch} 675 | else: 676 | # Reset timing stats 677 | timing_stats = {} 678 | # Print with flush=True forces the output buffer to be printed immediately 679 | print("", flush=True) 680 | 681 | if start_epoch > config.epochs: 682 | print("Training already completed!") 683 | else: 684 | print(f"Training complete! (Trained epochs {start_epoch} to {config.epochs})") 685 | print( 686 | f"Best {eval_set} accuracy was {best_stats['max_accuracy']:.2f}%," 687 | f" seen at the end of epoch {best_stats['best_epoch']}" 688 | ) 689 | 690 | # TEST ==================================================================== 691 | print(f"\nEvaluating final model (epoch {config.epochs}) performance") 692 | # Evaluate on test set 693 | print("\nEvaluating final model on test set...") 694 | eval_stats = evaluate( 695 | dataloader=dataloader_test, 696 | model=model, 697 | device=device, 698 | partition_name="Test", 699 | is_distributed=config.distributed, 700 | ) 701 | # Send stats to wandb 702 | if config.log_wandb and config.global_rank == 0: 703 | wandb.log({**{f"Eval/Test/{k}": v for k, v in eval_stats.items()}}, step=total_step) 704 | 705 | if distinct_val_test: 706 | # Evaluate on validation set 707 | print(f"\nEvaluating final model on {eval_set} set...") 708 | eval_stats = evaluate( 709 | dataloader=dataloader_val, 710 | model=model, 711 | device=device, 712 | partition_name=eval_set, 713 | is_distributed=config.distributed, 714 | ) 715 | # Send stats to wandb 716 | if config.log_wandb and config.global_rank == 0: 717 | wandb.log( 718 | {**{f"Eval/{eval_set}/{k}": v for k, v in eval_stats.items()}}, 719 | step=total_step, 720 | ) 721 | 722 | # Create a copy of the train partition with evaluation transforms 723 | # and a dataloader using the evaluation configuration (don't drop last) 724 | print("\nEvaluating final model on train set under test conditions (no augmentation, dropout, etc)...") 725 | dataset_train_eval = datasets.fetch_dataset( 726 | **dataset_args, 727 | transform_train=transform_eval, 728 | transform_eval=transform_eval, 729 | )[0] 730 | dl_train_eval_kwargs = copy.deepcopy(dl_test_kwargs) 731 | if config.distributed: 732 | # The DistributedSampler breaks up the dataset across the GPUs 733 | dl_train_eval_kwargs["sampler"] = DistributedSampler( 734 | dataset_train_eval, 735 | shuffle=False, 736 | drop_last=False, 737 | ) 738 | dl_train_eval_kwargs["shuffle"] = None 739 | dataloader_train_eval = torch.utils.data.DataLoader(dataset_train_eval, **dl_train_eval_kwargs) 740 | eval_stats = evaluate( 741 | dataloader=dataloader_train_eval, 742 | model=model, 743 | device=device, 744 | partition_name="Train", 745 | is_distributed=config.distributed, 746 | ) 747 | # Send stats to wandb 748 | if config.log_wandb and config.global_rank == 0: 749 | wandb.log({**{f"Eval/Train/{k}": v for k, v in eval_stats.items()}}, step=total_step) 750 | 751 | 752 | def train_one_epoch( 753 | config, 754 | encoder, 755 | classifier, 756 | optimizer, 757 | scheduler, 758 | criterion, 759 | dataloader, 760 | device="cuda", 761 | epoch=1, 762 | n_epoch=None, 763 | total_step=0, 764 | n_samples_seen=0, 765 | ): 766 | r""" 767 | Train the encoder and classifier for one epoch. 768 | 769 | Parameters 770 | ---------- 771 | config : argparse.Namespace or OmegaConf 772 | The global config object. 773 | encoder : torch.nn.Module 774 | The encoder network. 775 | classifier : torch.nn.Module 776 | The classifier network. 777 | optimizer : torch.optim.Optimizer 778 | The optimizer. 779 | scheduler : torch.optim.lr_scheduler._LRScheduler 780 | The learning rate scheduler. 781 | criterion : torch.nn.Module 782 | The loss function. 783 | dataloader : torch.utils.data.DataLoader 784 | A dataloader for the training set. 785 | device : str or torch.device, default="cuda" 786 | The device to use. 787 | epoch : int, default=1 788 | The current epoch number (indexed from 1). 789 | n_epoch : int, optional 790 | The total number of epochs scheduled to train for. 791 | total_step : int, default=0 792 | The total number of steps taken so far. 793 | n_samples_seen : int, default=0 794 | The total number of samples seen so far. 795 | 796 | Returns 797 | ------- 798 | results: dict 799 | A dictionary containing the training performance for this epoch. 800 | total_step : int 801 | The total number of steps taken after this epoch. 802 | n_samples_seen : int 803 | The total number of samples seen after this epoch. 804 | """ 805 | # Put the model in train mode 806 | encoder.train() 807 | classifier.train() 808 | 809 | if config.log_wandb: 810 | # Lazy import of wandb, since logging to wandb is optional 811 | import wandb 812 | 813 | loss_epoch = 0 814 | acc_epoch = 0 815 | 816 | if config.print_interval is None: 817 | # Default to printing to console every time we log to wandb 818 | config.print_interval = config.log_interval 819 | 820 | t_end_batch = time.time() 821 | t_start_wandb = t_end_wandb = None 822 | for batch_idx, (stimuli, y_true) in enumerate(dataloader): 823 | t_start_batch = time.time() 824 | batch_size_this_gpu = stimuli.shape[0] 825 | 826 | # Move training inputs and targets to the GPU 827 | stimuli = stimuli.to(device) 828 | y_true = y_true.to(device) 829 | 830 | # Forward pass -------------------------------------------------------- 831 | # Perform the forward pass through the model 832 | t_start_encoder = time.time() 833 | # N.B. To accurately time steps on GPU we need to use torch.cuda.Event 834 | ct_forward = torch.cuda.Event(enable_timing=True) 835 | ct_forward.record() 836 | with torch.no_grad() if config.freeze_encoder else nullcontext(): 837 | h = encoder(stimuli) 838 | logits = classifier(h) 839 | # Reset gradients 840 | optimizer.zero_grad() 841 | # Measure loss 842 | loss = criterion(logits, y_true) 843 | 844 | # Backward pass ------------------------------------------------------- 845 | # Now the backward pass 846 | ct_backward = torch.cuda.Event(enable_timing=True) 847 | ct_backward.record() 848 | loss.backward() 849 | 850 | # Update -------------------------------------------------------------- 851 | # Use our optimizer to update the model parameters 852 | ct_optimizer = torch.cuda.Event(enable_timing=True) 853 | ct_optimizer.record() 854 | optimizer.step() 855 | 856 | # Step the scheduler each batch 857 | scheduler.step() 858 | 859 | # Increment training progress counters 860 | total_step += 1 861 | batch_size_all = batch_size_this_gpu * config.world_size 862 | n_samples_seen += batch_size_all 863 | 864 | # Logging ------------------------------------------------------------- 865 | # Log details about training progress 866 | t_start_logging = time.time() 867 | ct_logging = torch.cuda.Event(enable_timing=True) 868 | ct_logging.record() 869 | 870 | # Update the total loss for the epoch 871 | if config.distributed: 872 | # Fetch results from other GPUs 873 | loss_batch = torch.mean(utils.concat_all_gather(loss.reshape((1,)))) 874 | loss_batch = loss_batch.item() 875 | else: 876 | loss_batch = loss.item() 877 | loss_epoch += loss_batch 878 | 879 | # Compute accuracy 880 | with torch.no_grad(): 881 | y_pred = torch.argmax(logits, dim=-1) 882 | is_correct = y_pred == y_true 883 | acc = 100.0 * is_correct.sum() / len(is_correct) 884 | if config.distributed: 885 | # Fetch results from other GPUs 886 | acc = torch.mean(utils.concat_all_gather(acc.reshape((1,)))) 887 | acc = acc.item() 888 | acc_epoch += acc 889 | 890 | if epoch <= 1 and batch_idx == 0: 891 | # Debugging 892 | print("stimuli.shape =", stimuli.shape) 893 | print("y_true.shape =", y_true.shape) 894 | print("y_pred.shape =", y_pred.shape) 895 | print("logits.shape =", logits.shape) 896 | print("loss.shape =", loss.shape) 897 | # Debugging intensifies 898 | print("y_true =", y_true) 899 | print("y_pred =", y_pred) 900 | print("logits[0] =", logits[0]) 901 | print("loss =", loss.detach().item()) 902 | 903 | # Log sample training images to show on wandb 904 | if config.log_wandb and batch_idx <= 1: 905 | # Log 8 example training images from each GPU 906 | img_indices = [offset + relative for offset in [0, batch_size_this_gpu // 2] for relative in [0, 1, 2, 3]] 907 | img_indices = sorted(set(img_indices)) 908 | log_images = stimuli[img_indices] 909 | if config.distributed: 910 | # Collate sample images from each GPU 911 | log_images = utils.concat_all_gather(log_images) 912 | if config.global_rank == 0: 913 | wandb.log( 914 | {"Training/stepwise/Train/stimuli": wandb.Image(log_images)}, 915 | step=total_step, 916 | ) 917 | 918 | # Log to console 919 | if batch_idx <= 2 or batch_idx % config.print_interval == 0 or batch_idx >= len(dataloader) - 1: 920 | print( 921 | f"Train Epoch:{epoch:4d}" + (f"/{n_epoch}" if n_epoch is not None else ""), 922 | " Step:{:4d}/{}".format(batch_idx + 1, len(dataloader)), 923 | " Loss:{:8.5f}".format(loss_batch), 924 | " Acc:{:6.2f}%".format(acc), 925 | " LR: {}".format(scheduler.get_last_lr()), 926 | ) 927 | 928 | # Log to wandb 929 | if config.log_wandb and config.global_rank == 0 and batch_idx % config.log_interval == 0: 930 | # Create a log dictionary to send to wandb 931 | # Epoch progress interpolates smoothly between epochs 932 | epoch_progress = epoch - 1 + (batch_idx + 1) / len(dataloader) 933 | # Throughput is the number of samples processed per second 934 | throughput = batch_size_all / (t_start_logging - t_end_batch) 935 | log_dict = { 936 | "Training/stepwise/epoch": epoch, 937 | "Training/stepwise/epoch_progress": epoch_progress, 938 | "Training/stepwise/n_samples_seen": n_samples_seen, 939 | "Training/stepwise/Train/throughput": throughput, 940 | "Training/stepwise/Train/loss": loss_batch, 941 | "Training/stepwise/Train/accuracy": acc, 942 | } 943 | # Track the learning rate of each parameter group 944 | for lr_idx in range(len(optimizer.param_groups)): 945 | if "name" in optimizer.param_groups[lr_idx]: 946 | grp_name = optimizer.param_groups[lr_idx]["name"] 947 | elif len(optimizer.param_groups) == 1: 948 | grp_name = "" 949 | else: 950 | grp_name = f"grp{lr_idx}" 951 | if grp_name != "": 952 | grp_name = f"-{grp_name}" 953 | grp_lr = optimizer.param_groups[lr_idx]["lr"] 954 | log_dict[f"Training/stepwise/lr{grp_name}"] = grp_lr 955 | # Synchronize ensures everything has finished running on each GPU 956 | torch.cuda.synchronize() 957 | # Record how long it took to do each step in the pipeline 958 | pre = "Training/stepwise/duration" 959 | if t_start_wandb is not None: 960 | # Record how long it took to send to wandb last time 961 | log_dict[f"{pre}/wandb"] = t_end_wandb - t_start_wandb 962 | log_dict[f"{pre}/dataloader"] = t_start_batch - t_end_batch 963 | log_dict[f"{pre}/preamble"] = t_start_encoder - t_start_batch 964 | log_dict[f"{pre}/forward"] = ct_forward.elapsed_time(ct_backward) / 1000 965 | log_dict[f"{pre}/backward"] = ct_backward.elapsed_time(ct_optimizer) / 1000 966 | log_dict[f"{pre}/optimizer"] = ct_optimizer.elapsed_time(ct_logging) / 1000 967 | log_dict[f"{pre}/overall"] = time.time() - t_end_batch 968 | t_start_wandb = time.time() 969 | log_dict[f"{pre}/logging"] = t_start_wandb - t_start_logging 970 | # Send to wandb 971 | wandb.log(log_dict, step=total_step) 972 | t_end_wandb = time.time() 973 | 974 | # Record the time when we finished this batch 975 | t_end_batch = time.time() 976 | 977 | results = { 978 | "loss": loss_epoch / len(dataloader), 979 | "accuracy": acc_epoch / len(dataloader), 980 | } 981 | return results, total_step, n_samples_seen 982 | 983 | 984 | def get_parser(): 985 | r""" 986 | Build argument parser for the command line interface. 987 | 988 | Returns 989 | ------- 990 | parser : argparse.ArgumentParser 991 | CLI argument parser. 992 | """ 993 | import argparse 994 | import sys 995 | 996 | # Use the name of the file called to determine the name of the program 997 | prog = os.path.split(sys.argv[0])[1] 998 | if prog == "__main__.py" or prog == "__main__": 999 | # If the file is called __main__.py, go up a level to the module name 1000 | prog = os.path.split(__file__)[1] 1001 | parser = argparse.ArgumentParser( 1002 | prog=prog, 1003 | description="Train image classification model.", 1004 | add_help=False, 1005 | ) 1006 | # Help arg ---------------------------------------------------------------- 1007 | group = parser.add_argument_group("Help") 1008 | group.add_argument( 1009 | "--help", 1010 | "-h", 1011 | action="help", 1012 | help="Show this help message and exit.", 1013 | ) 1014 | # Dataset args ------------------------------------------------------------ 1015 | group = parser.add_argument_group("Dataset") 1016 | group.add_argument( 1017 | "--dataset", 1018 | dest="dataset_name", 1019 | type=str, 1020 | default="cifar10", 1021 | help="Name of the dataset to learn. Default: %(default)s", 1022 | ) 1023 | group.add_argument( 1024 | "--prototyping", 1025 | dest="protoval_split_id", 1026 | nargs="?", 1027 | const=0, 1028 | type=int, 1029 | help=( 1030 | "Use a subset of the train partition for both train and val." 1031 | " If the dataset doesn't have a separate val and test set with" 1032 | " public labels (which is the case for most datasets), the train" 1033 | " partition will be reduced in size to create the val partition." 1034 | " In all cases where --prototyping is enabled, the test set is" 1035 | " never used during training. Generally, you should use" 1036 | " --prototyping throughout the model exploration and hyperparameter" 1037 | " optimization phases, and disable it for your final experiments so" 1038 | " they can run on a completely held-out test set." 1039 | ), 1040 | ) 1041 | group.add_argument( 1042 | "--data-dir", 1043 | type=str, 1044 | default=None, 1045 | help=( 1046 | "Directory within which the dataset can be found." 1047 | " Default is ~/Datasets, except on Vector servers where it is" 1048 | " adjusted as appropriate depending on the dataset's location." 1049 | ), 1050 | ) 1051 | group.add_argument( 1052 | "--allow-download-dataset", 1053 | action="store_true", 1054 | help="Attempt to download the dataset if it is not found locally.", 1055 | ) 1056 | group.add_argument( 1057 | "--transform-type", 1058 | type=str, 1059 | default="cifar", 1060 | help="Name of augmentation stack to apply to training data. Default: %(default)s", 1061 | ) 1062 | group.add_argument( 1063 | "--image-size", 1064 | type=int, 1065 | help="Size of images to use as model input. Default: encoder's default.", 1066 | ) 1067 | # Architecture args ------------------------------------------------------- 1068 | group = parser.add_argument_group("Architecture") 1069 | group.add_argument( 1070 | "--model", 1071 | "--encoder", 1072 | "--arch", 1073 | "--architecture", 1074 | dest="arch", 1075 | type=str, 1076 | default="resnet18", 1077 | help="Name of model architecture. Default: %(default)s", 1078 | ) 1079 | group.add_argument( 1080 | "--pretrained", 1081 | action="store_true", 1082 | help="Use default pretrained model weights, taken from hugging-face hub.", 1083 | ) 1084 | mx_group = group.add_mutually_exclusive_group() 1085 | mx_group.add_argument( 1086 | "--torchvision", 1087 | dest="arch_framework", 1088 | action="store_const", 1089 | const="torchvision", 1090 | default="timm", 1091 | help="Use model architecture from torchvision (default is timm).", 1092 | ) 1093 | mx_group.add_argument( 1094 | "--timm", 1095 | dest="arch_framework", 1096 | action="store_const", 1097 | const="timm", 1098 | default="timm", 1099 | help="Use model architecture from timm (default).", 1100 | ) 1101 | group.add_argument( 1102 | "--freeze-encoder", 1103 | action="store_true", 1104 | ) 1105 | # Optimization args ------------------------------------------------------- 1106 | group = parser.add_argument_group("Optimization routine") 1107 | group.add_argument( 1108 | "--epochs", 1109 | type=int, 1110 | default=5, 1111 | help="Number of epochs to train for. Default: %(default)s", 1112 | ) 1113 | group.add_argument( 1114 | "--lr", 1115 | dest="lr_relative", 1116 | type=float, 1117 | default=0.01, 1118 | help=( 1119 | f"Maximum learning rate, set per {BASE_BATCH_SIZE} batch size." 1120 | " The actual learning rate used will be scaled up by the total" 1121 | " batch size (across all GPUs). Default: %(default)s" 1122 | ), 1123 | ) 1124 | group.add_argument( 1125 | "--lr-encoder-mult", 1126 | type=float, 1127 | default=1.0, 1128 | help="Multiplier for encoder learning rate, relative to overall LR.", 1129 | ) 1130 | group.add_argument( 1131 | "--lr-classifier-mult", 1132 | type=float, 1133 | default=1.0, 1134 | help="Multiplier for classifier head's learning rate, relative to overall LR.", 1135 | ) 1136 | group.add_argument( 1137 | "--weight-decay", 1138 | "--wd", 1139 | dest="weight_decay", 1140 | type=float, 1141 | default=0.0, 1142 | help="Weight decay. Default: %(default)s", 1143 | ) 1144 | group.add_argument( 1145 | "--optimizer", 1146 | type=str, 1147 | default="AdamW", 1148 | help="Name of optimizer (case-sensitive). Default: %(default)s", 1149 | ) 1150 | group.add_argument( 1151 | "--scheduler", 1152 | type=str, 1153 | default="OneCycle", 1154 | help="Learning rate scheduler. Default: %(default)s", 1155 | ) 1156 | # Output checkpoint args -------------------------------------------------- 1157 | group = parser.add_argument_group("Output checkpoint") 1158 | group.add_argument( 1159 | "--models-dir", 1160 | type=str, 1161 | default="models", 1162 | metavar="PATH", 1163 | help="Output directory for all models. Ignored if --checkpoint is set. Default: %(default)s", 1164 | ) 1165 | group.add_argument( 1166 | "--checkpoint", 1167 | dest="checkpoint_path", 1168 | default="", 1169 | type=str, 1170 | metavar="PATH", 1171 | help=( 1172 | "Save and resume partially trained model and optimizer state from this checkpoint." 1173 | " Overrides --models-dir." 1174 | ), 1175 | ) 1176 | group.add_argument( 1177 | "--save-best-model", 1178 | action="store_true", 1179 | help="Save a copy of the model with best validation performance.", 1180 | ) 1181 | # Reproducibility args ---------------------------------------------------- 1182 | group = parser.add_argument_group("Reproducibility") 1183 | group.add_argument( 1184 | "--seed", 1185 | type=int, 1186 | help="Random number generator (RNG) seed. Default: not controlled", 1187 | ) 1188 | group.add_argument( 1189 | "--deterministic", 1190 | action="store_true", 1191 | help="Disable non-deterministic features of cuDNN.", 1192 | ) 1193 | # Hardware configuration args --------------------------------------------- 1194 | group = parser.add_argument_group("Hardware configuration") 1195 | group.add_argument( 1196 | "--batch-size", 1197 | dest="batch_size_per_gpu", 1198 | type=int, 1199 | default=BASE_BATCH_SIZE, 1200 | help=( 1201 | "Batch size per GPU. The total batch size will be this value times" 1202 | " the total number of GPUs used. Default: %(default)s" 1203 | ), 1204 | ) 1205 | group.add_argument( 1206 | "--cpu-workers", 1207 | "--workers", 1208 | dest="cpu_workers", 1209 | type=int, 1210 | help="Number of CPU workers per node. Default: number of CPUs available on device.", 1211 | ) 1212 | group.add_argument( 1213 | "--no-cuda", 1214 | action="store_true", 1215 | help="Use CPU only, no GPUs.", 1216 | ) 1217 | group.add_argument( 1218 | "--gpu", 1219 | "--local-rank", 1220 | dest="local_rank", 1221 | default=None, 1222 | type=int, 1223 | help="Index of GPU to use when training a single process. (Ignored for distributed training.)", 1224 | ) 1225 | # Logging args ------------------------------------------------------------ 1226 | group = parser.add_argument_group("Debugging and logging") 1227 | group.add_argument( 1228 | "--log-interval", 1229 | type=int, 1230 | default=20, 1231 | help="Number of batches between each log to wandb (if enabled). Default: %(default)s", 1232 | ) 1233 | group.add_argument( 1234 | "--print-interval", 1235 | type=int, 1236 | default=None, 1237 | help="Number of batches between each print to STDOUT. Default: same as LOG_INTERVAL.", 1238 | ) 1239 | group.add_argument( 1240 | "--log-wandb", 1241 | action="store_true", 1242 | help="Log results with Weights & Biases https://wandb.ai", 1243 | ) 1244 | group.add_argument( 1245 | "--disable-wandb", 1246 | "--no-wandb", 1247 | dest="disable_wandb", 1248 | action="store_true", 1249 | help="Overrides --log-wandb and ensures wandb is always disabled.", 1250 | ) 1251 | group.add_argument( 1252 | "--wandb-entity", 1253 | type=str, 1254 | help=( 1255 | "The entity (organization) within which your wandb project is" 1256 | ' located. By default, this will be your "default location" set on' 1257 | " wandb at https://wandb.ai/settings" 1258 | ), 1259 | ) 1260 | group.add_argument( 1261 | "--wandb-project", 1262 | type=str, 1263 | default="template-experiment", 1264 | help="Name of project on wandb, where these runs will be saved. Default: %(default)s", 1265 | ) 1266 | group.add_argument( 1267 | "--run-name", 1268 | type=str, 1269 | help="Human-readable identifier for the model run or job. Used to name the run on wandb.", 1270 | ) 1271 | group.add_argument( 1272 | "--run-id", 1273 | type=str, 1274 | help="Unique identifier for the model run or job. Used as the run ID on wandb.", 1275 | ) 1276 | 1277 | return parser 1278 | 1279 | 1280 | def cli(): 1281 | r"""Command-line interface for model training.""" 1282 | parser = get_parser() 1283 | config = parser.parse_args() 1284 | # Handle disable_wandb overriding log_wandb and forcing it to be disabled. 1285 | if config.disable_wandb: 1286 | config.log_wandb = False 1287 | del config.disable_wandb 1288 | # Set protoval_split_id from prototyping, and turn prototyping into a bool 1289 | config.prototyping = config.protoval_split_id is not None 1290 | return run(config) 1291 | 1292 | 1293 | if __name__ == "__main__": 1294 | cli() 1295 | -------------------------------------------------------------------------------- /template_experiment/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import secrets 4 | import string 5 | import warnings 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def get_num_cpu_available(): 12 | r""" 13 | Get the number of available CPU cores. 14 | 15 | Uses :func:`os.sched_getaffinity` if available, otherwise falls back to 16 | :func:`os.cpu_count`. 17 | 18 | Returns 19 | ------- 20 | ncpus : int 21 | The number of available CPU cores. 22 | """ 23 | try: 24 | # This is the number of CPUs available to this process, which may 25 | # be smaller than the number of CPUs on the system. 26 | # This command is only available on Unix-like systems. 27 | return len(os.sched_getaffinity(0)) 28 | except Exception: 29 | # Fall-back for Windows or other systems which don't support sched_getaffinity 30 | warnings.warn( 31 | "Unable to determine number of available CPUs available to this python" 32 | " process specifically. Falling back to the total number of CPUs on the" 33 | " system.", 34 | RuntimeWarning, 35 | stacklevel=2, 36 | ) 37 | return os.cpu_count() 38 | 39 | 40 | def init_or_resume_wandb_run(output_dir, basename="wandb_runid.txt", **kwargs): 41 | r""" 42 | Initialize a wandb run, resuming if one already exists for this job. 43 | 44 | Parameters 45 | ---------- 46 | output_dir : str 47 | Path to output directory for this job, where the wandb run id file 48 | will be stored. 49 | basename : str, default="wandb_runid.txt" 50 | Basename of wandb run id file. 51 | **kwargs 52 | Additional parameters to be passed through to ``wandb.init``. 53 | Examples include, ``"project"``, ``"name"``, ``"config"``. 54 | 55 | Returns 56 | ------- 57 | run : :class:`wandb.sdk.wandb_run.Run` 58 | A wandb Run object, as returned by :func:`wandb.init`. 59 | """ 60 | import wandb 61 | 62 | if not output_dir: 63 | wandb_id_file_path = None 64 | else: 65 | wandb_id_file_path = os.path.join(output_dir, basename) 66 | if wandb_id_file_path and os.path.isfile(wandb_id_file_path): 67 | # If the run_id was previously saved, get the id and resume it 68 | with open(wandb_id_file_path, "r") as f: 69 | resume_id = f.read() 70 | run = wandb.init(resume=resume_id, **kwargs) 71 | else: 72 | # If the run id file doesn't exist, create a new wandb run 73 | run = wandb.init(**kwargs) 74 | if wandb_id_file_path: 75 | # Write the run id to the expected file for resuming later 76 | with open(wandb_id_file_path, "w") as f: 77 | f.write(run.id) 78 | 79 | return run 80 | 81 | 82 | def set_rng_seeds_fixed(seed, all_gpu=True): 83 | r""" 84 | Seed pseudo-random number generators throughout python's random module, numpy.random, and pytorch. 85 | 86 | Parameters 87 | ---------- 88 | seed : int 89 | The random seed to use. Should be between 0 and 4294967295 to ensure 90 | unique behaviour for numpy, and between 0 and 18446744073709551615 to 91 | ensure unique behaviour for pytorch. 92 | all_gpu : bool, default=True 93 | Whether to set the torch seed on every GPU. If ``False``, only the 94 | current GPU has its seed set. 95 | 96 | Returns 97 | ------- 98 | None 99 | """ 100 | # Note that random, numpy, and pytorch all use different RNG methods/ 101 | # implementations, so you'll get different outputs from each of them even 102 | # if you use the same seed for them all. 103 | # We use modulo with the maximum values permitted for np.random and torch. 104 | # If your seed exceeds 4294967295, numpy will have looped around to a 105 | random.seed(seed) 106 | np.random.seed(seed % 0xFFFF_FFFF) 107 | torch.manual_seed(seed % 0xFFFF_FFFF_FFFF_FFFF) 108 | if all_gpu: 109 | torch.cuda.manual_seed_all(seed % 0xFFFF_FFFF_FFFF_FFFF) 110 | else: 111 | torch.cuda.manual_seed(seed % 0xFFFF_FFFF_FFFF_FFFF) 112 | 113 | 114 | def worker_seed_fn(worker_id): 115 | r""" 116 | Seed builtin :mod:`random` and :mod:`numpy`. 117 | 118 | A worker initialization function for :class:`torch.utils.data.DataLoader` 119 | objects which seeds builtin :mod:`random` and :mod:`numpy` with the 120 | torch seed for the worker. 121 | 122 | Parameters 123 | ---------- 124 | worker_id : int 125 | The ID of the worker. 126 | """ 127 | worker_seed = torch.utils.data.get_worker_info().seed 128 | random.seed(worker_seed) 129 | np.random.seed(worker_seed % 0xFFFF_FFFF) 130 | 131 | 132 | def determine_epoch_seed(seed, epoch): 133 | r""" 134 | Determine the seed to use for the random number generator for a given epoch. 135 | 136 | Parameters 137 | ---------- 138 | seed : int 139 | The original random seed, used to generate the sequence of seeds for 140 | the epochs. 141 | epoch : int 142 | The epoch for which to determine the seed. 143 | 144 | Returns 145 | ------- 146 | epoch_seed : int 147 | The seed to use for the random number generator for the given epoch. 148 | """ 149 | if epoch == 0: 150 | raise ValueError("Epoch must be indexed from 1, not 0.") 151 | random.seed(seed) 152 | # Generate a seed for every epoch so far. We have to traverse the 153 | # series of RNG calls to reach the next value (our next seed). The final 154 | # value is the one for our current epoch. 155 | # N.B. We use random.randint instead of torch.randint because torch.randint 156 | # only supports int32 at most (max value of 0xFFFF_FFFF). 157 | for _ in range(epoch): 158 | epoch_seed = random.randint(0, 0xFFFF_FFFF_FFFF_FFFF) 159 | return epoch_seed 160 | 161 | 162 | def generate_id(length: int = 8) -> str: 163 | r""" 164 | Generate a random base-36 string of `length` digits. 165 | 166 | Parameters 167 | ---------- 168 | length : int, default=8 169 | Length of the string to generate. 170 | 171 | Returns 172 | ------- 173 | id : str 174 | The randomly generated id. 175 | """ 176 | # Borrowed from https://github.com/wandb/wandb/blob/0e00efd/wandb/sdk/lib/runid.py 177 | # under the MIT license. 178 | # There are ~2.8T base-36 8-digit strings. If we generate 210k ids, 179 | # we'll have a ~1% chance of collision. 180 | alphabet = string.ascii_lowercase + string.digits 181 | return "".join(secrets.choice(alphabet) for _ in range(length)) 182 | 183 | 184 | def count_parameters(model, only_trainable=True): 185 | r""" 186 | Count the number of (trainable) parameters within a model and its children. 187 | 188 | Parameters 189 | ---------- 190 | model : torch.nn.Model 191 | The parametrized model. 192 | only_trainable : bool, optional 193 | Whether the count should be restricted to only trainable parameters 194 | (default), otherwise all parameters are included. 195 | Default is ``True``. 196 | 197 | Returns 198 | ------- 199 | int 200 | Total number of (trainable) parameters possessed by the model. 201 | """ 202 | if only_trainable: 203 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 204 | else: 205 | return sum(p.numel() for p in model.parameters()) 206 | 207 | 208 | @torch.no_grad() 209 | def concat_all_gather(tensor, **kwargs): 210 | r""" 211 | Gather a tensor over all processes and concatenate them into one. 212 | 213 | Similar to :func:`torch.distributed.all_gather`, except this function 214 | concatenates the result into a single tensor instead of a list of tensors. 215 | 216 | Parameters 217 | ---------- 218 | tensor : torch.Tensor 219 | The distributed tensor on the current process. 220 | group : ProcessGroup, optional 221 | The process group to work on. If ``None``, the default process group 222 | will be used. 223 | async_op : bool, default=False 224 | Whether this op should be an async op. 225 | 226 | Returns 227 | ------- 228 | gathered_tensor : torch.Tensor 229 | The contents of ``tensor`` from every distributed process, gathered 230 | together. None of the entries support a gradient. 231 | 232 | Warning 233 | ------- 234 | As with :func:`torch.distributed.all_gather`, this has no gradient. 235 | """ 236 | world_size = torch.distributed.get_world_size() 237 | tensors_gather = [torch.zeros_like(tensor) for _ in range(world_size)] 238 | torch.distributed.all_gather(tensors_gather, tensor, **kwargs) 239 | output = torch.cat(tensors_gather, dim=0) 240 | return output 241 | 242 | 243 | @torch.no_grad() 244 | def concat_all_gather_ragged(tensor, **kwargs): 245 | r""" 246 | Gather a tensor over all processes and concatenate them into one. 247 | 248 | This version supports ragged tensors where the first dimension is not the 249 | same across all processes. 250 | 251 | Parameters 252 | ---------- 253 | tensor : torch.Tensor 254 | The distributed tensor on the current process. The equivalent tensors 255 | on the other processes may differ in shape only in their first 256 | dimension. 257 | group : ProcessGroup, optional 258 | The process group to work on. If ``None``, the default process group 259 | will be used. 260 | async_op : bool, default=False 261 | Whether this op should be an async op. 262 | 263 | Returns 264 | ------- 265 | gathered_tensor : torch.Tensor 266 | The contents of ``tensor`` from every distributed process, gathered 267 | together. None of the entries support a gradient. 268 | 269 | Warning 270 | ------- 271 | As with :func:`torch.distributed.all_gather`, this has no gradient. 272 | """ 273 | world_size = torch.distributed.get_world_size() 274 | # Gather the lengths of the tensors from all processes 275 | local_length = torch.tensor(tensor.shape[0], device=tensor.device) 276 | all_length = [torch.zeros_like(local_length) for _ in range(world_size)] 277 | torch.distributed.all_gather(all_length, local_length, **kwargs) 278 | # We will have to pad them to be the size of the longest tensor 279 | max_length = max(x.item() for x in all_length) 280 | 281 | # Pad our tensor on the current process 282 | length_diff = max_length - local_length.item() 283 | if length_diff: 284 | pad_size = (length_diff, *tensor.shape[1:]) 285 | padding = torch.zeros(pad_size, device=tensor.device, dtype=tensor.dtype) 286 | tensor = torch.cat((tensor, padding), dim=0) 287 | 288 | # Gather the padded tensors from all processes 289 | all_tensors_padded = [torch.zeros_like(tensor) for _ in range(world_size)] 290 | torch.distributed.all_gather(all_tensors_padded, tensor, **kwargs) 291 | # Remove padding 292 | all_tensors = [] 293 | for tensor_i, length_i in zip(all_tensors_padded, all_length): 294 | all_tensors.append(tensor_i[:length_i]) 295 | 296 | # Concatenate the tensors 297 | output = torch.cat(all_tensors, dim=0) 298 | return output 299 | --------------------------------------------------------------------------------