├── gpt
├── assets
│ ├── nanogpt.jpg
│ └── gpt2_124M_loss.png
├── data
│ ├── shakespeare
│ │ ├── readme.md
│ │ └── prepare.py
│ ├── shakespeare_char
│ │ ├── readme.md
│ │ └── prepare.py
│ └── openwebtext
│ │ ├── readme.md
│ │ └── prepare.py
├── config
│ ├── eval_gpt2.py
│ ├── eval_gpt2_xl.py
│ ├── eval_gpt2_large.py
│ ├── eval_gpt2_medium.py
│ ├── finetune_shakespeare.py
│ ├── train_gpt2.py
│ └── train_shakespeare_char.py
├── LICENSE
├── README.md
├── configurator.py
├── sample.py
├── bench.py
├── train.py
├── transformer_sizing.ipynb
└── model.py
├── src
└── adopt
│ ├── __init__.py
│ └── adopt.py
├── imagenet
├── README.md
├── sampler.py
├── presets.py
├── transforms.py
├── train_quantization.py
├── utils.py
└── train.py
├── pyproject.toml
├── README.md
├── .gitignore
└── LICENSE
/gpt/assets/nanogpt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iShohei220/adopt/HEAD/gpt/assets/nanogpt.jpg
--------------------------------------------------------------------------------
/gpt/assets/gpt2_124M_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iShohei220/adopt/HEAD/gpt/assets/gpt2_124M_loss.png
--------------------------------------------------------------------------------
/src/adopt/__init__.py:
--------------------------------------------------------------------------------
1 | from .adopt import ADOPT, adopt
2 |
3 | __version__ = "0.1.0"
4 | __all__ = ["ADOPT", "adopt"]
5 |
--------------------------------------------------------------------------------
/gpt/data/shakespeare/readme.md:
--------------------------------------------------------------------------------
1 |
2 | # tiny shakespeare
3 |
4 | Tiny shakespeare, of the good old char-rnn fame :)
5 |
6 | After running `prepare.py`:
7 |
8 | - train.bin has 301,966 tokens
9 | - val.bin has 36,059 tokens
10 |
--------------------------------------------------------------------------------
/gpt/config/eval_gpt2.py:
--------------------------------------------------------------------------------
1 | # evaluate the base gpt2
2 | # n_layer=12, n_head=12, n_embd=768
3 | # 124M parameters
4 | batch_size = 8
5 | eval_iters = 500 # use more iterations to get good estimate
6 | eval_only = True
7 | wandb_log = False
8 | init_from = 'gpt2'
9 |
--------------------------------------------------------------------------------
/gpt/config/eval_gpt2_xl.py:
--------------------------------------------------------------------------------
1 | # evaluate the base gpt2
2 | # n_layer=48, n_head=25, n_embd=1600
3 | # 1558M parameters
4 | batch_size = 8
5 | eval_iters = 500 # use more iterations to get good estimate
6 | eval_only = True
7 | wandb_log = False
8 | init_from = 'gpt2-xl'
9 |
--------------------------------------------------------------------------------
/gpt/config/eval_gpt2_large.py:
--------------------------------------------------------------------------------
1 | # evaluate the base gpt2
2 | # n_layer=36, n_head=20, n_embd=1280
3 | # 774M parameters
4 | batch_size = 8
5 | eval_iters = 500 # use more iterations to get good estimate
6 | eval_only = True
7 | wandb_log = False
8 | init_from = 'gpt2-large'
9 |
--------------------------------------------------------------------------------
/gpt/config/eval_gpt2_medium.py:
--------------------------------------------------------------------------------
1 | # evaluate the base gpt2
2 | # n_layer=24, n_head=16, n_embd=1024
3 | # 350M parameters
4 | batch_size = 8
5 | eval_iters = 500 # use more iterations to get good estimate
6 | eval_only = True
7 | wandb_log = False
8 | init_from = 'gpt2-medium'
9 |
--------------------------------------------------------------------------------
/gpt/data/shakespeare_char/readme.md:
--------------------------------------------------------------------------------
1 |
2 | # tiny shakespeare, character-level
3 |
4 | Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level.
5 |
6 | After running `prepare.py`:
7 |
8 | - train.bin has 1,003,854 tokens
9 | - val.bin has 111,540 tokens
10 |
--------------------------------------------------------------------------------
/gpt/data/openwebtext/readme.md:
--------------------------------------------------------------------------------
1 |
2 | ## openwebtext dataset
3 |
4 | after running `prepare.py` (preprocess) we get:
5 |
6 | - train.bin is ~17GB, val.bin ~8.5MB
7 | - train has ~9B tokens (9,035,582,198)
8 | - val has ~4M tokens (4,434,897)
9 |
10 | this came from 8,013,769 documents in total.
11 |
12 | references:
13 |
14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset
16 |
--------------------------------------------------------------------------------
/gpt/config/finetune_shakespeare.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | out_dir = 'out-shakespeare'
4 | eval_interval = 5
5 | eval_iters = 40
6 | wandb_log = False # feel free to turn on
7 | wandb_project = 'shakespeare'
8 | wandb_run_name = 'ft-' + str(time.time())
9 |
10 | dataset = 'shakespeare'
11 | init_from = 'gpt2-xl' # this is the largest GPT-2 model
12 |
13 | # only save checkpoints if the validation loss improves
14 | always_save_checkpoint = False
15 |
16 | # the number of examples per iter:
17 | # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter
18 | # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters
19 | batch_size = 1
20 | gradient_accumulation_steps = 32
21 | max_iters = 20
22 |
23 | # finetune at constant LR
24 | learning_rate = 3e-5
25 | decay_lr = False
26 |
--------------------------------------------------------------------------------
/gpt/config/train_gpt2.py:
--------------------------------------------------------------------------------
1 | # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
2 | # launch as the following (e.g. in a screen session) and wait ~5 days:
3 | # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
4 |
5 | wandb_log = True
6 | wandb_project = 'owt'
7 | wandb_run_name='gpt2-124M'
8 |
9 | # these make the total batch size be ~0.5M
10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
11 | batch_size = 12
12 | block_size = 1024
13 | gradient_accumulation_steps = 5 * 8
14 |
15 | # this makes total number of tokens be 300B
16 | max_iters = 600000
17 | lr_decay_iters = 600000
18 |
19 | # eval stuff
20 | eval_interval = 1000
21 | eval_iters = 200
22 | log_interval = 10
23 |
24 | # weight decay
25 | weight_decay = 1e-1
26 |
--------------------------------------------------------------------------------
/gpt/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Andrej Karpathy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/imagenet/README.md:
--------------------------------------------------------------------------------
1 | This code is based on the official training recipe for ImageNet classification provided by [Torchvision](https://github.com/pytorch/vision/tree/main/references/classification).
2 |
3 | # Image classification reference training scripts
4 |
5 | This folder contains reference training scripts for image classification.
6 | They serve as a log of how to train specific models, as provide baseline
7 | training and evaluation scripts to quickly bootstrap research.
8 |
9 | ### SwinTransformer
10 | ```
11 | torchrun --nproc_per_node=8 train.py\
12 | --model $MODEL --epochs 300 --batch-size 128 --opt adoptw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 224
13 | ```
14 | Here `$MODEL` is one of `swin_t`, `swin_s` or `swin_b`.
15 | Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
16 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "torch-adopt"
7 | version = "0.1.0"
8 | authors = [
9 | { name = "Shohei Taniguchi", email = "ishohei220@gmail.com" }
10 | ]
11 | description = "ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate"
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "Programming Language :: Python :: 3.8",
17 | "Programming Language :: Python :: 3.9",
18 | "Programming Language :: Python :: 3.10",
19 | "Programming Language :: Python :: 3.11",
20 | "License :: OSI Approved :: Apache Software License",
21 | "Operating System :: OS Independent",
22 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
23 | "Intended Audience :: Science/Research",
24 | ]
25 | dependencies = [
26 | "torch>=2.5.0",
27 | ]
28 |
29 | [project.urls]
30 | "Homepage" = "https://github.com/iShohei220/adopt"
31 | "Bug Tracker" = "https://github.com/iShohei220/adopt/issues"
32 | "Documentation" = "https://github.com/iShohei220/adopt"
33 | "Research Paper" = "https://arxiv.org/abs/2411.02853"
34 |
35 | [tool.hatch.build.targets.wheel]
36 | packages = ["src/adopt"]
37 |
--------------------------------------------------------------------------------
/gpt/data/shakespeare/prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import tiktoken
4 | import numpy as np
5 |
6 | # download the tiny shakespeare dataset
7 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
8 | if not os.path.exists(input_file_path):
9 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
10 | with open(input_file_path, 'w', encoding='utf-8') as f:
11 | f.write(requests.get(data_url).text)
12 |
13 | with open(input_file_path, 'r', encoding='utf-8') as f:
14 | data = f.read()
15 | n = len(data)
16 | train_data = data[:int(n*0.9)]
17 | val_data = data[int(n*0.9):]
18 |
19 | # encode with tiktoken gpt2 bpe
20 | enc = tiktoken.get_encoding("gpt2")
21 | train_ids = enc.encode_ordinary(train_data)
22 | val_ids = enc.encode_ordinary(val_data)
23 | print(f"train has {len(train_ids):,} tokens")
24 | print(f"val has {len(val_ids):,} tokens")
25 |
26 | # export to bin files
27 | train_ids = np.array(train_ids, dtype=np.uint16)
28 | val_ids = np.array(val_ids, dtype=np.uint16)
29 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
30 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
31 |
32 | # train.bin has 301,966 tokens
33 | # val.bin has 36,059 tokens
34 |
--------------------------------------------------------------------------------
/gpt/config/train_shakespeare_char.py:
--------------------------------------------------------------------------------
1 | # train a miniature character-level shakespeare model
2 | # good for debugging and playing on macbooks and such
3 |
4 | out_dir = 'out-shakespeare-char'
5 | eval_interval = 250 # keep frequent because we'll overfit
6 | eval_iters = 200
7 | log_interval = 10 # don't print too too often
8 |
9 | # we expect to overfit on this small dataset, so only save when val improves
10 | always_save_checkpoint = False
11 |
12 | wandb_log = False # override via command line if you like
13 | wandb_project = 'shakespeare-char'
14 | wandb_run_name = 'mini-gpt'
15 |
16 | dataset = 'shakespeare_char'
17 | gradient_accumulation_steps = 1
18 | batch_size = 64
19 | block_size = 256 # context of up to 256 previous characters
20 |
21 | # baby GPT model :)
22 | n_layer = 6
23 | n_head = 6
24 | n_embd = 384
25 | dropout = 0.2
26 |
27 | learning_rate = 1e-3 # with baby networks can afford to go a bit higher
28 | max_iters = 5000
29 | lr_decay_iters = 5000 # make equal to max_iters usually
30 | min_lr = 1e-4 # learning_rate / 10 usually
31 | beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
32 |
33 | warmup_iters = 100 # not super necessary potentially
34 |
35 | # on macbook also add
36 | # device = 'cpu' # run on cpu only
37 | # compile = False # do not torch compile the model
38 |
--------------------------------------------------------------------------------
/gpt/README.md:
--------------------------------------------------------------------------------
1 | This code is based on [nanoGPT](https://github.com/karpathy/nanoGPT) repository.
2 |
3 | ## Install
4 |
5 | ```
6 | pip install torch numpy transformers datasets tiktoken wandb tqdm
7 | ```
8 |
9 | Dependencies:
10 |
11 | - [pytorch](https://pytorch.org) <3
12 | - [numpy](https://numpy.org/install/) <3
13 | - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints)
14 | - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
15 | - `tiktoken` for OpenAI's fast BPE code <3
16 | - `wandb` for optional logging <3
17 | - `tqdm` for progress bars <3
18 |
19 | ## Train GPT-2
20 |
21 | We first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText:
22 |
23 | ```sh
24 | python data/openwebtext/prepare.py
25 | ```
26 |
27 | This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run:
28 |
29 | ```sh
30 | torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py --optim_type=adopt
31 | ```
32 |
--------------------------------------------------------------------------------
/gpt/configurator.py:
--------------------------------------------------------------------------------
1 | """
2 | Poor Man's Configurator. Probably a terrible idea. Example usage:
3 | $ python train.py config/override_file.py --batch_size=32
4 | this will first run config/override_file.py, then override batch_size to 32
5 |
6 | The code in this file will be run as follows from e.g. train.py:
7 | >>> exec(open('configurator.py').read())
8 |
9 | So it's not a Python module, it's just shuttling this code away from train.py
10 | The code in this script then overrides the globals()
11 |
12 | I know people are not going to love this, I just really dislike configuration
13 | complexity and having to prepend config. to every single variable. If someone
14 | comes up with a better simple Python solution I am all ears.
15 | """
16 |
17 | import sys
18 | from ast import literal_eval
19 |
20 | for arg in sys.argv[1:]:
21 | if '=' not in arg:
22 | # assume it's the name of a config file
23 | assert not arg.startswith('--')
24 | config_file = arg
25 | print(f"Overriding config with {config_file}:")
26 | with open(config_file) as f:
27 | print(f.read())
28 | exec(open(config_file).read())
29 | else:
30 | # assume it's a --key=value argument
31 | assert arg.startswith('--')
32 | key, val = arg.split('=')
33 | key = key[2:]
34 | if key in globals():
35 | try:
36 | # attempt to eval it it (e.g. if bool, number, or etc)
37 | attempt = literal_eval(val)
38 | except (SyntaxError, ValueError):
39 | # if that goes wrong, just use the string
40 | attempt = val
41 | # ensure the types match ok
42 | assert type(attempt) == type(globals()[key])
43 | # cross fingers
44 | print(f"Overriding: {key} = {attempt}")
45 | globals()[key] = attempt
46 | else:
47 | raise ValueError(f"Unknown config key: {key}")
48 |
--------------------------------------------------------------------------------
/gpt/data/shakespeare_char/prepare.py:
--------------------------------------------------------------------------------
1 | """
2 | Prepare the Shakespeare dataset for character-level language modeling.
3 | So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
4 | Will save train.bin, val.bin containing the ids, and meta.pkl containing the
5 | encoder and decoder and some other related info.
6 | """
7 | import os
8 | import pickle
9 | import requests
10 | import numpy as np
11 |
12 | # download the tiny shakespeare dataset
13 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
14 | if not os.path.exists(input_file_path):
15 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
16 | with open(input_file_path, 'w') as f:
17 | f.write(requests.get(data_url).text)
18 |
19 | with open(input_file_path, 'r') as f:
20 | data = f.read()
21 | print(f"length of dataset in characters: {len(data):,}")
22 |
23 | # get all the unique characters that occur in this text
24 | chars = sorted(list(set(data)))
25 | vocab_size = len(chars)
26 | print("all the unique characters:", ''.join(chars))
27 | print(f"vocab size: {vocab_size:,}")
28 |
29 | # create a mapping from characters to integers
30 | stoi = { ch:i for i,ch in enumerate(chars) }
31 | itos = { i:ch for i,ch in enumerate(chars) }
32 | def encode(s):
33 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers
34 | def decode(l):
35 | return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
36 |
37 | # create the train and test splits
38 | n = len(data)
39 | train_data = data[:int(n*0.9)]
40 | val_data = data[int(n*0.9):]
41 |
42 | # encode both to integers
43 | train_ids = encode(train_data)
44 | val_ids = encode(val_data)
45 | print(f"train has {len(train_ids):,} tokens")
46 | print(f"val has {len(val_ids):,} tokens")
47 |
48 | # export to bin files
49 | train_ids = np.array(train_ids, dtype=np.uint16)
50 | val_ids = np.array(val_ids, dtype=np.uint16)
51 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
52 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
53 |
54 | # save the meta information as well, to help us encode/decode later
55 | meta = {
56 | 'vocab_size': vocab_size,
57 | 'itos': itos,
58 | 'stoi': stoi,
59 | }
60 | with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
61 | pickle.dump(meta, f)
62 |
63 | # length of dataset in characters: 1115394
64 | # all the unique characters:
65 | # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
66 | # vocab size: 65
67 | # train has 1003854 tokens
68 | # val has 111540 tokens
69 |
--------------------------------------------------------------------------------
/imagenet/sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 |
7 | class RASampler(torch.utils.data.Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset for distributed,
9 | with repeated augmentation.
10 | It ensures that different each augmented version of a sample will be visible to a
11 | different process (GPU).
12 | Heavily based on 'torch.utils.data.DistributedSampler'.
13 |
14 | This is borrowed from the DeiT Repo:
15 | https://github.com/facebookresearch/deit/blob/main/samplers.py
16 | """
17 |
18 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
19 | if num_replicas is None:
20 | if not dist.is_available():
21 | raise RuntimeError("Requires distributed package to be available!")
22 | num_replicas = dist.get_world_size()
23 | if rank is None:
24 | if not dist.is_available():
25 | raise RuntimeError("Requires distributed package to be available!")
26 | rank = dist.get_rank()
27 | self.dataset = dataset
28 | self.num_replicas = num_replicas
29 | self.rank = rank
30 | self.epoch = 0
31 | self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
32 | self.total_size = self.num_samples * self.num_replicas
33 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
34 | self.shuffle = shuffle
35 | self.seed = seed
36 | self.repetitions = repetitions
37 |
38 | def __iter__(self):
39 | if self.shuffle:
40 | # Deterministically shuffle based on epoch
41 | g = torch.Generator()
42 | g.manual_seed(self.seed + self.epoch)
43 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
44 | else:
45 | indices = list(range(len(self.dataset)))
46 |
47 | # Add extra samples to make it evenly divisible
48 | indices = [ele for ele in indices for i in range(self.repetitions)]
49 | indices += indices[: (self.total_size - len(indices))]
50 | assert len(indices) == self.total_size
51 |
52 | # Subsample
53 | indices = indices[self.rank : self.total_size : self.num_replicas]
54 | assert len(indices) == self.num_samples
55 |
56 | return iter(indices[: self.num_selected_samples])
57 |
58 | def __len__(self):
59 | return self.num_selected_samples
60 |
61 | def set_epoch(self, epoch):
62 | self.epoch = epoch
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ADOPT: Modified Adam Can Converge with Any $β_2$ with the Optimal Rate
2 | Official Implementation of "[ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853)", which is presented at NeurIPS 2024.
3 |
4 | ## Update on Nov 22, 2024
5 |
6 | Based on feedbacks from some practitioners, we have updated the implementation to improve the stability of our ADOPT algorithm.
7 | In the original version, ADOPT sometimes gets unstable especially in the early stage of training.
8 | This seems to be because the near-zero division by the second memont estimate occurs when some elements of the parameter gradient are near zero at initialization.
9 | For example, when some parameters (e.g., the last layer of a neural net) are initialized with zero, which is often-used technique in deep learing, a near-zero gradient is observed at the first parameter update.
10 | To avoid such near-zero divisions, we have decided to add a clipping operation in the momentum update.
11 |
12 | 
13 |
14 | Even when clipping is applied, the convergence guarantee in theory is maintained by properly scheduling the clipping value (see the updated arXiv paper).
15 | In our implementation, the clipping value is controlled by the argument `clip_lambda`, which is a callable function that determines the schedule of the clipping value depending on the number of gradient steps.
16 | By default, the clipping value is set to `step**0.25`, which aligns with the theory to ensure the convergence.
17 | We observe that the clipped ADOPT works much more stably than the original one, so we recommend to use it over the unclipped version.
18 | If you want to reproduce the behaivior of the original version, you should set `clip_lambda = None`.
19 |
20 | ## Requirements
21 |
22 | ADOPT requires PyTorch 2.5.0 or later.
23 |
24 | ## Installation
25 |
26 | ```bash
27 | pip install torch-adopt
28 | ```
29 |
30 | ## Usage
31 |
32 | You can use ADOPT just like any other PyTorch optimizers by importing the `ADOPT` class.
33 |
34 | When you replace the `Adam` optimizer to our `ADOPT`, you should just replace the optimizer as follows:
35 |
36 | ```python3
37 | from adopt import ADOPT
38 | # optimizer = Adam(model.parameters(), lr=1e-3)
39 | optimizer = ADOPT(model.parameters(), lr=1e-3)
40 | ```
41 |
42 | When you are using `AdamW` as a default optimizer, you should set `decouple=True` for our `ADOPT`:
43 |
44 | ```python3
45 | # optimizer = AdamW(model.parameters(), lr=1e-3)
46 | optimizer = ADOPT(model.parameters(), lr=1e-3, decouple=True)
47 | ```
48 |
49 | ## Citation
50 | If you use ADOPT in your research, please cite the paper.
51 | ```text
52 | @inproceedings{taniguchi2024adopt,
53 | author={Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka},
54 | booktitle = {Advances in Neural Information Processing Systems},
55 | title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
56 | year = {2024}
57 | }
58 | ```
59 |
60 | ## License
61 | [Apache 2.0](./LICENSE)
62 |
--------------------------------------------------------------------------------
/gpt/data/openwebtext/prepare.py:
--------------------------------------------------------------------------------
1 | # saves the openwebtext dataset to a binary file for training. following was helpful:
2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
3 |
4 | import os
5 | from tqdm import tqdm
6 | import numpy as np
7 | import tiktoken
8 | from datasets import load_dataset # huggingface datasets
9 |
10 | # number of workers in .map() call
11 | # good number to use is ~order number of cpu cores // 2
12 | num_proc = 8
13 |
14 | # number of workers in load_dataset() call
15 | # best number might be different from num_proc above as it also depends on NW speed.
16 | # it is better than 1 usually though
17 | num_proc_load_dataset = num_proc
18 |
19 | enc = tiktoken.get_encoding("gpt2")
20 |
21 | if __name__ == '__main__':
22 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
23 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset)
24 |
25 | # owt by default only contains the 'train' split, so create a test split
26 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
27 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
28 |
29 | # this results in:
30 | # >>> split_dataset
31 | # DatasetDict({
32 | # train: Dataset({
33 | # features: ['text'],
34 | # num_rows: 8009762
35 | # })
36 | # val: Dataset({
37 | # features: ['text'],
38 | # num_rows: 4007
39 | # })
40 | # })
41 |
42 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
43 | def process(example):
44 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
45 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
46 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
47 | out = {'ids': ids, 'len': len(ids)}
48 | return out
49 |
50 | # tokenize the dataset
51 | tokenized = split_dataset.map(
52 | process,
53 | remove_columns=['text'],
54 | desc="tokenizing the splits",
55 | num_proc=num_proc,
56 | )
57 |
58 | # concatenate all the ids in each dataset into one large file we can use for training
59 | for split, dset in tokenized.items():
60 | arr_len = np.sum(dset['len'], dtype=np.uint64)
61 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
62 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
63 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
64 | total_batches = 1024
65 |
66 | idx = 0
67 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
68 | # Batch together samples for faster write
69 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
70 | arr_batch = np.concatenate(batch['ids'])
71 | # Write into mmap
72 | arr[idx : idx + len(arr_batch)] = arr_batch
73 | idx += len(arr_batch)
74 | arr.flush()
75 |
76 | # train.bin is ~17GB, val.bin ~8.5MB
77 | # train has ~9B tokens (9,035,582,198)
78 | # val has ~4M tokens (4,434,897)
79 |
80 | # to read the bin files later, e.g. with numpy:
81 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r')
82 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/gpt/sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Sample from a trained model
3 | """
4 | import os
5 | import pickle
6 | from contextlib import nullcontext
7 | import torch
8 | import tiktoken
9 | from model import GPTConfig, GPT
10 |
11 | # -----------------------------------------------------------------------------
12 | init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
13 | out_dir = 'out' # ignored if init_from is not 'resume'
14 | start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
15 | num_samples = 10 # number of samples to draw
16 | max_new_tokens = 500 # number of tokens generated in each sample
17 | temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
18 | top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
19 | seed = 1337
20 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
21 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
22 | compile = False # use PyTorch 2.0 to compile the model to be faster
23 | exec(open('configurator.py').read()) # overrides from command line or config file
24 | # -----------------------------------------------------------------------------
25 |
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
29 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
30 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
31 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
32 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
33 |
34 | # model
35 | if init_from == 'resume':
36 | # init from a model saved in a specific directory
37 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
38 | checkpoint = torch.load(ckpt_path, map_location=device)
39 | gptconf = GPTConfig(**checkpoint['model_args'])
40 | model = GPT(gptconf)
41 | state_dict = checkpoint['model']
42 | unwanted_prefix = '_orig_mod.'
43 | for k,v in list(state_dict.items()):
44 | if k.startswith(unwanted_prefix):
45 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
46 | model.load_state_dict(state_dict)
47 | elif init_from.startswith('gpt2'):
48 | # init from a given GPT-2 model
49 | model = GPT.from_pretrained(init_from, dict(dropout=0.0))
50 |
51 | model.eval()
52 | model.to(device)
53 | if compile:
54 | model = torch.compile(model) # requires PyTorch 2.0 (optional)
55 |
56 | # look for the meta pickle in case it is available in the dataset folder
57 | load_meta = False
58 | if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
59 | meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
60 | load_meta = os.path.exists(meta_path)
61 | if load_meta:
62 | print(f"Loading meta from {meta_path}...")
63 | with open(meta_path, 'rb') as f:
64 | meta = pickle.load(f)
65 | # TODO want to make this more general to arbitrary encoder/decoder schemes
66 | stoi, itos = meta['stoi'], meta['itos']
67 | encode = lambda s: [stoi[c] for c in s]
68 | decode = lambda l: ''.join([itos[i] for i in l])
69 | else:
70 | # ok let's assume gpt-2 encodings by default
71 | print("No meta.pkl found, assuming GPT-2 encodings...")
72 | enc = tiktoken.get_encoding("gpt2")
73 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
74 | decode = lambda l: enc.decode(l)
75 |
76 | # encode the beginning of the prompt
77 | if start.startswith('FILE:'):
78 | with open(start[5:], 'r', encoding='utf-8') as f:
79 | start = f.read()
80 | start_ids = encode(start)
81 | x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
82 |
83 | # run generation
84 | with torch.no_grad():
85 | with ctx:
86 | for k in range(num_samples):
87 | y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
88 | print(decode(y[0].tolist()))
89 | print('---------------')
90 |
--------------------------------------------------------------------------------
/imagenet/presets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms.functional import InterpolationMode
3 |
4 |
5 | def get_module(use_v2):
6 | # We need a protected import to avoid the V2 warning in case just V1 is used
7 | if use_v2:
8 | import torchvision.transforms.v2
9 |
10 | return torchvision.transforms.v2
11 | else:
12 | import torchvision.transforms
13 |
14 | return torchvision.transforms
15 |
16 |
17 | class ClassificationPresetTrain:
18 | # Note: this transform assumes that the input to forward() are always PIL
19 | # images, regardless of the backend parameter. We may change that in the
20 | # future though, if we change the output type from the dataset.
21 | def __init__(
22 | self,
23 | *,
24 | crop_size,
25 | mean=(0.485, 0.456, 0.406),
26 | std=(0.229, 0.224, 0.225),
27 | interpolation=InterpolationMode.BILINEAR,
28 | hflip_prob=0.5,
29 | auto_augment_policy=None,
30 | ra_magnitude=9,
31 | augmix_severity=3,
32 | random_erase_prob=0.0,
33 | backend="pil",
34 | use_v2=False,
35 | ):
36 | T = get_module(use_v2)
37 |
38 | transforms = []
39 | backend = backend.lower()
40 | if backend == "tensor":
41 | transforms.append(T.PILToTensor())
42 | elif backend != "pil":
43 | raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
44 |
45 | transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
46 | if hflip_prob > 0:
47 | transforms.append(T.RandomHorizontalFlip(hflip_prob))
48 | if auto_augment_policy is not None:
49 | if auto_augment_policy == "ra":
50 | transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
51 | elif auto_augment_policy == "ta_wide":
52 | transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
53 | elif auto_augment_policy == "augmix":
54 | transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
55 | else:
56 | aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
57 | transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))
58 |
59 | if backend == "pil":
60 | transforms.append(T.PILToTensor())
61 |
62 | transforms.extend(
63 | [
64 | T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
65 | T.Normalize(mean=mean, std=std),
66 | ]
67 | )
68 | if random_erase_prob > 0:
69 | transforms.append(T.RandomErasing(p=random_erase_prob))
70 |
71 | if use_v2:
72 | transforms.append(T.ToPureTensor())
73 |
74 | self.transforms = T.Compose(transforms)
75 |
76 | def __call__(self, img):
77 | return self.transforms(img)
78 |
79 |
80 | class ClassificationPresetEval:
81 | def __init__(
82 | self,
83 | *,
84 | crop_size,
85 | resize_size=256,
86 | mean=(0.485, 0.456, 0.406),
87 | std=(0.229, 0.224, 0.225),
88 | interpolation=InterpolationMode.BILINEAR,
89 | backend="pil",
90 | use_v2=False,
91 | ):
92 | T = get_module(use_v2)
93 | transforms = []
94 | backend = backend.lower()
95 | if backend == "tensor":
96 | transforms.append(T.PILToTensor())
97 | elif backend != "pil":
98 | raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
99 |
100 | transforms += [
101 | T.Resize(resize_size, interpolation=interpolation, antialias=True),
102 | T.CenterCrop(crop_size),
103 | ]
104 |
105 | if backend == "pil":
106 | transforms.append(T.PILToTensor())
107 |
108 | transforms += [
109 | T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
110 | T.Normalize(mean=mean, std=std),
111 | ]
112 |
113 | if use_v2:
114 | transforms.append(T.ToPureTensor())
115 |
116 | self.transforms = T.Compose(transforms)
117 |
118 | def __call__(self, img):
119 | return self.transforms(img)
120 |
--------------------------------------------------------------------------------
/gpt/bench.py:
--------------------------------------------------------------------------------
1 | """
2 | A much shorter version of train.py for benchmarking
3 | """
4 | import os
5 | from contextlib import nullcontext
6 | import numpy as np
7 | import time
8 | import torch
9 | from model import GPTConfig, GPT
10 |
11 | # -----------------------------------------------------------------------------
12 | batch_size = 12
13 | block_size = 1024
14 | bias = False
15 | real_data = True
16 | seed = 1337
17 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
18 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
19 | compile = True # use PyTorch 2.0 to compile the model to be faster
20 | profile = False # use pytorch profiler, or just simple benchmarking?
21 | exec(open('configurator.py').read()) # overrides from command line or config file
22 | # -----------------------------------------------------------------------------
23 |
24 | torch.manual_seed(seed)
25 | torch.cuda.manual_seed(seed)
26 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
29 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
30 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
31 |
32 | # data loading init
33 | if real_data:
34 | dataset = 'openwebtext'
35 | data_dir = os.path.join('data', dataset)
36 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
37 | def get_batch(split):
38 | data = train_data # note ignore split in benchmarking script
39 | ix = torch.randint(len(data) - block_size, (batch_size,))
40 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
41 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
42 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
43 | return x, y
44 | else:
45 | # alternatively, if fixed data is desired to not care about data loading
46 | x = torch.randint(50304, (batch_size, block_size), device=device)
47 | y = torch.randint(50304, (batch_size, block_size), device=device)
48 | get_batch = lambda split: (x, y)
49 |
50 | # model init
51 | gptconf = GPTConfig(
52 | block_size = block_size, # how far back does the model look? i.e. context size
53 | n_layer = 12, n_head = 12, n_embd = 768, # size of the model
54 | dropout = 0, # for determinism
55 | bias = bias,
56 | )
57 | model = GPT(gptconf)
58 | model.to(device)
59 |
60 | optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
61 |
62 | if compile:
63 | print("Compiling model...")
64 | model = torch.compile(model) # pytorch 2.0
65 |
66 | if profile:
67 | # useful docs on pytorch profiler:
68 | # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
69 | # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile
70 | wait, warmup, active = 5, 5, 5
71 | num_steps = wait + warmup + active
72 | with torch.profiler.profile(
73 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
74 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
75 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
76 | record_shapes=False,
77 | profile_memory=False,
78 | with_stack=False, # incurs an additional overhead, disable if not needed
79 | with_flops=True,
80 | with_modules=False, # only for torchscript models atm
81 | ) as prof:
82 |
83 | X, Y = get_batch('train')
84 | for k in range(num_steps):
85 | with ctx:
86 | logits, loss = model(X, Y)
87 | X, Y = get_batch('train')
88 | optimizer.zero_grad(set_to_none=True)
89 | loss.backward()
90 | optimizer.step()
91 | lossf = loss.item()
92 | print(f"{k}/{num_steps} loss: {lossf:.4f}")
93 |
94 | prof.step() # notify the profiler at end of each step
95 |
96 | else:
97 |
98 | # simple benchmarking
99 | torch.cuda.synchronize()
100 | for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
101 | t0 = time.time()
102 | X, Y = get_batch('train')
103 | for k in range(num_steps):
104 | with ctx:
105 | logits, loss = model(X, Y)
106 | X, Y = get_batch('train')
107 | optimizer.zero_grad(set_to_none=True)
108 | loss.backward()
109 | optimizer.step()
110 | lossf = loss.item()
111 | print(f"{k}/{num_steps} loss: {lossf:.4f}")
112 | torch.cuda.synchronize()
113 | t1 = time.time()
114 | dt = t1-t0
115 | mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt)
116 | if stage == 1:
117 | print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%")
118 |
--------------------------------------------------------------------------------
/imagenet/transforms.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple
3 |
4 | import torch
5 | from presets import get_module
6 | from torch import Tensor
7 | from torchvision.transforms import functional as F
8 |
9 |
10 | def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
11 | transforms_module = get_module(use_v2)
12 |
13 | mixup_cutmix = []
14 | if mixup_alpha > 0:
15 | mixup_cutmix.append(
16 | transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes)
17 | if use_v2
18 | else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
19 | )
20 | if cutmix_alpha > 0:
21 | mixup_cutmix.append(
22 | transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes)
23 | if use_v2
24 | else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha)
25 | )
26 | if not mixup_cutmix:
27 | return None
28 |
29 | return transforms_module.RandomChoice(mixup_cutmix)
30 |
31 |
32 | class RandomMixUp(torch.nn.Module):
33 | """Randomly apply MixUp to the provided batch and targets.
34 | The class implements the data augmentations as described in the paper
35 | `"mixup: Beyond Empirical Risk Minimization" `_.
36 |
37 | Args:
38 | num_classes (int): number of classes used for one-hot encoding.
39 | p (float): probability of the batch being transformed. Default value is 0.5.
40 | alpha (float): hyperparameter of the Beta distribution used for mixup.
41 | Default value is 1.0.
42 | inplace (bool): boolean to make this transform inplace. Default set to False.
43 | """
44 |
45 | def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
46 | super().__init__()
47 |
48 | if num_classes < 1:
49 | raise ValueError(
50 | f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
51 | )
52 |
53 | if alpha <= 0:
54 | raise ValueError("Alpha param can't be zero.")
55 |
56 | self.num_classes = num_classes
57 | self.p = p
58 | self.alpha = alpha
59 | self.inplace = inplace
60 |
61 | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
62 | """
63 | Args:
64 | batch (Tensor): Float tensor of size (B, C, H, W)
65 | target (Tensor): Integer tensor of size (B, )
66 |
67 | Returns:
68 | Tensor: Randomly transformed batch.
69 | """
70 | if batch.ndim != 4:
71 | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
72 | if target.ndim != 1:
73 | raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
74 | if not batch.is_floating_point():
75 | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
76 | if target.dtype != torch.int64:
77 | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
78 |
79 | if not self.inplace:
80 | batch = batch.clone()
81 | target = target.clone()
82 |
83 | if target.ndim == 1:
84 | target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
85 |
86 | if torch.rand(1).item() >= self.p:
87 | return batch, target
88 |
89 | # It's faster to roll the batch by one instead of shuffling it to create image pairs
90 | batch_rolled = batch.roll(1, 0)
91 | target_rolled = target.roll(1, 0)
92 |
93 | # Implemented as on mixup paper, page 3.
94 | lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
95 | batch_rolled.mul_(1.0 - lambda_param)
96 | batch.mul_(lambda_param).add_(batch_rolled)
97 |
98 | target_rolled.mul_(1.0 - lambda_param)
99 | target.mul_(lambda_param).add_(target_rolled)
100 |
101 | return batch, target
102 |
103 | def __repr__(self) -> str:
104 | s = (
105 | f"{self.__class__.__name__}("
106 | f"num_classes={self.num_classes}"
107 | f", p={self.p}"
108 | f", alpha={self.alpha}"
109 | f", inplace={self.inplace}"
110 | f")"
111 | )
112 | return s
113 |
114 |
115 | class RandomCutMix(torch.nn.Module):
116 | """Randomly apply CutMix to the provided batch and targets.
117 | The class implements the data augmentations as described in the paper
118 | `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
119 | `_.
120 |
121 | Args:
122 | num_classes (int): number of classes used for one-hot encoding.
123 | p (float): probability of the batch being transformed. Default value is 0.5.
124 | alpha (float): hyperparameter of the Beta distribution used for cutmix.
125 | Default value is 1.0.
126 | inplace (bool): boolean to make this transform inplace. Default set to False.
127 | """
128 |
129 | def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
130 | super().__init__()
131 | if num_classes < 1:
132 | raise ValueError("Please provide a valid positive value for the num_classes.")
133 | if alpha <= 0:
134 | raise ValueError("Alpha param can't be zero.")
135 |
136 | self.num_classes = num_classes
137 | self.p = p
138 | self.alpha = alpha
139 | self.inplace = inplace
140 |
141 | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
142 | """
143 | Args:
144 | batch (Tensor): Float tensor of size (B, C, H, W)
145 | target (Tensor): Integer tensor of size (B, )
146 |
147 | Returns:
148 | Tensor: Randomly transformed batch.
149 | """
150 | if batch.ndim != 4:
151 | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
152 | if target.ndim != 1:
153 | raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
154 | if not batch.is_floating_point():
155 | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
156 | if target.dtype != torch.int64:
157 | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
158 |
159 | if not self.inplace:
160 | batch = batch.clone()
161 | target = target.clone()
162 |
163 | if target.ndim == 1:
164 | target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
165 |
166 | if torch.rand(1).item() >= self.p:
167 | return batch, target
168 |
169 | # It's faster to roll the batch by one instead of shuffling it to create image pairs
170 | batch_rolled = batch.roll(1, 0)
171 | target_rolled = target.roll(1, 0)
172 |
173 | # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
174 | lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
175 | _, H, W = F.get_dimensions(batch)
176 |
177 | r_x = torch.randint(W, (1,))
178 | r_y = torch.randint(H, (1,))
179 |
180 | r = 0.5 * math.sqrt(1.0 - lambda_param)
181 | r_w_half = int(r * W)
182 | r_h_half = int(r * H)
183 |
184 | x1 = int(torch.clamp(r_x - r_w_half, min=0))
185 | y1 = int(torch.clamp(r_y - r_h_half, min=0))
186 | x2 = int(torch.clamp(r_x + r_w_half, max=W))
187 | y2 = int(torch.clamp(r_y + r_h_half, max=H))
188 |
189 | batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
190 | lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
191 |
192 | target_rolled.mul_(1.0 - lambda_param)
193 | target.mul_(lambda_param).add_(target_rolled)
194 |
195 | return batch, target
196 |
197 | def __repr__(self) -> str:
198 | s = (
199 | f"{self.__class__.__name__}("
200 | f"num_classes={self.num_classes}"
201 | f", p={self.p}"
202 | f", alpha={self.alpha}"
203 | f", inplace={self.inplace}"
204 | f")"
205 | )
206 | return s
207 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/imagenet/train_quantization.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import datetime
3 | import os
4 | import time
5 |
6 | import torch
7 | import torch.ao.quantization
8 | import torch.utils.data
9 | import torchvision
10 | import utils
11 | from torch import nn
12 | from train import evaluate, load_data, train_one_epoch
13 |
14 |
15 | def main(args):
16 | if args.output_dir:
17 | utils.mkdir(args.output_dir)
18 |
19 | utils.init_distributed_mode(args)
20 | print(args)
21 |
22 | if args.post_training_quantize and args.distributed:
23 | raise RuntimeError("Post training quantization example should not be performed on distributed mode")
24 |
25 | # Set backend engine to ensure that quantized model runs on the correct kernels
26 | if args.qbackend not in torch.backends.quantized.supported_engines:
27 | raise RuntimeError("Quantized backend not supported: " + str(args.qbackend))
28 | torch.backends.quantized.engine = args.qbackend
29 |
30 | device = torch.device(args.device)
31 | torch.backends.cudnn.benchmark = True
32 |
33 | # Data loading code
34 | print("Loading data")
35 | train_dir = os.path.join(args.data_path, "train")
36 | val_dir = os.path.join(args.data_path, "val")
37 |
38 | dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
39 | data_loader = torch.utils.data.DataLoader(
40 | dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
41 | )
42 |
43 | data_loader_test = torch.utils.data.DataLoader(
44 | dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
45 | )
46 |
47 | print("Creating model", args.model)
48 | # when training quantized models, we always start from a pre-trained fp32 reference model
49 | prefix = "quantized_"
50 | model_name = args.model
51 | if not model_name.startswith(prefix):
52 | model_name = prefix + model_name
53 | model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
54 | model.to(device)
55 |
56 | if not (args.test_only or args.post_training_quantize):
57 | model.fuse_model(is_qat=True)
58 | model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend)
59 | torch.ao.quantization.prepare_qat(model, inplace=True)
60 |
61 | if args.distributed and args.sync_bn:
62 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
63 |
64 | optimizer = torch.optim.SGD(
65 | model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
66 | )
67 |
68 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
69 |
70 | criterion = nn.CrossEntropyLoss()
71 | model_without_ddp = model
72 | if args.distributed:
73 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
74 | model_without_ddp = model.module
75 |
76 | if args.resume:
77 | checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
78 | model_without_ddp.load_state_dict(checkpoint["model"])
79 | optimizer.load_state_dict(checkpoint["optimizer"])
80 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
81 | args.start_epoch = checkpoint["epoch"] + 1
82 |
83 | if args.post_training_quantize:
84 | # perform calibration on a subset of the training dataset
85 | # for that, create a subset of the training dataset
86 | ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
87 | data_loader_calibration = torch.utils.data.DataLoader(
88 | ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
89 | )
90 | model.eval()
91 | model.fuse_model(is_qat=False)
92 | model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend)
93 | torch.ao.quantization.prepare(model, inplace=True)
94 | # Calibrate first
95 | print("Calibrating")
96 | evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
97 | torch.ao.quantization.convert(model, inplace=True)
98 | if args.output_dir:
99 | print("Saving quantized model")
100 | if utils.is_main_process():
101 | torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
102 | print("Evaluating post-training quantized model")
103 | evaluate(model, criterion, data_loader_test, device=device)
104 | return
105 |
106 | if args.test_only:
107 | evaluate(model, criterion, data_loader_test, device=device)
108 | return
109 |
110 | model.apply(torch.ao.quantization.enable_observer)
111 | model.apply(torch.ao.quantization.enable_fake_quant)
112 | start_time = time.time()
113 | for epoch in range(args.start_epoch, args.epochs):
114 | if args.distributed:
115 | train_sampler.set_epoch(epoch)
116 | print("Starting training for epoch", epoch)
117 | train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
118 | lr_scheduler.step()
119 | with torch.inference_mode():
120 | if epoch >= args.num_observer_update_epochs:
121 | print("Disabling observer for subseq epochs, epoch = ", epoch)
122 | model.apply(torch.ao.quantization.disable_observer)
123 | if epoch >= args.num_batch_norm_update_epochs:
124 | print("Freezing BN for subseq epochs, epoch = ", epoch)
125 | model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
126 | print("Evaluate QAT model")
127 |
128 | evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
129 | quantized_eval_model = copy.deepcopy(model_without_ddp)
130 | quantized_eval_model.eval()
131 | quantized_eval_model.to(torch.device("cpu"))
132 | torch.ao.quantization.convert(quantized_eval_model, inplace=True)
133 |
134 | print("Evaluate Quantized model")
135 | evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
136 |
137 | model.train()
138 |
139 | if args.output_dir:
140 | checkpoint = {
141 | "model": model_without_ddp.state_dict(),
142 | "eval_model": quantized_eval_model.state_dict(),
143 | "optimizer": optimizer.state_dict(),
144 | "lr_scheduler": lr_scheduler.state_dict(),
145 | "epoch": epoch,
146 | "args": args,
147 | }
148 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
149 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
150 | print("Saving models after epoch ", epoch)
151 |
152 | total_time = time.time() - start_time
153 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
154 | print(f"Training time {total_time_str}")
155 |
156 |
157 | def get_args_parser(add_help=True):
158 | import argparse
159 |
160 | parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)
161 |
162 | parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
163 | parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
164 | parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack")
165 | parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
166 |
167 | parser.add_argument(
168 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
169 | )
170 | parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
171 | parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
172 | parser.add_argument(
173 | "--num-observer-update-epochs",
174 | default=4,
175 | type=int,
176 | metavar="N",
177 | help="number of total epochs to update observers",
178 | )
179 | parser.add_argument(
180 | "--num-batch-norm-update-epochs",
181 | default=3,
182 | type=int,
183 | metavar="N",
184 | help="number of total epochs to update batch norm stats",
185 | )
186 | parser.add_argument(
187 | "--num-calibration-batches",
188 | default=32,
189 | type=int,
190 | metavar="N",
191 | help="number of batches of training set for \
192 | observer calibration ",
193 | )
194 |
195 | parser.add_argument(
196 | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
197 | )
198 | parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
199 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
200 | parser.add_argument(
201 | "--wd",
202 | "--weight-decay",
203 | default=1e-4,
204 | type=float,
205 | metavar="W",
206 | help="weight decay (default: 1e-4)",
207 | dest="weight_decay",
208 | )
209 | parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
210 | parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
211 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
212 | parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
213 | parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
214 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
215 | parser.add_argument(
216 | "--cache-dataset",
217 | dest="cache_dataset",
218 | help="Cache the datasets for quicker initialization. \
219 | It also serializes the transforms",
220 | action="store_true",
221 | )
222 | parser.add_argument(
223 | "--sync-bn",
224 | dest="sync_bn",
225 | help="Use sync batch norm",
226 | action="store_true",
227 | )
228 | parser.add_argument(
229 | "--test-only",
230 | dest="test_only",
231 | help="Only test the model",
232 | action="store_true",
233 | )
234 | parser.add_argument(
235 | "--post-training-quantize",
236 | dest="post_training_quantize",
237 | help="Post training quantize the model",
238 | action="store_true",
239 | )
240 |
241 | # distributed training parameters
242 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
243 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
244 |
245 | parser.add_argument(
246 | "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
247 | )
248 | parser.add_argument(
249 | "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
250 | )
251 | parser.add_argument(
252 | "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
253 | )
254 | parser.add_argument(
255 | "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
256 | )
257 | parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
258 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
259 |
260 | parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
261 | parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
262 |
263 | return parser
264 |
265 |
266 | if __name__ == "__main__":
267 | args = get_args_parser().parse_args()
268 | if args.backend in ("fbgemm", "qnnpack"):
269 | raise ValueError(
270 | "The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) "
271 | "instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend."
272 | )
273 | main(args)
274 |
--------------------------------------------------------------------------------
/gpt/train.py:
--------------------------------------------------------------------------------
1 | """
2 | This training script can be run both on a single gpu in debug mode,
3 | and also in a larger training run with distributed data parallel (ddp).
4 |
5 | To run on a single GPU, example:
6 | $ python train.py --batch_size=32 --compile=False
7 |
8 | To run with DDP on 4 gpus on 1 node, example:
9 | $ torchrun --standalone --nproc_per_node=4 train.py
10 |
11 | To run with DDP on 4 gpus across 2 nodes, example:
12 | - Run on the first (master) node with example IP 123.456.123.456:
13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
14 | - Run on the worker node:
15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
17 | """
18 |
19 | import os
20 | import time
21 | import math
22 | import pickle
23 | from contextlib import nullcontext
24 |
25 | import numpy as np
26 | import torch
27 | from torch.nn.parallel import DistributedDataParallel as DDP
28 | from torch.distributed import init_process_group, destroy_process_group
29 |
30 | from model import GPTConfig, GPT
31 |
32 |
33 | # -----------------------------------------------------------------------------
34 | # default config values designed to train a gpt2 (124M) on OpenWebText
35 | # I/O
36 | out_dir = 'out'
37 | eval_interval = 2000
38 | log_interval = 1
39 | eval_iters = 200
40 | eval_only = False # if True, script exits right after the first eval
41 | always_save_checkpoint = True # if True, always save a checkpoint after each eval
42 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
43 | # wandb logging
44 | wandb_log = False # disabled by default
45 | wandb_project = 'owt'
46 | wandb_run_name = 'gpt2' # 'run' + str(time.time())
47 | # data
48 | dataset = 'openwebtext'
49 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
50 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
51 | block_size = 1024
52 | # model
53 | n_layer = 12
54 | n_head = 12
55 | n_embd = 768
56 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
57 | bias = False # do we use bias inside LayerNorm and Linear layers?
58 | # optimizer
59 | optim_type = 'adamw'
60 | learning_rate = 6e-4 # max learning rate
61 | max_iters = 600000 # total number of training iterations
62 | weight_decay = 1e-1
63 | beta1 = 0.9
64 | beta2 = 0.95
65 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
66 | # learning rate decay settings
67 | decay_lr = True # whether to decay the learning rate
68 | warmup_iters = 2000 # how many steps to warm up for
69 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
70 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
71 | # DDP settings
72 | backend = 'nccl' # 'nccl', 'gloo', etc.
73 | # system
74 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
75 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
76 | compile = True # use PyTorch 2.0 to compile the model to be faster
77 | # -----------------------------------------------------------------------------
78 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
79 | exec(open('configurator.py').read()) # overrides from command line or config file
80 | config = {k: globals()[k] for k in config_keys} # will be useful for logging
81 | # -----------------------------------------------------------------------------
82 |
83 | # various inits, derived attributes, I/O setup
84 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
85 | if ddp:
86 | init_process_group(backend=backend)
87 | ddp_rank = int(os.environ['RANK'])
88 | ddp_local_rank = int(os.environ['LOCAL_RANK'])
89 | ddp_world_size = int(os.environ['WORLD_SIZE'])
90 | device = f'cuda:{ddp_local_rank}'
91 | torch.cuda.set_device(device)
92 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
93 | seed_offset = ddp_rank # each process gets a different seed
94 | # world_size number of processes will be training simultaneously, so we can scale
95 | # down the desired gradient accumulation iterations per process proportionally
96 | assert gradient_accumulation_steps % ddp_world_size == 0
97 | gradient_accumulation_steps //= ddp_world_size
98 | else:
99 | # if not ddp, we are running on a single gpu, and one process
100 | master_process = True
101 | seed_offset = 0
102 | ddp_world_size = 1
103 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
104 | print(f"tokens per iteration will be: {tokens_per_iter:,}")
105 |
106 | if master_process:
107 | os.makedirs(out_dir, exist_ok=True)
108 | torch.manual_seed(1337 + seed_offset)
109 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
110 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
111 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
112 | # note: float16 data type will automatically use a GradScaler
113 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
114 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
115 |
116 | # poor man's data loader
117 | data_dir = os.path.join('data', dataset)
118 | def get_batch(split):
119 | # We recreate np.memmap every batch to avoid a memory leak, as per
120 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
121 | if split == 'train':
122 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
123 | else:
124 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
125 | ix = torch.randint(len(data) - block_size, (batch_size,))
126 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
127 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
128 | if device_type == 'cuda':
129 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
130 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
131 | else:
132 | x, y = x.to(device), y.to(device)
133 | return x, y
134 |
135 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
136 | iter_num = 0
137 | best_val_loss = 1e9
138 |
139 | # attempt to derive vocab_size from the dataset
140 | meta_path = os.path.join(data_dir, 'meta.pkl')
141 | meta_vocab_size = None
142 | if os.path.exists(meta_path):
143 | with open(meta_path, 'rb') as f:
144 | meta = pickle.load(f)
145 | meta_vocab_size = meta['vocab_size']
146 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
147 |
148 | # model init
149 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
150 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
151 | if init_from == 'scratch':
152 | # init a new model from scratch
153 | print("Initializing a new model from scratch")
154 | # determine the vocab size we'll use for from-scratch training
155 | if meta_vocab_size is None:
156 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
157 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
158 | gptconf = GPTConfig(**model_args)
159 | model = GPT(gptconf)
160 | elif init_from == 'resume':
161 | print(f"Resuming training from {out_dir}")
162 | # resume training from a checkpoint.
163 | ckpt_path = os.path.join(out_dir, 'ckpt.pt')
164 | checkpoint = torch.load(ckpt_path, map_location=device)
165 | checkpoint_model_args = checkpoint['model_args']
166 | # force these config attributes to be equal otherwise we can't even resume training
167 | # the rest of the attributes (e.g. dropout) can stay as desired from command line
168 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
169 | model_args[k] = checkpoint_model_args[k]
170 | # create the model
171 | gptconf = GPTConfig(**model_args)
172 | model = GPT(gptconf)
173 | state_dict = checkpoint['model']
174 | # fix the keys of the state dictionary :(
175 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more
176 | unwanted_prefix = '_orig_mod.'
177 | for k,v in list(state_dict.items()):
178 | if k.startswith(unwanted_prefix):
179 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
180 | model.load_state_dict(state_dict)
181 | iter_num = checkpoint['iter_num']
182 | best_val_loss = checkpoint['best_val_loss']
183 | elif init_from.startswith('gpt2'):
184 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
185 | # initialize from OpenAI GPT-2 weights
186 | override_args = dict(dropout=dropout)
187 | model = GPT.from_pretrained(init_from, override_args)
188 | # read off the created config params, so we can store them into checkpoint correctly
189 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
190 | model_args[k] = getattr(model.config, k)
191 | # crop down the model block size if desired, using model surgery
192 | if block_size < model.config.block_size:
193 | model.crop_block_size(block_size)
194 | model_args['block_size'] = block_size # so that the checkpoint will have the right value
195 | model.to(device)
196 |
197 | # initialize a GradScaler. If enabled=False scaler is a no-op
198 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
199 |
200 | # optimizer
201 | optimizer = model.configure_optimizers(optim_type, weight_decay, learning_rate, (beta1, beta2), device_type)
202 | if init_from == 'resume':
203 | optimizer.load_state_dict(checkpoint['optimizer'])
204 | checkpoint = None # free up memory
205 |
206 | # compile the model
207 | if compile:
208 | print("compiling the model... (takes a ~minute)")
209 | unoptimized_model = model
210 | model = torch.compile(model) # requires PyTorch 2.0
211 |
212 | # wrap model into DDP container
213 | if ddp:
214 | model = DDP(model, device_ids=[ddp_local_rank])
215 |
216 | # helps estimate an arbitrarily accurate loss over either split using many batches
217 | @torch.no_grad()
218 | def estimate_loss():
219 | out = {}
220 | model.eval()
221 | for split in ['train', 'val']:
222 | losses = torch.zeros(eval_iters)
223 | for k in range(eval_iters):
224 | X, Y = get_batch(split)
225 | with ctx:
226 | logits, loss = model(X, Y)
227 | losses[k] = loss.item()
228 | out[split] = losses.mean()
229 | model.train()
230 | return out
231 |
232 | # learning rate decay scheduler (cosine with warmup)
233 | def get_lr(it):
234 | # 1) linear warmup for warmup_iters steps
235 | if it < warmup_iters:
236 | return learning_rate * it / warmup_iters
237 | # 2) if it > lr_decay_iters, return min learning rate
238 | if it > lr_decay_iters:
239 | return min_lr
240 | # 3) in between, use cosine decay down to min learning rate
241 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
242 | assert 0 <= decay_ratio <= 1
243 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
244 | return min_lr + coeff * (learning_rate - min_lr)
245 |
246 | # logging
247 | if wandb_log and master_process:
248 | import wandb
249 | wandb.init(project=wandb_project, name=wandb_run_name, config=config)
250 |
251 | # training loop
252 | X, Y = get_batch('train') # fetch the very first batch
253 | t0 = time.time()
254 | local_iter_num = 0 # number of iterations in the lifetime of this process
255 | raw_model = model.module if ddp else model # unwrap DDP container if needed
256 | running_mfu = -1.0
257 | while True:
258 |
259 | # determine and set the learning rate for this iteration
260 | lr = get_lr(iter_num) if decay_lr else learning_rate
261 | for param_group in optimizer.param_groups:
262 | param_group['lr'] = lr
263 |
264 | # evaluate the loss on train/val sets and write checkpoints
265 | if iter_num % eval_interval == 0 and master_process:
266 | losses = estimate_loss()
267 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
268 | if wandb_log:
269 | wandb.log({
270 | "iter": iter_num,
271 | "train/loss": losses['train'],
272 | "val/loss": losses['val'],
273 | "lr": lr,
274 | "mfu": running_mfu*100, # convert to percentage
275 | })
276 | if losses['val'] < best_val_loss or always_save_checkpoint:
277 | best_val_loss = losses['val']
278 | if iter_num > 0:
279 | checkpoint = {
280 | 'model': raw_model.state_dict(),
281 | 'optimizer': optimizer.state_dict(),
282 | 'model_args': model_args,
283 | 'iter_num': iter_num,
284 | 'best_val_loss': best_val_loss,
285 | 'config': config,
286 | }
287 | print(f"saving checkpoint to {out_dir}")
288 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
289 | if iter_num == 0 and eval_only:
290 | break
291 |
292 | # forward backward update, with optional gradient accumulation to simulate larger batch size
293 | # and using the GradScaler if data type is float16
294 | for micro_step in range(gradient_accumulation_steps):
295 | if ddp:
296 | # in DDP training we only need to sync gradients at the last micro step.
297 | # the official way to do this is with model.no_sync() context manager, but
298 | # I really dislike that this bloats the code and forces us to repeat code
299 | # looking at the source of that context manager, it just toggles this variable
300 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
301 | with ctx:
302 | logits, loss = model(X, Y)
303 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
304 | # immediately async prefetch next batch while model is doing the forward pass on the GPU
305 | X, Y = get_batch('train')
306 | # backward pass, with gradient scaling if training in fp16
307 | scaler.scale(loss).backward()
308 | # clip the gradient
309 | if grad_clip != 0.0:
310 | scaler.unscale_(optimizer)
311 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
312 | # step the optimizer and scaler if training in fp16
313 | scaler.step(optimizer)
314 | scaler.update()
315 | # flush the gradients as soon as we can, no need for this memory anymore
316 | optimizer.zero_grad(set_to_none=True)
317 |
318 | # timing and logging
319 | t1 = time.time()
320 | dt = t1 - t0
321 | t0 = t1
322 | if iter_num % log_interval == 0 and master_process:
323 | # get loss as float. note: this is a CPU-GPU sync point
324 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
325 | lossf = loss.item() * gradient_accumulation_steps
326 | if local_iter_num >= 5: # let the training loop settle a bit
327 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
328 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
329 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
330 | iter_num += 1
331 | local_iter_num += 1
332 |
333 | # termination conditions
334 | if iter_num > max_iters:
335 | break
336 |
337 | if ddp:
338 | destroy_process_group()
339 |
--------------------------------------------------------------------------------
/gpt/transformer_sizing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "### Transformer Theoretical Model\n",
9 | "\n",
10 | "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc."
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from collections import OrderedDict"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "# config_args = {\n",
29 | "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n",
30 | "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n",
31 | "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n",
32 | "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n",
33 | "# }[model_type]\n",
34 | "\n",
35 | "block_size = 1024\n",
36 | "vocab_size = 50257\n",
37 | "n_layer = 12\n",
38 | "n_head = 12\n",
39 | "n_embd = 768\n",
40 | "bias = False\n",
41 | "assert not bias, \"this notebook assumes bias=False just for simplicity\""
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 3,
47 | "metadata": {},
48 | "outputs": [
49 | {
50 | "name": "stdout",
51 | "output_type": "stream",
52 | "text": [
53 | "we see: 124337664, expected: 124337664, match: True\n",
54 | "name params ratio (%) \n",
55 | "emebedding/position 786432 0.6325\n",
56 | "embedding/token 38597376 31.0424\n",
57 | "embedding 39383808 31.6749\n",
58 | "attention/ln 768 0.0006\n",
59 | "attention/kqv 1769472 1.4231\n",
60 | "attention/proj 589824 0.4744\n",
61 | "attention 2360064 1.8981\n",
62 | "mlp/ln 768 0.0006\n",
63 | "mlp/ffw 2359296 1.8975\n",
64 | "mlp/proj 2359296 1.8975\n",
65 | "mlp 4719360 3.7956\n",
66 | "block 7079424 5.6937\n",
67 | "transformer 84953088 68.3245\n",
68 | "ln_f 768 0.0006\n",
69 | "dense 0 0.0000\n",
70 | "total 124337664 100.0000\n"
71 | ]
72 | }
73 | ],
74 | "source": [
75 | "def params():\n",
76 | " \"\"\" estimates the number of parameters in the model\"\"\"\n",
77 | " out = OrderedDict()\n",
78 | "\n",
79 | " # token and position embeddings\n",
80 | " out['emebedding/position'] = n_embd * block_size\n",
81 | " out['embedding/token'] = n_embd * vocab_size\n",
82 | " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n",
83 | "\n",
84 | " # attention blocks\n",
85 | " out['attention/ln'] = n_embd # note, bias=False in our LN\n",
86 | " out['attention/kqv'] = n_embd * 3*n_embd\n",
87 | " out['attention/proj'] = n_embd**2\n",
88 | " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n",
89 | "\n",
90 | " # MLP blocks\n",
91 | " ffw_size = 4*n_embd # feed forward size\n",
92 | " out['mlp/ln'] = n_embd\n",
93 | " out['mlp/ffw'] = n_embd * ffw_size\n",
94 | " out['mlp/proj'] = ffw_size * n_embd\n",
95 | " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n",
96 | " \n",
97 | " # the transformer and the rest of it\n",
98 | " out['block'] = out['attention'] + out['mlp']\n",
99 | " out['transformer'] = n_layer * out['block']\n",
100 | " out['ln_f'] = n_embd # final layernorm\n",
101 | " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n",
102 | "\n",
103 | " # total\n",
104 | " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n",
105 | "\n",
106 | " return out\n",
107 | "\n",
108 | "# compare our param count to that reported by PyTorch\n",
109 | "p = params()\n",
110 | "params_total = p['total']\n",
111 | "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n",
112 | "# create a header\n",
113 | "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n",
114 | "for k,v in p.items():\n",
115 | " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n",
116 | " "
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 4,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "name": "stdout",
126 | "output_type": "stream",
127 | "text": [
128 | "est checkpoint size: 1.49 GB\n",
129 | "measured with wc -c ckpt.pt: 1542470366\n",
130 | "fluff ratio: 103.38%\n"
131 | ]
132 | }
133 | ],
134 | "source": [
135 | "# we can now calculate the size of each checkpoint\n",
136 | "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n",
137 | "params_bytes = params_total*4\n",
138 | "params_and_buffers_bytes = params_bytes + 2*params_bytes\n",
139 | "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n",
140 | "measured_bytes = 1542470366 # from wc -c ckpt.pt\n",
141 | "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n",
142 | "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")"
143 | ]
144 | },
145 | {
146 | "attachments": {},
147 | "cell_type": "markdown",
148 | "metadata": {},
149 | "source": [
150 | "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 5,
156 | "metadata": {},
157 | "outputs": [
158 | {
159 | "name": "stdout",
160 | "output_type": "stream",
161 | "text": [
162 | "memory ratio taken up just for parameters: 3.73%\n"
163 | ]
164 | }
165 | ],
166 | "source": [
167 | "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n",
168 | "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")"
169 | ]
170 | },
171 | {
172 | "attachments": {},
173 | "cell_type": "markdown",
174 | "metadata": {},
175 | "source": [
176 | "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models."
177 | ]
178 | },
179 | {
180 | "attachments": {},
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "Let's estimate FLOPs for a single forward pass."
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 6,
190 | "metadata": {},
191 | "outputs": [
192 | {
193 | "name": "stdout",
194 | "output_type": "stream",
195 | "text": [
196 | "name flops ratio (%) \n",
197 | "attention/kqv 3623878656 1.2426\n",
198 | "attention/scores 1610612736 0.5522\n",
199 | "attention/reduce 1610612736 0.5522\n",
200 | "attention/proj 1207959552 0.4142\n",
201 | "attention 8053063680 2.7612\n",
202 | "mlp/ffw1 4831838208 1.6567\n",
203 | "mlp/ffw2 4831838208 1.6567\n",
204 | "mlp 9663676416 3.3135\n",
205 | "block 17716740096 6.0747\n",
206 | "transformer 212600881152 72.8963\n",
207 | "dense 79047426048 27.1037\n",
208 | "forward_total 291648307200 100.0000\n",
209 | "backward_total 583296614400 200.0000\n",
210 | "total 874944921600 300.0000\n"
211 | ]
212 | }
213 | ],
214 | "source": [
215 | "def flops():\n",
216 | " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n",
217 | " # we count actual FLOPs, not MACs. Hence 2* all over the place\n",
218 | " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n",
219 | "\n",
220 | " out = OrderedDict()\n",
221 | " head_size = n_embd // n_head\n",
222 | "\n",
223 | " # attention blocks\n",
224 | " # 1) the projection to key, query, values\n",
225 | " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n",
226 | " # 2) calculating the attention scores\n",
227 | " out['attention/scores'] = 2 * block_size * block_size * n_embd\n",
228 | " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
229 | " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n",
230 | " # 4) the final linear projection\n",
231 | " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n",
232 | " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n",
233 | "\n",
234 | " # MLP blocks\n",
235 | " ffw_size = 4*n_embd # feed forward size\n",
236 | " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n",
237 | " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n",
238 | " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n",
239 | "\n",
240 | " # the transformer and the rest of it\n",
241 | " out['block'] = out['attention'] + out['mlp']\n",
242 | " out['transformer'] = n_layer * out['block']\n",
243 | " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n",
244 | "\n",
245 | " # forward,backward,total\n",
246 | " out['forward_total'] = out['transformer'] + out['dense']\n",
247 | " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n",
248 | " out['total'] = out['forward_total'] + out['backward_total']\n",
249 | "\n",
250 | " return out\n",
251 | " \n",
252 | "# compare our param count to that reported by PyTorch\n",
253 | "f = flops()\n",
254 | "flops_total = f['forward_total']\n",
255 | "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n",
256 | "for k,v in f.items():\n",
257 | " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n",
258 | " "
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": 7,
264 | "metadata": {},
265 | "outputs": [
266 | {
267 | "name": "stdout",
268 | "output_type": "stream",
269 | "text": [
270 | "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n"
271 | ]
272 | }
273 | ],
274 | "source": [
275 | "# now here is an estimate copy pasted from the PaLM paper\n",
276 | "# this formula is often used to calculate MFU (model flops utilization)\n",
277 | "def palm_flops():\n",
278 | " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n",
279 | " # non-embedding model parameters. note that we do not subtract the\n",
280 | " # embedding/token params because those are tied and get used in the last layer.\n",
281 | " N = params()['total'] - params()['emebedding/position']\n",
282 | " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n",
283 | " mf_per_token = 6*N + 12*L*H*Q*T\n",
284 | " mf = mf_per_token * block_size\n",
285 | " return mf\n",
286 | "\n",
287 | "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")"
288 | ]
289 | },
290 | {
291 | "attachments": {},
292 | "cell_type": "markdown",
293 | "metadata": {},
294 | "source": [
295 | "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 8,
301 | "metadata": {},
302 | "outputs": [
303 | {
304 | "name": "stdout",
305 | "output_type": "stream",
306 | "text": [
307 | "fraction of A100 used: 37.14%\n"
308 | ]
309 | }
310 | ],
311 | "source": [
312 | "# here is what we currently roughly measure\n",
313 | "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n",
314 | "measured_time = 0.755 # in seconds per iteration\n",
315 | "measured_throughput = batch_size / measured_time\n",
316 | "flops_achieved = f['total'] * measured_throughput\n",
317 | "\n",
318 | "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n",
319 | "a100_flops_promised = 312e12\n",
320 | "\n",
321 | "# the fraction of the A100 that we are using:\n",
322 | "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")"
323 | ]
324 | },
325 | {
326 | "attachments": {},
327 | "cell_type": "markdown",
328 | "metadata": {},
329 | "source": [
330 | "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU."
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": 9,
336 | "metadata": {},
337 | "outputs": [
338 | {
339 | "name": "stdout",
340 | "output_type": "stream",
341 | "text": [
342 | "time needed to train the model: 3.46 days\n"
343 | ]
344 | }
345 | ],
346 | "source": [
347 | "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n",
348 | "model_size = params()['total'] # this is number of parameters, N\n",
349 | "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n",
350 | "a100_flops = 312e12 # 312 TFLOPS\n",
351 | "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n",
352 | "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n",
353 | "flops_needed = 6 * model_size * tokens_num # 6ND\n",
354 | "time_needed_s = flops_needed / flops_throughput # in seconds\n",
355 | "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")"
356 | ]
357 | },
358 | {
359 | "attachments": {},
360 | "cell_type": "markdown",
361 | "metadata": {},
362 | "source": [
363 | "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)."
364 | ]
365 | },
366 | {
367 | "attachments": {},
368 | "cell_type": "markdown",
369 | "metadata": {},
370 | "source": [
371 | "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later."
372 | ]
373 | }
374 | ],
375 | "metadata": {
376 | "kernelspec": {
377 | "display_name": "pytorch2",
378 | "language": "python",
379 | "name": "python3"
380 | },
381 | "language_info": {
382 | "codemirror_mode": {
383 | "name": "ipython",
384 | "version": 3
385 | },
386 | "file_extension": ".py",
387 | "mimetype": "text/x-python",
388 | "name": "python",
389 | "nbconvert_exporter": "python",
390 | "pygments_lexer": "ipython3",
391 | "version": "3.10.8"
392 | },
393 | "orig_nbformat": 4,
394 | "vscode": {
395 | "interpreter": {
396 | "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
397 | }
398 | }
399 | },
400 | "nbformat": 4,
401 | "nbformat_minor": 2
402 | }
403 |
--------------------------------------------------------------------------------
/imagenet/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import datetime
3 | import errno
4 | import hashlib
5 | import os
6 | import time
7 | from collections import defaultdict, deque, OrderedDict
8 | from typing import List, Optional, Tuple
9 |
10 | import torch
11 | import torch.distributed as dist
12 |
13 |
14 | class SmoothedValue:
15 | """Track a series of values and provide access to smoothed values over a
16 | window or the global series average.
17 | """
18 |
19 | def __init__(self, window_size=20, fmt=None):
20 | if fmt is None:
21 | fmt = "{median:.4f} ({global_avg:.4f})"
22 | self.deque = deque(maxlen=window_size)
23 | self.total = 0.0
24 | self.count = 0
25 | self.fmt = fmt
26 |
27 | def update(self, value, n=1):
28 | self.deque.append(value)
29 | self.count += n
30 | self.total += value * n
31 |
32 | def synchronize_between_processes(self):
33 | """
34 | Warning: does not synchronize the deque!
35 | """
36 | t = reduce_across_processes([self.count, self.total])
37 | t = t.tolist()
38 | self.count = int(t[0])
39 | self.total = t[1]
40 |
41 | @property
42 | def median(self):
43 | d = torch.tensor(list(self.deque))
44 | return d.median().item()
45 |
46 | @property
47 | def avg(self):
48 | d = torch.tensor(list(self.deque), dtype=torch.float32)
49 | return d.mean().item()
50 |
51 | @property
52 | def global_avg(self):
53 | return self.total / self.count
54 |
55 | @property
56 | def max(self):
57 | return max(self.deque)
58 |
59 | @property
60 | def value(self):
61 | return self.deque[-1]
62 |
63 | def __str__(self):
64 | return self.fmt.format(
65 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
66 | )
67 |
68 |
69 | class MetricLogger:
70 | def __init__(self, delimiter="\t"):
71 | self.meters = defaultdict(SmoothedValue)
72 | self.delimiter = delimiter
73 |
74 | def update(self, **kwargs):
75 | for k, v in kwargs.items():
76 | if isinstance(v, torch.Tensor):
77 | v = v.item()
78 | assert isinstance(v, (float, int))
79 | self.meters[k].update(v)
80 |
81 | def __getattr__(self, attr):
82 | if attr in self.meters:
83 | return self.meters[attr]
84 | if attr in self.__dict__:
85 | return self.__dict__[attr]
86 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
87 |
88 | def __str__(self):
89 | loss_str = []
90 | for name, meter in self.meters.items():
91 | loss_str.append(f"{name}: {str(meter)}")
92 | return self.delimiter.join(loss_str)
93 |
94 | def synchronize_between_processes(self):
95 | for meter in self.meters.values():
96 | meter.synchronize_between_processes()
97 |
98 | def add_meter(self, name, meter):
99 | self.meters[name] = meter
100 |
101 | def log_every(self, iterable, print_freq, header=None):
102 | i = 0
103 | if not header:
104 | header = ""
105 | start_time = time.time()
106 | end = time.time()
107 | iter_time = SmoothedValue(fmt="{avg:.4f}")
108 | data_time = SmoothedValue(fmt="{avg:.4f}")
109 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
110 | if torch.cuda.is_available():
111 | log_msg = self.delimiter.join(
112 | [
113 | header,
114 | "[{0" + space_fmt + "}/{1}]",
115 | "eta: {eta}",
116 | "{meters}",
117 | "time: {time}",
118 | "data: {data}",
119 | "max mem: {memory:.0f}",
120 | ]
121 | )
122 | else:
123 | log_msg = self.delimiter.join(
124 | [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
125 | )
126 | MB = 1024.0 * 1024.0
127 | for obj in iterable:
128 | data_time.update(time.time() - end)
129 | yield obj
130 | iter_time.update(time.time() - end)
131 | if i % print_freq == 0:
132 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
133 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
134 | if torch.cuda.is_available():
135 | print(
136 | log_msg.format(
137 | i,
138 | len(iterable),
139 | eta=eta_string,
140 | meters=str(self),
141 | time=str(iter_time),
142 | data=str(data_time),
143 | memory=torch.cuda.max_memory_allocated() / MB,
144 | )
145 | )
146 | else:
147 | print(
148 | log_msg.format(
149 | i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
150 | )
151 | )
152 | i += 1
153 | end = time.time()
154 | total_time = time.time() - start_time
155 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
156 | print(f"{header} Total time: {total_time_str}")
157 |
158 |
159 | class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
160 | """Maintains moving averages of model parameters using an exponential decay.
161 | ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
162 | `torch.optim.swa_utils.AveragedModel `_
163 | is used to compute the EMA.
164 | """
165 |
166 | def __init__(self, model, decay, device="cpu"):
167 | def ema_avg(avg_model_param, model_param, num_averaged):
168 | return decay * avg_model_param + (1 - decay) * model_param
169 |
170 | super().__init__(model, device, ema_avg, use_buffers=True)
171 |
172 |
173 | def accuracy(output, target, topk=(1,)):
174 | """Computes the accuracy over the k top predictions for the specified values of k"""
175 | with torch.inference_mode():
176 | maxk = max(topk)
177 | batch_size = target.size(0)
178 | if target.ndim == 2:
179 | target = target.max(dim=1)[1]
180 |
181 | _, pred = output.topk(maxk, 1, True, True)
182 | pred = pred.t()
183 | correct = pred.eq(target[None])
184 |
185 | res = []
186 | for k in topk:
187 | correct_k = correct[:k].flatten().sum(dtype=torch.float32)
188 | res.append(correct_k * (100.0 / batch_size))
189 | return res
190 |
191 |
192 | def mkdir(path):
193 | try:
194 | os.makedirs(path)
195 | except OSError as e:
196 | if e.errno != errno.EEXIST:
197 | raise
198 |
199 |
200 | def setup_for_distributed(is_master):
201 | """
202 | This function disables printing when not in master process
203 | """
204 | import builtins as __builtin__
205 |
206 | builtin_print = __builtin__.print
207 |
208 | def print(*args, **kwargs):
209 | force = kwargs.pop("force", False)
210 | if is_master or force:
211 | builtin_print(*args, **kwargs)
212 |
213 | __builtin__.print = print
214 |
215 |
216 | def is_dist_avail_and_initialized():
217 | if not dist.is_available():
218 | return False
219 | if not dist.is_initialized():
220 | return False
221 | return True
222 |
223 |
224 | def get_world_size():
225 | if not is_dist_avail_and_initialized():
226 | return 1
227 | return dist.get_world_size()
228 |
229 |
230 | def get_rank():
231 | if not is_dist_avail_and_initialized():
232 | return 0
233 | return dist.get_rank()
234 |
235 |
236 | def is_main_process():
237 | return get_rank() == 0
238 |
239 |
240 | def save_on_master(*args, **kwargs):
241 | if is_main_process():
242 | torch.save(*args, **kwargs)
243 |
244 |
245 | def init_distributed_mode(args):
246 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
247 | args.rank = int(os.environ["RANK"])
248 | args.world_size = int(os.environ["WORLD_SIZE"])
249 | args.gpu = int(os.environ["LOCAL_RANK"])
250 | elif "SLURM_PROCID" in os.environ:
251 | args.rank = int(os.environ["SLURM_PROCID"])
252 | args.gpu = args.rank % torch.cuda.device_count()
253 | elif hasattr(args, "rank"):
254 | pass
255 | else:
256 | print("Not using distributed mode")
257 | args.distributed = False
258 | return
259 |
260 | args.distributed = True
261 |
262 | torch.cuda.set_device(args.gpu)
263 | args.dist_backend = "nccl"
264 | print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
265 | torch.distributed.init_process_group(
266 | backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
267 | )
268 | torch.distributed.barrier()
269 | setup_for_distributed(args.rank == 0)
270 |
271 |
272 | def average_checkpoints(inputs):
273 | """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
274 | https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
275 |
276 | Args:
277 | inputs (List[str]): An iterable of string paths of checkpoints to load from.
278 | Returns:
279 | A dict of string keys mapping to various values. The 'model' key
280 | from the returned dict should correspond to an OrderedDict mapping
281 | string parameter names to torch Tensors.
282 | """
283 | params_dict = OrderedDict()
284 | params_keys = None
285 | new_state = None
286 | num_models = len(inputs)
287 | for fpath in inputs:
288 | with open(fpath, "rb") as f:
289 | state = torch.load(
290 | f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True
291 | )
292 | # Copies over the settings from the first checkpoint
293 | if new_state is None:
294 | new_state = state
295 | model_params = state["model"]
296 | model_params_keys = list(model_params.keys())
297 | if params_keys is None:
298 | params_keys = model_params_keys
299 | elif params_keys != model_params_keys:
300 | raise KeyError(
301 | f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
302 | )
303 | for k in params_keys:
304 | p = model_params[k]
305 | if isinstance(p, torch.HalfTensor):
306 | p = p.float()
307 | if k not in params_dict:
308 | params_dict[k] = p.clone()
309 | # NOTE: clone() is needed in case of p is a shared parameter
310 | else:
311 | params_dict[k] += p
312 | averaged_params = OrderedDict()
313 | for k, v in params_dict.items():
314 | averaged_params[k] = v
315 | if averaged_params[k].is_floating_point():
316 | averaged_params[k].div_(num_models)
317 | else:
318 | averaged_params[k] //= num_models
319 | new_state["model"] = averaged_params
320 | return new_state
321 |
322 |
323 | def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
324 | """
325 | This method can be used to prepare weights files for new models. It receives as
326 | input a model architecture and a checkpoint from the training script and produces
327 | a file with the weights ready for release.
328 |
329 | Examples:
330 | from torchvision import models as M
331 |
332 | # Classification
333 | model = M.mobilenet_v3_large(weights=None)
334 | print(store_model_weights(model, './class.pth'))
335 |
336 | # Quantized Classification
337 | model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
338 | model.fuse_model(is_qat=True)
339 | model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
340 | _ = torch.ao.quantization.prepare_qat(model, inplace=True)
341 | print(store_model_weights(model, './qat.pth'))
342 |
343 | # Object Detection
344 | model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
345 | print(store_model_weights(model, './obj.pth'))
346 |
347 | # Segmentation
348 | model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
349 | print(store_model_weights(model, './segm.pth', strict=False))
350 |
351 | Args:
352 | model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
353 | checkpoint_path (str): The path of the checkpoint we will load.
354 | checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
355 | Default: "model".
356 | strict (bool): whether to strictly enforce that the keys
357 | in :attr:`state_dict` match the keys returned by this module's
358 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
359 |
360 | Returns:
361 | output_path (str): The location where the weights are saved.
362 | """
363 | # Store the new model next to the checkpoint_path
364 | checkpoint_path = os.path.abspath(checkpoint_path)
365 | output_dir = os.path.dirname(checkpoint_path)
366 |
367 | # Deep copy to avoid side effects on the model object.
368 | model = copy.deepcopy(model)
369 | checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
370 |
371 | # Load the weights to the model to validate that everything works
372 | # and remove unnecessary weights (such as auxiliaries, etc.)
373 | if checkpoint_key == "model_ema":
374 | del checkpoint[checkpoint_key]["n_averaged"]
375 | torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
376 | model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
377 |
378 | tmp_path = os.path.join(output_dir, str(model.__hash__()))
379 | torch.save(model.state_dict(), tmp_path)
380 |
381 | sha256_hash = hashlib.sha256()
382 | with open(tmp_path, "rb") as f:
383 | # Read and update hash string value in blocks of 4K
384 | for byte_block in iter(lambda: f.read(4096), b""):
385 | sha256_hash.update(byte_block)
386 | hh = sha256_hash.hexdigest()
387 |
388 | output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
389 | os.replace(tmp_path, output_path)
390 |
391 | return output_path
392 |
393 |
394 | def reduce_across_processes(val):
395 | if not is_dist_avail_and_initialized():
396 | # nothing to sync, but we still convert to tensor for consistency with the distributed case.
397 | return torch.tensor(val)
398 |
399 | t = torch.tensor(val, device="cuda")
400 | dist.barrier()
401 | dist.all_reduce(t)
402 | return t
403 |
404 |
405 | def set_weight_decay(
406 | model: torch.nn.Module,
407 | weight_decay: float,
408 | norm_weight_decay: Optional[float] = None,
409 | norm_classes: Optional[List[type]] = None,
410 | custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
411 | ):
412 | if not norm_classes:
413 | norm_classes = [
414 | torch.nn.modules.batchnorm._BatchNorm,
415 | torch.nn.LayerNorm,
416 | torch.nn.GroupNorm,
417 | torch.nn.modules.instancenorm._InstanceNorm,
418 | torch.nn.LocalResponseNorm,
419 | ]
420 | norm_classes = tuple(norm_classes)
421 |
422 | params = {
423 | "other": [],
424 | "norm": [],
425 | }
426 | params_weight_decay = {
427 | "other": weight_decay,
428 | "norm": norm_weight_decay,
429 | }
430 | custom_keys = []
431 | if custom_keys_weight_decay is not None:
432 | for key, weight_decay in custom_keys_weight_decay:
433 | params[key] = []
434 | params_weight_decay[key] = weight_decay
435 | custom_keys.append(key)
436 |
437 | def _add_params(module, prefix=""):
438 | for name, p in module.named_parameters(recurse=False):
439 | if not p.requires_grad:
440 | continue
441 | is_custom_key = False
442 | for key in custom_keys:
443 | target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
444 | if key == target_name:
445 | params[key].append(p)
446 | is_custom_key = True
447 | break
448 | if not is_custom_key:
449 | if norm_weight_decay is not None and isinstance(module, norm_classes):
450 | params["norm"].append(p)
451 | else:
452 | params["other"].append(p)
453 |
454 | for child_name, child_module in module.named_children():
455 | child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
456 | _add_params(child_module, prefix=child_prefix)
457 |
458 | _add_params(model)
459 |
460 | param_groups = []
461 | for key in params:
462 | if len(params[key]) > 0:
463 | param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
464 | return param_groups
465 |
--------------------------------------------------------------------------------
/gpt/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Full definition of a GPT Language Model, all of it in this single file.
3 | References:
4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5 | https://github.com/openai/gpt-2/blob/master/src/model.py
6 | 2) huggingface/transformers PyTorch implementation:
7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8 | """
9 |
10 | import math
11 | import inspect
12 | from dataclasses import dataclass
13 | import sys
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch.nn import functional as F
18 |
19 | from adopt import ADOPT
20 |
21 |
22 | class LayerNorm(nn.Module):
23 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
24 |
25 | def __init__(self, ndim, bias):
26 | super().__init__()
27 | self.weight = nn.Parameter(torch.ones(ndim))
28 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
29 |
30 | def forward(self, input):
31 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
32 |
33 | class CausalSelfAttention(nn.Module):
34 |
35 | def __init__(self, config):
36 | super().__init__()
37 | assert config.n_embd % config.n_head == 0
38 | # key, query, value projections for all heads, but in a batch
39 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
40 | # output projection
41 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
42 | # regularization
43 | self.attn_dropout = nn.Dropout(config.dropout)
44 | self.resid_dropout = nn.Dropout(config.dropout)
45 | self.n_head = config.n_head
46 | self.n_embd = config.n_embd
47 | self.dropout = config.dropout
48 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
49 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
50 | if not self.flash:
51 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
52 | # causal mask to ensure that attention is only applied to the left in the input sequence
53 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
54 | .view(1, 1, config.block_size, config.block_size))
55 |
56 | def forward(self, x):
57 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
58 |
59 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
60 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
61 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
62 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
63 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
64 |
65 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
66 | if self.flash:
67 | # efficient attention using Flash Attention CUDA kernels
68 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
69 | else:
70 | # manual implementation of attention
71 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
72 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
73 | att = F.softmax(att, dim=-1)
74 | att = self.attn_dropout(att)
75 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
76 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
77 |
78 | # output projection
79 | y = self.resid_dropout(self.c_proj(y))
80 | return y
81 |
82 | class MLP(nn.Module):
83 |
84 | def __init__(self, config):
85 | super().__init__()
86 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
87 | self.gelu = nn.GELU()
88 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
89 | self.dropout = nn.Dropout(config.dropout)
90 |
91 | def forward(self, x):
92 | x = self.c_fc(x)
93 | x = self.gelu(x)
94 | x = self.c_proj(x)
95 | x = self.dropout(x)
96 | return x
97 |
98 | class Block(nn.Module):
99 |
100 | def __init__(self, config):
101 | super().__init__()
102 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
103 | self.attn = CausalSelfAttention(config)
104 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
105 | self.mlp = MLP(config)
106 |
107 | def forward(self, x):
108 | x = x + self.attn(self.ln_1(x))
109 | x = x + self.mlp(self.ln_2(x))
110 | return x
111 |
112 | @dataclass
113 | class GPTConfig:
114 | block_size: int = 1024
115 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
116 | n_layer: int = 12
117 | n_head: int = 12
118 | n_embd: int = 768
119 | dropout: float = 0.0
120 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
121 |
122 | class GPT(nn.Module):
123 |
124 | def __init__(self, config):
125 | super().__init__()
126 | assert config.vocab_size is not None
127 | assert config.block_size is not None
128 | self.config = config
129 |
130 | self.transformer = nn.ModuleDict(dict(
131 | wte = nn.Embedding(config.vocab_size, config.n_embd),
132 | wpe = nn.Embedding(config.block_size, config.n_embd),
133 | drop = nn.Dropout(config.dropout),
134 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
135 | ln_f = LayerNorm(config.n_embd, bias=config.bias),
136 | ))
137 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
138 | # with weight tying when using torch.compile() some warnings get generated:
139 | # "UserWarning: functional_call was passed multiple values for tied weights.
140 | # This behavior is deprecated and will be an error in future versions"
141 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
142 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
143 |
144 | # init all weights
145 | self.apply(self._init_weights)
146 | # apply special scaled init to the residual projections, per GPT-2 paper
147 | for pn, p in self.named_parameters():
148 | if pn.endswith('c_proj.weight'):
149 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
150 |
151 | # report number of parameters
152 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
153 |
154 | def get_num_params(self, non_embedding=True):
155 | """
156 | Return the number of parameters in the model.
157 | For non-embedding count (default), the position embeddings get subtracted.
158 | The token embeddings would too, except due to the parameter sharing these
159 | params are actually used as weights in the final layer, so we include them.
160 | """
161 | n_params = sum(p.numel() for p in self.parameters())
162 | if non_embedding:
163 | n_params -= self.transformer.wpe.weight.numel()
164 | return n_params
165 |
166 | def _init_weights(self, module):
167 | if isinstance(module, nn.Linear):
168 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169 | if module.bias is not None:
170 | torch.nn.init.zeros_(module.bias)
171 | elif isinstance(module, nn.Embedding):
172 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
173 |
174 | def forward(self, idx, targets=None):
175 | device = idx.device
176 | b, t = idx.size()
177 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
178 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
179 |
180 | # forward the GPT model itself
181 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
182 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
183 | x = self.transformer.drop(tok_emb + pos_emb)
184 | for block in self.transformer.h:
185 | x = block(x)
186 | x = self.transformer.ln_f(x)
187 |
188 | if targets is not None:
189 | # if we are given some desired targets also calculate the loss
190 | logits = self.lm_head(x)
191 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
192 | else:
193 | # inference-time mini-optimization: only forward the lm_head on the very last position
194 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
195 | loss = None
196 |
197 | return logits, loss
198 |
199 | def crop_block_size(self, block_size):
200 | # model surgery to decrease the block size if necessary
201 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
202 | # but want to use a smaller block size for some smaller, simpler model
203 | assert block_size <= self.config.block_size
204 | self.config.block_size = block_size
205 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
206 | for block in self.transformer.h:
207 | if hasattr(block.attn, 'bias'):
208 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
209 |
210 | @classmethod
211 | def from_pretrained(cls, model_type, override_args=None):
212 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
213 | override_args = override_args or {} # default to empty dict
214 | # only dropout can be overridden see more notes below
215 | assert all(k == 'dropout' for k in override_args)
216 | from transformers import GPT2LMHeadModel
217 | print("loading weights from pretrained gpt: %s" % model_type)
218 |
219 | # n_layer, n_head and n_embd are determined from model_type
220 | config_args = {
221 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
222 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
223 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
224 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
225 | }[model_type]
226 | print("forcing vocab_size=50257, block_size=1024, bias=True")
227 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
228 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
229 | config_args['bias'] = True # always True for GPT model checkpoints
230 | # we can override the dropout rate, if desired
231 | if 'dropout' in override_args:
232 | print(f"overriding dropout rate to {override_args['dropout']}")
233 | config_args['dropout'] = override_args['dropout']
234 | # create a from-scratch initialized minGPT model
235 | config = GPTConfig(**config_args)
236 | model = GPT(config)
237 | sd = model.state_dict()
238 | sd_keys = sd.keys()
239 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
240 |
241 | # init a huggingface/transformers model
242 | model_hf = GPT2LMHeadModel.from_pretrained(model_type)
243 | sd_hf = model_hf.state_dict()
244 |
245 | # copy while ensuring all of the parameters are aligned and match in names and shapes
246 | sd_keys_hf = sd_hf.keys()
247 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
248 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
249 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
250 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
251 | # this means that we have to transpose these weights when we import them
252 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
253 | for k in sd_keys_hf:
254 | if any(k.endswith(w) for w in transposed):
255 | # special treatment for the Conv1D weights we need to transpose
256 | assert sd_hf[k].shape[::-1] == sd[k].shape
257 | with torch.no_grad():
258 | sd[k].copy_(sd_hf[k].t())
259 | else:
260 | # vanilla copy over the other parameters
261 | assert sd_hf[k].shape == sd[k].shape
262 | with torch.no_grad():
263 | sd[k].copy_(sd_hf[k])
264 |
265 | return model
266 |
267 | def configure_optimizers(self, optim_type, weight_decay, learning_rate, betas, device_type):
268 | # start with all of the candidate parameters
269 | param_dict = {pn: p for pn, p in self.named_parameters()}
270 | # filter out those that do not require grad
271 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
272 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
273 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
274 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
275 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
276 | optim_groups = [
277 | {'params': decay_params, 'weight_decay': weight_decay},
278 | {'params': nodecay_params, 'weight_decay': 0.0}
279 | ]
280 | num_decay_params = sum(p.numel() for p in decay_params)
281 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
282 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
283 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
284 | if optim_type == 'adopt':
285 | optimizer = ADOPT(optim_groups, lr=learning_rate, decoupled=True)
286 | print(f"using ADOPT")
287 | else:
288 | # Create AdamW optimizer and use the fused version if it is available
289 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
290 | use_fused = fused_available and device_type == 'cuda'
291 | extra_args = dict(fused=True) if use_fused else dict()
292 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
293 | print(f"using fused AdamW: {use_fused}")
294 |
295 | return optimizer
296 |
297 | def estimate_mfu(self, fwdbwd_per_iter, dt):
298 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
299 | # first estimate the number of flops we do per iteration.
300 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
301 | N = self.get_num_params()
302 | cfg = self.config
303 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
304 | flops_per_token = 6*N + 12*L*H*Q*T
305 | flops_per_fwdbwd = flops_per_token * T
306 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
307 | # express our flops throughput as ratio of A100 bfloat16 peak flops
308 | flops_achieved = flops_per_iter * (1.0/dt) # per second
309 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
310 | mfu = flops_achieved / flops_promised
311 | return mfu
312 |
313 | @torch.no_grad()
314 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
315 | """
316 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
317 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
318 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
319 | """
320 | for _ in range(max_new_tokens):
321 | # if the sequence context is growing too long we must crop it at block_size
322 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
323 | # forward the model to get the logits for the index in the sequence
324 | logits, _ = self(idx_cond)
325 | # pluck the logits at the final step and scale by desired temperature
326 | logits = logits[:, -1, :] / temperature
327 | # optionally crop the logits to only the top k options
328 | if top_k is not None:
329 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
330 | logits[logits < v[:, [-1]]] = -float('Inf')
331 | # apply softmax to convert logits to (normalized) probabilities
332 | probs = F.softmax(logits, dim=-1)
333 | # sample from the distribution
334 | idx_next = torch.multinomial(probs, num_samples=1)
335 | # append sampled index to the running sequence and continue
336 | idx = torch.cat((idx, idx_next), dim=1)
337 |
338 | return idx
339 |
--------------------------------------------------------------------------------
/src/adopt/adopt.py:
--------------------------------------------------------------------------------
1 | # mypy: allow-untyped-decorators
2 | # mypy: allow-untyped-defs
3 | from typing import cast, Callable, List, Optional, Tuple, Union
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | from torch.optim.optimizer import (
9 | _capturable_doc,
10 | _default_to_fused_or_foreach,
11 | _device_dtype_check_for_fused,
12 | _differentiable_doc,
13 | _disable_dynamo_if_unsupported,
14 | _foreach_doc,
15 | _fused_doc,
16 | _get_capturable_supported_devices,
17 | _get_scalar_dtype,
18 | _get_value,
19 | _maximize_doc,
20 | _stack_if_compiling,
21 | _use_grad_for_differentiable,
22 | _view_as_real,
23 | DeviceDict,
24 | Optimizer,
25 | ParamsT,
26 | )
27 |
28 |
29 | __all__ = ["ADOPT", "adopt"]
30 |
31 |
32 | class ADOPT(Optimizer):
33 | def __init__(
34 | self,
35 | params: ParamsT,
36 | lr: Union[float, Tensor] = 1e-3,
37 | betas: Tuple[float, float] = (0.9, 0.9999),
38 | eps: float = 1e-6,
39 | clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
40 | weight_decay: float = 0.0,
41 | decouple: bool = False,
42 | *,
43 | foreach: Optional[bool] = None,
44 | maximize: bool = False,
45 | capturable: bool = False,
46 | differentiable: bool = False,
47 | fused: Optional[bool] = None,
48 | ):
49 | if isinstance(lr, Tensor):
50 | if foreach and not capturable:
51 | raise ValueError(
52 | "lr as a Tensor is not supported for capturable=False and foreach=True"
53 | )
54 | if lr.numel() != 1:
55 | raise ValueError("Tensor lr must be 1-element")
56 | if not 0.0 <= lr:
57 | raise ValueError(f"Invalid learning rate: {lr}")
58 | if not 0.0 <= eps:
59 | raise ValueError(f"Invalid epsilon value: {eps}")
60 | if not 0.0 <= betas[0] < 1.0:
61 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
62 | if not 0.0 <= betas[1] < 1.0:
63 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
64 | if not 0.0 <= weight_decay:
65 | raise ValueError(f"Invalid weight_decay value: {weight_decay}")
66 |
67 | self.clip_lambda = clip_lambda
68 |
69 | defaults = dict(
70 | lr=lr,
71 | betas=betas,
72 | eps=eps,
73 | weight_decay=weight_decay,
74 | decouple=decouple,
75 | maximize=maximize,
76 | foreach=foreach,
77 | capturable=capturable,
78 | differentiable=differentiable,
79 | fused=fused,
80 | )
81 | super().__init__(params, defaults)
82 |
83 | if fused:
84 | # TODO: support fused
85 | raise RuntimeError("`fused` is not currently supported")
86 |
87 | if differentiable:
88 | raise RuntimeError("`fused` does not support `differentiable`")
89 | self._step_supports_amp_scaling = True
90 | # TODO(crcrpar): [low prec params & their higher prec copy]
91 | # Support AMP with FP16/BF16 model params which would need
92 | # higher prec copy of params to do update math in higher prec to
93 | # alleviate the loss of information.
94 | if foreach:
95 | raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
96 |
97 | def __setstate__(self, state):
98 | super().__setstate__(state)
99 | for group in self.param_groups:
100 | group.setdefault("maximize", False)
101 | group.setdefault("foreach", None)
102 | group.setdefault("capturable", False)
103 | group.setdefault("differentiable", False)
104 | fused = group.setdefault("fused", None)
105 | for p in group["params"]:
106 | p_state = self.state.get(p, [])
107 | if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
108 | step_val = float(p_state["step"])
109 | p_state["step"] = (
110 | torch.tensor(
111 | step_val,
112 | dtype=_get_scalar_dtype(is_fused=fused),
113 | device=p.device,
114 | )
115 | if group["capturable"] or group["fused"]
116 | else torch.tensor(step_val, dtype=_get_scalar_dtype())
117 | )
118 |
119 | def _init_group(
120 | self,
121 | group,
122 | params_with_grad,
123 | grads,
124 | exp_avgs,
125 | exp_avg_sqs,
126 | state_steps,
127 | ):
128 | has_complex = False
129 | for p in group["params"]:
130 | if p.grad is not None:
131 | has_complex |= torch.is_complex(p)
132 | params_with_grad.append(p)
133 | if p.grad.is_sparse:
134 | raise RuntimeError(
135 | "ADOPT does not support sparse gradients"
136 | )
137 | grads.append(p.grad)
138 |
139 | state = self.state[p]
140 | # Lazy state initialization
141 | if len(state) == 0:
142 | if group["fused"]:
143 | _device_dtype_check_for_fused(p)
144 | # note(crcrpar): [special device hosting for step]
145 | # Deliberately host `step` on CPU if both capturable and fused are off.
146 | # This is because kernel launches are costly on CUDA and XLA.
147 | state["step"] = (
148 | torch.zeros(
149 | (),
150 | dtype=_get_scalar_dtype(is_fused=group["fused"]),
151 | device=p.device,
152 | )
153 | if group["capturable"] or group["fused"]
154 | else torch.tensor(0.0, dtype=_get_scalar_dtype())
155 | )
156 | # Exponential moving average of gradient values
157 | state["exp_avg"] = torch.zeros_like(
158 | p, memory_format=torch.preserve_format
159 | )
160 | # Exponential moving average of squared gradient values
161 | state["exp_avg_sq"] = torch.zeros_like(
162 | p, memory_format=torch.preserve_format
163 | )
164 |
165 | exp_avgs.append(state["exp_avg"])
166 | exp_avg_sqs.append(state["exp_avg_sq"])
167 |
168 | if group["differentiable"] and state["step"].requires_grad:
169 | raise RuntimeError(
170 | "`requires_grad` is not supported for `step` in differentiable mode"
171 | )
172 |
173 | # Foreach without capturable does not support a tensor lr
174 | if (
175 | group["foreach"]
176 | and torch.is_tensor(group["lr"])
177 | and not group["capturable"]
178 | ):
179 | raise RuntimeError(
180 | "lr as a Tensor is not supported for capturable=False and foreach=True"
181 | )
182 |
183 | state_steps.append(state["step"])
184 | return has_complex
185 |
186 | @_use_grad_for_differentiable
187 | def step(self, closure=None):
188 | """Perform a single optimization step.
189 |
190 | Args:
191 | closure (Callable, optional): A closure that reevaluates the model
192 | and returns the loss.
193 | """
194 | self._cuda_graph_capture_health_check()
195 |
196 | loss = None
197 | if closure is not None:
198 | with torch.enable_grad():
199 | loss = closure()
200 |
201 | for group in self.param_groups:
202 | params_with_grad: List[Tensor] = []
203 | grads: List[Tensor] = []
204 | exp_avgs: List[Tensor] = []
205 | exp_avg_sqs: List[Tensor] = []
206 | state_steps: List[Tensor] = []
207 | beta1, beta2 = group["betas"]
208 |
209 | has_complex = self._init_group(
210 | group,
211 | params_with_grad,
212 | grads,
213 | exp_avgs,
214 | exp_avg_sqs,
215 | state_steps,
216 | )
217 |
218 | adopt(
219 | params_with_grad,
220 | grads,
221 | exp_avgs,
222 | exp_avg_sqs,
223 | state_steps,
224 | has_complex=has_complex,
225 | beta1=beta1,
226 | beta2=beta2,
227 | lr=group["lr"],
228 | clip_lambda=self.clip_lambda,
229 | weight_decay=group["weight_decay"],
230 | decouple=group["decouple"],
231 | eps=group["eps"],
232 | maximize=group["maximize"],
233 | foreach=group["foreach"],
234 | capturable=group["capturable"],
235 | differentiable=group["differentiable"],
236 | fused=group["fused"],
237 | grad_scale=getattr(self, "grad_scale", None),
238 | found_inf=getattr(self, "found_inf", None),
239 | )
240 |
241 | return loss
242 |
243 |
244 | def _single_tensor_adopt(
245 | params: List[Tensor],
246 | grads: List[Tensor],
247 | exp_avgs: List[Tensor],
248 | exp_avg_sqs: List[Tensor],
249 | state_steps: List[Tensor],
250 | grad_scale: Optional[Tensor],
251 | found_inf: Optional[Tensor],
252 | *,
253 | has_complex: bool,
254 | beta1: float,
255 | beta2: float,
256 | lr: Union[float, Tensor],
257 | clip_lambda: Optional[Callable[[int], float]],
258 | weight_decay: float,
259 | decouple: bool,
260 | eps: float,
261 | maximize: bool,
262 | capturable: bool,
263 | differentiable: bool,
264 | ):
265 | assert grad_scale is None and found_inf is None
266 |
267 | if torch.jit.is_scripting():
268 | # this assert is due to JIT being dumb and not realizing that the ops below
269 | # have overloads to handle both float and Tensor lrs, so we just assert it's
270 | # a float since most people using JIT are using floats
271 | assert isinstance(lr, float)
272 |
273 | for i, param in enumerate(params):
274 | grad = grads[i] if not maximize else -grads[i]
275 | exp_avg = exp_avgs[i]
276 | exp_avg_sq = exp_avg_sqs[i]
277 | step_t = state_steps[i]
278 |
279 | # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
280 | if not torch._utils.is_compiling() and capturable:
281 | capturable_supported_devices = _get_capturable_supported_devices()
282 | assert (
283 | param.device.type == step_t.device.type
284 | and param.device.type in capturable_supported_devices
285 | ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
286 |
287 | step = step_t if capturable or differentiable else _get_value(step_t)
288 |
289 | if weight_decay != 0 and not decouple:
290 | grad = grad.add(param, alpha=weight_decay)
291 |
292 | if torch.is_complex(param):
293 | grad = torch.view_as_real(grad)
294 | if exp_avg is not None:
295 | exp_avg = torch.view_as_real(exp_avg)
296 | if exp_avg_sq is not None:
297 | exp_avg_sq = torch.view_as_real(exp_avg_sq)
298 | param = torch.view_as_real(param)
299 |
300 | if step == 0:
301 | exp_avg_sq.addcmul_(grad, grad.conj())
302 | # update step
303 | step_t += 1
304 | continue
305 |
306 | if weight_decay != 0 and decouple:
307 | param.add_(param, alpha=-lr*weight_decay)
308 |
309 | denom = torch.clamp(exp_avg_sq.sqrt(), eps)
310 | normed_grad = grad.div(denom)
311 | if clip_lambda is not None:
312 | clip = clip_lambda(step)
313 | normed_grad.clamp_(-clip, clip)
314 |
315 | exp_avg.lerp_(normed_grad, 1 - beta1)
316 |
317 | param.add_(exp_avg, alpha=-lr)
318 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
319 |
320 | # update step
321 | step_t += 1
322 |
323 |
324 | def _multi_tensor_adopt(
325 | params: List[Tensor],
326 | grads: List[Tensor],
327 | exp_avgs: List[Tensor],
328 | exp_avg_sqs: List[Tensor],
329 | state_steps: List[Tensor],
330 | grad_scale: Optional[Tensor],
331 | found_inf: Optional[Tensor],
332 | *,
333 | has_complex: bool,
334 | beta1: float,
335 | beta2: float,
336 | lr: Union[float, Tensor],
337 | clip_lambda: Optional[Callable[[int], float]],
338 | weight_decay: float,
339 | decouple: bool,
340 | eps: float,
341 | maximize: bool,
342 | capturable: bool,
343 | differentiable: bool,
344 | ):
345 | if len(params) == 0:
346 | return
347 |
348 | if isinstance(lr, Tensor) and not capturable:
349 | raise RuntimeError(
350 | "lr as a Tensor is not supported for capturable=False and foreach=True"
351 | )
352 |
353 | # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
354 | if not torch._utils.is_compiling() and capturable:
355 | capturable_supported_devices = _get_capturable_supported_devices(
356 | supports_xla=False
357 | )
358 | assert all(
359 | p.device.type == step.device.type
360 | and p.device.type in capturable_supported_devices
361 | for p, step in zip(params, state_steps)
362 | ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
363 |
364 | assert grad_scale is None and found_inf is None
365 |
366 | assert not differentiable, "_foreach ops don't support autograd"
367 |
368 | grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
369 | [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
370 | )
371 | for (
372 | device_params_,
373 | device_grads_,
374 | device_exp_avgs_,
375 | device_exp_avg_sqs_,
376 | device_state_steps_,
377 | ), _ in grouped_tensors.values():
378 | device_params = cast(List[Tensor], device_params_)
379 | device_grads = cast(List[Tensor], device_grads_)
380 | device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
381 | device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
382 | device_state_steps = cast(List[Tensor], device_state_steps_)
383 |
384 | # Handle complex parameters
385 | if has_complex:
386 | _view_as_real(
387 | device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
388 | )
389 |
390 | if maximize:
391 | device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
392 |
393 | if weight_decay != 0 and not decouple:
394 | # Re-use the intermediate memory (device_grads) already allocated for maximize
395 | if maximize:
396 | torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
397 | else:
398 | device_grads = torch._foreach_add( # type: ignore[assignment]
399 | device_grads, device_params, alpha=weight_decay
400 | )
401 |
402 | if device_state_steps[0] == 0:
403 | torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
404 |
405 | # Update steps
406 | # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
407 | # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
408 | # wrapped it once now. The alpha is required to assure we go to the right overload.
409 | if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
410 | torch._foreach_add_(
411 | device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
412 | )
413 | else:
414 | torch._foreach_add_(device_state_steps, 1)
415 |
416 | continue
417 |
418 | if weight_decay != 0 and decouple:
419 | torch._foreach_add_(device_params, device_params, alpha=-lr*weight_decay)
420 |
421 | exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
422 | torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
423 |
424 | normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
425 | if clip_lambda is not None:
426 | clip = clip_lambda(device_state_steps[0])
427 | torch._foreach_maximum_(normed_grad, -clip)
428 | torch._foreach_minimum_(normed_grad, clip)
429 |
430 | torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
431 |
432 | torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
433 | torch._foreach_mul_(device_exp_avg_sqs, beta2)
434 | torch._foreach_addcmul_(
435 | device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
436 | )
437 |
438 | # Update steps
439 | # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
440 | # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
441 | # wrapped it once now. The alpha is required to assure we go to the right overload.
442 | if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
443 | torch._foreach_add_(
444 | device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
445 | )
446 | else:
447 | torch._foreach_add_(device_state_steps, 1)
448 |
449 |
450 | def _fused_adopt(
451 | params: List[Tensor],
452 | grads: List[Tensor],
453 | exp_avgs: List[Tensor],
454 | exp_avg_sqs: List[Tensor],
455 | state_steps: List[Tensor],
456 | grad_scale: Optional[Tensor],
457 | found_inf: Optional[Tensor],
458 | *,
459 | has_complex: bool, # Needed for consistency.
460 | beta1: float,
461 | beta2: float,
462 | lr: Union[float, Tensor],
463 | clip_lambda: Optional[Callable[[int], float]],
464 | weight_decay: float,
465 | decouple: bool,
466 | eps: float,
467 | maximize: bool,
468 | capturable: bool, # Needed for consistency.
469 | differentiable: bool,
470 | ) -> None:
471 | raise NotImplementedError
472 |
473 |
474 | @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
475 | def adopt(
476 | params: List[Tensor],
477 | grads: List[Tensor],
478 | exp_avgs: List[Tensor],
479 | exp_avg_sqs: List[Tensor],
480 | state_steps: List[Tensor],
481 | # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
482 | # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
483 | foreach: Optional[bool] = None,
484 | capturable: bool = False,
485 | differentiable: bool = False,
486 | fused: Optional[bool] = None,
487 | grad_scale: Optional[Tensor] = None,
488 | found_inf: Optional[Tensor] = None,
489 | has_complex: bool = False,
490 | *,
491 | beta1: float,
492 | beta2: float,
493 | lr: Union[float, Tensor],
494 | clip_lambda: Optional[Callable[[int], float]],
495 | weight_decay: float,
496 | decouple: bool,
497 | eps: float,
498 | maximize: bool,
499 | ):
500 | r"""Functional API that performs ADOPT algorithm computation.
501 |
502 | """
503 | # Respect when the user inputs False/True for foreach or fused. We only want to change
504 | # the default when neither have been user-specified. Note that we default to foreach
505 | # and pass False to use_fused. This is not a mistake--we want to give the fused impl
506 | # bake-in time before making it the default, even if it is typically faster.
507 | if fused is None and foreach is None:
508 | _, foreach = _default_to_fused_or_foreach(
509 | params, differentiable, use_fused=False
510 | )
511 | # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
512 | if foreach and isinstance(lr, Tensor) and not capturable:
513 | foreach = False
514 | if fused is None:
515 | fused = False
516 | if foreach is None:
517 | foreach = False
518 |
519 | # this check is slow during compilation, so we skip it
520 | # if it's strictly needed we can add this check back in dynamo
521 | if not torch._utils.is_compiling() and not all(
522 | isinstance(t, torch.Tensor) for t in state_steps
523 | ):
524 | raise RuntimeError(
525 | "API has changed, `state_steps` argument must contain a list of singleton tensors"
526 | )
527 |
528 | if foreach and torch.jit.is_scripting():
529 | raise RuntimeError("torch.jit.script not supported with foreach optimizers")
530 | if fused and torch.jit.is_scripting():
531 | raise RuntimeError("torch.jit.script not supported with fused optimizers")
532 |
533 | if fused and not torch.jit.is_scripting():
534 | func = _fused_adopt
535 | elif foreach and not torch.jit.is_scripting():
536 | func = _multi_tensor_adopt
537 | else:
538 | func = _single_tensor_adopt
539 |
540 | func(
541 | params,
542 | grads,
543 | exp_avgs,
544 | exp_avg_sqs,
545 | state_steps,
546 | has_complex=has_complex,
547 | beta1=beta1,
548 | beta2=beta2,
549 | lr=lr,
550 | clip_lambda=clip_lambda,
551 | weight_decay=weight_decay,
552 | decouple=decouple,
553 | eps=eps,
554 | maximize=maximize,
555 | capturable=capturable,
556 | differentiable=differentiable,
557 | grad_scale=grad_scale,
558 | found_inf=found_inf,
559 | )
560 |
--------------------------------------------------------------------------------
/imagenet/train.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import sys
4 | import time
5 | import warnings
6 |
7 | import presets
8 | import torch
9 | import torch.utils.data
10 | import torchvision
11 | import torchvision.transforms
12 | import utils
13 | from sampler import RASampler
14 | from torch import nn
15 | from torch.utils.data.dataloader import default_collate
16 | from torchvision.transforms.functional import InterpolationMode
17 | from transforms import get_mixup_cutmix
18 |
19 | from adopt import ADOPT
20 |
21 |
22 | def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
23 | model.train()
24 | metric_logger = utils.MetricLogger(delimiter=" ")
25 | metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
26 | metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
27 |
28 | header = f"Epoch: [{epoch}]"
29 | for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
30 | start_time = time.time()
31 | image, target = image.to(device), target.to(device)
32 | with torch.cuda.amp.autocast(enabled=scaler is not None):
33 | output = model(image)
34 | loss = criterion(output, target)
35 |
36 | optimizer.zero_grad()
37 | if scaler is not None:
38 | scaler.scale(loss).backward()
39 | if args.clip_grad_norm is not None:
40 | # we should unscale the gradients of optimizer's assigned params if do gradient clipping
41 | scaler.unscale_(optimizer)
42 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
43 | scaler.step(optimizer)
44 | scaler.update()
45 | else:
46 | loss.backward()
47 | if args.clip_grad_norm is not None:
48 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
49 | optimizer.step()
50 |
51 | if model_ema and i % args.model_ema_steps == 0:
52 | model_ema.update_parameters(model)
53 | if epoch < args.lr_warmup_epochs:
54 | # Reset ema buffer to keep copying weights during warmup period
55 | model_ema.n_averaged.fill_(0)
56 |
57 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
58 | batch_size = image.shape[0]
59 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
60 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
61 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
62 | metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
63 |
64 |
65 | def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
66 | model.eval()
67 | metric_logger = utils.MetricLogger(delimiter=" ")
68 | header = f"Test: {log_suffix}"
69 |
70 | num_processed_samples = 0
71 | with torch.inference_mode():
72 | for image, target in metric_logger.log_every(data_loader, print_freq, header):
73 | image = image.to(device, non_blocking=True)
74 | target = target.to(device, non_blocking=True)
75 | output = model(image)
76 | loss = criterion(output, target)
77 |
78 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
79 | # FIXME need to take into account that the datasets
80 | # could have been padded in distributed setup
81 | batch_size = image.shape[0]
82 | metric_logger.update(loss=loss.item())
83 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
84 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
85 | num_processed_samples += batch_size
86 | # gather the stats from all processes
87 |
88 | num_processed_samples = utils.reduce_across_processes(num_processed_samples)
89 | if (
90 | hasattr(data_loader.dataset, "__len__")
91 | and len(data_loader.dataset) != num_processed_samples
92 | and torch.distributed.get_rank() == 0
93 | ):
94 | # See FIXME above
95 | warnings.warn(
96 | f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
97 | "samples were used for the validation, which might bias the results. "
98 | "Try adjusting the batch size and / or the world size. "
99 | "Setting the world size to 1 is always a safe bet."
100 | )
101 |
102 | metric_logger.synchronize_between_processes()
103 |
104 | print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
105 | return metric_logger.acc1.global_avg
106 |
107 |
108 | def _get_cache_path(filepath):
109 | import hashlib
110 |
111 | h = hashlib.sha1(filepath.encode()).hexdigest()
112 | cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
113 | cache_path = os.path.expanduser(cache_path)
114 | return cache_path
115 |
116 |
117 | def load_data(traindir, valdir, args):
118 | # Data loading code
119 | print("Loading data")
120 | val_resize_size, val_crop_size, train_crop_size = (
121 | args.val_resize_size,
122 | args.val_crop_size,
123 | args.train_crop_size,
124 | )
125 | interpolation = InterpolationMode(args.interpolation)
126 |
127 | print("Loading training data")
128 | st = time.time()
129 | cache_path = _get_cache_path(traindir)
130 | if args.cache_dataset and os.path.exists(cache_path):
131 | # Attention, as the transforms are also cached!
132 | print(f"Loading dataset_train from {cache_path}")
133 | # TODO: this could probably be weights_only=True
134 | dataset, _ = torch.load(cache_path, weights_only=False)
135 | else:
136 | # We need a default value for the variables below because args may come
137 | # from train_quantization.py which doesn't define them.
138 | auto_augment_policy = getattr(args, "auto_augment", None)
139 | random_erase_prob = getattr(args, "random_erase", 0.0)
140 | ra_magnitude = getattr(args, "ra_magnitude", None)
141 | augmix_severity = getattr(args, "augmix_severity", None)
142 | dataset = torchvision.datasets.ImageFolder(
143 | traindir,
144 | presets.ClassificationPresetTrain(
145 | crop_size=train_crop_size,
146 | interpolation=interpolation,
147 | auto_augment_policy=auto_augment_policy,
148 | random_erase_prob=random_erase_prob,
149 | ra_magnitude=ra_magnitude,
150 | augmix_severity=augmix_severity,
151 | backend=args.backend,
152 | use_v2=args.use_v2,
153 | ),
154 | )
155 | if args.cache_dataset:
156 | print(f"Saving dataset_train to {cache_path}")
157 | utils.mkdir(os.path.dirname(cache_path))
158 | utils.save_on_master((dataset, traindir), cache_path)
159 | print("Took", time.time() - st)
160 |
161 | print("Loading validation data")
162 | cache_path = _get_cache_path(valdir)
163 | if args.cache_dataset and os.path.exists(cache_path):
164 | # Attention, as the transforms are also cached!
165 | print(f"Loading dataset_test from {cache_path}")
166 | # TODO: this could probably be weights_only=True
167 | dataset_test, _ = torch.load(cache_path, weights_only=False)
168 | else:
169 | if args.weights and args.test_only:
170 | weights = torchvision.models.get_weight(args.weights)
171 | preprocessing = weights.transforms(antialias=True)
172 | if args.backend == "tensor":
173 | preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])
174 |
175 | else:
176 | preprocessing = presets.ClassificationPresetEval(
177 | crop_size=val_crop_size,
178 | resize_size=val_resize_size,
179 | interpolation=interpolation,
180 | backend=args.backend,
181 | use_v2=args.use_v2,
182 | )
183 |
184 | dataset_test = torchvision.datasets.ImageFolder(
185 | valdir,
186 | preprocessing,
187 | )
188 | if args.cache_dataset:
189 | print(f"Saving dataset_test to {cache_path}")
190 | utils.mkdir(os.path.dirname(cache_path))
191 | utils.save_on_master((dataset_test, valdir), cache_path)
192 |
193 | print("Creating data loaders")
194 | if args.distributed:
195 | if hasattr(args, "ra_sampler") and args.ra_sampler:
196 | train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
197 | else:
198 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
199 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
200 | else:
201 | train_sampler = torch.utils.data.RandomSampler(dataset)
202 | test_sampler = torch.utils.data.SequentialSampler(dataset_test)
203 |
204 | return dataset, dataset_test, train_sampler, test_sampler
205 |
206 |
207 | def main(args):
208 | if args.output_dir:
209 | utils.mkdir(args.output_dir)
210 |
211 | utils.init_distributed_mode(args)
212 | print(args)
213 |
214 | device = torch.device(args.device)
215 |
216 | if args.use_deterministic_algorithms:
217 | torch.backends.cudnn.benchmark = False
218 | torch.use_deterministic_algorithms(True)
219 | else:
220 | torch.backends.cudnn.benchmark = True
221 |
222 | train_dir = os.path.join(args.data_path, "train")
223 | val_dir = os.path.join(args.data_path, "val")
224 | dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
225 |
226 | num_classes = len(dataset.classes)
227 | mixup_cutmix = get_mixup_cutmix(
228 | mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_classes=num_classes, use_v2=args.use_v2
229 | )
230 | if mixup_cutmix is not None:
231 |
232 | def collate_fn(batch):
233 | return mixup_cutmix(*default_collate(batch))
234 |
235 | else:
236 | collate_fn = default_collate
237 |
238 | data_loader = torch.utils.data.DataLoader(
239 | dataset,
240 | batch_size=args.batch_size,
241 | sampler=train_sampler,
242 | num_workers=args.workers,
243 | pin_memory=True,
244 | collate_fn=collate_fn,
245 | )
246 | data_loader_test = torch.utils.data.DataLoader(
247 | dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
248 | )
249 |
250 | print("Creating model")
251 | model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
252 | model.to(device)
253 |
254 | if args.distributed and args.sync_bn:
255 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
256 |
257 | criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
258 |
259 | custom_keys_weight_decay = []
260 | if args.bias_weight_decay is not None:
261 | custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
262 | if args.transformer_embedding_decay is not None:
263 | for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
264 | custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
265 | parameters = utils.set_weight_decay(
266 | model,
267 | args.weight_decay,
268 | norm_weight_decay=args.norm_weight_decay,
269 | custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
270 | )
271 |
272 | opt_name = args.opt.lower()
273 | if opt_name.startswith("sgd"):
274 | optimizer = torch.optim.SGD(
275 | parameters,
276 | lr=args.lr,
277 | momentum=args.momentum,
278 | weight_decay=args.weight_decay,
279 | nesterov="nesterov" in opt_name,
280 | )
281 | elif opt_name == "rmsprop":
282 | optimizer = torch.optim.RMSprop(
283 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
284 | )
285 | elif opt_name == "adamw":
286 | optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
287 | elif opt_name == "adopt":
288 | optimizer = ADOPT(parameters, lr=args.lr, weight_decay=args.weight_decay)
289 | elif opt_name == "adoptw":
290 | optimizer = ADOPT(parameters, lr=args.lr, weight_decay=args.weight_decay, decoupled=True)
291 | else:
292 | raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop, AdamW and ADOPT are supported.")
293 |
294 | scaler = torch.cuda.amp.GradScaler() if args.amp else None
295 |
296 | args.lr_scheduler = args.lr_scheduler.lower()
297 | if args.lr_scheduler == "steplr":
298 | main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
299 | elif args.lr_scheduler == "cosineannealinglr":
300 | main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
301 | optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
302 | )
303 | elif args.lr_scheduler == "exponentiallr":
304 | main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
305 | else:
306 | raise RuntimeError(
307 | f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
308 | "are supported."
309 | )
310 |
311 | if args.lr_warmup_epochs > 0:
312 | if args.lr_warmup_method == "linear":
313 | warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
314 | optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
315 | )
316 | elif args.lr_warmup_method == "constant":
317 | warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
318 | optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
319 | )
320 | else:
321 | raise RuntimeError(
322 | f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
323 | )
324 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
325 | optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
326 | )
327 | else:
328 | lr_scheduler = main_lr_scheduler
329 |
330 | model_without_ddp = model
331 | if args.distributed:
332 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
333 | model_without_ddp = model.module
334 |
335 | model_ema = None
336 | if args.model_ema:
337 | # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
338 | # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
339 | #
340 | # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
341 | # We consider constant = Dataset_size for a given dataset/setup and omit it. Thus:
342 | # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
343 | adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
344 | alpha = 1.0 - args.model_ema_decay
345 | alpha = min(1.0, alpha * adjust)
346 | model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
347 |
348 | if args.resume:
349 | checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
350 | model_without_ddp.load_state_dict(checkpoint["model"])
351 | if not args.test_only:
352 | optimizer.load_state_dict(checkpoint["optimizer"])
353 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
354 | args.start_epoch = checkpoint["epoch"] + 1
355 | if model_ema:
356 | model_ema.load_state_dict(checkpoint["model_ema"])
357 | if scaler:
358 | scaler.load_state_dict(checkpoint["scaler"])
359 |
360 | if args.test_only:
361 | # We disable the cudnn benchmarking because it can noticeably affect the accuracy
362 | torch.backends.cudnn.benchmark = False
363 | torch.backends.cudnn.deterministic = True
364 | if model_ema:
365 | evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
366 | else:
367 | evaluate(model, criterion, data_loader_test, device=device)
368 | return
369 |
370 | print("Start training")
371 | start_time = time.time()
372 | for epoch in range(args.start_epoch, args.epochs):
373 | if args.distributed:
374 | train_sampler.set_epoch(epoch)
375 | train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
376 | lr_scheduler.step()
377 | evaluate(model, criterion, data_loader_test, device=device)
378 | if model_ema:
379 | evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
380 | if args.output_dir:
381 | checkpoint = {
382 | "model": model_without_ddp.state_dict(),
383 | "optimizer": optimizer.state_dict(),
384 | "lr_scheduler": lr_scheduler.state_dict(),
385 | "epoch": epoch,
386 | "args": args,
387 | }
388 | if model_ema:
389 | checkpoint["model_ema"] = model_ema.state_dict()
390 | if scaler:
391 | checkpoint["scaler"] = scaler.state_dict()
392 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
393 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
394 |
395 | total_time = time.time() - start_time
396 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
397 | print(f"Training time {total_time_str}")
398 |
399 |
400 | def get_args_parser(add_help=True):
401 | import argparse
402 |
403 | parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
404 |
405 | parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
406 | parser.add_argument("--model", default="resnet18", type=str, help="model name")
407 | parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
408 | parser.add_argument(
409 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
410 | )
411 | parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
412 | parser.add_argument(
413 | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
414 | )
415 | parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
416 | parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
417 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
418 | parser.add_argument(
419 | "--wd",
420 | "--weight-decay",
421 | default=1e-4,
422 | type=float,
423 | metavar="W",
424 | help="weight decay (default: 1e-4)",
425 | dest="weight_decay",
426 | )
427 | parser.add_argument(
428 | "--norm-weight-decay",
429 | default=None,
430 | type=float,
431 | help="weight decay for Normalization layers (default: None, same value as --wd)",
432 | )
433 | parser.add_argument(
434 | "--bias-weight-decay",
435 | default=None,
436 | type=float,
437 | help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
438 | )
439 | parser.add_argument(
440 | "--transformer-embedding-decay",
441 | default=None,
442 | type=float,
443 | help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
444 | )
445 | parser.add_argument(
446 | "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
447 | )
448 | parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
449 | parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
450 | parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
451 | parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
452 | parser.add_argument(
453 | "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
454 | )
455 | parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
456 | parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
457 | parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
458 | parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
459 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
460 | parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
461 | parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
462 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
463 | parser.add_argument(
464 | "--cache-dataset",
465 | dest="cache_dataset",
466 | help="Cache the datasets for quicker initialization. It also serializes the transforms",
467 | action="store_true",
468 | )
469 | parser.add_argument(
470 | "--sync-bn",
471 | dest="sync_bn",
472 | help="Use sync batch norm",
473 | action="store_true",
474 | )
475 | parser.add_argument(
476 | "--test-only",
477 | dest="test_only",
478 | help="Only test the model",
479 | action="store_true",
480 | )
481 | parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
482 | parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
483 | parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
484 | parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
485 |
486 | # Mixed precision training parameters
487 | parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
488 |
489 | # distributed training parameters
490 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
491 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
492 | parser.add_argument(
493 | "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
494 | )
495 | parser.add_argument(
496 | "--model-ema-steps",
497 | type=int,
498 | default=32,
499 | help="the number of iterations that controls how often to update the EMA model (default: 32)",
500 | )
501 | parser.add_argument(
502 | "--model-ema-decay",
503 | type=float,
504 | default=0.99998,
505 | help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
506 | )
507 | parser.add_argument(
508 | "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
509 | )
510 | parser.add_argument(
511 | "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
512 | )
513 | parser.add_argument(
514 | "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
515 | )
516 | parser.add_argument(
517 | "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
518 | )
519 | parser.add_argument(
520 | "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
521 | )
522 | parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
523 | parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
524 | parser.add_argument(
525 | "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
526 | )
527 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
528 | parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
529 | parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
530 | return parser
531 |
532 |
533 | if __name__ == "__main__":
534 | args = get_args_parser().parse_args()
535 | main(args)
536 |
--------------------------------------------------------------------------------