├── .flake8
├── .github
└── workflows
│ └── pre-commit.yaml
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.rst
├── pyproject.toml
├── requirements-dev.txt
├── requirements-notebook.txt
├── requirements-train.txt
├── requirements.txt
├── setup.py
├── slogs
└── .placeholder
├── slurm
├── notebook.slrm
├── train.slrm
└── utils
│ ├── report_env_config.sh
│ ├── report_repo.sh
│ └── report_slurm_config.sh
└── template_experiment
├── __init__.py
├── __meta__.py
├── data_transformations.py
├── datasets.py
├── encoders.py
├── evaluation.py
├── io.py
├── train.py
└── utils.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | # E203: whitespace before ":". Sometimes violated by black.
4 | # E402: Module level import not at top of file. Violated by lazy imports.
5 | # D100-D107: Missing docstrings
6 | # D200: One-line docstring should fit on one line with quotes.
7 | extend-ignore = E203,E402,D100,D101,D102,D103,D104,D105,D106,D107,D200
8 | docstring-convention = numpy
9 | # F401: Module imported but unused.
10 | # Ignore missing docstrings within unit testing functions.
11 | per-file-ignores = **/__init__.py:F401 **/tests/:D100,D101,D102,D103,D104,D105,D106,D107
12 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yaml:
--------------------------------------------------------------------------------
1 | # pre-commit workflow
2 | #
3 | # Ensures the codebase passes the pre-commit stack.
4 | # We run this on GHA to catch issues in commits from contributors who haven't
5 | # set up pre-commit.
6 |
7 | name: pre-commit
8 |
9 | on: [push, pull_request]
10 |
11 | jobs:
12 | pre-commit:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 | - uses: actions/setup-python@v5
17 | with:
18 | python-version: "3.10"
19 | cache: pip
20 | - uses: pre-commit/action@v3.0.0
21 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/PyCQA/isort
3 | rev: 6.0.0
4 | hooks:
5 | - id: isort
6 | name: isort
7 | args: ["--profile=black"]
8 | - id: isort
9 | name: isort (cython)
10 | types: [cython]
11 | args: ["--profile=black"]
12 |
13 | - repo: https://github.com/psf/black
14 | rev: 25.1.0
15 | hooks:
16 | - id: black
17 | types: [python]
18 |
19 | - repo: https://github.com/asottile/blacken-docs
20 | rev: 1.19.1
21 | hooks:
22 | - id: blacken-docs
23 | additional_dependencies: ["black==25.1.0"]
24 |
25 | - repo: https://github.com/kynan/nbstripout
26 | rev: 0.8.1
27 | hooks:
28 | - id: nbstripout
29 |
30 | - repo: https://github.com/nbQA-dev/nbQA
31 | rev: 1.9.1
32 | hooks:
33 | - id: nbqa-isort
34 | args: ["--profile=black"]
35 | - id: nbqa-black
36 | additional_dependencies: ["black==25.1.0"]
37 | - id: nbqa-flake8
38 |
39 | - repo: https://github.com/pre-commit/pygrep-hooks
40 | rev: v1.10.0
41 | hooks:
42 | - id: python-check-blanket-noqa
43 | - id: python-check-blanket-type-ignore
44 | - id: python-check-mock-methods
45 | - id: python-no-log-warn
46 | - id: rst-backticks
47 | - id: rst-directive-colons
48 | types: [text]
49 | - id: rst-inline-touching-normal
50 | types: [text]
51 |
52 | - repo: https://github.com/pre-commit/pre-commit-hooks
53 | rev: v5.0.0
54 | hooks:
55 | - id: check-added-large-files
56 | - id: check-ast
57 | - id: check-builtin-literals
58 | - id: check-case-conflict
59 | - id: check-docstring-first
60 | - id: check-shebang-scripts-are-executable
61 | - id: check-merge-conflict
62 | - id: check-json
63 | - id: check-toml
64 | - id: check-xml
65 | - id: check-yaml
66 | - id: debug-statements
67 | - id: destroyed-symlinks
68 | - id: detect-private-key
69 | - id: end-of-file-fixer
70 | exclude: ^LICENSE|\.(html|csv|txt|svg|py)$
71 | - id: pretty-format-json
72 | args: ["--autofix", "--no-ensure-ascii", "--no-sort-keys"]
73 | exclude_types: [jupyter]
74 | - id: requirements-txt-fixer
75 | - id: trailing-whitespace
76 | args: [--markdown-linebreak-ext=md]
77 | exclude: \.(html|svg)$
78 |
79 | - repo: https://github.com/asottile/setup-cfg-fmt
80 | rev: v2.7.0
81 | hooks:
82 | - id: setup-cfg-fmt
83 |
84 | - repo: https://github.com/PyCQA/flake8
85 | rev: 7.1.2
86 | hooks:
87 | - id: flake8
88 | additional_dependencies:
89 | - flake8-2020
90 | - flake8-bugbear
91 | - flake8-comprehensions
92 | - flake8-implicit-str-concat
93 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICEN[CS]E*
2 | include pyproject.toml
3 | include README.rst
4 | include requirements*.txt
5 |
6 | global-exclude *.py[co]
7 | global-exclude __pycache__
8 | global-exclude *~
9 | global-exclude *.ipynb_checkpoints/*
10 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | |SLURM| |preempt| |PyTorch| |wandb| |pre-commit| |black|
2 |
3 | PyTorch Experiment Template
4 | ===========================
5 |
6 | This repository gives a fully-featured template or skeleton for new PyTorch
7 | experiments for use on the Vector Institute cluster.
8 | It supports:
9 |
10 | - multi-node, multi-GPU jobs using DistributedDataParallel (DDP),
11 | - preemption handling which gracefully stops and resumes the job,
12 | - logging experiments to Weights & Biases,
13 | - configuration seemlessly scales up as the amount of resources increases.
14 |
15 | If you want to run the example experiment, you can clone the repository as-is
16 | without modifying it. Then see the `Installation`_ and
17 | `Executing the example experiment`_ sections below.
18 |
19 | If you want to create a new repository from this template, you should follow
20 | the `Creating a git repository using this template`_ instructions first.
21 |
22 |
23 | Creating a git repository using this template
24 | ---------------------------------------------
25 |
26 | When creating a new repository from this template, these are the steps to follow:
27 |
28 | #. *Don't click the fork button.*
29 | The fork button is for making a new template based in this one, not for using the template to make a new repository.
30 |
31 | #. Create repository.
32 |
33 | #. **New GitHub repository**.
34 |
35 | You can create a new repository on GitHub from this template by clicking the `Use this template `_ button.
36 |
37 | Then clone your new repository to your local system [pseudocode]:
38 |
39 | .. code-block:: bash
40 |
41 | git clone git@github.com:your_org/your_repo_name.git
42 | cd your_repo_name
43 |
44 | #. **New repository not on GitHub**.
45 |
46 | Alternatively, if your new repository is not going to be on GitHub, you can download `this repo as a zip `_ and work from there.
47 |
48 | Note that this zip does not include the .gitignore and .gitattributes files (because GitHub automatically omits them, which is usually helpful but is not for our purposes).
49 | Thus you will also need to download the `.gitignore `__ and `.gitattributes `__ files.
50 |
51 | #. Delete the LICENSE file and replace it with a LICENSE file of your own choosing.
52 | If the code is intended to be freely available for anyone to use, use an `open source license`_, such as `MIT License`_ or `GPLv3`_.
53 | If you don't want your code to be used by anyone else, add a LICENSE file which just says:
54 |
55 | .. code-block:: none
56 |
57 | Copyright (c) CURRENT_YEAR, YOUR_NAME
58 |
59 | All rights reserved.
60 |
61 | Note that if you don't include a LICENSE file, you will still have copyright over your own code (this copyright is automatically granted), and your code will be private source (technically nobody else will be permitted to use it, even if you make your code publicly available).
62 |
63 | #. Edit the file ``template_experiment/__meta__.py`` to contain your author and repo details.
64 |
65 | name
66 | The name as it would be on PyPI (users will do ``pip install new_name_here``).
67 | It is `recommended `__ to use a name all lowercase, runtogetherwords but if separators are needed hyphens are preferred over underscores.
68 |
69 | path
70 | The path to the package. What you will rename the directory ``template_experiment``.
71 | `Should be `__ the same as ``name``, but now hyphens are disallowed and should be swapped for underscores.
72 | By default, this is automatically inferred from ``name``.
73 |
74 | license
75 | Should be the name of the license you just picked and put in the LICENSE file (e.g. ``MIT`` or ``GPLv3``).
76 |
77 | Other fields to enter should be self-explanatory.
78 |
79 | #. Rename the directory ``template_experiment`` to be the ``path`` variable you just added to ``__meta__.py``:
80 |
81 | .. code-block:: bash
82 |
83 | # Define PROJ_HYPH as your actual project name (use hyphens instead of underscores or spaces)
84 | PROJ_HYPH=your-actual-project-name-with-hyphens-for-spaces
85 |
86 | # Automatically convert hyphens to underscores to get the directory name
87 | PROJ_DIRN="${PROJ_HYPH//-/_}"
88 | # Rename the directory
89 | mv template_experiment "$PROJ_DIRN"
90 |
91 | #. Change references to ``template_experiment`` and ``template-experiment``
92 | to your path variable.
93 |
94 | This can be done with the sed command:
95 |
96 | .. code-block:: bash
97 |
98 | sed -i "s/template_experiment/$PROJ_DIRN/" "$PROJ_DIRN"/*.py setup.py pyproject.toml slurm/*.slrm .pre-commit-config.yaml
99 | sed -i "s/template-experiment/$PROJ_HYPH/" "$PROJ_DIRN"/*.py setup.py pyproject.toml slurm/*.slrm .pre-commit-config.yaml
100 |
101 | Which will make changes in the following places.
102 |
103 | - In ``setup.py``, `L51 `__::
104 |
105 | exec(read("template_experiment/__meta__.py"), meta)
106 |
107 | - In ``__meta__.py``, `L2,4 `__::
108 |
109 | name = "template-experiment"
110 |
111 | - In ``train.py``, `L18-19 `__::
112 |
113 | from template_experiment import data_transformations, datasets, encoders, utils
114 | from template_experiment.evaluation import evaluate
115 |
116 | - In ``train.py``, `L1260 `__::
117 |
118 | group.add_argument(
119 | "--wandb-project",
120 | type=str,
121 | default="template-experiment",
122 | help="Name of project on wandb, where these runs will be saved.",
123 | )
124 |
125 | - In ``slurm/train.slrm``, `L19 `__::
126 |
127 | #SBATCH --job-name=template-experiment # Set this to be a shorthand for your project's name.
128 |
129 | - In ``slurm/train.slrm``, `L23 `__::
130 |
131 | PROJECT_NAME="template-experiment"
132 |
133 | - In ``slurm/notebook.slrm``, `L15 `__::
134 |
135 | PROJECT_NAME="template-experiment"
136 |
137 | #. Swap out the contents of ``README.rst`` with an initial description of your project.
138 | If you prefer, you can use markdown (``README.md``) instead of rST:
139 |
140 | .. code-block:: bash
141 |
142 | git rm README.rst
143 | # touch README.rst
144 | touch README.md && sed -i "s/.rst/.md/" MANIFEST.in
145 |
146 | #. Add your changes to the repo's initial commit and force-push your changes:
147 |
148 | .. code-block:: bash
149 |
150 | git add .
151 | git commit --amend -m "Initial commit"
152 | git push --force
153 |
154 | .. _PEP-8: https://www.python.org/dev/peps/pep-0008/
155 | .. _open source license: https://choosealicense.com/
156 | .. _MIT License: https://choosealicense.com/licenses/mit/
157 | .. _GPLv3: https://choosealicense.com/licenses/gpl-3.0/
158 |
159 |
160 | Installation
161 | ------------
162 |
163 | I recommend using miniconda to create an environment for your project.
164 | By using one virtual environment dedicated to each project, you are ensured
165 | stability - if you upgrade a package for one project, it won't affect the
166 | environments you already have established for the others.
167 |
168 | Vector one-time set-up
169 | ~~~~~~~~~~~~~~~~~~~~~~
170 |
171 | Run this code block to install miniconda before you make your first environment
172 | (you don't need to re-run this every time you start a new project):
173 |
174 | .. code-block:: bash
175 |
176 | # Login to Vector
177 | ssh USERNAME@v.vectorinstitute.ai
178 | # Enter your password and 2FA code to login.
179 | # Run the rest of this code block on the gateway node of the cluster that
180 | # you get to after establishing the ssh connection.
181 |
182 | # Make a screen session for us to work in
183 | screen;
184 |
185 | # Download miniconda to your ~/Downloads directory
186 | mkdir -p $HOME/Downloads;
187 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \
188 | -O "$HOME/Downloads/miniconda.sh";
189 | # Install miniconda to the home directory, if it isn't there already.
190 | if [ ! -d "$HOME/miniconda/bin" ]; then
191 | if [ -d "$HOME/miniconda" ]; then rm -r "$HOME/miniconda"; fi;
192 | bash $HOME/Downloads/miniconda.sh -b -p "$HOME/miniconda";
193 | fi;
194 |
195 | # Add conda to the PATH environment variable
196 | export PATH="$HOME/miniconda/bin:$PATH";
197 |
198 | # Automatically say yes to any check from conda (optional)
199 | conda config --set always_yes yes
200 |
201 | # Set the command prompt prefix to be the name of the current venv
202 | conda config --set env_prompt '({name}) '
203 |
204 | # Add conda setup to your ~/.bashrc file
205 | conda init;
206 |
207 | # Now exit this screen session (you have to exit the current terminal
208 | # session after conda init, and exiting the screen session achieves that
209 | # without closing the ssh connection)
210 | exit;
211 |
212 | Follow this next step if you want to use `Weights and Biases`_ to log your experiments.
213 | Weights and Biases is an online service for tracking your experiments which is
214 | free for academic usage.
215 | To set this up, you need to install the wandb pip package, and you'll need to
216 | `create a Weights and Biases account `_ if you don't already have one:
217 |
218 | .. code-block:: bash
219 |
220 | # (On v.vectorinstitute.ai)
221 | # You need to run the conda setup instructions that miniconda added to
222 | # your ~/.bashrc file so that conda is on your PATH and you can run it.
223 | # Either create a new screen session - when you launch a new screen session,
224 | # bash automatically runs source ~/.bashrc
225 | screen;
226 | # Or stay in your current window and explicitly yourself run
227 | source ~/.bashrc
228 | # Either way, you'll now see "(miniconda)" at the left of your command prompt,
229 | # indicating miniconda is on your PATH and using your default conda environment.
230 |
231 | # Install wandb
232 | pip install wandb
233 |
234 | # Log in to wandb at the command prompt
235 | wandb login
236 | # wandb asks you for your username, then password
237 | # Then wandb creates a file in ~/.netrc which it uses to automatically login in the future
238 |
239 | .. _Weights and Biases: https://wandb.ai/
240 | .. _wandb-signup: https://wandb.ai/login?signup=true
241 |
242 |
243 | Project one-time set-up
244 | ~~~~~~~~~~~~~~~~~~~~~~~
245 |
246 | Run this code block once every time you start a new project from this template.
247 | Change ENVNAME to equal the name of your project. This code will then create a
248 | new virtual environment to use for the project:
249 |
250 | .. code-block:: bash
251 |
252 | # (On v.vectorinstitute.ai)
253 | # You need to run the conda setup instructions that miniconda added to
254 | # your ~/.bashrc file so that conda is on your PATH and you can run it.
255 | # Either create a new screen session - when you launch a new screen session,
256 | # bash automatically runs source ~/.bashrc
257 | screen;
258 | # Or stay in your current window and explicitly yourself run
259 | source ~/.bashrc
260 | # Either way, you'll now see "(miniconda)" at the left of your command prompt,
261 | # indicating miniconda is on your PATH and using your default conda environment.
262 |
263 | # Now run the following one-time setup per virtual environment (i.e. once per project)
264 |
265 | # Pick a name for the new environment.
266 | # It should correspond to the name of your project (hyphen separated, no spaces)
267 | ENVNAME=template-experiment
268 |
269 | # Create a python3.x conda environment, with pip installed, with this name.
270 | conda create -y --name "$ENVNAME" -q python=3 pip
271 |
272 | # Activate the environment
273 | conda activate "$ENVNAME"
274 | # The command prompt should now have your environment at the left of it, e.g.
275 | # (template-experiment) slowe@v3:~$
276 |
277 |
278 | Resuming work on an existing project
279 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
280 |
281 | Run this code block when you want to resume work on an existing project:
282 |
283 | .. code-block:: bash
284 |
285 | # (On v.vectorinstitute.ai)
286 | # Run conda setup in ~/.bashrc if you it hasn't already been run in this
287 | # terminal session
288 | source ~/.bashrc
289 | # The command prompt should now say (miniconda) at the left of it.
290 |
291 | # Activate the environment
292 | conda activate template-experiment
293 | # The command prompt should now have your environment at the left of it, e.g.
294 | # (template-experiment) slowe@v3:~$
295 |
296 |
297 | Executing the example experiment
298 | --------------------------------
299 |
300 | The following commands describe how to setup and run the example repository
301 | in its unmodified state.
302 |
303 | To run the code in a repository you have
304 | `created from this template `_,
305 | replace ``template-experiment`` with the name of your package and
306 | ``template_experiment`` with the name of your package directory, etc.
307 |
308 | Set-up
309 | ~~~~~~
310 |
311 | #. If you haven't already, then follow the `Vector one-time set-up`_
312 | instructions.
313 |
314 | #. Then clone the repository:
315 |
316 | .. code-block:: bash
317 |
318 | git clone git@github.com:scottclowe/pytorch-experiment-template.git
319 | cd pytorch-experiment-template
320 |
321 | #. Run the `Project one-time set-up`_ (using ``template-experiment`` as
322 | the environment name).
323 |
324 | #. With the project's conda environment activated, install the package and its
325 | training dependencies:
326 |
327 | .. code-block:: bash
328 |
329 | pip install --editable .[train]
330 |
331 | This step will typically take 5-10 minutes to run.
332 |
333 | #. Check the installation by running the help command:
334 |
335 | .. code-block:: bash
336 |
337 | python template_experiment/train.py -h
338 |
339 | This should print the help message for the training script.
340 |
341 |
342 | Example commands
343 | ~~~~~~~~~~~~~~~~
344 |
345 | - To run the default training command locally:
346 |
347 | .. code-block:: bash
348 |
349 | python template_experiment/train.py
350 |
351 | or alternatively::
352 |
353 | template-experiment-train
354 |
355 | - Run the default training command with on the cluster with SLURM.
356 | First, ssh into the cluster and cd to the project repository.
357 | You don't need to activate the project's conda environment.
358 | Then use sbatch to add your SLURM job to the queue:
359 |
360 | .. code-block:: bash
361 |
362 | sbatch slurm/train.slrm
363 |
364 | - You can supply arguments to sbatch by including them before the path to the
365 | SLURM script.
366 | Arguments set on the command prompt like this will override the arguments in
367 | ``slurm/train.slrm``.
368 | This is useful for customizing the job name, for example:
369 |
370 | .. code-block:: bash
371 |
372 | sbatch --job-name=exp_cf10_rn18 slurm/train.slrm
373 |
374 | I recommend you should pretty much always customize the name of your job.
375 | The custom job name will be visible in the output of ``squeue -u "$USER"``
376 | when browsing your active jobs (helpful if you have multiple jobs running
377 | and need to check on their status or cancel one of them).
378 | When using this codebase, the custom job name is also used in the path to the
379 | checkpoint, the path to the SLURM log file, and the name of the job on wandb.
380 |
381 | - Any arguments you include after ``slurm/train.slrm`` will be passed through to train.py.
382 |
383 | For example, you can specify to use a pretrained model:
384 |
385 | .. code-block:: bash
386 |
387 | sbatch --job-name=exp_cf10_rn18-pt slurm/train.slrm --dataset=cifar10 --pretrained
388 |
389 | change the architecture and dataset:
390 |
391 | .. code-block:: bash
392 |
393 | sbatch --job-name=exp_cf100_vit-pt \
394 | slurm/train.slrm --dataset=cifar100 --model=vit_small_patch16_224 --pretrained
395 |
396 | or change the learning rate of the encoder:
397 |
398 | .. code-block:: bash
399 |
400 | sbatch --job-name=exp_cf10_rn18-pt_enc-lr-0.01 \
401 | slurm/train.slrm --dataset=cifar10 --pretrained --lr-encoder-mult=0.01
402 |
403 | - You can trivially scale up the job to run across multiple GPUs, either by
404 | changing the gres argument to use more of the GPUs on the node (up to 8 GPUs
405 | per node on the t4v2 partition, 4 GPUs per node otherwise):
406 |
407 | .. code-block:: bash
408 |
409 | sbatch --job-name=exp_cf10_rn18-pt_4gpu --gres=gpu:4 slurm/train.slrm --pretrained
410 |
411 | or increasing the number of nodes being requested:
412 |
413 | .. code-block:: bash
414 |
415 | sbatch --job-name=exp_cf10_rn18-pt_2x1gpu --nodes=2 slurm/train.slrm --pretrained
416 |
417 | or both:
418 |
419 | .. code-block:: bash
420 |
421 | sbatch --job-name=exp_cf10_rn18-pt_2x4gpu --nodes=2 --gres=gpu:4 slurm/train.slrm --pretrained
422 |
423 | In each case, the amount of memory and CPUs requested in the SLURM job will
424 | automatically be scaled up with the number of GPUs requested.
425 | The total batch size will be scaled up by the number of GPUs requested too.
426 |
427 | As you run these commands, you can see the results logged on wandb at
428 | https://wandb.ai/your-username/template-experiment
429 |
430 |
431 | Jupyter notebook
432 | ~~~~~~~~~~~~~~~~
433 |
434 | You can use the script ``slurm/notebook.slrm`` to launch a Jupyter notebook
435 | server on one of the interactive compute nodes.
436 | This uses the methodology of https://support.vectorinstitute.ai/jupyter_notebook
437 |
438 | You'll need to install jupyter into your conda environment to launch the notebook.
439 | After activating the environment for this project, run:
440 |
441 | .. code-block:: bash
442 |
443 | pip install -r requirements-notebook.txt
444 |
445 | To launch a notebook server and connect to it on your local machine, perform
446 | the following steps.
447 |
448 | #. Run the notebook SLURM script to launch the jupyter notebook:
449 |
450 | .. code-block:: bash
451 |
452 | sbatch slurm/notebook.slrm
453 |
454 | The job will launch on one of the interactive nodes, and will acquire a
455 | random port on that node to serve the notebook on.
456 |
457 | #. Wait for the job to start running. You can monitor it with:
458 |
459 | .. code-block:: bash
460 |
461 | squeue --me
462 |
463 | Note the job id of the notebook job. e.g.:
464 |
465 | .. code-block:: none
466 |
467 | (template-experiment) slowe@v2:~/pytorch-experiment-template$ squeue --me
468 | JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)
469 | 10618891 interacti jnb slowe R 1:07 1 gpu026
470 |
471 | Here we can see our JOBID is 10618891, and it is running on node gpu026.
472 |
473 | #. Inspect the output of the job with:
474 |
475 | .. code-block:: bash
476 |
477 | cat jnb_JOBID.out
478 |
479 | e.g.:
480 |
481 | .. code-block:: bash
482 |
483 | cat jnb_10618891.out
484 |
485 | The output will contain the port number that the notebook server is using,
486 | and the token as follows:
487 |
488 | .. code-block:: none
489 |
490 | To access the server, open this file in a browser:
491 | file:///ssd005/home/slowe/.local/share/jupyter/runtime/jpserver-7885-open.html
492 | Or copy and paste one of these URLs:
493 | http://gpu026:47201/tree?token=f54c10f52e3dad08e19101149a54985d1561dca7eec96b29
494 | http://127.0.0.1:47201/tree?token=f54c10f52e3dad08e19101149a54985d1561dca7eec96b29
495 |
496 | Here we can see the job is on node gpu026 and the notebook is being served
497 | on port 47201.
498 | We will need to use the token f54c10f52e3dad08e19101149a54985d1561dca7eec96b29
499 | to log in to the notebook.
500 |
501 | #. On your local machine, use ssh to forward the port from the compute node to
502 | your local machine:
503 |
504 | .. code-block:: bash
505 |
506 | ssh USERNAME@v.vectorinstitute.ai -N -L 8887:gpu026:47201
507 |
508 | You need to replace USERNAME with your Vector username, gpu026 with the node
509 | your job is running on, and 47201 with the port number from the previous
510 | step.
511 | In this example, the local port which the notebook is being forwarded to is
512 | port 8887.
513 |
514 | #. Open a browser on your local machine and navigate to http://localhost:8887
515 | (or whatever port you chose in the previous step):
516 |
517 | .. code-block:: bash
518 |
519 | sensible-browser http://localhost:8887
520 |
521 | You should see the Jupyter notebook interface.
522 | Copy the token from the URL shown in the log file and paste it into the
523 | ``Password or token: [ ] Log in`` box.
524 | You should now have access to the remote notebook server on your local
525 | machine.
526 |
527 | #. Once you are done working in your notebooks (and have saved your changes),
528 | make sure to end the job running the notebook with:
529 |
530 | .. code-block:: bash
531 |
532 | scancel JOBID
533 |
534 | e.g.:
535 |
536 | .. code-block:: bash
537 |
538 | scancel 10618891
539 |
540 | This will free up the interactive GPU node for other users to use.
541 |
542 | Note that you can skip the need to copy the access token if you
543 | `set up Jupyter notebook to use a password `_ instead.
544 |
545 | .. _jnb-password: https://saturncloud.io/blog/how-to-autoconfigure-jupyter-password-from-command-line/
546 |
547 |
548 | Features
549 | --------
550 |
551 | This template includes the following features.
552 |
553 |
554 | Scalable training script
555 | ~~~~~~~~~~~~~~~~~~~~~~~~
556 |
557 | The SLURM training script ``slurm/train.slrm`` will interface with the python
558 | training script ``template_experiment/train.py`` to train a model on multiple
559 | GPUs across, multiple nodes, using DistributedDataParallel_ (DDP).
560 |
561 | The SLURM script is configured to scale up the amount of RAM and CPUs requested
562 | with the GPUs requested.
563 |
564 | The arguments to the python script control the batch size per GPU, and the
565 | learning rate for a fixed batch size of 128 samples.
566 | The total batch size will automatically scale up when deployed on more GPUs,
567 | and the learning rate will automatically scale up linearly with the total batch
568 | size. (This is the linear scaling rule from `Training ImageNet in 1 Hour`_.)
569 |
570 | .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
571 | .. _Training ImageNet in 1 Hour: https://arxiv.org/abs/1706.02677
572 |
573 |
574 | Preemptable
575 | ~~~~~~~~~~~
576 |
577 | Everything is set up to resume correctly if the job is interrupted by
578 | preemption.
579 |
580 |
581 | Checkpoints
582 | ~~~~~~~~~~~
583 |
584 | The training script will save a checkpoint every epoch, and will resume from
585 | this if the job is interrupted by preemption.
586 |
587 | The checkpoint for a job will be saved to the directory
588 | ``/checkpoint/USERNAME/PROJECT__JOBNAME__JOBID`` (with double-underscores
589 | between each category) along with a record of the conda environment and
590 | frozen pip requirements used to run the job in ``environment.yml`` and
591 | ``frozen-requirements.txt``.
592 |
593 |
594 | Log messages
595 | ~~~~~~~~~~~~
596 |
597 | Any print statements and error messages from the training script will be saved
598 | to the file ``slogs/JOBNAME__JOBID_ARRAYID.out``.
599 | Only the output from the rank 0 worker (the worker which saves the
600 | checkpoints and sends logs to wandb) will be saved to this file.
601 | When using multiple nodes, the output from each node will be saved to a
602 | separate file: ``slogs-inner/JOBNAME__JOBID_ARRAYID-NODERANK.out``.
603 |
604 | You can monitor the progress of a job that is currently running by monitoring
605 | the contents of its log file. For example:
606 |
607 | .. code-block:: bash
608 |
609 | tail -n 50 -f slogs/JOBNAME__JOBID_ARRAYID.out
610 |
611 |
612 | Weights and Biases
613 | ~~~~~~~~~~~~~~~~~~
614 |
615 | `Weights and Biases`_ (wandb) is an online service for tracking your
616 | experiments which is free for academic usage.
617 |
618 | This template repository is set up to automatically log your experiments, using
619 | the same job label across both SLURM and wandb.
620 |
621 | If the job is preempted, the wandb logging will resume to the same wandb job
622 | ID instead of spawning a new one.
623 |
624 |
625 | Random Number Generator (RNG) state
626 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
627 |
628 | All RNG states are configured based on the overall seed that is set with the
629 | ``--seed`` argument to ``train.py``.
630 |
631 | When running ``train.py`` directly, the seed is **not** set by default, so
632 | behaviour will not be reproducible.
633 | You will need to include the argument ``--seed=0`` (for example), to make sure
634 | your experiments are reproducible.
635 |
636 | When running on SLURM with slurm/train.slrm, the seed **is** set by default.
637 | The seed used is equal the `array ID `_ of the job.
638 | This configuration lets you easily run the same job with multiple seeds in one
639 | sbatch command.
640 | Our default job array in ``slurm/train.slrm`` is ``--array=0``, so only one job
641 | will be launched, and that job will use the default seed of ``0``.
642 |
643 | To launch the same job 5 times, each with a different seed (1, 2, 3, 4, and 5):
644 |
645 | .. code-block:: bash
646 |
647 | sbatch --array=1-5 slurm/train.slrm
648 |
649 | or to use seeds 42 and 888:
650 |
651 | .. code-block:: bash
652 |
653 | sbatch --array=42,888 slurm/train.slrm
654 |
655 | or to use a randomly selected seed:
656 |
657 | .. code-block:: bash
658 |
659 | sbatch --array="$RANDOM" slurm/train.slrm
660 |
661 | The seed is used to set the following RNG states:
662 |
663 | - Each epoch gets its own RNG seed (derived from the overall seed and the epoch
664 | number).
665 | The RNG state is set with this seed at the start of each epoch. This makes it
666 | possible to resume from preemption without needing to save all the RNG states
667 | to the model checkpoint and restore them on resume.
668 |
669 | - Each GPU gets its own RNG seed, so any random operations such as dropout
670 | or random masking in the training script itself will be different on each
671 | GPU, but deterministically so.
672 |
673 | - The dataloader workers each have distinct seeds from each other for torch,
674 | numpy and python's random module, so randomly selected augmentations won't be
675 | replicated across workers.
676 | (Pytorch only sets up its own worker seeds correctly, leaving numpy and
677 | random mirrored across all workers.)
678 |
679 | **Caution:** To get *exactly* the same model produced when training with the
680 | same seed, you will need to run the training script with the ``--deterministic``
681 | flag to disable cuDNN's non-deterministic operations *and* use precisely the
682 | same number of GPU devices and CPU workers on each attempt.
683 | Without these steps, the model will be *almost* the same (because the initial
684 | seed for the model parameters was the same, and the training trajectory was
685 | very similar), but not *exactly* the same, due to (a) non-deterministic cuDNN
686 | operations (b) the batch size increasing with the number of devices
687 | (c) any randomized augmentation operations depending on the identity of the CPU
688 | worker, which will each have an offset seed.
689 |
690 | .. _slurm-job-array: https://slurm.schedmd.com/job_array.html
691 |
692 |
693 | Prototyping mode, with distinct val/test sets
694 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
695 |
696 | Initial experiments and hyperparameter searches should be performed without
697 | seeing the final test performance. They should be run only on a validation set.
698 | Unfortunately, many datasets do not come with a validation set, and it is easy
699 | to accidentally use the test set as a validation set, which can lead to
700 | overfitting the model selection on the test set.
701 |
702 | The image datasets implemented in ``template_experiment/datasets.py`` come with
703 | support for creating a validation set from the training set, which is separate
704 | from the test set. You should use this (with flag ``--prototyping``) during the
705 | initial model development steps and for any hyperparameter searches.
706 |
707 | Your final models should be trained without ``--prototyping`` enabled, so that
708 | the full training set is used for training and the best model is produced.
709 |
710 |
711 | Optional extra package dependencies
712 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
713 |
714 | There are several requirements files in the root directory of the repository.
715 | The idea is the requirements.txt file contains the minimal set of packages
716 | that are needed to use the models in the package.
717 | The other requirements files are for optional extra packages.
718 |
719 | requirements-dev.txt
720 | Extra packages needed for code development (i.e. writing the codebase)
721 |
722 | requirements-notebook.txt
723 | Extra packages needed for running the notebooks.
724 |
725 | requirements-train.txt
726 | Extra packages needed for training the models.
727 |
728 | The setup.py file will automatically parse any requirements files in the
729 | root directory of the repository which are named like ``requirements-*.txt``
730 | and make them available to ``pip`` as extras.
731 |
732 | For example, to install the repository to your virtual environment with the
733 | extra packages needed for training::
734 |
735 | pip install --editable .[train]
736 |
737 | You can also install all the extras at once::
738 |
739 | pip install --editable .[all]
740 |
741 | Or you can install the extras directly from the requirements files::
742 |
743 | pip install -r requirements-train.txt
744 |
745 | As a developer of the repository, you will need to pip install the package
746 | with the ``--editable`` flag so the installed copy is updated automatically
747 | when you make changes to the codebase.
748 |
749 |
750 | Automated code checking and formatting
751 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
752 |
753 | The template repository comes with a pre-commit_ stack.
754 | This is a set of git hooks which are executed every time you make a commit.
755 | The hooks catch errors as they occur, and will automatically fix some of these errors.
756 |
757 | To set up the pre-commit hooks, run the following code from within the repo directory::
758 |
759 | pip install -r requirements-dev.txt
760 | pre-commit install
761 |
762 | Whenever you try to commit code which is flagged by the pre-commit hooks,
763 | *the commit will not go through*. Some of the pre-commit hooks
764 | (such as black_, isort_) will automatically modify your code to fix formatting
765 | issues. When this happens, you'll have to stage the changes made by the commit
766 | hooks and then try your commit again. Other pre-commit hooks, such as flake8_,
767 | will not modify your code and will just tell you about issues in what you tried
768 | to commit (e.g. a variable was declared and never used), and you'll then have
769 | to manually fix these yourself before staging the corrected version.
770 |
771 | After installing it, the pre-commit stack will run every time you try to make
772 | a commit to this repository on that machine.
773 | You can also manually run the pre-commit stack on all the files at any time:
774 |
775 | .. code-block:: bash
776 |
777 | pre-commit run --all-files
778 |
779 | To force a commit to go through without passing the pre-commit hooks use the ``--no-verify`` flag:
780 |
781 | .. code-block:: bash
782 |
783 | git commit --no-verify
784 |
785 | The pre-commit stack which comes with the template is highly opinionated, and
786 | includes the following operations:
787 |
788 | - All **outputs in Jupyter notebooks are cleared** using nbstripout_.
789 |
790 | - Code is reformatted to use the black_ style.
791 | Any code inside docstrings will be formatted to black using blackendocs_.
792 | All code cells in Jupyter notebooks are also formatted to black using black_nbconvert_.
793 |
794 | - Imports are automatically sorted using isort_.
795 |
796 | - Entries in requirements.txt files are automatically sorted alphabetically.
797 |
798 | - Several `hooks from pre-commit `_ are used to screen for
799 | non-language specific git issues, such as incomplete git merges, overly large
800 | files being commited to the repo, bugged JSON and YAML files.
801 |
802 | - JSON files are also prettified automatically to have standardised indentation.
803 |
804 | The pre-commit stack will also run on github with one of the action workflows,
805 | which ensures the code that is pushed is validated without relying on every
806 | contributor installing pre-commit locally.
807 |
808 | This development practice of using pre-commit_, and standardizing the
809 | code-style using black_, is popular among leading open-source python projects
810 | including numpy, scipy, sklearn, Pillow, and many others.
811 |
812 | If you want to use pre-commit, but **want to commit outputs in Jupyter notebooks**
813 | instead of stripping them, simply remove the nbstripout_ hook from the
814 | `.pre-commit-config.yaml file `__
815 | and commit that change.
816 |
817 | If you don't want to use pre-commit at all, you can uninstall it:
818 |
819 | .. code-block:: bash
820 |
821 | pre-commit uninstall
822 |
823 | and purge it (along with black and flake8) from the repository:
824 |
825 | .. code-block:: bash
826 |
827 | git rm .pre-commit-config.yaml .flake8 .github/workflows/pre-commit.yaml
828 | git commit -m "DEV: Remove pre-commit hooks"
829 |
830 | .. _black: https://github.com/psf/black
831 | .. _black_nbconvert: https://github.com/dfm/black_nbconvert
832 | .. _blackendocs: https://github.com/asottile/blacken-docs
833 | .. _flake8: https://gitlab.com/pycqa/flake8
834 | .. _isort: https://github.com/timothycrosley/isort
835 | .. _nbstripout: https://github.com/kynan/nbstripout
836 | .. _pre-commit: https://pre-commit.com/
837 | .. _pre-commit-hooks: https://github.com/pre-commit/pre-commit-hooks
838 | .. _pre-commit-py-hooks: https://github.com/pre-commit/pygrep-hooks
839 |
840 |
841 | Additional features
842 | -------------------
843 |
844 | This template was forked from a more general `python template repository`_.
845 |
846 | For more information on the features of the python template repository, see
847 | `here `_.
848 |
849 | .. _`python template repository`: https://github.com/scottclowe/python-template-repo
850 | .. _`python-template-repository-features`: https://github.com/scottclowe/python-template-repo#features
851 |
852 |
853 | Contributing
854 | ------------
855 |
856 | Contributions are welcome! If you can see a way to improve this template:
857 |
858 | - Clone this repo
859 | - Create a feature branch
860 | - Make your changes in the feature branch
861 | - Push your branch and make a pull request
862 |
863 | Or to report a bug or request something new, make an issue.
864 |
865 |
866 | .. |SLURM| image:: https://img.shields.io/badge/scheduler-SLURM-40B1EC
867 | :target: https://slurm.schedmd.com/
868 | :alt: SLURM
869 | .. |preempt| image:: https://img.shields.io/badge/preemption-supported-brightgreen
870 | :alt: preemption
871 | .. |PyTorch| image:: https://img.shields.io/badge/PyTorch-DDP-EE4C2C?logo=pytorch&logoColor=EE4C2C
872 | :target: https://pytorch.org/
873 | :alt: pytorch
874 | .. |wandb| image:: https://img.shields.io/badge/Weights_%26_Biases-enabled-FFCC33?logo=WeightsAndBiases&logoColor=FFCC33
875 | :target: https://wandb.ai
876 | :alt: Weights&Biases
877 | .. |pre-commit| image:: https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white
878 | :target: https://github.com/pre-commit/pre-commit
879 | :alt: pre-commit
880 | .. |black| image:: https://img.shields.io/badge/code%20style-black-000000.svg
881 | :target: https://github.com/psf/black
882 | :alt: black
883 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.black]
6 | line-length = 120
7 | target-version = ["py38"]
8 |
9 | [tool.isort]
10 | src_paths = ["template_experiment"]
11 | known_first_party = ["template_experiment"]
12 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | black==25.1.0
2 | identify>=1.4.20
3 | pre-commit
4 |
--------------------------------------------------------------------------------
/requirements-notebook.txt:
--------------------------------------------------------------------------------
1 | jupyter[notebook]
2 |
--------------------------------------------------------------------------------
/requirements-train.txt:
--------------------------------------------------------------------------------
1 | wandb
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scikit-learn>=1.2.1
3 | timm>=0.6.12
4 | torch>=1.12
5 | torchvision>=0.13
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import glob
4 | import os
5 |
6 | from setuptools import find_packages, setup
7 |
8 |
9 | def read(fname):
10 | """
11 | Read the contents of a file.
12 |
13 | Parameters
14 | ----------
15 | fname : str
16 | Path to file.
17 |
18 | Returns
19 | -------
20 | str
21 | File contents.
22 | """
23 | with open(os.path.join(os.path.dirname(__file__), fname)) as f:
24 | return f.read()
25 |
26 |
27 | install_requires = read("requirements.txt").splitlines()
28 |
29 | # Dynamically determine extra dependencies
30 | extras_require = {}
31 | extra_req_files = glob.glob("requirements-*.txt")
32 | for extra_req_file in extra_req_files:
33 | name = os.path.splitext(extra_req_file)[0].replace("requirements-", "", 1)
34 | extras_require[name] = read(extra_req_file).splitlines()
35 |
36 | # If there are any extras, add a catch-all case that includes everything.
37 | # This assumes that entries in extras_require are lists (not single strings),
38 | # and that there are no duplicated packages across the extras.
39 | if extras_require:
40 | extras_require["all"] = sorted({x for v in extras_require.values() for x in v})
41 |
42 |
43 | # Import meta data from __meta__.py
44 | #
45 | # We use exec for this because __meta__.py runs its __init__.py first,
46 | # __init__.py may assume the requirements are already present, but this code
47 | # is being run during the `python setup.py install` step, before requirements
48 | # are installed.
49 | # https://packaging.python.org/guides/single-sourcing-package-version/
50 | meta = {}
51 | exec(read("template_experiment/__meta__.py"), meta)
52 |
53 |
54 | # Import the README and use it as the long-description.
55 | # If your readme path is different, add it here.
56 | possible_readme_names = ["README.rst", "README.md", "README.txt", "README"]
57 |
58 | # Handle turning a README file into long_description
59 | long_description = meta["description"]
60 | readme_fname = ""
61 | for fname in possible_readme_names:
62 | try:
63 | long_description = read(fname)
64 | except IOError:
65 | # doesn't exist
66 | continue
67 | else:
68 | # exists
69 | readme_fname = fname
70 | break
71 |
72 | # Infer the content type of the README file from its extension.
73 | # If the contents of your README do not match its extension, manually assign
74 | # long_description_content_type to the appropriate value.
75 | readme_ext = os.path.splitext(readme_fname)[1]
76 | if readme_ext.lower() == ".rst":
77 | long_description_content_type = "text/x-rst"
78 | elif readme_ext.lower() == ".md":
79 | long_description_content_type = "text/markdown"
80 | else:
81 | long_description_content_type = "text/plain"
82 |
83 |
84 | setup(
85 | # Essential details on the package and its dependencies
86 | name=meta["name"],
87 | version=meta["version"],
88 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]),
89 | package_dir={meta["name"]: os.path.join(".", meta["path"])},
90 | # If any package contains *.txt or *.rst files, include them:
91 | # package_data={"": ["*.txt", "*.rst"],}
92 | python_requires=">=3.8",
93 | install_requires=install_requires,
94 | extras_require=extras_require,
95 | # Metadata to display on PyPI
96 | author=meta["author"],
97 | author_email=meta["author_email"],
98 | description=meta["description"],
99 | long_description=long_description,
100 | long_description_content_type=long_description_content_type,
101 | license=meta["license"],
102 | url=meta["url"],
103 | classifiers=[
104 | # Trove classifiers
105 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
106 | "Natural Language :: English",
107 | "Intended Audience :: Science/Research",
108 | "Programming Language :: Python",
109 | "Programming Language :: Python :: 3",
110 | "Programming Language :: Python :: 3.8",
111 | "Programming Language :: Python :: 3.9",
112 | "Programming Language :: Python :: 3.10",
113 | "Programming Language :: Python :: 3.11",
114 | "Programming Language :: Python :: 3.12",
115 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
116 | ],
117 | entry_points={
118 | "console_scripts": [
119 | "template-experiment-train=template_experiment.train:cli",
120 | ],
121 | },
122 | )
123 |
--------------------------------------------------------------------------------
/slogs/.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scottclowe/pytorch-experiment-template/138c3bdec82bf4fb026d458032e05ae405de1fdb/slogs/.placeholder
--------------------------------------------------------------------------------
/slurm/notebook.slrm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --partition=t4v1,t4v2 # Which node partitions to use. Use a comma-separated list if you don't mind which partition: t4v1,t4v2,rtx6000,a40
3 | #SBATCH --nodes=1 # Number of nodes to request. Usually use 1 for interactive jobs.
4 | #SBATCH --tasks-per-node=1 # Number of processes to spawn per node. Should always be set to 1, regardless of number of GPUs!
5 | #SBATCH --gres=gpu:1 # Number of GPUs per node to request
6 | #SBATCH --cpus-per-gpu=4 # Number of CPUs to request per GPU (soft maximum of 4 per GPU requested)
7 | #SBATCH --mem-per-gpu=10G # RAM per GPU
8 | #SBATCH --time=16:00:00 # You must specify a maximum run-time if you want to run for more than 2h
9 | #SBATCH --output=jnb_%j.out # You'll need to inspect this log file to find out how to connect to the notebook
10 | #SBATCH --job-name=jnb
11 |
12 |
13 | # Manually define the project name.
14 | # This should also be the name of your conda environment used for this project.
15 | PROJECT_NAME="template-experiment"
16 |
17 | # Exit if any command hits an error
18 | set -e
19 |
20 | # Store the time at which the script was launched
21 | start_time="$SECONDS"
22 |
23 | echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) begins on $(hostname), submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)"
24 |
25 | # Print slurm config report
26 | echo "Running slurm/utils/report_slurm_config.sh"
27 | source "slurm/utils/report_slurm_config.sh"
28 |
29 | echo ""
30 | echo "-------- Activating environment ----------------------------------------"
31 | date
32 | echo ""
33 | echo "Running ~/.bashrc"
34 | source ~/.bashrc
35 |
36 | # Activate virtual environment
37 | ENVNAME="$PROJECT_NAME"
38 | echo "Activating environment $ENVNAME"
39 | conda activate "$ENVNAME"
40 | echo ""
41 |
42 | # We don't want to use a fixed port number (e.g. 8888) because that would lead
43 | # to collisions between jobs on the same node as each other.
44 | # Try to use a port number automatically selected from the job id number.
45 | JOB_SOCKET=$(( $SLURM_JOB_ID % 16384 + 49152 ))
46 | if ss -tulpn | grep -q ":$JOB_SOCKET ";
47 | then
48 | # The port we selected is in use, so we'll get a random available port instead.
49 | JOB_SOCKET="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')";
50 | fi
51 | echo "Will use port $JOB_SOCKET to serve the notebook"
52 | echo ""
53 | echo "-------- Connection instructions ---------------------------------------"
54 | echo ""
55 | echo "You should be able to access the notebook locally with the following commands:"
56 | echo ""
57 | echo "ssh ${USER}@v.vectorinstitute.ai -N -L 8887:$(hostname):${JOB_SOCKET}"
58 | echo "sensible-browser http://localhost:8887"
59 | echo "# Then enter the token printed below, or your password if you set one"
60 | echo ""
61 | echo "-------- Starting jupyter notebook -------------------------------------"
62 | date
63 | export XDG_RUNTIME_DIR=""
64 | python -m jupyter notebook --ip 0.0.0.0 --port "$JOB_SOCKET"
65 |
--------------------------------------------------------------------------------
/slurm/train.slrm:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --partition=t4v1,t4v2 # Which node partitions to use. Use a comma-separated list if you don't mind which partition: t4v1,t4v2,rtx6000,a40
3 | #SBATCH --nodes=1 # Number of nodes to request. Can increase to --nodes=2, etc, for more GPUs (spread out over different nodes).
4 | #SBATCH --tasks-per-node=1 # Number of processes to spawn per node. Should always be set to 1, regardless of number of GPUs!
5 | #SBATCH --gres=gpu:1 # Number of GPUs per node. Can increase to --gres=gpu:2, etc, for more GPUs (together on the same node).
6 | #SBATCH --cpus-per-gpu=4 # Number of CPUs per GPU. Soft maximum of 4 per GPU requested on t4, 8 otherwise. Hard maximum of 32 per node.
7 | #SBATCH --mem-per-gpu=10G # RAM per GPU. Soft maximum of 20G per GPU requested on t4v2, 41G otherwise. Hard maximum of 167G per node.
8 | #SBATCH --time=08:00:00 # You must specify a maximum run-time if you want to run for more than 2h
9 | #SBATCH --signal=B:USR1@120 # Send signal SIGUSR1 120 seconds before the job hits the time limit
10 | #SBATCH --output=slogs/%x__%A_%a.out
11 | # %x=job-name, %A=job ID, %a=array value, %n=node rank, %t=task rank, %N=hostname
12 | # Note: You must create output directory "slogs" before launching job, otherwise it will immediately
13 | # fail without an error message.
14 | # Note: If you specify --output and not --error, then both STDOUT and STDERR will both be sent to the
15 | # file specified by --output.
16 | #SBATCH --array=0 # Use array to run multiple jobs that are identical except for $SLURM_ARRAY_TASK_ID.
17 | # In this example, we use this to set the seed. You can run multiple seeds with --array=0-4, for example.
18 | #SBATCH --open-mode=append # Use append mode otherwise preemption resets the checkpoint file.
19 | #SBATCH --job-name=template-experiment # Set this to be a shorthand for your project's name.
20 |
21 | # Manually define the project name.
22 | # This must also be the name of your conda environment used for this project.
23 | PROJECT_NAME="template-experiment"
24 | # Automatically convert hyphens to underscores, to get the name of the project directory.
25 | PROJECT_DIRN="${PROJECT_NAME//-/_}"
26 |
27 | # Exit the script if any command hits an error
28 | set -e
29 |
30 | # Set up a handler to requeue the job if it hits the time-limit without terminating
31 | function term_handler()
32 | {
33 | echo "** Job $SLURM_JOB_NAME ($SLURM_JOB_ID) received SIGUSR1 at $(date) **"
34 | echo "** Requeuing job $SLURM_JOB_ID so it can run for longer **"
35 | scontrol requeue "${SLURM_JOB_ID}"
36 | }
37 | # Call this term_hnadler function when the job recieves the SIGUSR1 signal
38 | trap term_handler SIGUSR1
39 |
40 | # sbatch script for Vector
41 | # Inspired by:
42 | # https://github.com/VectorInstitute/TechAndEngineering/blob/master/benchmarks/resnet_torch/sample_script/script.sh
43 | # https://github.com/VectorInstitute/TechAndEngineering/blob/master/checkpoint_examples/PyTorch/launch_job.slrm
44 | # https://github.com/VectorInstitute/TechAndEngineering/blob/master/checkpoint_examples/PyTorch/run_train.sh
45 | # https://github.com/PrincetonUniversity/multi_gpu_training/tree/main/02_pytorch_ddp
46 | # https://pytorch.org/docs/stable/elastic/run.html
47 | # https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
48 | # https://unix.stackexchange.com/a/146770/154576
49 |
50 | # Store the time at which the script was launched, so we can measure how long has elapsed.
51 | start_time="$SECONDS"
52 |
53 | echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) begins on $(hostname), submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)"
54 | echo ""
55 | # Print slurm config report (SLURM environment variables, some of which we use later in the script)
56 | # By sourcing the script, we execute it as if its code were here in the script
57 | # N.B. This script only prints things out, it doesn't assign any environment variables.
58 | echo "Running slurm/utils/report_slurm_config.sh"
59 | source "slurm/utils/report_slurm_config.sh"
60 | # Print repo status report (current branch, commit ref, where any uncommitted changes are located)
61 | # N.B. This script only prints things out, it doesn't assign any environment variables.
62 | echo "Running slurm/utils/report_repo.sh"
63 | source "slurm/utils/report_repo.sh"
64 | echo ""
65 | if false; then
66 | # Print disk usage report, to catch errors due to lack of file space.
67 | # This is disabled by default to prevent confusing new users with too
68 | # much output.
69 | echo "------------------------------------"
70 | echo "df -h:"
71 | # Print header, then sort the rows alphabetically by mount point
72 | df -h --output=target,pcent,size,used,avail,source | head -n 1
73 | df -h --output=target,pcent,size,used,avail,source | tail -n +2 | sort -h
74 | echo ""
75 | fi
76 | echo "-------- Input handling ------------------------------------------------"
77 | date
78 | echo ""
79 | # Use the SLURM job array to select the seed for the experiment
80 | SEED="$SLURM_ARRAY_TASK_ID"
81 | if [[ "$SEED" == "" ]];
82 | then
83 | SEED=0
84 | fi
85 | echo "SEED = $SEED"
86 |
87 | # Check if the first argument is a path to the python script to run
88 | if [[ "$1" == *.py ]];
89 | then
90 | # If it is, we'll run this python script and remove it from the list of
91 | # arguments to pass on to the script.
92 | SCRIPT_PATH="$1"
93 | shift
94 | else
95 | # Otherwise, use our default python training script.
96 | SCRIPT_PATH="$PROJECT_DIRN/train.py"
97 | fi
98 | echo "SCRIPT_PATH = $SCRIPT_PATH"
99 |
100 | # Any arguments provided to sbatch after the name of the slurm script will be
101 | # passed through to the main script later.
102 | # (The pass-through works like *args or **kwargs in python.)
103 | echo "Pass-through args: ${@}"
104 | echo ""
105 | echo "-------- Activating environment ----------------------------------------"
106 | date
107 | echo ""
108 | echo "Running ~/.bashrc"
109 | source ~/.bashrc
110 |
111 | # Activate virtual environment
112 | ENVNAME="$PROJECT_NAME"
113 | echo "Activating conda environment $ENVNAME"
114 | conda activate "$ENVNAME"
115 | echo ""
116 | # Print env status (which packages you have installed - useful for diagnostics)
117 | # N.B. This script only prints things out, it doesn't assign any environment variables.
118 | echo "Running slurm/utils/report_env_config.sh"
119 | source "slurm/utils/report_env_config.sh"
120 |
121 | # Set the JOB_LABEL environment variable
122 | echo "-------- Setting JOB_LABEL ---------------------------------------------"
123 | echo ""
124 | # Decide the name of the paths to use for saving this job
125 | if [ "$SLURM_ARRAY_TASK_COUNT" != "" ] && [ "$SLURM_ARRAY_TASK_COUNT" -gt 1 ];
126 | then
127 | JOB_ID="${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}";
128 | else
129 | JOB_ID="${SLURM_JOB_ID}";
130 | fi
131 | # Decide the name of the paths to use for saving this job
132 | JOB_LABEL="${SLURM_JOB_NAME}__${JOB_ID}";
133 | echo "JOB_ID = $JOB_ID"
134 | echo "JOB_LABEL = $JOB_LABEL"
135 | echo ""
136 |
137 | # Set checkpoint directory ($CKPT_DIR) environment variables
138 | echo "-------- Setting checkpoint and output path variables ------------------"
139 | echo ""
140 | # Vector provides a fast parallel filesystem local to the GPU nodes, dedicated
141 | # for checkpointing. It is mounted under /checkpoint. It is strongly
142 | # recommended that you keep your intermediary checkpoints under this directory
143 | CKPT_DIR="/checkpoint/${SLURM_JOB_USER}/${SLURM_JOB_ID}"
144 | echo "CKPT_DIR = $CKPT_DIR"
145 | CKPT_PTH="$CKPT_DIR/checkpoint_latest.pt"
146 | echo "CKPT_PTH = $CKPT_PTH"
147 | echo ""
148 | # Ensure the checkpoint dir exists
149 | mkdir -p "$CKPT_DIR"
150 | echo "Current contents of ${CKPT_DIR}:"
151 | ls -lh "${CKPT_DIR}"
152 | echo ""
153 | # Create a symlink to the job's checkpoint directory within a subfolder of the
154 | # current directory (repository directory) named "checkpoint_working".
155 | mkdir -p "checkpoints_working"
156 | ln -sfn "$CKPT_DIR" "$PWD/checkpoints_working/$SLURM_JOB_NAME"
157 | # Specify an output directory to place checkpoints for long term storage once
158 | # the job is finished.
159 | if [[ -d "/scratch/hdd001/home/$SLURM_JOB_USER" ]];
160 | then
161 | OUTPUT_DIR="/scratch/hdd001/home/$SLURM_JOB_USER"
162 | elif [[ -d "/scratch/ssd004/scratch/$SLURM_JOB_USER" ]];
163 | then
164 | OUTPUT_DIR="/scratch/ssd004/scratch/$SLURM_JOB_USER"
165 | else
166 | OUTPUT_DIR=""
167 | fi
168 | if [[ "$OUTPUT_DIR" != "" ]];
169 | then
170 | # Directory OUTPUT_DIR will contain all completed jobs for this project.
171 | OUTPUT_DIR="$OUTPUT_DIR/checkpoints/$PROJECT_NAME"
172 | # Subdirectory JOB_OUTPUT_DIR will contain the outputs from this job.
173 | JOB_OUTPUT_DIR="$OUTPUT_DIR/$JOB_LABEL"
174 | echo "JOB_OUTPUT_DIR = $JOB_OUTPUT_DIR"
175 | if [[ -d "$JOB_OUTPUT_DIR" ]];
176 | then
177 | echo "Current contents of ${JOB_OUTPUT_DIR}"
178 | ls -lh "${JOB_OUTPUT_DIR}"
179 | fi
180 | echo ""
181 | fi
182 |
183 | # Save a list of installed packages and their versions to a file in the output directory
184 | conda env export > "$CKPT_DIR/environment.yml"
185 | pip freeze > "$CKPT_DIR/frozen-requirements.txt"
186 |
187 | if [[ "$SLURM_RESTART_COUNT" > 0 && ! -f "$CKPT_PTH" ]];
188 | then
189 | echo ""
190 | echo "====================================================================="
191 | echo "SLURM SCRIPT ERROR:"
192 | echo " Resuming after pre-emption (SLURM_RESTART_COUNT=$SLURM_RESTART_COUNT)"
193 | echo " but there is no checkpoint file at $CKPT_PTH"
194 | echo "====================================================================="
195 | exit 1;
196 | fi;
197 |
198 | echo ""
199 | echo "------------------------------------"
200 | elapsed=$(( SECONDS - start_time ))
201 | eval "echo Running total elapsed time for restart $SLURM_RESTART_COUNT: $(date -ud "@$elapsed" +'$((%s/3600/24)) days %H hr %M min %S sec')"
202 | echo ""
203 | echo "-------- Begin main script ---------------------------------------------"
204 | date
205 | echo ""
206 | # Store the master node's IP address in the MASTER_ADDR environment variable,
207 | # which torch.distributed will use to initialize DDP.
208 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
209 | echo "Rank 0 node is at $MASTER_ADDR"
210 |
211 | # Use a port number automatically selected from the job id number.
212 | # This will let us use the same port for every task in the job without having
213 | # to create a file to store the port number.
214 | # We don't want to use a fixed port number because that would lead to
215 | # collisions between jobs when they are scheduled on the same node.
216 | # We only use ports in the range 49152-65535 (inclusive), which are the
217 | # Dynamic Ports, also known as Private Ports.
218 | MASTER_PORT="$(( $SLURM_JOB_ID % 16384 + 49152 ))"
219 | if ss -tulpn | grep -q ":$MASTER_PORT ";
220 | then
221 | # The port we selected is in use, so we'll get a random available port instead.
222 | echo "Finding a free port to use for $SLURM_NNODES node training"
223 | MASTER_PORT="$(python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1])')";
224 | fi
225 | export MASTER_PORT;
226 | echo "Will use port $MASTER_PORT for c10d communication"
227 |
228 | export WORLD_SIZE="$(($SLURM_NNODES * $SLURM_GPUS_ON_NODE))"
229 | echo "WORLD_SIZE = $WORLD_SIZE"
230 |
231 | # NCCL options ----------------------------------------------------------------
232 |
233 | # This is needed to print debug info from NCCL, can be removed if all goes well
234 | # export NCCL_DEBUG=INFO
235 |
236 | # This is needed to avoid NCCL to use ifiniband, which the cluster does not have
237 | export NCCL_IB_DISABLE=1
238 |
239 | # This is to tell NCCL to use bond interface for network communication
240 | if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]];
241 | then
242 | echo "Using NCCL_SOCKET_IFNAME=bond0 on ${SLURM_JOB_PARTITION}"
243 | export NCCL_SOCKET_IFNAME=bond0
244 | fi
245 |
246 | # Set this when using the NCCL backend for inter-GPU communication.
247 | export TORCH_NCCL_BLOCKING_WAIT=1
248 | # -----------------------------------------------------------------------------
249 |
250 | # Multi-GPU configuration
251 | echo ""
252 | echo "Main script begins via torchrun with host tcp://${MASTER_ADDR}:$MASTER_PORT with backend NCCL"
253 | if [[ "$SLURM_JOB_NUM_NODES" == "1" ]];
254 | then
255 | echo "Single ($SLURM_JOB_NUM_NODES) node training ($SLURM_GPUS_ON_NODE GPUs)"
256 | else
257 | echo "Multiple ($SLURM_JOB_NUM_NODES) node training (x$SLURM_GPUS_ON_NODE GPUs per node)"
258 | fi
259 | echo ""
260 |
261 | # We use the torchrun command to launch our main python script.
262 | # It will automatically set up the necessary environment variables for DDP,
263 | # and will launch the script once for each GPU on each node.
264 | #
265 | # We pass the CKPT_DIR environment variable on as the output path for our
266 | # python script, and also try to resume from a checkpoint in this directory
267 | # in case of pre-emption. The python script should run from scratch if there
268 | # is no checkpoint at this path to resume from.
269 | #
270 | # We pass on to train.py an arary of arbitrary extra arguments given to this
271 | # slurm script contained in the `$@` magic variable.
272 | #
273 | # We execute the srun command in the background with `&` (and then check its
274 | # process ID and wait for it to finish before continuing) so the main process
275 | # can handle the SIGUSR1 signal. Otherwise if a child process is running, the
276 | # signal will be ignored.
277 | srun -N "$SLURM_NNODES" --ntasks-per-node=1 \
278 | torchrun \
279 | --nnodes="$SLURM_JOB_NUM_NODES" \
280 | --nproc_per_node="$SLURM_GPUS_ON_NODE" \
281 | --rdzv_id="$SLURM_JOB_ID" \
282 | --rdzv_backend=c10d \
283 | --rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \
284 | "$SCRIPT_PATH" \
285 | --cpu-workers="$SLURM_CPUS_PER_GPU" \
286 | --seed="$SEED" \
287 | --checkpoint="$CKPT_PTH" \
288 | --log-wandb \
289 | --run-name="$SLURM_JOB_NAME" \
290 | --run-id="$JOB_ID" \
291 | "${@}" &
292 | child="$!"
293 | wait "$child"
294 |
295 | echo ""
296 | echo "------------------------------------"
297 | elapsed=$(( SECONDS - start_time ))
298 | eval "echo Running total elapsed time for restart $SLURM_RESTART_COUNT: $(date -ud "@$elapsed" +'$((%s/3600/24)) days %H hr %M min %S sec')"
299 | echo ""
300 | # Now the job is finished, remove the symlink to the job's checkpoint directory
301 | # from checkpoints_working
302 | rm "$PWD/checkpoints_working/$SLURM_JOB_NAME"
303 | # By overriding the JOB_OUTPUT_DIR environment variable, we disable saving
304 | # checkpoints to long-term storage. This is disabled by default to preserve
305 | # disk space. When you are sure your job config is correct and you are sure
306 | # you need to save your checkpoints for posterity, comment out this line.
307 | JOB_OUTPUT_DIR=""
308 | #
309 | if [[ "$CKPT_DIR" == "" ]];
310 | then
311 | # This shouldn't ever happen, but we have a check for just in case.
312 | # If $CKPT_DIR were somehow not set, we would mistakenly try to copy far
313 | # too much data to $JOB_OUTPUT_DIR.
314 | echo "CKPT_DIR is unset. Will not copy outputs to $JOB_OUTPUT_DIR."
315 | elif [[ "$JOB_OUTPUT_DIR" == "" ]];
316 | then
317 | echo "JOB_OUTPUT_DIR is unset. Will not copy outputs from $CKPT_DIR."
318 | else
319 | echo "-------- Saving outputs for long term storage --------------------------"
320 | date
321 | echo ""
322 | echo "Copying outputs from $CKPT_DIR to $JOB_OUTPUT_DIR"
323 | mkdir -p "$JOB_OUTPUT_DIR"
324 | rsync -rutlzv "$CKPT_DIR/" "$JOB_OUTPUT_DIR/"
325 | echo ""
326 | echo "Output contents of ${JOB_OUTPUT_DIR}:"
327 | ls -lh "$JOB_OUTPUT_DIR"
328 | # Set up a symlink to the long term storage directory
329 | ln -sfn "$OUTPUT_DIR" "checkpoints_finished"
330 | fi
331 | echo ""
332 | echo "------------------------------------------------------------------------"
333 | echo ""
334 | echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) finished, submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)"
335 | date
336 | echo "------------------------------------"
337 | elapsed=$(( SECONDS - start_time ))
338 | eval "echo Total elapsed time for restart $SLURM_RESTART_COUNT: $(date -ud "@$elapsed" +'$((%s/3600/24)) days %H hr %M min %S sec')"
339 | echo "========================================================================"
340 |
--------------------------------------------------------------------------------
/slurm/utils/report_env_config.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "========================================================================"
4 | echo "-------- Reporting environment configuration ---------------------------"
5 | date
6 | echo ""
7 | echo "pwd:"
8 | pwd
9 | echo ""
10 | echo "which python:"
11 | which python
12 | echo ""
13 | echo "python version:"
14 | python --version
15 | echo ""
16 | echo "which conda:"
17 | which conda
18 | echo ""
19 | echo "conda info:"
20 | conda info
21 | echo ""
22 | echo "which pip:"
23 | which pip
24 | echo ""
25 | ## Don't bother looking at system nvcc, as we have a conda installation
26 | # echo "which nvcc:"
27 | # which nvcc || echo "No nvcc"
28 | # echo ""
29 | # echo "nvcc version:"
30 | # nvcc --version || echo "No nvcc"
31 | # echo ""
32 | echo "nvidia-smi:"
33 | nvidia-smi || echo "No nvidia-smi"
34 | echo ""
35 | echo "torch info:"
36 | python -c "import torch; print(f'pytorch={torch.__version__}, cuda={torch.cuda.is_available()}, gpus={torch.cuda.device_count()}')"
37 | python -c "import torch; print(str(torch.ones(1, device=torch.device('cuda')))); print('able to use cuda')"
38 | echo ""
39 | echo "========================================================================"
40 | echo ""
41 |
--------------------------------------------------------------------------------
/slurm/utils/report_repo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "========================================================================"
4 | echo "-------- Reporting git repo configuration ------------------------------"
5 | date
6 | echo ""
7 | echo "pwd: $(pwd)"
8 | echo "commit ref: $(git rev-parse HEAD)"
9 | echo ""
10 | git status
11 | echo "========================================================================"
12 | echo ""
13 |
--------------------------------------------------------------------------------
/slurm/utils/report_slurm_config.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "========================================================================"
4 | echo "-------- Reporting SLURM configuration ---------------------------------"
5 | date
6 | echo ""
7 | echo "SLURM_CLUSTER_NAME = $SLURM_CLUSTER_NAME" # Name of the cluster on which the job is executing.
8 | echo "SLURM_JOB_QOS = $SLURM_JOB_QOS" # Quality Of Service (QOS) of the job allocation.
9 | echo "SLURM_JOB_ID = $SLURM_JOB_ID" # The ID of the job allocation.
10 | echo "SLURM_RESTART_COUNT = $SLURM_RESTART_COUNT" # The number of times the job has been restarted.
11 | if [ "$SLURM_ARRAY_TASK_COUNT" != "" ]; then
12 | echo ""
13 | echo "SLURM_ARRAY_JOB_ID = $SLURM_ARRAY_JOB_ID" # Job array's master job ID number.
14 | echo "SLURM_ARRAY_TASK_COUNT = $SLURM_ARRAY_TASK_COUNT" # Total number of tasks in a job array.
15 | echo "SLURM_ARRAY_TASK_ID = $SLURM_ARRAY_TASK_ID" # Job array ID (index) number.
16 | echo "SLURM_ARRAY_TASK_MAX = $SLURM_ARRAY_TASK_MAX" # Job array's maximum ID (index) number.
17 | echo "SLURM_ARRAY_TASK_STEP = $SLURM_ARRAY_TASK_STEP" # Job array's index step size.
18 | fi;
19 | echo ""
20 | echo "SLURM_JOB_NUM_NODES = $SLURM_JOB_NUM_NODES" # Total number of nodes in the job's resource allocation.
21 | echo "SLURM_JOB_NODELIST = $SLURM_JOB_NODELIST" # List of nodes allocated to the job.
22 | echo "SLURM_TASKS_PER_NODE = $SLURM_TASKS_PER_NODE" # Number of tasks to be initiated on each node.
23 | echo "SLURM_NTASKS = $SLURM_NTASKS" # Number of tasks to spawn.
24 | echo "SLURM_PROCID = $SLURM_PROCID" # The MPI rank (or relative process ID) of the current process
25 | echo ""
26 | echo "SLURM_GPUS_ON_NODE = $SLURM_GPUS_ON_NODE" # Number of allocated GPUs per node.
27 | echo "SLURM_CPUS_ON_NODE = $SLURM_CPUS_ON_NODE" # Number of allocated CPUs per node.
28 | echo "SLURM_CPUS_PER_GPU = $SLURM_CPUS_PER_GPU" # Number of CPUs requested per GPU. Only set if the --cpus-per-gpu option is specified.
29 | echo "SLURM_MEM_PER_GPU = $SLURM_MEM_PER_GPU" # Memory per allocated GPU. Only set if the --mem-per-gpu option is specified.
30 | echo ""
31 | if [[ "$SLURM_TMPDIR" != "" ]];
32 | then
33 | echo "------------------------------------"
34 | echo ""
35 | echo "SLURM_TMPDIR = $SLURM_TMPDIR"
36 | echo ""
37 | echo "Contents of $SLURM_TMPDIR"
38 | ls -lh "$SLURM_TMPDIR"
39 | echo ""
40 | fi;
41 | echo "========================================================================"
42 | echo ""
43 |
--------------------------------------------------------------------------------
/template_experiment/__init__.py:
--------------------------------------------------------------------------------
1 | from . import __meta__
2 |
3 | __version__ = __meta__.version
4 |
--------------------------------------------------------------------------------
/template_experiment/__meta__.py:
--------------------------------------------------------------------------------
1 | # `name` is the name of the package as used for `pip install package` (use hyphens)
2 | name = "template-experiment"
3 | # `path` is the name of the package for `import package` (use underscores)
4 | path = name.lower().replace("-", "_").replace(" ", "_")
5 | # Your version number should follow https://python.org/dev/peps/pep-0440 and
6 | # https://semver.org
7 | version = "0.1.dev0"
8 | author = "Author Name"
9 | author_email = ""
10 | description = "" # One-liner
11 | url = "" # your project homepage
12 | license = "Unlicense" # See https://choosealicense.com
13 |
--------------------------------------------------------------------------------
/template_experiment/data_transformations.py:
--------------------------------------------------------------------------------
1 | import timm.data
2 | import torch
3 | from torchvision import transforms
4 |
5 | NORMALIZATION = {
6 | "imagenet": [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)],
7 | "mnist": [(0.1307,), (0.3081,)],
8 | "cifar": [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)],
9 | }
10 |
11 | VALID_TRANSFORMS = ["imagenet", "cifar", "mnist"]
12 |
13 |
14 | def get_transform(transform_type="noaug", image_size=32, args=None):
15 | if args is None:
16 | args = {}
17 | mean, std = NORMALIZATION[args.get("normalization", "imagenet")]
18 | if "mean" in args:
19 | mean = args["mean"]
20 | if "std" in args:
21 | std = args["std"]
22 |
23 | if transform_type == "noaug":
24 | # No augmentations, just resize and normalize.
25 | # N.B. If the raw training image isn't square, there is a small
26 | # "augmentation" as we will randomly crop a square (of length equal to
27 | # the shortest side) from it. We do this because we assume inputs to
28 | # the network must be square.
29 | train_transform = transforms.Compose(
30 | [
31 | transforms.Resize(image_size), # Resize shortest side to image_size
32 | transforms.RandomCrop(image_size), # If it is not square, *random* crop
33 | transforms.ToTensor(),
34 | transforms.Normalize(mean=mean, std=std),
35 | ]
36 | )
37 | test_transform = transforms.Compose(
38 | [
39 | transforms.Resize(image_size), # Resize shortest side to image_size
40 | transforms.CenterCrop(image_size), # If it is not square, center crop
41 | transforms.ToTensor(),
42 | transforms.Normalize(mean=mean, std=std),
43 | ]
44 | )
45 |
46 | elif transform_type == "imagenet":
47 | # Appropriate for really large natual images, as in ImageNet.
48 | # For training:
49 | # - Zoom in randomly with scale (big range of how much to zoom in)
50 | # - Stretch with random aspect ratio
51 | # - Flip horizontally
52 | # - Randomly adjust brightness/contrast/saturation
53 | # - (No rotation or skew)
54 | # - Interpolation is randomly either bicubic or bilinear
55 | train_transform = timm.data.create_transform(
56 | input_size=image_size,
57 | is_training=True,
58 | scale=(0.08, 1.0), # default imagenet scale range
59 | ratio=(3.0 / 4.0, 4.0 / 3.0), # default imagenet ratio range
60 | hflip=0.5,
61 | vflip=0.0,
62 | color_jitter=0.4,
63 | interpolation="random",
64 | mean=mean,
65 | std=std,
66 | )
67 | # For testing:
68 | # - Zoom in 87.5%
69 | # - Center crop
70 | # - Interpolation is bilinear
71 | test_transform = transforms.Compose(
72 | [
73 | transforms.Resize(int(image_size / 0.875)),
74 | transforms.CenterCrop(image_size),
75 | transforms.ToTensor(),
76 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
77 | ]
78 | )
79 |
80 | elif transform_type == "cifar":
81 | # Appropriate for smaller natural images, as in CIFAR-10/100.
82 | # For training:
83 | # - Zoom in randomly with scale (small range of how much to zoom in by)
84 | # - Stretch with random aspect ratio
85 | # - Flip horizontally
86 | # - Randomly adjust brightness/contrast/saturation
87 | # - (No rotation or skew)
88 | train_transform = timm.data.create_transform(
89 | input_size=image_size,
90 | is_training=True,
91 | scale=(0.7, 1.0), # reduced scale range
92 | ratio=(3.0 / 4.0, 4.0 / 3.0), # default imagenet ratio range
93 | hflip=0.5,
94 | vflip=0.0,
95 | color_jitter=0.4, # default imagenet color jitter
96 | interpolation="random",
97 | mean=mean,
98 | std=std,
99 | )
100 | # For testing:
101 | # - Resize to desired size only, with a center crop step included in
102 | # case the raw image was not square.
103 | test_transform = transforms.Compose(
104 | [
105 | transforms.Resize(image_size),
106 | transforms.CenterCrop(image_size),
107 | transforms.ToTensor(),
108 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
109 | ]
110 | )
111 |
112 | elif transform_type == "digits":
113 | # Appropriate for smaller images containing digits, as in MNIST.
114 | # - Zoom in randomly with scale (small range of how much to zoom in by)
115 | # - Stretch with random aspect ratio
116 | # - Don't flip the images (that would change the digit)
117 | # - Randomly adjust brightness/contrast/saturation
118 | # - (No rotation or skew)
119 | train_transform = timm.data.create_transform(
120 | input_size=image_size,
121 | is_training=True,
122 | scale=(0.7, 1.0), # reduced scale range
123 | ratio=(3.0 / 4.0, 4.0 / 3.0), # default imagenet ratio range
124 | hflip=0.0,
125 | vflip=0.0,
126 | color_jitter=0.4, # default imagenet color jitter
127 | interpolation="random",
128 | mean=mean,
129 | std=std,
130 | )
131 | # For testing:
132 | # - Resize to desired size only, with a center crop step included in
133 | # case the raw image was not square.
134 | test_transform = transforms.Compose(
135 | [
136 | transforms.Resize(image_size),
137 | transforms.CenterCrop(image_size),
138 | transforms.ToTensor(),
139 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
140 | ]
141 | )
142 |
143 | elif transform_type == "autoaugment-imagenet":
144 | # Augmentation policy learnt by AutoAugment, described in
145 | # https://arxiv.org/abs/1805.09501
146 | # The policies mostly concern changing the colours of the image,
147 | # but there is a little rotation and shear too. We need to include
148 | # our own random cropping, stretching, and flipping.
149 | train_transform = transforms.Compose(
150 | [
151 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
152 | transforms.AutoAugment(
153 | policy=transforms.AutoAugmentPolicy.IMAGENET,
154 | interpolation=transforms.InterpolationMode.BILINEAR,
155 | ),
156 | transforms.RandomHorizontalFlip(0.5),
157 | transforms.ToTensor(),
158 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
159 | ]
160 | )
161 | # For testing:
162 | # - Zoom in 87.5%
163 | # - Center crop
164 | test_transform = transforms.Compose(
165 | [
166 | transforms.Resize(int(image_size / 0.875)),
167 | transforms.CenterCrop(image_size),
168 | transforms.ToTensor(),
169 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
170 | ]
171 | )
172 |
173 | elif transform_type == "autoaugment-cifar":
174 | # Augmentation policy learnt by AutoAugment, described in
175 | # https://arxiv.org/abs/1805.09501
176 | # The policies mostly concern changing the colours of the image,
177 | # but there is a little rotation and shear too. We need to include
178 | # our own random cropping, stretching, and flipping.
179 | train_transform = transforms.Compose(
180 | [
181 | transforms.AutoAugment(
182 | policy=transforms.AutoAugmentPolicy.CIFAR10,
183 | interpolation=transforms.InterpolationMode.BILINEAR,
184 | ),
185 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
186 | transforms.RandomHorizontalFlip(0.5),
187 | transforms.ToTensor(),
188 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
189 | ]
190 | )
191 | # For testing:
192 | # - Resize to desired size only, with a center crop step included in
193 | # case the raw image was not square.
194 | test_transform = transforms.Compose(
195 | [
196 | transforms.Resize(image_size),
197 | transforms.CenterCrop(image_size),
198 | transforms.ToTensor(),
199 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
200 | ]
201 | )
202 |
203 | elif transform_type == "randaugment-imagenet":
204 | # Augmentation policy learnt by RandAugment, described in
205 | # https://arxiv.org/abs/1909.13719
206 | train_transform = transforms.Compose(
207 | [
208 | transforms.RandAugment(
209 | interpolation=transforms.InterpolationMode.BILINEAR,
210 | ),
211 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
212 | transforms.RandomHorizontalFlip(0.5),
213 | transforms.ToTensor(),
214 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
215 | ]
216 | )
217 | # For testing:
218 | # - Zoom in 87.5%
219 | # - Center crop
220 | test_transform = transforms.Compose(
221 | [
222 | transforms.Resize(int(image_size / 0.875)),
223 | transforms.CenterCrop(image_size),
224 | transforms.ToTensor(),
225 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
226 | ]
227 | )
228 |
229 | elif transform_type == "randaugment-cifar":
230 | # Augmentation policy learnt by RandAugment, described in
231 | # https://arxiv.org/abs/1909.13719
232 | train_transform = transforms.Compose(
233 | [
234 | transforms.RandAugment(
235 | interpolation=transforms.InterpolationMode.BILINEAR,
236 | ),
237 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
238 | transforms.RandomHorizontalFlip(0.5),
239 | transforms.ToTensor(),
240 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
241 | ]
242 | )
243 | # For testing:
244 | # - Resize to desired size only, with a center crop step included in
245 | # case the raw image was not square.
246 | test_transform = transforms.Compose(
247 | [
248 | transforms.Resize(image_size),
249 | transforms.CenterCrop(image_size),
250 | transforms.ToTensor(),
251 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
252 | ]
253 | )
254 |
255 | elif transform_type == "trivialaugment-imagenet":
256 | # Trivial augmentation policy, described in https://arxiv.org/abs/2103.10158
257 | train_transform = transforms.Compose(
258 | [
259 | transforms.TrivialAugmentWide(
260 | interpolation=transforms.InterpolationMode.BILINEAR,
261 | ),
262 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
263 | transforms.RandomHorizontalFlip(0.5),
264 | transforms.ToTensor(),
265 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
266 | ]
267 | )
268 | # For testing:
269 | # - Zoom in 87.5%
270 | # - Center crop
271 | test_transform = transforms.Compose(
272 | [
273 | transforms.Resize(int(image_size / 0.875)),
274 | transforms.CenterCrop(image_size),
275 | transforms.ToTensor(),
276 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
277 | ]
278 | )
279 |
280 | elif transform_type == "trivialaugment-cifar":
281 | # Trivial augmentation policy, described in https://arxiv.org/abs/2103.10158
282 | train_transform = transforms.Compose(
283 | [
284 | transforms.TrivialAugmentWide(
285 | interpolation=transforms.InterpolationMode.BILINEAR,
286 | ),
287 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
288 | transforms.RandomHorizontalFlip(0.5),
289 | transforms.ToTensor(),
290 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
291 | ]
292 | )
293 | # For testing:
294 | # - Resize to desired size only, with a center crop step included in
295 | # case the raw image was not square.
296 | test_transform = transforms.Compose(
297 | [
298 | transforms.Resize(image_size),
299 | transforms.CenterCrop(image_size),
300 | transforms.ToTensor(),
301 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
302 | ]
303 | )
304 |
305 | elif transform_type == "randomerasing-imagenet":
306 | train_transform = transforms.Compose(
307 | [
308 | transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
309 | transforms.RandomHorizontalFlip(0.5),
310 | transforms.ToTensor(),
311 | transforms.RandomErasing(),
312 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
313 | ]
314 | )
315 | # For testing:
316 | # - Zoom in 87.5%
317 | # - Center crop
318 | test_transform = transforms.Compose(
319 | [
320 | transforms.Resize(int(image_size / 0.875)),
321 | transforms.CenterCrop(image_size),
322 | transforms.ToTensor(),
323 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
324 | ]
325 | )
326 |
327 | elif transform_type == "randomerasing-cifar":
328 | train_transform = transforms.Compose(
329 | [
330 | transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
331 | transforms.RandomHorizontalFlip(0.5),
332 | transforms.ToTensor(),
333 | transforms.RandomErasing(),
334 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
335 | ]
336 | )
337 | # For testing:
338 | # - Resize to desired size only, with a center crop step included in
339 | # case the raw image was not square.
340 | test_transform = transforms.Compose(
341 | [
342 | transforms.Resize(image_size),
343 | transforms.CenterCrop(image_size),
344 | transforms.ToTensor(),
345 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
346 | ]
347 | )
348 |
349 | else:
350 | raise NotImplementedError
351 |
352 | return (train_transform, test_transform)
353 |
--------------------------------------------------------------------------------
/template_experiment/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Handlers for various image datasets.
3 | """
4 |
5 | import os
6 | import socket
7 | import warnings
8 |
9 | import numpy as np
10 | import sklearn.model_selection
11 | import torch
12 | import torchvision.datasets
13 |
14 |
15 | def determine_host():
16 | r"""
17 | Determine which compute server we are on.
18 |
19 | Returns
20 | -------
21 | host : str, one of {"vaughan", "mars"}
22 | An identifier for the host compute system.
23 | """
24 | hostname = socket.gethostname()
25 | slurm_submit_host = os.environ.get("SLURM_SUBMIT_HOST")
26 | slurm_cluster_name = os.environ.get("SLURM_CLUSTER_NAME")
27 |
28 | if slurm_cluster_name and slurm_cluster_name.startswith("vaughan"):
29 | return "vaughan"
30 | if slurm_submit_host and slurm_submit_host in ["q.vector.local", "m.vector.local"]:
31 | return "mars"
32 | if hostname and hostname in ["q.vector.local", "m.vector.local"]:
33 | return "mars"
34 | if hostname and hostname.startswith("v"):
35 | return "vaughan"
36 | if slurm_submit_host and slurm_submit_host.startswith("v"):
37 | return "vaughan"
38 | return ""
39 |
40 |
41 | def image_dataset_sizes(dataset):
42 | r"""
43 | Get the image size and number of classes for a dataset.
44 |
45 | Parameters
46 | ----------
47 | dataset : str
48 | Name of the dataset.
49 |
50 | Returns
51 | -------
52 | num_classes : int
53 | Number of classes in the dataset.
54 | img_size : int or None
55 | Size of the images in the dataset, or None if the images are not all
56 | the same size. Images are assumed to be square.
57 | num_channels : int
58 | Number of colour channels in the images in the dataset. This will be
59 | 1 for greyscale images, and 3 for colour images.
60 | """
61 | dataset = dataset.lower().replace("-", "").replace("_", "").replace(" ", "")
62 |
63 | if dataset == "cifar10":
64 | num_classes = 10
65 | img_size = 32
66 | num_channels = 3
67 |
68 | elif dataset == "cifar100":
69 | num_classes = 100
70 | img_size = 32
71 | num_channels = 3
72 |
73 | elif dataset in ["imagenet", "imagenet1k", "ilsvrc2012"]:
74 | num_classes = 1000
75 | img_size = None
76 | num_channels = 3
77 |
78 | elif dataset.startswith("imagenette"):
79 | num_classes = 10
80 | img_size = None
81 | num_channels = 3
82 |
83 | elif dataset.startswith("imagewoof"):
84 | num_classes = 10
85 | img_size = None
86 | num_channels = 3
87 |
88 | elif dataset == "mnist":
89 | num_classes = 10
90 | img_size = 28
91 | num_channels = 1
92 |
93 | elif dataset == "svhn":
94 | num_classes = 10
95 | img_size = 32
96 | num_channels = 3
97 |
98 | else:
99 | raise ValueError("Unrecognised dataset: {}".format(dataset))
100 |
101 | return num_classes, img_size, num_channels
102 |
103 |
104 | def fetch_image_dataset(
105 | dataset,
106 | root=None,
107 | transform_train=None,
108 | transform_eval=None,
109 | download=False,
110 | ):
111 | r"""
112 | Fetch a train and test dataset object for a given image dataset name.
113 |
114 | Parameters
115 | ----------
116 | dataset : str
117 | Name of dataset.
118 | root : str, optional
119 | Path to root directory containing the dataset.
120 | transform_train : callable, optional
121 | A function/transform that takes in an PIL image and returns a
122 | transformed version, to be applied to the training dataset.
123 | transform_eval : callable, optional
124 | A function/transform that takes in an PIL image and returns a
125 | transformed version, to be applied to the evaluation dataset.
126 | download : bool, optional
127 | Whether to download the dataset to the expected directory if it is not
128 | there. Only supported by some datasets. Default is ``False``.
129 | """
130 | dataset = dataset.lower().replace("-", "").replace("_", "").replace(" ", "")
131 | host = determine_host()
132 |
133 | if dataset == "cifar10":
134 | if root:
135 | pass
136 | elif host == "vaughan":
137 | root = "/scratch/ssd002/datasets/"
138 | elif host == "mars":
139 | root = "/scratch/gobi1/datasets/"
140 | else:
141 | root = "~/Datasets"
142 | dataset_train = torchvision.datasets.CIFAR10(
143 | os.path.join(root, dataset),
144 | train=True,
145 | transform=transform_train,
146 | download=download,
147 | )
148 | dataset_val = None
149 | dataset_test = torchvision.datasets.CIFAR10(
150 | os.path.join(root, dataset),
151 | train=False,
152 | transform=transform_eval,
153 | download=download,
154 | )
155 |
156 | elif dataset == "cifar100":
157 | if root:
158 | pass
159 | elif host == "vaughan":
160 | root = "/scratch/ssd002/datasets/"
161 | else:
162 | root = "~/Datasets"
163 | dataset_train = torchvision.datasets.CIFAR100(
164 | os.path.join(root, dataset),
165 | train=True,
166 | transform=transform_train,
167 | download=download,
168 | )
169 | dataset_val = None
170 | dataset_test = torchvision.datasets.CIFAR100(
171 | os.path.join(root, dataset),
172 | train=False,
173 | transform=transform_eval,
174 | download=download,
175 | )
176 |
177 | elif dataset in ["imagenet", "imagenet1k", "ilsvrc2012"]:
178 | if root:
179 | pass
180 | elif host == "vaughan":
181 | root = "/scratch/ssd004/datasets/"
182 | elif host == "mars":
183 | root = "/scratch/gobi1/datasets/"
184 | else:
185 | root = "~/Datasets"
186 | dataset_train = torchvision.datasets.ImageFolder(
187 | os.path.join(root, "imagenet", "train"),
188 | transform=transform_train,
189 | )
190 | dataset_val = None
191 | dataset_test = torchvision.datasets.ImageFolder(
192 | os.path.join(root, "imagenet", "val"),
193 | transform=transform_eval,
194 | )
195 |
196 | elif dataset == "imagenette":
197 | if root:
198 | root = os.path.join(root, "imagenette")
199 | elif host == "vaughan":
200 | root = "/scratch/ssd004/datasets/imagenette2/full/"
201 | else:
202 | root = "~/Datasets/imagenette/"
203 | dataset_train = torchvision.datasets.ImageFolder(
204 | os.path.join(root, "train"),
205 | transform=transform_train,
206 | )
207 | dataset_val = None
208 | dataset_test = torchvision.datasets.ImageFolder(
209 | os.path.join(root, "val"),
210 | transform=transform_eval,
211 | )
212 |
213 | elif dataset == "imagewoof":
214 | if root:
215 | root = os.path.join(root, "imagewoof")
216 | elif host == "vaughan":
217 | root = "/scratch/ssd004/datasets/imagewoof2/full/"
218 | else:
219 | root = "~/Datasets/imagewoof/"
220 | dataset_train = torchvision.datasets.ImageFolder(
221 | os.path.join(root, "train"),
222 | transform=transform_train,
223 | )
224 | dataset_val = None
225 | dataset_test = torchvision.datasets.ImageFolder(
226 | os.path.join(root, "val"),
227 | transform=transform_eval,
228 | )
229 |
230 | elif dataset == "mnist":
231 | if root:
232 | pass
233 | elif host == "vaughan":
234 | root = "/scratch/ssd004/datasets/"
235 | else:
236 | root = "~/Datasets"
237 | # Will read from [root]/MNIST/processed
238 | dataset_train = torchvision.datasets.MNIST(
239 | root,
240 | train=True,
241 | transform=transform_train,
242 | download=download,
243 | )
244 | dataset_val = None
245 | dataset_test = torchvision.datasets.MNIST(
246 | root,
247 | train=False,
248 | transform=transform_eval,
249 | download=download,
250 | )
251 |
252 | elif dataset == "svhn":
253 | # SVHN has:
254 | # 73,257 digits for training,
255 | # 26,032 digits for testing, and
256 | # 531,131 additional, less difficult, samples to use as extra training data
257 | # We don't use the extra split here, only train. There are original
258 | # images which are large and have bounding boxes, but the pytorch class
259 | # just uses the 32px cropped individual digits.
260 | if root:
261 | pass
262 | elif host == "vaughan":
263 | root = "/scratch/ssd004/datasets/"
264 | elif host == "mars":
265 | root = "/scratch/gobi1/datasets/"
266 | else:
267 | root = "~/Datasets"
268 | dataset_train = torchvision.datasets.SVHN(
269 | os.path.join(root, dataset),
270 | split="train",
271 | transform=transform_train,
272 | download=download,
273 | )
274 | dataset_val = None
275 | dataset_test = torchvision.datasets.SVHN(
276 | os.path.join(root, dataset),
277 | split="test",
278 | transform=transform_eval,
279 | download=download,
280 | )
281 |
282 | else:
283 | raise ValueError("Unrecognised dataset: {}".format(dataset))
284 |
285 | return dataset_train, dataset_val, dataset_test
286 |
287 |
288 | def fetch_dataset(
289 | dataset,
290 | root=None,
291 | prototyping=False,
292 | transform_train=None,
293 | transform_eval=None,
294 | protoval_split_rate=0.1,
295 | protoval_split_id=0,
296 | download=False,
297 | ):
298 | r"""
299 | Fetch a train and test dataset object for a given dataset name.
300 |
301 | Parameters
302 | ----------
303 | dataset : str
304 | Name of dataset.
305 | root : str, optional
306 | Path to root directory containing the dataset.
307 | prototyping : bool, default=False
308 | Whether to return a validation split distinct from the test split.
309 | If ``False``, the validation split will be the same as the test split
310 | for datasets which don't intrincally have a separate val and test
311 | partition.
312 | If ``True``, the validation partition is carved out of the train
313 | partition (resulting in a smaller training set) when there is no
314 | distinct validation partition available.
315 | transform_train : callable, optional
316 | A function/transform that takes in an PIL image and returns a
317 | transformed version, to be applied to the training dataset.
318 | transform_eval : callable, optional
319 | A function/transform that takes in an PIL image and returns a
320 | transformed version, to be applied to the evaluation dataset.
321 | protoval_split_rate : float or str, default=0.1
322 | The fraction of the train data to use for validating when in
323 | prototyping mode. If this is set to "auto", the split rate will be
324 | chosen such that the validation partition is the same size as the test
325 | partition.
326 | protoval_split_id : int, default=0
327 | The identity of the random split used for the train/val partitioning.
328 | This controls the seed of the folds used for the split, and which
329 | fold to use for the validation set.
330 | The seed is equal to ``int(protoval_split_id * protoval_split_rate)``
331 | and the fold is equal to ``protoval_split_id % (1 / protoval_split_rate)``.
332 | download : bool, optional
333 | Whether to download the dataset to the expected directory if it is not
334 | there. Only supported by some datasets. Default is ``False``.
335 |
336 | Returns
337 | -------
338 | dataset_train : torch.utils.data.Dataset
339 | The training dataset.
340 | dataset_val : torch.utils.data.Dataset
341 | The validation dataset.
342 | dataset_test : torch.utils.data.Dataset
343 | The test dataset.
344 | distinct_val_test : bool
345 | Whether the validation and test partitions are distinct (True) or
346 | identical (False).
347 | """
348 | dataset_train, dataset_val, dataset_test = fetch_image_dataset(
349 | dataset=dataset,
350 | root=root,
351 | transform_train=transform_train,
352 | transform_eval=transform_eval,
353 | download=download,
354 | )
355 |
356 | # Handle the validation partition
357 | if dataset_val is not None:
358 | distinct_val_test = True
359 | elif not prototyping:
360 | dataset_val = dataset_test
361 | distinct_val_test = False
362 | else:
363 | # Create our own train/val split.
364 | #
365 | # For the validation part, we need a copy of dataset_train with the
366 | # evaluation transform instead.
367 | # The transform argument is *probably* going to be set to an attribute
368 | # on the dataset object called transform and called from there. But we
369 | # can't be completely sure, so to be completely agnostic about the
370 | # internals of the dataset class let's instantiate the dataset again!
371 | dataset_val = fetch_dataset(
372 | dataset,
373 | root=root,
374 | prototyping=False,
375 | transform_train=transform_eval,
376 | )[0]
377 | # dataset_val is a copy of the full training set, but with the transform
378 | # changed to transform_eval
379 | # Handle automatic validation partition sizing option.
380 | if not isinstance(protoval_split_rate, str):
381 | pass
382 | elif protoval_split_rate == "auto":
383 | # We want the validation set to be the same size as the test set.
384 | # This is the same as having a split rate of 1 - test_size.
385 | protoval_split_rate = len(dataset_test) / len(dataset_train)
386 | else:
387 | raise ValueError(f"Unsupported protoval_split_rate: {protoval_split_rate}")
388 | # Create the train/val split using these dataset objects.
389 | dataset_train, dataset_val = create_train_val_split(
390 | dataset_train,
391 | dataset_val,
392 | split_rate=protoval_split_rate,
393 | split_id=protoval_split_id,
394 | )
395 | distinct_val_test = True
396 |
397 | return (
398 | dataset_train,
399 | dataset_val,
400 | dataset_test,
401 | distinct_val_test,
402 | )
403 |
404 |
405 | def create_train_val_split(
406 | dataset_train,
407 | dataset_val=None,
408 | split_rate=0.1,
409 | split_id=0,
410 | ):
411 | r"""
412 | Create a train/val split of a dataset.
413 |
414 | Parameters
415 | ----------
416 | dataset_train : torch.utils.data.Dataset
417 | The full training dataset with training transforms.
418 | dataset_val : torch.utils.data.Dataset, optional
419 | The full training dataset with evaluation transforms.
420 | If this is not given, the source for the validation set will be
421 | ``dataset_test`` (with the same transforms as the training partition).
422 | Note that ``dataset_val`` must have the same samples as
423 | ``dataset_train``, and the samples must be in the same order.
424 | split_rate : float, default=0.1
425 | The fraction of the train data to use for the validation split.
426 | split_id : int, default=0
427 | The identity of the split to use.
428 | This controls the seed of the folds used for the split, and which
429 | fold to use for the validation set.
430 | The seed is equal to ``int(split_id * split_rate)``
431 | and the fold is equal to ``split_id % (1 / split_rate)``.
432 |
433 | Returns
434 | -------
435 | dataset_train : torch.utils.data.Dataset
436 | The training subset of the dataset.
437 | dataset_val : torch.utils.data.Dataset
438 | The validation subset of the dataset.
439 | """
440 | if dataset_val is None:
441 | dataset_val = dataset_train
442 | # Now we need to reduce it down to just a subset of the training set.
443 | # Let's use K-folds so subsequent prototype split IDs will have
444 | # non-overlapping validation sets. With split_rate = 0.1,
445 | # there will be 10 folds.
446 | n_splits = round(1.0 / split_rate)
447 | if (1.0 / n_splits) != split_rate:
448 | warnings.warn(
449 | "The requested train/val split rate is not possible when using"
450 | " dataset into K folds. The actual split rate will be"
451 | f" {1.0 / n_splits} instead of {split_rate}.",
452 | UserWarning,
453 | stacklevel=2,
454 | )
455 | split_seed = int(split_id * split_rate)
456 | fold_id = split_id % n_splits
457 | print(
458 | f"Creating prototyping train/val split #{split_id}."
459 | f" Using fold {fold_id} of {n_splits} folds, generated with seed"
460 | f" {split_seed}."
461 | )
462 | # Try to do a stratified split.
463 | classes = get_dataset_labels(dataset_train)
464 | if classes is None:
465 | warnings.warn(
466 | "Creating prototyping splits without stratification.",
467 | UserWarning,
468 | stacklevel=2,
469 | )
470 | splitter_ftry = sklearn.model_selection.KFold
471 | else:
472 | splitter_ftry = sklearn.model_selection.StratifiedKFold
473 |
474 | # Create our splits. Assuming the dataset objects are always loaded
475 | # the same way, since a given split ID will always be the same
476 | # fold from the same seeded KFold splitter, it will yield the same
477 | # train/val split on each run.
478 | splitter = splitter_ftry(n_splits=n_splits, shuffle=True, random_state=split_seed)
479 | splits = splitter.split(np.arange(len(dataset_train)), classes)
480 | # splits is an iterable and we want to take the n-th fold from it.
481 | for i, (train_indices, val_indices) in enumerate(splits): # noqa: B007
482 | if i == fold_id:
483 | break
484 | dataset_train = torch.utils.data.Subset(dataset_train, train_indices)
485 | dataset_val = torch.utils.data.Subset(dataset_val, val_indices)
486 | return dataset_train, dataset_val
487 |
488 |
489 | def get_dataset_labels(dataset):
490 | r"""
491 | Get the class labels within a :class:`torch.utils.data.Dataset` object.
492 |
493 | Parameters
494 | ----------
495 | dataset : torch.utils.data.Dataset
496 | The dataset object.
497 |
498 | Returns
499 | -------
500 | array_like or None
501 | The class labels for each sample.
502 | """
503 | if isinstance(dataset, torch.utils.data.Subset):
504 | # For a dataset subset, we need to get the full set of labels from the
505 | # interior subset and then reduce them down to just the labels we have
506 | # in the subset.
507 | labels = get_dataset_labels(dataset.dataset)
508 | if labels is None:
509 | return labels
510 | return np.array(labels)[dataset.indices]
511 |
512 | labels = None
513 | if hasattr(dataset, "targets"):
514 | # MNIST, CIFAR, ImageFolder, etc
515 | labels = dataset.targets
516 | elif hasattr(dataset, "labels"):
517 | # STL10, SVHN
518 | labels = dataset.labels
519 | elif hasattr(dataset, "_labels"):
520 | # Flowers102
521 | labels = dataset._labels
522 |
523 | return labels
524 |
--------------------------------------------------------------------------------
/template_experiment/encoders.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import timm
4 | import torch
5 | from timm.data import resolve_data_config
6 |
7 |
8 | def get_timm_encoder(model_name, pretrained=False, in_chans=3):
9 | r"""
10 | Get the encoder model and its configuration from timm.
11 |
12 | Parameters
13 | ----------
14 | model_name : str
15 | Name of the model to load.
16 | pretrained : bool, default=False
17 | Whether to load the model with pretrained weights.
18 | in_chans : int, default=3
19 | Number of input channels.
20 |
21 | Returns
22 | -------
23 | encoder : torch.nn.Module
24 | The encoder model (with pretrained weights loaded if requested).
25 | encoder_config : dict
26 | The data configuration of the encoder model.
27 | """
28 | if len(timm.list_models(model_name)) == 0:
29 | warnings.warn(
30 | f"Unrecognized model '{model_name}'. Trying to fetch it from the hugging-face hub.",
31 | UserWarning,
32 | stacklevel=2,
33 | )
34 | model_name = "hf-hub:timm/" + model_name
35 |
36 | # We request the model without the classification head (num_classes=0)
37 | # to get it is an encoder-only model
38 | encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0, in_chans=in_chans)
39 | encoder_config = resolve_data_config({}, model=encoder)
40 |
41 | # Send a dummy input through the encoder to find out the shape of its output
42 | encoder.eval()
43 | dummy_output = encoder(torch.zeros((1, *encoder_config["input_size"])))
44 | encoder_config["n_feature"] = dummy_output.shape[1]
45 | encoder.train()
46 |
47 | encoder_config["in_channels"] = encoder_config["input_size"][0]
48 |
49 | return encoder, encoder_config
50 |
--------------------------------------------------------------------------------
/template_experiment/evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluation routines.
3 | """
4 |
5 | import numpy as np
6 | import sklearn.metrics
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from . import utils
11 |
12 |
13 | def evaluate(
14 | dataloader,
15 | model,
16 | device,
17 | partition_name="Val",
18 | verbosity=1,
19 | is_distributed=False,
20 | ):
21 | r"""
22 | Evaluate model performance on a dataset.
23 |
24 | Parameters
25 | ----------
26 | dataloader : torch.utils.data.DataLoader
27 | Dataloader for the dataset to evaluate on.
28 | model : torch.nn.Module
29 | Model to evaluate.
30 | device : torch.device
31 | Device to run the model on.
32 | partition_name : str, default="Val"
33 | Name of the partition being evaluated.
34 | verbosity : int, default=1
35 | Verbosity level.
36 | is_distributed : bool, default=False
37 | Whether the model is distributed across multiple GPUs.
38 |
39 | Returns
40 | -------
41 | results : dict
42 | Dictionary of evaluation results.
43 | """
44 | model.eval()
45 |
46 | y_true_all = []
47 | y_pred_all = []
48 | xent_all = []
49 |
50 | for stimuli, y_true in dataloader:
51 | stimuli = stimuli.to(device)
52 | y_true = y_true.to(device)
53 | with torch.no_grad():
54 | logits = model(stimuli)
55 | xent = F.cross_entropy(logits, y_true, reduction="none")
56 | y_pred = torch.argmax(logits, dim=-1)
57 |
58 | if is_distributed:
59 | # Fetch results from other GPUs
60 | xent = utils.concat_all_gather_ragged(xent)
61 | y_true = utils.concat_all_gather_ragged(y_true)
62 | y_pred = utils.concat_all_gather_ragged(y_pred)
63 |
64 | xent_all.append(xent.cpu().numpy())
65 | y_true_all.append(y_true.cpu().numpy())
66 | y_pred_all.append(y_pred.cpu().numpy())
67 |
68 | # Concatenate the targets and predictions from each batch
69 | xent = np.concatenate(xent_all)
70 | y_true = np.concatenate(y_true_all)
71 | y_pred = np.concatenate(y_pred_all)
72 | # If the dataset size was not evenly divisible by the world size,
73 | # DistributedSampler will pad the end of the list of samples
74 | # with some repetitions. We need to trim these off.
75 | n_samples = len(dataloader.dataset)
76 | xent = xent[:n_samples]
77 | y_true = y_true[:n_samples]
78 | y_pred = y_pred[:n_samples]
79 | # Create results dictionary
80 | results = {}
81 | results["count"] = len(y_true)
82 | results["cross-entropy"] = np.mean(xent)
83 | # Note that these evaluation metrics have all been converted to percentages
84 | results["accuracy"] = 100.0 * sklearn.metrics.accuracy_score(y_true, y_pred)
85 | results["accuracy-balanced"] = 100.0 * sklearn.metrics.balanced_accuracy_score(y_true, y_pred)
86 | results["f1-micro"] = 100.0 * sklearn.metrics.f1_score(y_true, y_pred, average="micro")
87 | results["f1-macro"] = 100.0 * sklearn.metrics.f1_score(y_true, y_pred, average="macro")
88 | results["f1-support"] = 100.0 * sklearn.metrics.f1_score(y_true, y_pred, average="weighted")
89 | # Could expand to other metrics too
90 |
91 | if verbosity >= 1:
92 | print(f"\n{partition_name} evaluation results:")
93 | for k, v in results.items():
94 | if k == "count":
95 | print(f" {k + ' ':.<21s}{v:7d}")
96 | elif "entropy" in k:
97 | print(f" {k + ' ':.<24s} {v:9.5f} nat")
98 | else:
99 | print(f" {k + ' ':.<24s} {v:6.2f} %")
100 |
101 | return results
102 |
--------------------------------------------------------------------------------
/template_experiment/io.py:
--------------------------------------------------------------------------------
1 | """
2 | Input/output utilities.
3 | """
4 |
5 | import os
6 | from inspect import getsourcefile
7 |
8 | import torch
9 |
10 | PACKAGE_DIR = os.path.dirname(os.path.abspath(getsourcefile(lambda: 0)))
11 |
12 |
13 | def get_project_root() -> str:
14 | return os.path.dirname(PACKAGE_DIR)
15 |
16 |
17 | def safe_save_model(modules, checkpoint_path=None, config=None, **kwargs):
18 | """
19 | Save a model to a checkpoint file, along with any additional data.
20 |
21 | Parameters
22 | ----------
23 | modules : dict
24 | A dictionary of modules to save. The keys are the names of the modules
25 | and the values are the modules themselves.
26 | checkpoint_path : str, optional
27 | Path to the checkpoint file. If not provided, the path will be taken
28 | from the config object.
29 | config : :class:`argparse.Namespace`, optional
30 | A configuration object containing the checkpoint path.
31 | **kwargs
32 | Additional data to save to the checkpoint file.
33 | """
34 | if checkpoint_path is not None:
35 | pass
36 | elif config is not None and hasattr(config, "checkpoint_path"):
37 | checkpoint_path = config.checkpoint_path
38 | else:
39 | raise ValueError("No checkpoint path provided")
40 | print(f"\nSaving model to {checkpoint_path}")
41 | # Create the directory if it doesn't already exist
42 | os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
43 | # Save to a temporary file first, then move the temporary file to the target
44 | # destination. This is to prevent clobbering the checkpoint with a partially
45 | # saved file, in the event that the saving process is interrupted. Saving
46 | # the checkpoint takes a little while and can be disrupted by preemption,
47 | # whereas moving the file is an atomic operation.
48 | tmp_a, tmp_b = os.path.split(checkpoint_path)
49 | tmp_fname = os.path.join(tmp_a, ".tmp." + tmp_b)
50 | data = {k: v.state_dict() for k, v in modules.items()}
51 | data.update(kwargs)
52 | if config is not None:
53 | data["config"] = config
54 |
55 | torch.save(data, tmp_fname)
56 | os.rename(tmp_fname, checkpoint_path)
57 | print(f"Saved model to {checkpoint_path}")
58 |
--------------------------------------------------------------------------------
/template_experiment/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import builtins
4 | import copy
5 | import os
6 | import shutil
7 | import time
8 | import warnings
9 | from contextlib import nullcontext
10 | from datetime import datetime
11 | from socket import gethostname
12 |
13 | import torch
14 | import torch.optim
15 | from torch import nn
16 | from torch.utils.data.distributed import DistributedSampler
17 |
18 | from template_experiment import data_transformations, datasets, encoders, utils
19 | from template_experiment.evaluation import evaluate
20 | from template_experiment.io import safe_save_model
21 |
22 | BASE_BATCH_SIZE = 128
23 |
24 |
25 | def check_is_distributed():
26 | r"""
27 | Check if the current job is running in distributed mode.
28 |
29 | Returns
30 | -------
31 | bool
32 | Whether the job is running in distributed mode.
33 | """
34 | return (
35 | "WORLD_SIZE" in os.environ
36 | and "RANK" in os.environ
37 | and "LOCAL_RANK" in os.environ
38 | and "MASTER_ADDR" in os.environ
39 | and "MASTER_PORT" in os.environ
40 | )
41 |
42 |
43 | def setup_slurm_distributed():
44 | r"""
45 | Use SLURM environment variables to set up environment variables needed for DDP.
46 |
47 | Note: This is not used when using torchrun, as that sets RANK etc. for us,
48 | but is useful if you're using srun without torchrun (i.e. using srun within
49 | the sbatch file to lauching one task per GPU).
50 | """
51 | if "WORLD_SIZE" in os.environ:
52 | pass
53 | elif "SLURM_NNODES" in os.environ and "SLURM_GPUS_ON_NODE" in os.environ:
54 | os.environ["WORLD_SIZE"] = str(int(os.environ["SLURM_NNODES"]) * int(os.environ["SLURM_GPUS_ON_NODE"]))
55 | elif "SLURM_NPROCS" in os.environ:
56 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]
57 | if "RANK" not in os.environ and "SLURM_PROCID" in os.environ:
58 | os.environ["RANK"] = os.environ["SLURM_PROCID"]
59 | if int(os.environ["RANK"]) > 0 and "WORLD_SIZE" not in os.environ:
60 | raise EnvironmentError(
61 | f"SLURM_PROCID is {os.environ['SLURM_PROCID']}, implying"
62 | " distributed training, but WORLD_SIZE could not be determined."
63 | )
64 | if "LOCAL_RANK" not in os.environ and "SLURM_LOCALID" in os.environ:
65 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
66 | if "MASTER_ADDR" not in os.environ and "SLURM_NODELIST" in os.environ:
67 | os.environ["MASTER_ADDR"] = os.environ["SLURM_NODELIST"].split("-")[0]
68 | if "MASTER_PORT" not in os.environ and "SLURM_JOB_ID" in os.environ:
69 | os.environ["MASTER_PORT"] = str(49152 + int(os.environ["SLURM_JOB_ID"]) % 16384)
70 |
71 |
72 | def run(config):
73 | r"""
74 | Run training job (one worker if using distributed training).
75 |
76 | Parameters
77 | ----------
78 | config : argparse.Namespace or OmegaConf
79 | The configuration for this experiment.
80 | """
81 | if config.log_wandb:
82 | # Lazy import of wandb, since logging to wandb is optional
83 | import wandb
84 |
85 | if config.seed is not None:
86 | utils.set_rng_seeds_fixed(config.seed)
87 |
88 | if config.deterministic:
89 | print("Running in deterministic cuDNN mode. Performance may be slower, but more reproducible.")
90 | torch.backends.cudnn.deterministic = True
91 | torch.backends.cudnn.benchmark = False
92 |
93 | # DISTRIBUTION ============================================================
94 | # Setup for distributed training
95 | setup_slurm_distributed()
96 | config.world_size = int(os.environ.get("WORLD_SIZE", 1))
97 | config.distributed = check_is_distributed()
98 | if config.world_size > 1 and not config.distributed:
99 | raise EnvironmentError(
100 | f"WORLD_SIZE is {config.world_size}, but not all other required"
101 | " environment variables for distributed training are set."
102 | )
103 | # Work out the total batch size depending on the number of GPUs we are using
104 | config.batch_size = config.batch_size_per_gpu * config.world_size
105 |
106 | if config.distributed:
107 | # For multiprocessing distributed training, gpu rank needs to be
108 | # set to the global rank among all the processes.
109 | config.global_rank = int(os.environ["RANK"])
110 | config.local_rank = int(os.environ["LOCAL_RANK"])
111 | print(
112 | f"Rank {config.global_rank} of {config.world_size} on {gethostname()}"
113 | f" (local GPU {config.local_rank} of {torch.cuda.device_count()})."
114 | f" Communicating with master at {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
115 | )
116 | torch.distributed.init_process_group(backend="nccl")
117 | else:
118 | config.global_rank = 0
119 |
120 | # Suppress printing if this is not the master process for the node
121 | if config.distributed and config.global_rank != 0:
122 |
123 | def print_pass(*args, **kwargs):
124 | pass
125 |
126 | builtins.print = print_pass
127 |
128 | print()
129 | print("Configuration:")
130 | print()
131 | print(config)
132 | print()
133 | print(f"Found {torch.cuda.device_count()} GPUs and {utils.get_num_cpu_available()} CPUs.")
134 |
135 | # Check which device to use
136 | use_cuda = not config.no_cuda and torch.cuda.is_available()
137 |
138 | if config.distributed and not use_cuda:
139 | raise EnvironmentError("Distributed training with NCCL requires CUDA.")
140 | if not use_cuda:
141 | device = torch.device("cpu")
142 | elif config.local_rank is not None:
143 | device = f"cuda:{config.local_rank}"
144 | else:
145 | device = "cuda"
146 |
147 | print(f"Using device {device}")
148 |
149 | # RESTORE OMITTED CONFIG FROM RESUMPTION CHECKPOINT =======================
150 | checkpoint = None
151 | config.model_output_dir = None
152 | if config.checkpoint_path:
153 | config.model_output_dir = os.path.dirname(config.checkpoint_path)
154 | if not config.checkpoint_path:
155 | # Not trying to resume from a checkpoint
156 | pass
157 | elif not os.path.isfile(config.checkpoint_path):
158 | # Looks like we're trying to resume from the checkpoint that this job
159 | # will itself create. Let's assume this is to let the job resume upon
160 | # preemption, and it just hasn't been preempted yet.
161 | print(f"Skipping premature resumption from preemption: no checkpoint file found at '{config.checkpoint_path}'")
162 | else:
163 | print(f"Loading resumption checkpoint '{config.checkpoint_path}'")
164 | # Map model parameters to be load to the specified gpu.
165 | checkpoint = torch.load(config.checkpoint_path, map_location=device)
166 | keys = vars(get_parser().parse_args("")).keys()
167 | keys = set(keys).difference(["resume", "gpu", "global_rank", "local_rank", "cpu_workers"])
168 | for key in keys:
169 | if getattr(checkpoint["config"], key, None) is None:
170 | continue
171 | if getattr(config, key) is None:
172 | print(f" Restoring config value for {key} from checkpoint: {getattr(checkpoint['config'], key)}")
173 | setattr(config, key, getattr(checkpoint["config"], key, None))
174 | elif getattr(config, key) != getattr(checkpoint["config"], key):
175 | print(
176 | f" Warning: config value for {key} differs from checkpoint:"
177 | f" {getattr(config, key)} (ours) vs {getattr(checkpoint['config'], key)} (checkpoint)"
178 | )
179 |
180 | if checkpoint is None:
181 | # Our epochs go from 1 to n_epoch, inclusive
182 | start_epoch = 1
183 | else:
184 | # Continue from where we left off
185 | start_epoch = checkpoint["epoch"] + 1
186 | if config.seed is not None:
187 | # Make sure we don't get the same behaviour as we did on the
188 | # first epoch repeated on this resumed epoch.
189 | utils.set_rng_seeds_fixed(config.seed + start_epoch, all_gpu=False)
190 |
191 | # MODEL ===================================================================
192 |
193 | # Encoder -----------------------------------------------------------------
194 | # Build our Encoder.
195 | # We have to build the encoder before we load the dataset because it will
196 | # inform us about what size images we should produce in the preprocessing pipeline.
197 | n_class, raw_img_size, img_channels = datasets.image_dataset_sizes(config.dataset_name)
198 | if img_channels > 3 and config.freeze_encoder:
199 | raise ValueError(
200 | "Using a dataset with more than 3 image channels will require retraining the encoder"
201 | ", but a frozen encoder was requested."
202 | )
203 | if config.arch_framework == "timm":
204 | encoder, encoder_config = encoders.get_timm_encoder(config.arch, config.pretrained, in_chans=img_channels)
205 | elif config.arch_framework == "torchvision":
206 | # It's trickier to implement this for torchvision models, because they
207 | # don't have the same naming conventions for model names as in timm;
208 | # need us to specify the name of the weights when loading a pretrained
209 | # model; and don't support changing the number of input channels.
210 | raise NotImplementedError(f"Unsupported architecture framework: {config.arch_framework}")
211 | else:
212 | raise ValueError(f"Unknown architecture framework: {config.arch_framework}")
213 |
214 | if config.freeze_encoder and not config.pretrained:
215 | warnings.warn(
216 | "A frozen encoder was requested, but the encoder is not pretrained.",
217 | UserWarning,
218 | stacklevel=2,
219 | )
220 |
221 | if config.image_size is None:
222 | if "input_size" in encoder_config:
223 | config.image_size = encoder_config["input_size"][-1]
224 | print(f"Setting model input image size to encoder's expected input size: {config.image_size}")
225 | else:
226 | config.image_size = 224
227 | print(f"Setting model input image size to default: {config.image_size}")
228 | if raw_img_size:
229 | warnings.warn(
230 | "Be aware that we are using a different input image size"
231 | f" ({config.image_size}px) to the raw image size in the"
232 | f" dataset ({raw_img_size}px).",
233 | UserWarning,
234 | stacklevel=2,
235 | )
236 | elif "input_size" in encoder_config and config.pretrained and encoder_config["input_size"][-1] != config.image_size:
237 | warnings.warn(
238 | f"A different image size {config.image_size} than what the model was"
239 | f" pretrained with {encoder_config['input_size'][-1]} was suplied",
240 | UserWarning,
241 | stacklevel=2,
242 | )
243 |
244 | # Classifier -------------------------------------------------------------
245 | # Build our classifier head
246 | classifier = nn.Linear(encoder_config["n_feature"], n_class)
247 |
248 | # Configure model for distributed training --------------------------------
249 | print("\nEncoder architecture:")
250 | print(encoder)
251 | print("\nClassifier architecture:")
252 | print(classifier)
253 | print()
254 |
255 | if config.cpu_workers is None:
256 | config.cpu_workers = utils.get_num_cpu_available()
257 |
258 | if not use_cuda:
259 | print("Using CPU (this will be slow)")
260 | elif config.distributed:
261 | # Convert batchnorm into SyncBN, using stats computed from all GPUs
262 | encoder = nn.SyncBatchNorm.convert_sync_batchnorm(encoder)
263 | classifier = nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
264 | # For multiprocessing distributed, the DistributedDataParallel
265 | # constructor should always set a single device scope, otherwise
266 | # DistributedDataParallel will use all available devices.
267 | encoder.to(device)
268 | classifier.to(device)
269 | torch.cuda.set_device(device)
270 | encoder = nn.parallel.DistributedDataParallel(
271 | encoder, device_ids=[config.local_rank], output_device=config.local_rank
272 | )
273 | classifier = nn.parallel.DistributedDataParallel(
274 | classifier, device_ids=[config.local_rank], output_device=config.local_rank
275 | )
276 | else:
277 | if config.local_rank is not None:
278 | torch.cuda.set_device(config.local_rank)
279 | encoder = encoder.to(device)
280 | classifier = classifier.to(device)
281 |
282 | # DATASET =================================================================
283 | # Transforms --------------------------------------------------------------
284 | transform_args = {}
285 | if config.dataset_name in data_transformations.VALID_TRANSFORMS:
286 | transform_args["normalization"] = config.dataset_name
287 |
288 | if "mean" in encoder_config:
289 | transform_args["mean"] = encoder_config["mean"]
290 | if "std" in encoder_config:
291 | transform_args["std"] = encoder_config["std"]
292 |
293 | transform_train, transform_eval = data_transformations.get_transform(
294 | config.transform_type, config.image_size, transform_args
295 | )
296 |
297 | # Dataset -----------------------------------------------------------------
298 | dataset_args = {
299 | "dataset": config.dataset_name,
300 | "root": config.data_dir,
301 | "prototyping": config.prototyping,
302 | "download": config.allow_download_dataset,
303 | }
304 | if config.protoval_split_id is not None:
305 | dataset_args["protoval_split_id"] = config.protoval_split_id
306 | (
307 | dataset_train,
308 | dataset_val,
309 | dataset_test,
310 | distinct_val_test,
311 | ) = datasets.fetch_dataset(
312 | **dataset_args,
313 | transform_train=transform_train,
314 | transform_eval=transform_eval,
315 | )
316 | eval_set = "Val" if distinct_val_test else "Test"
317 |
318 | # Dataloader --------------------------------------------------------------
319 | dl_train_kwargs = {
320 | "batch_size": config.batch_size_per_gpu,
321 | "drop_last": True,
322 | "sampler": None,
323 | "shuffle": True,
324 | "worker_init_fn": utils.worker_seed_fn,
325 | }
326 | dl_test_kwargs = {
327 | "batch_size": config.batch_size_per_gpu,
328 | "drop_last": False,
329 | "sampler": None,
330 | "shuffle": False,
331 | "worker_init_fn": utils.worker_seed_fn,
332 | }
333 | if use_cuda:
334 | cuda_kwargs = {"num_workers": config.cpu_workers, "pin_memory": True}
335 | dl_train_kwargs.update(cuda_kwargs)
336 | dl_test_kwargs.update(cuda_kwargs)
337 |
338 | dl_val_kwargs = copy.deepcopy(dl_test_kwargs)
339 |
340 | if config.distributed:
341 | # The DistributedSampler breaks up the dataset across the GPUs
342 | dl_train_kwargs["sampler"] = DistributedSampler(
343 | dataset_train,
344 | shuffle=True,
345 | seed=config.seed if config.seed is not None else 0,
346 | drop_last=False,
347 | )
348 | dl_train_kwargs["shuffle"] = None
349 | dl_val_kwargs["sampler"] = DistributedSampler(
350 | dataset_val,
351 | shuffle=False,
352 | drop_last=False,
353 | )
354 | dl_val_kwargs["shuffle"] = None
355 | dl_test_kwargs["sampler"] = DistributedSampler(
356 | dataset_test,
357 | shuffle=False,
358 | drop_last=False,
359 | )
360 | dl_test_kwargs["shuffle"] = None
361 |
362 | dataloader_train = torch.utils.data.DataLoader(dataset_train, **dl_train_kwargs)
363 | dataloader_val = torch.utils.data.DataLoader(dataset_val, **dl_val_kwargs)
364 | dataloader_test = torch.utils.data.DataLoader(dataset_test, **dl_test_kwargs)
365 |
366 | # OPTIMIZATION ============================================================
367 | # Optimizer ---------------------------------------------------------------
368 | # Set up the optimizer
369 |
370 | # Bigger batch sizes mean better estimates of the gradient, so we can use a
371 | # bigger learning rate. See https://arxiv.org/abs/1706.02677
372 | # Hence we scale the learning rate linearly with the total batch size.
373 | config.lr = config.lr_relative * config.batch_size / BASE_BATCH_SIZE
374 |
375 | # Freeze the encoder, if requested
376 | if config.freeze_encoder:
377 | for m in encoder.parameters():
378 | m.requires_grad = False
379 |
380 | # Set up a parameter group for each component of the model, allowing
381 | # them to have different learning rates (for fine-tuning encoder).
382 | params = []
383 | if not config.freeze_encoder:
384 | params.append(
385 | {
386 | "params": encoder.parameters(),
387 | "lr": config.lr * config.lr_encoder_mult,
388 | "name": "encoder",
389 | }
390 | )
391 | params.append(
392 | {
393 | "params": classifier.parameters(),
394 | "lr": config.lr * config.lr_classifier_mult,
395 | "name": "classifier",
396 | }
397 | )
398 |
399 | # Fetch the constructor of the appropriate optimizer from torch.optim
400 | optimizer = getattr(torch.optim, config.optimizer)(params, lr=config.lr, weight_decay=config.weight_decay)
401 |
402 | # Scheduler ---------------------------------------------------------------
403 | # Set up the learning rate scheduler
404 | if config.scheduler.lower() == "onecycle":
405 | scheduler = torch.optim.lr_scheduler.OneCycleLR(
406 | optimizer,
407 | [p["lr"] for p in optimizer.param_groups],
408 | epochs=config.epochs,
409 | steps_per_epoch=len(dataloader_train),
410 | )
411 | else:
412 | raise NotImplementedError(f"Scheduler {config.scheduler} not supported.")
413 |
414 | # Loss function -----------------------------------------------------------
415 | # Set up loss function
416 | criterion = nn.CrossEntropyLoss()
417 |
418 | # LOGGING =================================================================
419 | # Setup logging and saving
420 |
421 | # If we're using wandb, initialize the run, or resume it if the job was preempted.
422 | if config.log_wandb and config.global_rank == 0:
423 | wandb_run_name = config.run_name
424 | if wandb_run_name is not None and config.run_id is not None:
425 | wandb_run_name = f"{wandb_run_name}__{config.run_id}"
426 | EXCLUDED_WANDB_CONFIG_KEYS = [
427 | "log_wandb",
428 | "wandb_entity",
429 | "wandb_project",
430 | "global_rank",
431 | "local_rank",
432 | "run_name",
433 | "run_id",
434 | "model_output_dir",
435 | ]
436 | utils.init_or_resume_wandb_run(
437 | config.model_output_dir,
438 | name=wandb_run_name,
439 | id=config.run_id,
440 | entity=config.wandb_entity,
441 | project=config.wandb_project,
442 | config=wandb.helper.parse_config(config, exclude=EXCLUDED_WANDB_CONFIG_KEYS),
443 | job_type="train",
444 | tags=["prototype" if config.prototyping else "final"],
445 | )
446 | # If a run_id was not supplied at the command prompt, wandb will
447 | # generate a name. Let's use that as the run_name.
448 | if config.run_name is None:
449 | config.run_name = wandb.run.name
450 | if config.run_id is None:
451 | config.run_id = wandb.run.id
452 |
453 | # If we still don't have a run name, generate one from the current time.
454 | if config.run_name is None:
455 | config.run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
456 | if config.run_id is None:
457 | config.run_id = utils.generate_id()
458 |
459 | # If no checkpoint path was supplied, but models_dir was, we will automatically
460 | # determine the path to which we will save the model checkpoint.
461 | # If both are empty, we won't save the model.
462 | if not config.checkpoint_path and config.models_dir:
463 | config.model_output_dir = os.path.join(
464 | config.models_dir,
465 | config.dataset_name,
466 | f"{config.run_name}__{config.run_id}",
467 | )
468 | config.checkpoint_path = os.path.join(config.model_output_dir, "checkpoint_latest.pt")
469 | if config.log_wandb and config.global_rank == 0:
470 | wandb.config.update({"checkpoint_path": config.checkpoint_path}, allow_val_change=True)
471 |
472 | if config.checkpoint_path is None:
473 | print("Model will not be saved.")
474 | else:
475 | print(f"Model will be saved to '{config.checkpoint_path}'")
476 |
477 | # RESUME ==================================================================
478 | # Now that everything is set up, we can load the state of the model,
479 | # optimizer, and scheduler from a checkpoint, if supplied.
480 |
481 | # Initialize step related variables as if we're starting from scratch.
482 | # Their values will be overridden by the checkpoint if we're resuming.
483 | total_step = 0
484 | n_samples_seen = 0
485 |
486 | best_stats = {"max_accuracy": 0, "best_epoch": 0}
487 |
488 | if checkpoint is not None:
489 | print(f"Loading state from checkpoint (epoch {checkpoint['epoch']})")
490 | # Map model to be loaded to specified single gpu.
491 | total_step = checkpoint["total_step"]
492 | n_samples_seen = checkpoint["n_samples_seen"]
493 | encoder.load_state_dict(checkpoint["encoder"])
494 | classifier.load_state_dict(checkpoint["classifier"])
495 | optimizer.load_state_dict(checkpoint["optimizer"])
496 | scheduler.load_state_dict(checkpoint["scheduler"])
497 | best_stats["max_accuracy"] = checkpoint.get("max_accuracy", 0)
498 | best_stats["best_epoch"] = checkpoint.get("best_epoch", 0)
499 |
500 | # TRAIN ===================================================================
501 | print()
502 | print("Configuration:")
503 | print()
504 | print(config)
505 | print()
506 |
507 | # Ensure modules are on the correct device
508 | encoder = encoder.to(device)
509 | classifier = classifier.to(device)
510 |
511 | # Stack the encoder and classifier together to create an overall model.
512 | # At inference time, we don't need to make a distinction between modules
513 | # within this stack.
514 | model = nn.Sequential(encoder, classifier)
515 |
516 | timing_stats = {}
517 | t_end_epoch = time.time()
518 | for epoch in range(start_epoch, config.epochs + 1):
519 | t_start_epoch = time.time()
520 | if config.seed is not None:
521 | # If the job is resumed from preemption, our RNG state is currently set the
522 | # same as it was at the start of the first epoch, not where it was when we
523 | # stopped training. This is not good as it means jobs which are resumed
524 | # don't do the same thing as they would be if they'd run uninterrupted
525 | # (making preempted jobs non-reproducible).
526 | # To address this, we reset the seed at the start of every epoch. Since jobs
527 | # can only save at the end of and resume at the start of an epoch, this
528 | # makes the training process reproducible. But we shouldn't use the same
529 | # RNG state for each epoch - instead we use the original seed to define the
530 | # series of seeds that we will use at the start of each epoch.
531 | epoch_seed = utils.determine_epoch_seed(config.seed, epoch=epoch)
532 | # We want each GPU to have a different seed to the others to avoid
533 | # correlated randomness between the workers on the same batch.
534 | # We offset the seed for this epoch by the GPU rank, so every GPU will get a
535 | # unique seed for the epoch. This means the job is only precisely
536 | # reproducible if it is rerun with the same number of GPUs (and the same
537 | # number of CPU workers for the dataloader).
538 | utils.set_rng_seeds_fixed(epoch_seed + config.global_rank, all_gpu=False)
539 | if isinstance(getattr(dataloader_train, "generator", None), torch.Generator):
540 | # Finesse the dataloader's RNG state, if it is not using the global state.
541 | dataloader_train.generator.manual_seed(epoch_seed + config.global_rank)
542 | if isinstance(getattr(dataloader_train.sampler, "generator", None), torch.Generator):
543 | # Finesse the sampler's RNG state, if it is not using the global RNG state.
544 | dataloader_train.sampler.generator.manual_seed(config.seed + epoch + 10000 * config.global_rank)
545 |
546 | if hasattr(dataloader_train.sampler, "set_epoch"):
547 | # Handling for DistributedSampler.
548 | # Set the epoch for the sampler so that it can shuffle the data
549 | # differently for each epoch, but synchronized across all GPUs.
550 | dataloader_train.sampler.set_epoch(epoch)
551 |
552 | # Train ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
553 | # Note the number of samples seen before this epoch started, so we can
554 | # calculate the number of samples seen in this epoch.
555 | n_samples_seen_before = n_samples_seen
556 | # Run one epoch of training
557 | train_stats, total_step, n_samples_seen = train_one_epoch(
558 | config=config,
559 | encoder=encoder,
560 | classifier=classifier,
561 | optimizer=optimizer,
562 | scheduler=scheduler,
563 | criterion=criterion,
564 | dataloader=dataloader_train,
565 | device=device,
566 | epoch=epoch,
567 | n_epoch=config.epochs,
568 | total_step=total_step,
569 | n_samples_seen=n_samples_seen,
570 | )
571 | t_end_train = time.time()
572 |
573 | timing_stats["train"] = t_end_train - t_start_epoch
574 | n_epoch_samples = n_samples_seen - n_samples_seen_before
575 | train_stats["throughput"] = n_epoch_samples / timing_stats["train"]
576 |
577 | print(f"Training epoch {epoch}/{config.epochs} summary:")
578 | print(f" Steps ..............{len(dataloader_train):8d}")
579 | print(f" Samples ............{n_epoch_samples:8d}")
580 | if timing_stats["train"] > 172800:
581 | print(f" Duration ...........{timing_stats['train']/86400:11.2f} days")
582 | elif timing_stats["train"] > 5400:
583 | print(f" Duration ...........{timing_stats['train']/3600:11.2f} hours")
584 | elif timing_stats["train"] > 120:
585 | print(f" Duration ...........{timing_stats['train']/60:11.2f} minutes")
586 | else:
587 | print(f" Duration ...........{timing_stats['train']:11.2f} seconds")
588 | print(f" Throughput .........{train_stats['throughput']:11.2f} samples/sec")
589 | print(f" Loss ...............{train_stats['loss']:14.5f}")
590 | print(f" Accuracy ...........{train_stats['accuracy']:11.2f} %")
591 |
592 | # Validate ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
593 | # Evaluate on validation set
594 | t_start_val = time.time()
595 |
596 | eval_stats = evaluate(
597 | dataloader=dataloader_val,
598 | model=model,
599 | device=device,
600 | partition_name=eval_set,
601 | is_distributed=config.distributed,
602 | )
603 | t_end_val = time.time()
604 | timing_stats["val"] = t_end_val - t_start_val
605 | eval_stats["throughput"] = len(dataloader_val.dataset) / timing_stats["val"]
606 |
607 | # Check if this is the new best model
608 | if eval_stats["accuracy"] >= best_stats["max_accuracy"]:
609 | best_stats["max_accuracy"] = eval_stats["accuracy"]
610 | best_stats["best_epoch"] = epoch
611 |
612 | print(f"Evaluating epoch {epoch}/{config.epochs} summary:")
613 | if timing_stats["val"] > 172800:
614 | print(f" Duration ...........{timing_stats['val']/86400:11.2f} days")
615 | elif timing_stats["val"] > 5400:
616 | print(f" Duration ...........{timing_stats['val']/3600:11.2f} hours")
617 | elif timing_stats["val"] > 120:
618 | print(f" Duration ...........{timing_stats['val']/60:11.2f} minutes")
619 | else:
620 | print(f" Duration ...........{timing_stats['val']:11.2f} seconds")
621 | print(f" Throughput .........{eval_stats['throughput']:11.2f} samples/sec")
622 | print(f" Cross-entropy ......{eval_stats['cross-entropy']:14.5f}")
623 | print(f" Accuracy ...........{eval_stats['accuracy']:11.2f} %")
624 |
625 | # Save model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
626 | t_start_save = time.time()
627 | if config.model_output_dir and (not config.distributed or config.global_rank == 0):
628 | safe_save_model(
629 | {
630 | "encoder": encoder,
631 | "classifier": classifier,
632 | "optimizer": optimizer,
633 | "scheduler": scheduler,
634 | },
635 | config.checkpoint_path,
636 | config=config,
637 | epoch=epoch,
638 | total_step=total_step,
639 | n_samples_seen=n_samples_seen,
640 | encoder_config=encoder_config,
641 | transform_args=transform_args,
642 | **best_stats,
643 | )
644 | if config.save_best_model and best_stats["best_epoch"] == epoch:
645 | ckpt_path_best = os.path.join(config.model_output_dir, "best_model.pt")
646 | print(f"Copying model to {ckpt_path_best}")
647 | shutil.copyfile(config.checkpoint_path, ckpt_path_best)
648 |
649 | t_end_save = time.time()
650 | timing_stats["saving"] = t_end_save - t_start_save
651 |
652 | # Log to wandb ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
653 | # Overall time won't include uploading to wandb, but there's nothing
654 | # we can do about that.
655 | timing_stats["overall"] = time.time() - t_end_epoch
656 | t_end_epoch = time.time()
657 |
658 | # Send training and eval stats for this epoch to wandb
659 | if config.log_wandb and config.global_rank == 0:
660 | pre = "Training/epochwise"
661 | wandb.log(
662 | {
663 | "Training/stepwise/epoch": epoch,
664 | "Training/stepwise/epoch_progress": epoch,
665 | "Training/stepwise/n_samples_seen": n_samples_seen,
666 | f"{pre}/epoch": epoch,
667 | **{f"{pre}/Train/{k}": v for k, v in train_stats.items()},
668 | **{f"{pre}/{eval_set}/{k}": v for k, v in eval_stats.items()},
669 | **{f"{pre}/duration/{k}": v for k, v in timing_stats.items()},
670 | },
671 | step=total_step,
672 | )
673 | # Record the wandb time as contributing to the next epoch
674 | timing_stats = {"wandb": time.time() - t_end_epoch}
675 | else:
676 | # Reset timing stats
677 | timing_stats = {}
678 | # Print with flush=True forces the output buffer to be printed immediately
679 | print("", flush=True)
680 |
681 | if start_epoch > config.epochs:
682 | print("Training already completed!")
683 | else:
684 | print(f"Training complete! (Trained epochs {start_epoch} to {config.epochs})")
685 | print(
686 | f"Best {eval_set} accuracy was {best_stats['max_accuracy']:.2f}%,"
687 | f" seen at the end of epoch {best_stats['best_epoch']}"
688 | )
689 |
690 | # TEST ====================================================================
691 | print(f"\nEvaluating final model (epoch {config.epochs}) performance")
692 | # Evaluate on test set
693 | print("\nEvaluating final model on test set...")
694 | eval_stats = evaluate(
695 | dataloader=dataloader_test,
696 | model=model,
697 | device=device,
698 | partition_name="Test",
699 | is_distributed=config.distributed,
700 | )
701 | # Send stats to wandb
702 | if config.log_wandb and config.global_rank == 0:
703 | wandb.log({**{f"Eval/Test/{k}": v for k, v in eval_stats.items()}}, step=total_step)
704 |
705 | if distinct_val_test:
706 | # Evaluate on validation set
707 | print(f"\nEvaluating final model on {eval_set} set...")
708 | eval_stats = evaluate(
709 | dataloader=dataloader_val,
710 | model=model,
711 | device=device,
712 | partition_name=eval_set,
713 | is_distributed=config.distributed,
714 | )
715 | # Send stats to wandb
716 | if config.log_wandb and config.global_rank == 0:
717 | wandb.log(
718 | {**{f"Eval/{eval_set}/{k}": v for k, v in eval_stats.items()}},
719 | step=total_step,
720 | )
721 |
722 | # Create a copy of the train partition with evaluation transforms
723 | # and a dataloader using the evaluation configuration (don't drop last)
724 | print("\nEvaluating final model on train set under test conditions (no augmentation, dropout, etc)...")
725 | dataset_train_eval = datasets.fetch_dataset(
726 | **dataset_args,
727 | transform_train=transform_eval,
728 | transform_eval=transform_eval,
729 | )[0]
730 | dl_train_eval_kwargs = copy.deepcopy(dl_test_kwargs)
731 | if config.distributed:
732 | # The DistributedSampler breaks up the dataset across the GPUs
733 | dl_train_eval_kwargs["sampler"] = DistributedSampler(
734 | dataset_train_eval,
735 | shuffle=False,
736 | drop_last=False,
737 | )
738 | dl_train_eval_kwargs["shuffle"] = None
739 | dataloader_train_eval = torch.utils.data.DataLoader(dataset_train_eval, **dl_train_eval_kwargs)
740 | eval_stats = evaluate(
741 | dataloader=dataloader_train_eval,
742 | model=model,
743 | device=device,
744 | partition_name="Train",
745 | is_distributed=config.distributed,
746 | )
747 | # Send stats to wandb
748 | if config.log_wandb and config.global_rank == 0:
749 | wandb.log({**{f"Eval/Train/{k}": v for k, v in eval_stats.items()}}, step=total_step)
750 |
751 |
752 | def train_one_epoch(
753 | config,
754 | encoder,
755 | classifier,
756 | optimizer,
757 | scheduler,
758 | criterion,
759 | dataloader,
760 | device="cuda",
761 | epoch=1,
762 | n_epoch=None,
763 | total_step=0,
764 | n_samples_seen=0,
765 | ):
766 | r"""
767 | Train the encoder and classifier for one epoch.
768 |
769 | Parameters
770 | ----------
771 | config : argparse.Namespace or OmegaConf
772 | The global config object.
773 | encoder : torch.nn.Module
774 | The encoder network.
775 | classifier : torch.nn.Module
776 | The classifier network.
777 | optimizer : torch.optim.Optimizer
778 | The optimizer.
779 | scheduler : torch.optim.lr_scheduler._LRScheduler
780 | The learning rate scheduler.
781 | criterion : torch.nn.Module
782 | The loss function.
783 | dataloader : torch.utils.data.DataLoader
784 | A dataloader for the training set.
785 | device : str or torch.device, default="cuda"
786 | The device to use.
787 | epoch : int, default=1
788 | The current epoch number (indexed from 1).
789 | n_epoch : int, optional
790 | The total number of epochs scheduled to train for.
791 | total_step : int, default=0
792 | The total number of steps taken so far.
793 | n_samples_seen : int, default=0
794 | The total number of samples seen so far.
795 |
796 | Returns
797 | -------
798 | results: dict
799 | A dictionary containing the training performance for this epoch.
800 | total_step : int
801 | The total number of steps taken after this epoch.
802 | n_samples_seen : int
803 | The total number of samples seen after this epoch.
804 | """
805 | # Put the model in train mode
806 | encoder.train()
807 | classifier.train()
808 |
809 | if config.log_wandb:
810 | # Lazy import of wandb, since logging to wandb is optional
811 | import wandb
812 |
813 | loss_epoch = 0
814 | acc_epoch = 0
815 |
816 | if config.print_interval is None:
817 | # Default to printing to console every time we log to wandb
818 | config.print_interval = config.log_interval
819 |
820 | t_end_batch = time.time()
821 | t_start_wandb = t_end_wandb = None
822 | for batch_idx, (stimuli, y_true) in enumerate(dataloader):
823 | t_start_batch = time.time()
824 | batch_size_this_gpu = stimuli.shape[0]
825 |
826 | # Move training inputs and targets to the GPU
827 | stimuli = stimuli.to(device)
828 | y_true = y_true.to(device)
829 |
830 | # Forward pass --------------------------------------------------------
831 | # Perform the forward pass through the model
832 | t_start_encoder = time.time()
833 | # N.B. To accurately time steps on GPU we need to use torch.cuda.Event
834 | ct_forward = torch.cuda.Event(enable_timing=True)
835 | ct_forward.record()
836 | with torch.no_grad() if config.freeze_encoder else nullcontext():
837 | h = encoder(stimuli)
838 | logits = classifier(h)
839 | # Reset gradients
840 | optimizer.zero_grad()
841 | # Measure loss
842 | loss = criterion(logits, y_true)
843 |
844 | # Backward pass -------------------------------------------------------
845 | # Now the backward pass
846 | ct_backward = torch.cuda.Event(enable_timing=True)
847 | ct_backward.record()
848 | loss.backward()
849 |
850 | # Update --------------------------------------------------------------
851 | # Use our optimizer to update the model parameters
852 | ct_optimizer = torch.cuda.Event(enable_timing=True)
853 | ct_optimizer.record()
854 | optimizer.step()
855 |
856 | # Step the scheduler each batch
857 | scheduler.step()
858 |
859 | # Increment training progress counters
860 | total_step += 1
861 | batch_size_all = batch_size_this_gpu * config.world_size
862 | n_samples_seen += batch_size_all
863 |
864 | # Logging -------------------------------------------------------------
865 | # Log details about training progress
866 | t_start_logging = time.time()
867 | ct_logging = torch.cuda.Event(enable_timing=True)
868 | ct_logging.record()
869 |
870 | # Update the total loss for the epoch
871 | if config.distributed:
872 | # Fetch results from other GPUs
873 | loss_batch = torch.mean(utils.concat_all_gather(loss.reshape((1,))))
874 | loss_batch = loss_batch.item()
875 | else:
876 | loss_batch = loss.item()
877 | loss_epoch += loss_batch
878 |
879 | # Compute accuracy
880 | with torch.no_grad():
881 | y_pred = torch.argmax(logits, dim=-1)
882 | is_correct = y_pred == y_true
883 | acc = 100.0 * is_correct.sum() / len(is_correct)
884 | if config.distributed:
885 | # Fetch results from other GPUs
886 | acc = torch.mean(utils.concat_all_gather(acc.reshape((1,))))
887 | acc = acc.item()
888 | acc_epoch += acc
889 |
890 | if epoch <= 1 and batch_idx == 0:
891 | # Debugging
892 | print("stimuli.shape =", stimuli.shape)
893 | print("y_true.shape =", y_true.shape)
894 | print("y_pred.shape =", y_pred.shape)
895 | print("logits.shape =", logits.shape)
896 | print("loss.shape =", loss.shape)
897 | # Debugging intensifies
898 | print("y_true =", y_true)
899 | print("y_pred =", y_pred)
900 | print("logits[0] =", logits[0])
901 | print("loss =", loss.detach().item())
902 |
903 | # Log sample training images to show on wandb
904 | if config.log_wandb and batch_idx <= 1:
905 | # Log 8 example training images from each GPU
906 | img_indices = [offset + relative for offset in [0, batch_size_this_gpu // 2] for relative in [0, 1, 2, 3]]
907 | img_indices = sorted(set(img_indices))
908 | log_images = stimuli[img_indices]
909 | if config.distributed:
910 | # Collate sample images from each GPU
911 | log_images = utils.concat_all_gather(log_images)
912 | if config.global_rank == 0:
913 | wandb.log(
914 | {"Training/stepwise/Train/stimuli": wandb.Image(log_images)},
915 | step=total_step,
916 | )
917 |
918 | # Log to console
919 | if batch_idx <= 2 or batch_idx % config.print_interval == 0 or batch_idx >= len(dataloader) - 1:
920 | print(
921 | f"Train Epoch:{epoch:4d}" + (f"/{n_epoch}" if n_epoch is not None else ""),
922 | " Step:{:4d}/{}".format(batch_idx + 1, len(dataloader)),
923 | " Loss:{:8.5f}".format(loss_batch),
924 | " Acc:{:6.2f}%".format(acc),
925 | " LR: {}".format(scheduler.get_last_lr()),
926 | )
927 |
928 | # Log to wandb
929 | if config.log_wandb and config.global_rank == 0 and batch_idx % config.log_interval == 0:
930 | # Create a log dictionary to send to wandb
931 | # Epoch progress interpolates smoothly between epochs
932 | epoch_progress = epoch - 1 + (batch_idx + 1) / len(dataloader)
933 | # Throughput is the number of samples processed per second
934 | throughput = batch_size_all / (t_start_logging - t_end_batch)
935 | log_dict = {
936 | "Training/stepwise/epoch": epoch,
937 | "Training/stepwise/epoch_progress": epoch_progress,
938 | "Training/stepwise/n_samples_seen": n_samples_seen,
939 | "Training/stepwise/Train/throughput": throughput,
940 | "Training/stepwise/Train/loss": loss_batch,
941 | "Training/stepwise/Train/accuracy": acc,
942 | }
943 | # Track the learning rate of each parameter group
944 | for lr_idx in range(len(optimizer.param_groups)):
945 | if "name" in optimizer.param_groups[lr_idx]:
946 | grp_name = optimizer.param_groups[lr_idx]["name"]
947 | elif len(optimizer.param_groups) == 1:
948 | grp_name = ""
949 | else:
950 | grp_name = f"grp{lr_idx}"
951 | if grp_name != "":
952 | grp_name = f"-{grp_name}"
953 | grp_lr = optimizer.param_groups[lr_idx]["lr"]
954 | log_dict[f"Training/stepwise/lr{grp_name}"] = grp_lr
955 | # Synchronize ensures everything has finished running on each GPU
956 | torch.cuda.synchronize()
957 | # Record how long it took to do each step in the pipeline
958 | pre = "Training/stepwise/duration"
959 | if t_start_wandb is not None:
960 | # Record how long it took to send to wandb last time
961 | log_dict[f"{pre}/wandb"] = t_end_wandb - t_start_wandb
962 | log_dict[f"{pre}/dataloader"] = t_start_batch - t_end_batch
963 | log_dict[f"{pre}/preamble"] = t_start_encoder - t_start_batch
964 | log_dict[f"{pre}/forward"] = ct_forward.elapsed_time(ct_backward) / 1000
965 | log_dict[f"{pre}/backward"] = ct_backward.elapsed_time(ct_optimizer) / 1000
966 | log_dict[f"{pre}/optimizer"] = ct_optimizer.elapsed_time(ct_logging) / 1000
967 | log_dict[f"{pre}/overall"] = time.time() - t_end_batch
968 | t_start_wandb = time.time()
969 | log_dict[f"{pre}/logging"] = t_start_wandb - t_start_logging
970 | # Send to wandb
971 | wandb.log(log_dict, step=total_step)
972 | t_end_wandb = time.time()
973 |
974 | # Record the time when we finished this batch
975 | t_end_batch = time.time()
976 |
977 | results = {
978 | "loss": loss_epoch / len(dataloader),
979 | "accuracy": acc_epoch / len(dataloader),
980 | }
981 | return results, total_step, n_samples_seen
982 |
983 |
984 | def get_parser():
985 | r"""
986 | Build argument parser for the command line interface.
987 |
988 | Returns
989 | -------
990 | parser : argparse.ArgumentParser
991 | CLI argument parser.
992 | """
993 | import argparse
994 | import sys
995 |
996 | # Use the name of the file called to determine the name of the program
997 | prog = os.path.split(sys.argv[0])[1]
998 | if prog == "__main__.py" or prog == "__main__":
999 | # If the file is called __main__.py, go up a level to the module name
1000 | prog = os.path.split(__file__)[1]
1001 | parser = argparse.ArgumentParser(
1002 | prog=prog,
1003 | description="Train image classification model.",
1004 | add_help=False,
1005 | )
1006 | # Help arg ----------------------------------------------------------------
1007 | group = parser.add_argument_group("Help")
1008 | group.add_argument(
1009 | "--help",
1010 | "-h",
1011 | action="help",
1012 | help="Show this help message and exit.",
1013 | )
1014 | # Dataset args ------------------------------------------------------------
1015 | group = parser.add_argument_group("Dataset")
1016 | group.add_argument(
1017 | "--dataset",
1018 | dest="dataset_name",
1019 | type=str,
1020 | default="cifar10",
1021 | help="Name of the dataset to learn. Default: %(default)s",
1022 | )
1023 | group.add_argument(
1024 | "--prototyping",
1025 | dest="protoval_split_id",
1026 | nargs="?",
1027 | const=0,
1028 | type=int,
1029 | help=(
1030 | "Use a subset of the train partition for both train and val."
1031 | " If the dataset doesn't have a separate val and test set with"
1032 | " public labels (which is the case for most datasets), the train"
1033 | " partition will be reduced in size to create the val partition."
1034 | " In all cases where --prototyping is enabled, the test set is"
1035 | " never used during training. Generally, you should use"
1036 | " --prototyping throughout the model exploration and hyperparameter"
1037 | " optimization phases, and disable it for your final experiments so"
1038 | " they can run on a completely held-out test set."
1039 | ),
1040 | )
1041 | group.add_argument(
1042 | "--data-dir",
1043 | type=str,
1044 | default=None,
1045 | help=(
1046 | "Directory within which the dataset can be found."
1047 | " Default is ~/Datasets, except on Vector servers where it is"
1048 | " adjusted as appropriate depending on the dataset's location."
1049 | ),
1050 | )
1051 | group.add_argument(
1052 | "--allow-download-dataset",
1053 | action="store_true",
1054 | help="Attempt to download the dataset if it is not found locally.",
1055 | )
1056 | group.add_argument(
1057 | "--transform-type",
1058 | type=str,
1059 | default="cifar",
1060 | help="Name of augmentation stack to apply to training data. Default: %(default)s",
1061 | )
1062 | group.add_argument(
1063 | "--image-size",
1064 | type=int,
1065 | help="Size of images to use as model input. Default: encoder's default.",
1066 | )
1067 | # Architecture args -------------------------------------------------------
1068 | group = parser.add_argument_group("Architecture")
1069 | group.add_argument(
1070 | "--model",
1071 | "--encoder",
1072 | "--arch",
1073 | "--architecture",
1074 | dest="arch",
1075 | type=str,
1076 | default="resnet18",
1077 | help="Name of model architecture. Default: %(default)s",
1078 | )
1079 | group.add_argument(
1080 | "--pretrained",
1081 | action="store_true",
1082 | help="Use default pretrained model weights, taken from hugging-face hub.",
1083 | )
1084 | mx_group = group.add_mutually_exclusive_group()
1085 | mx_group.add_argument(
1086 | "--torchvision",
1087 | dest="arch_framework",
1088 | action="store_const",
1089 | const="torchvision",
1090 | default="timm",
1091 | help="Use model architecture from torchvision (default is timm).",
1092 | )
1093 | mx_group.add_argument(
1094 | "--timm",
1095 | dest="arch_framework",
1096 | action="store_const",
1097 | const="timm",
1098 | default="timm",
1099 | help="Use model architecture from timm (default).",
1100 | )
1101 | group.add_argument(
1102 | "--freeze-encoder",
1103 | action="store_true",
1104 | )
1105 | # Optimization args -------------------------------------------------------
1106 | group = parser.add_argument_group("Optimization routine")
1107 | group.add_argument(
1108 | "--epochs",
1109 | type=int,
1110 | default=5,
1111 | help="Number of epochs to train for. Default: %(default)s",
1112 | )
1113 | group.add_argument(
1114 | "--lr",
1115 | dest="lr_relative",
1116 | type=float,
1117 | default=0.01,
1118 | help=(
1119 | f"Maximum learning rate, set per {BASE_BATCH_SIZE} batch size."
1120 | " The actual learning rate used will be scaled up by the total"
1121 | " batch size (across all GPUs). Default: %(default)s"
1122 | ),
1123 | )
1124 | group.add_argument(
1125 | "--lr-encoder-mult",
1126 | type=float,
1127 | default=1.0,
1128 | help="Multiplier for encoder learning rate, relative to overall LR.",
1129 | )
1130 | group.add_argument(
1131 | "--lr-classifier-mult",
1132 | type=float,
1133 | default=1.0,
1134 | help="Multiplier for classifier head's learning rate, relative to overall LR.",
1135 | )
1136 | group.add_argument(
1137 | "--weight-decay",
1138 | "--wd",
1139 | dest="weight_decay",
1140 | type=float,
1141 | default=0.0,
1142 | help="Weight decay. Default: %(default)s",
1143 | )
1144 | group.add_argument(
1145 | "--optimizer",
1146 | type=str,
1147 | default="AdamW",
1148 | help="Name of optimizer (case-sensitive). Default: %(default)s",
1149 | )
1150 | group.add_argument(
1151 | "--scheduler",
1152 | type=str,
1153 | default="OneCycle",
1154 | help="Learning rate scheduler. Default: %(default)s",
1155 | )
1156 | # Output checkpoint args --------------------------------------------------
1157 | group = parser.add_argument_group("Output checkpoint")
1158 | group.add_argument(
1159 | "--models-dir",
1160 | type=str,
1161 | default="models",
1162 | metavar="PATH",
1163 | help="Output directory for all models. Ignored if --checkpoint is set. Default: %(default)s",
1164 | )
1165 | group.add_argument(
1166 | "--checkpoint",
1167 | dest="checkpoint_path",
1168 | default="",
1169 | type=str,
1170 | metavar="PATH",
1171 | help=(
1172 | "Save and resume partially trained model and optimizer state from this checkpoint."
1173 | " Overrides --models-dir."
1174 | ),
1175 | )
1176 | group.add_argument(
1177 | "--save-best-model",
1178 | action="store_true",
1179 | help="Save a copy of the model with best validation performance.",
1180 | )
1181 | # Reproducibility args ----------------------------------------------------
1182 | group = parser.add_argument_group("Reproducibility")
1183 | group.add_argument(
1184 | "--seed",
1185 | type=int,
1186 | help="Random number generator (RNG) seed. Default: not controlled",
1187 | )
1188 | group.add_argument(
1189 | "--deterministic",
1190 | action="store_true",
1191 | help="Disable non-deterministic features of cuDNN.",
1192 | )
1193 | # Hardware configuration args ---------------------------------------------
1194 | group = parser.add_argument_group("Hardware configuration")
1195 | group.add_argument(
1196 | "--batch-size",
1197 | dest="batch_size_per_gpu",
1198 | type=int,
1199 | default=BASE_BATCH_SIZE,
1200 | help=(
1201 | "Batch size per GPU. The total batch size will be this value times"
1202 | " the total number of GPUs used. Default: %(default)s"
1203 | ),
1204 | )
1205 | group.add_argument(
1206 | "--cpu-workers",
1207 | "--workers",
1208 | dest="cpu_workers",
1209 | type=int,
1210 | help="Number of CPU workers per node. Default: number of CPUs available on device.",
1211 | )
1212 | group.add_argument(
1213 | "--no-cuda",
1214 | action="store_true",
1215 | help="Use CPU only, no GPUs.",
1216 | )
1217 | group.add_argument(
1218 | "--gpu",
1219 | "--local-rank",
1220 | dest="local_rank",
1221 | default=None,
1222 | type=int,
1223 | help="Index of GPU to use when training a single process. (Ignored for distributed training.)",
1224 | )
1225 | # Logging args ------------------------------------------------------------
1226 | group = parser.add_argument_group("Debugging and logging")
1227 | group.add_argument(
1228 | "--log-interval",
1229 | type=int,
1230 | default=20,
1231 | help="Number of batches between each log to wandb (if enabled). Default: %(default)s",
1232 | )
1233 | group.add_argument(
1234 | "--print-interval",
1235 | type=int,
1236 | default=None,
1237 | help="Number of batches between each print to STDOUT. Default: same as LOG_INTERVAL.",
1238 | )
1239 | group.add_argument(
1240 | "--log-wandb",
1241 | action="store_true",
1242 | help="Log results with Weights & Biases https://wandb.ai",
1243 | )
1244 | group.add_argument(
1245 | "--disable-wandb",
1246 | "--no-wandb",
1247 | dest="disable_wandb",
1248 | action="store_true",
1249 | help="Overrides --log-wandb and ensures wandb is always disabled.",
1250 | )
1251 | group.add_argument(
1252 | "--wandb-entity",
1253 | type=str,
1254 | help=(
1255 | "The entity (organization) within which your wandb project is"
1256 | ' located. By default, this will be your "default location" set on'
1257 | " wandb at https://wandb.ai/settings"
1258 | ),
1259 | )
1260 | group.add_argument(
1261 | "--wandb-project",
1262 | type=str,
1263 | default="template-experiment",
1264 | help="Name of project on wandb, where these runs will be saved. Default: %(default)s",
1265 | )
1266 | group.add_argument(
1267 | "--run-name",
1268 | type=str,
1269 | help="Human-readable identifier for the model run or job. Used to name the run on wandb.",
1270 | )
1271 | group.add_argument(
1272 | "--run-id",
1273 | type=str,
1274 | help="Unique identifier for the model run or job. Used as the run ID on wandb.",
1275 | )
1276 |
1277 | return parser
1278 |
1279 |
1280 | def cli():
1281 | r"""Command-line interface for model training."""
1282 | parser = get_parser()
1283 | config = parser.parse_args()
1284 | # Handle disable_wandb overriding log_wandb and forcing it to be disabled.
1285 | if config.disable_wandb:
1286 | config.log_wandb = False
1287 | del config.disable_wandb
1288 | # Set protoval_split_id from prototyping, and turn prototyping into a bool
1289 | config.prototyping = config.protoval_split_id is not None
1290 | return run(config)
1291 |
1292 |
1293 | if __name__ == "__main__":
1294 | cli()
1295 |
--------------------------------------------------------------------------------
/template_experiment/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import secrets
4 | import string
5 | import warnings
6 |
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def get_num_cpu_available():
12 | r"""
13 | Get the number of available CPU cores.
14 |
15 | Uses :func:`os.sched_getaffinity` if available, otherwise falls back to
16 | :func:`os.cpu_count`.
17 |
18 | Returns
19 | -------
20 | ncpus : int
21 | The number of available CPU cores.
22 | """
23 | try:
24 | # This is the number of CPUs available to this process, which may
25 | # be smaller than the number of CPUs on the system.
26 | # This command is only available on Unix-like systems.
27 | return len(os.sched_getaffinity(0))
28 | except Exception:
29 | # Fall-back for Windows or other systems which don't support sched_getaffinity
30 | warnings.warn(
31 | "Unable to determine number of available CPUs available to this python"
32 | " process specifically. Falling back to the total number of CPUs on the"
33 | " system.",
34 | RuntimeWarning,
35 | stacklevel=2,
36 | )
37 | return os.cpu_count()
38 |
39 |
40 | def init_or_resume_wandb_run(output_dir, basename="wandb_runid.txt", **kwargs):
41 | r"""
42 | Initialize a wandb run, resuming if one already exists for this job.
43 |
44 | Parameters
45 | ----------
46 | output_dir : str
47 | Path to output directory for this job, where the wandb run id file
48 | will be stored.
49 | basename : str, default="wandb_runid.txt"
50 | Basename of wandb run id file.
51 | **kwargs
52 | Additional parameters to be passed through to ``wandb.init``.
53 | Examples include, ``"project"``, ``"name"``, ``"config"``.
54 |
55 | Returns
56 | -------
57 | run : :class:`wandb.sdk.wandb_run.Run`
58 | A wandb Run object, as returned by :func:`wandb.init`.
59 | """
60 | import wandb
61 |
62 | if not output_dir:
63 | wandb_id_file_path = None
64 | else:
65 | wandb_id_file_path = os.path.join(output_dir, basename)
66 | if wandb_id_file_path and os.path.isfile(wandb_id_file_path):
67 | # If the run_id was previously saved, get the id and resume it
68 | with open(wandb_id_file_path, "r") as f:
69 | resume_id = f.read()
70 | run = wandb.init(resume=resume_id, **kwargs)
71 | else:
72 | # If the run id file doesn't exist, create a new wandb run
73 | run = wandb.init(**kwargs)
74 | if wandb_id_file_path:
75 | # Write the run id to the expected file for resuming later
76 | with open(wandb_id_file_path, "w") as f:
77 | f.write(run.id)
78 |
79 | return run
80 |
81 |
82 | def set_rng_seeds_fixed(seed, all_gpu=True):
83 | r"""
84 | Seed pseudo-random number generators throughout python's random module, numpy.random, and pytorch.
85 |
86 | Parameters
87 | ----------
88 | seed : int
89 | The random seed to use. Should be between 0 and 4294967295 to ensure
90 | unique behaviour for numpy, and between 0 and 18446744073709551615 to
91 | ensure unique behaviour for pytorch.
92 | all_gpu : bool, default=True
93 | Whether to set the torch seed on every GPU. If ``False``, only the
94 | current GPU has its seed set.
95 |
96 | Returns
97 | -------
98 | None
99 | """
100 | # Note that random, numpy, and pytorch all use different RNG methods/
101 | # implementations, so you'll get different outputs from each of them even
102 | # if you use the same seed for them all.
103 | # We use modulo with the maximum values permitted for np.random and torch.
104 | # If your seed exceeds 4294967295, numpy will have looped around to a
105 | random.seed(seed)
106 | np.random.seed(seed % 0xFFFF_FFFF)
107 | torch.manual_seed(seed % 0xFFFF_FFFF_FFFF_FFFF)
108 | if all_gpu:
109 | torch.cuda.manual_seed_all(seed % 0xFFFF_FFFF_FFFF_FFFF)
110 | else:
111 | torch.cuda.manual_seed(seed % 0xFFFF_FFFF_FFFF_FFFF)
112 |
113 |
114 | def worker_seed_fn(worker_id):
115 | r"""
116 | Seed builtin :mod:`random` and :mod:`numpy`.
117 |
118 | A worker initialization function for :class:`torch.utils.data.DataLoader`
119 | objects which seeds builtin :mod:`random` and :mod:`numpy` with the
120 | torch seed for the worker.
121 |
122 | Parameters
123 | ----------
124 | worker_id : int
125 | The ID of the worker.
126 | """
127 | worker_seed = torch.utils.data.get_worker_info().seed
128 | random.seed(worker_seed)
129 | np.random.seed(worker_seed % 0xFFFF_FFFF)
130 |
131 |
132 | def determine_epoch_seed(seed, epoch):
133 | r"""
134 | Determine the seed to use for the random number generator for a given epoch.
135 |
136 | Parameters
137 | ----------
138 | seed : int
139 | The original random seed, used to generate the sequence of seeds for
140 | the epochs.
141 | epoch : int
142 | The epoch for which to determine the seed.
143 |
144 | Returns
145 | -------
146 | epoch_seed : int
147 | The seed to use for the random number generator for the given epoch.
148 | """
149 | if epoch == 0:
150 | raise ValueError("Epoch must be indexed from 1, not 0.")
151 | random.seed(seed)
152 | # Generate a seed for every epoch so far. We have to traverse the
153 | # series of RNG calls to reach the next value (our next seed). The final
154 | # value is the one for our current epoch.
155 | # N.B. We use random.randint instead of torch.randint because torch.randint
156 | # only supports int32 at most (max value of 0xFFFF_FFFF).
157 | for _ in range(epoch):
158 | epoch_seed = random.randint(0, 0xFFFF_FFFF_FFFF_FFFF)
159 | return epoch_seed
160 |
161 |
162 | def generate_id(length: int = 8) -> str:
163 | r"""
164 | Generate a random base-36 string of `length` digits.
165 |
166 | Parameters
167 | ----------
168 | length : int, default=8
169 | Length of the string to generate.
170 |
171 | Returns
172 | -------
173 | id : str
174 | The randomly generated id.
175 | """
176 | # Borrowed from https://github.com/wandb/wandb/blob/0e00efd/wandb/sdk/lib/runid.py
177 | # under the MIT license.
178 | # There are ~2.8T base-36 8-digit strings. If we generate 210k ids,
179 | # we'll have a ~1% chance of collision.
180 | alphabet = string.ascii_lowercase + string.digits
181 | return "".join(secrets.choice(alphabet) for _ in range(length))
182 |
183 |
184 | def count_parameters(model, only_trainable=True):
185 | r"""
186 | Count the number of (trainable) parameters within a model and its children.
187 |
188 | Parameters
189 | ----------
190 | model : torch.nn.Model
191 | The parametrized model.
192 | only_trainable : bool, optional
193 | Whether the count should be restricted to only trainable parameters
194 | (default), otherwise all parameters are included.
195 | Default is ``True``.
196 |
197 | Returns
198 | -------
199 | int
200 | Total number of (trainable) parameters possessed by the model.
201 | """
202 | if only_trainable:
203 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
204 | else:
205 | return sum(p.numel() for p in model.parameters())
206 |
207 |
208 | @torch.no_grad()
209 | def concat_all_gather(tensor, **kwargs):
210 | r"""
211 | Gather a tensor over all processes and concatenate them into one.
212 |
213 | Similar to :func:`torch.distributed.all_gather`, except this function
214 | concatenates the result into a single tensor instead of a list of tensors.
215 |
216 | Parameters
217 | ----------
218 | tensor : torch.Tensor
219 | The distributed tensor on the current process.
220 | group : ProcessGroup, optional
221 | The process group to work on. If ``None``, the default process group
222 | will be used.
223 | async_op : bool, default=False
224 | Whether this op should be an async op.
225 |
226 | Returns
227 | -------
228 | gathered_tensor : torch.Tensor
229 | The contents of ``tensor`` from every distributed process, gathered
230 | together. None of the entries support a gradient.
231 |
232 | Warning
233 | -------
234 | As with :func:`torch.distributed.all_gather`, this has no gradient.
235 | """
236 | world_size = torch.distributed.get_world_size()
237 | tensors_gather = [torch.zeros_like(tensor) for _ in range(world_size)]
238 | torch.distributed.all_gather(tensors_gather, tensor, **kwargs)
239 | output = torch.cat(tensors_gather, dim=0)
240 | return output
241 |
242 |
243 | @torch.no_grad()
244 | def concat_all_gather_ragged(tensor, **kwargs):
245 | r"""
246 | Gather a tensor over all processes and concatenate them into one.
247 |
248 | This version supports ragged tensors where the first dimension is not the
249 | same across all processes.
250 |
251 | Parameters
252 | ----------
253 | tensor : torch.Tensor
254 | The distributed tensor on the current process. The equivalent tensors
255 | on the other processes may differ in shape only in their first
256 | dimension.
257 | group : ProcessGroup, optional
258 | The process group to work on. If ``None``, the default process group
259 | will be used.
260 | async_op : bool, default=False
261 | Whether this op should be an async op.
262 |
263 | Returns
264 | -------
265 | gathered_tensor : torch.Tensor
266 | The contents of ``tensor`` from every distributed process, gathered
267 | together. None of the entries support a gradient.
268 |
269 | Warning
270 | -------
271 | As with :func:`torch.distributed.all_gather`, this has no gradient.
272 | """
273 | world_size = torch.distributed.get_world_size()
274 | # Gather the lengths of the tensors from all processes
275 | local_length = torch.tensor(tensor.shape[0], device=tensor.device)
276 | all_length = [torch.zeros_like(local_length) for _ in range(world_size)]
277 | torch.distributed.all_gather(all_length, local_length, **kwargs)
278 | # We will have to pad them to be the size of the longest tensor
279 | max_length = max(x.item() for x in all_length)
280 |
281 | # Pad our tensor on the current process
282 | length_diff = max_length - local_length.item()
283 | if length_diff:
284 | pad_size = (length_diff, *tensor.shape[1:])
285 | padding = torch.zeros(pad_size, device=tensor.device, dtype=tensor.dtype)
286 | tensor = torch.cat((tensor, padding), dim=0)
287 |
288 | # Gather the padded tensors from all processes
289 | all_tensors_padded = [torch.zeros_like(tensor) for _ in range(world_size)]
290 | torch.distributed.all_gather(all_tensors_padded, tensor, **kwargs)
291 | # Remove padding
292 | all_tensors = []
293 | for tensor_i, length_i in zip(all_tensors_padded, all_length):
294 | all_tensors.append(tensor_i[:length_i])
295 |
296 | # Concatenate the tensors
297 | output = torch.cat(all_tensors, dim=0)
298 | return output
299 |
--------------------------------------------------------------------------------