├── gpt ├── assets │ ├── nanogpt.jpg │ └── gpt2_124M_loss.png ├── data │ ├── shakespeare │ │ ├── readme.md │ │ └── prepare.py │ ├── shakespeare_char │ │ ├── readme.md │ │ └── prepare.py │ └── openwebtext │ │ ├── readme.md │ │ └── prepare.py ├── config │ ├── eval_gpt2.py │ ├── eval_gpt2_xl.py │ ├── eval_gpt2_large.py │ ├── eval_gpt2_medium.py │ ├── finetune_shakespeare.py │ ├── train_gpt2.py │ └── train_shakespeare_char.py ├── LICENSE ├── README.md ├── configurator.py ├── sample.py ├── bench.py ├── train.py ├── transformer_sizing.ipynb └── model.py ├── src └── adopt │ ├── __init__.py │ └── adopt.py ├── imagenet ├── README.md ├── sampler.py ├── presets.py ├── transforms.py ├── train_quantization.py ├── utils.py └── train.py ├── pyproject.toml ├── README.md ├── .gitignore └── LICENSE /gpt/assets/nanogpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iShohei220/adopt/HEAD/gpt/assets/nanogpt.jpg -------------------------------------------------------------------------------- /gpt/assets/gpt2_124M_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iShohei220/adopt/HEAD/gpt/assets/gpt2_124M_loss.png -------------------------------------------------------------------------------- /src/adopt/__init__.py: -------------------------------------------------------------------------------- 1 | from .adopt import ADOPT, adopt 2 | 3 | __version__ = "0.1.0" 4 | __all__ = ["ADOPT", "adopt"] 5 | -------------------------------------------------------------------------------- /gpt/data/shakespeare/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 301,966 tokens 9 | - val.bin has 36,059 tokens 10 | -------------------------------------------------------------------------------- /gpt/config/eval_gpt2.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=12, n_head=12, n_embd=768 3 | # 124M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2' 9 | -------------------------------------------------------------------------------- /gpt/config/eval_gpt2_xl.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=48, n_head=25, n_embd=1600 3 | # 1558M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-xl' 9 | -------------------------------------------------------------------------------- /gpt/config/eval_gpt2_large.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=36, n_head=20, n_embd=1280 3 | # 774M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-large' 9 | -------------------------------------------------------------------------------- /gpt/config/eval_gpt2_medium.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=24, n_head=16, n_embd=1024 3 | # 350M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-medium' 9 | -------------------------------------------------------------------------------- /gpt/data/shakespeare_char/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare, character-level 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 1,003,854 tokens 9 | - val.bin has 111,540 tokens 10 | -------------------------------------------------------------------------------- /gpt/data/openwebtext/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /gpt/config/finetune_shakespeare.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | out_dir = 'out-shakespeare' 4 | eval_interval = 5 5 | eval_iters = 40 6 | wandb_log = False # feel free to turn on 7 | wandb_project = 'shakespeare' 8 | wandb_run_name = 'ft-' + str(time.time()) 9 | 10 | dataset = 'shakespeare' 11 | init_from = 'gpt2-xl' # this is the largest GPT-2 model 12 | 13 | # only save checkpoints if the validation loss improves 14 | always_save_checkpoint = False 15 | 16 | # the number of examples per iter: 17 | # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter 18 | # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters 19 | batch_size = 1 20 | gradient_accumulation_steps = 32 21 | max_iters = 20 22 | 23 | # finetune at constant LR 24 | learning_rate = 3e-5 25 | decay_lr = False 26 | -------------------------------------------------------------------------------- /gpt/config/train_gpt2.py: -------------------------------------------------------------------------------- 1 | # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB 2 | # launch as the following (e.g. in a screen session) and wait ~5 days: 3 | # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 4 | 5 | wandb_log = True 6 | wandb_project = 'owt' 7 | wandb_run_name='gpt2-124M' 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 * 8 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | -------------------------------------------------------------------------------- /gpt/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | This code is based on the official training recipe for ImageNet classification provided by [Torchvision](https://github.com/pytorch/vision/tree/main/references/classification). 2 | 3 | # Image classification reference training scripts 4 | 5 | This folder contains reference training scripts for image classification. 6 | They serve as a log of how to train specific models, as provide baseline 7 | training and evaluation scripts to quickly bootstrap research. 8 | 9 | ### SwinTransformer 10 | ``` 11 | torchrun --nproc_per_node=8 train.py\ 12 | --model $MODEL --epochs 300 --batch-size 128 --opt adoptw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 224 13 | ``` 14 | Here `$MODEL` is one of `swin_t`, `swin_s` or `swin_b`. 15 | Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. 16 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "torch-adopt" 7 | version = "0.1.0" 8 | authors = [ 9 | { name = "Shohei Taniguchi", email = "ishohei220@gmail.com" } 10 | ] 11 | description = "ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "Programming Language :: Python :: 3.8", 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.11", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Operating System :: OS Independent", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | "Intended Audience :: Science/Research", 24 | ] 25 | dependencies = [ 26 | "torch>=2.5.0", 27 | ] 28 | 29 | [project.urls] 30 | "Homepage" = "https://github.com/iShohei220/adopt" 31 | "Bug Tracker" = "https://github.com/iShohei220/adopt/issues" 32 | "Documentation" = "https://github.com/iShohei220/adopt" 33 | "Research Paper" = "https://arxiv.org/abs/2411.02853" 34 | 35 | [tool.hatch.build.targets.wheel] 36 | packages = ["src/adopt"] 37 | -------------------------------------------------------------------------------- /gpt/data/shakespeare/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tiktoken 4 | import numpy as np 5 | 6 | # download the tiny shakespeare dataset 7 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 8 | if not os.path.exists(input_file_path): 9 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 10 | with open(input_file_path, 'w', encoding='utf-8') as f: 11 | f.write(requests.get(data_url).text) 12 | 13 | with open(input_file_path, 'r', encoding='utf-8') as f: 14 | data = f.read() 15 | n = len(data) 16 | train_data = data[:int(n*0.9)] 17 | val_data = data[int(n*0.9):] 18 | 19 | # encode with tiktoken gpt2 bpe 20 | enc = tiktoken.get_encoding("gpt2") 21 | train_ids = enc.encode_ordinary(train_data) 22 | val_ids = enc.encode_ordinary(val_data) 23 | print(f"train has {len(train_ids):,} tokens") 24 | print(f"val has {len(val_ids):,} tokens") 25 | 26 | # export to bin files 27 | train_ids = np.array(train_ids, dtype=np.uint16) 28 | val_ids = np.array(val_ids, dtype=np.uint16) 29 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 30 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 31 | 32 | # train.bin has 301,966 tokens 33 | # val.bin has 36,059 tokens 34 | -------------------------------------------------------------------------------- /gpt/config/train_shakespeare_char.py: -------------------------------------------------------------------------------- 1 | # train a miniature character-level shakespeare model 2 | # good for debugging and playing on macbooks and such 3 | 4 | out_dir = 'out-shakespeare-char' 5 | eval_interval = 250 # keep frequent because we'll overfit 6 | eval_iters = 200 7 | log_interval = 10 # don't print too too often 8 | 9 | # we expect to overfit on this small dataset, so only save when val improves 10 | always_save_checkpoint = False 11 | 12 | wandb_log = False # override via command line if you like 13 | wandb_project = 'shakespeare-char' 14 | wandb_run_name = 'mini-gpt' 15 | 16 | dataset = 'shakespeare_char' 17 | gradient_accumulation_steps = 1 18 | batch_size = 64 19 | block_size = 256 # context of up to 256 previous characters 20 | 21 | # baby GPT model :) 22 | n_layer = 6 23 | n_head = 6 24 | n_embd = 384 25 | dropout = 0.2 26 | 27 | learning_rate = 1e-3 # with baby networks can afford to go a bit higher 28 | max_iters = 5000 29 | lr_decay_iters = 5000 # make equal to max_iters usually 30 | min_lr = 1e-4 # learning_rate / 10 usually 31 | beta2 = 0.99 # make a bit bigger because number of tokens per iter is small 32 | 33 | warmup_iters = 100 # not super necessary potentially 34 | 35 | # on macbook also add 36 | # device = 'cpu' # run on cpu only 37 | # compile = False # do not torch compile the model 38 | -------------------------------------------------------------------------------- /gpt/README.md: -------------------------------------------------------------------------------- 1 | This code is based on [nanoGPT](https://github.com/karpathy/nanoGPT) repository. 2 | 3 | ## Install 4 | 5 | ``` 6 | pip install torch numpy transformers datasets tiktoken wandb tqdm 7 | ``` 8 | 9 | Dependencies: 10 | 11 | - [pytorch](https://pytorch.org) <3 12 | - [numpy](https://numpy.org/install/) <3 13 | - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints) 14 | - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) 15 | - `tiktoken` for OpenAI's fast BPE code <3 16 | - `wandb` for optional logging <3 17 | - `tqdm` for progress bars <3 18 | 19 | ## Train GPT-2 20 | 21 | We first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText: 22 | 23 | ```sh 24 | python data/openwebtext/prepare.py 25 | ``` 26 | 27 | This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run: 28 | 29 | ```sh 30 | torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py --optim_type=adopt 31 | ``` 32 | -------------------------------------------------------------------------------- /gpt/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /gpt/data/shakespeare_char/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare the Shakespeare dataset for character-level language modeling. 3 | So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. 4 | Will save train.bin, val.bin containing the ids, and meta.pkl containing the 5 | encoder and decoder and some other related info. 6 | """ 7 | import os 8 | import pickle 9 | import requests 10 | import numpy as np 11 | 12 | # download the tiny shakespeare dataset 13 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 14 | if not os.path.exists(input_file_path): 15 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 16 | with open(input_file_path, 'w') as f: 17 | f.write(requests.get(data_url).text) 18 | 19 | with open(input_file_path, 'r') as f: 20 | data = f.read() 21 | print(f"length of dataset in characters: {len(data):,}") 22 | 23 | # get all the unique characters that occur in this text 24 | chars = sorted(list(set(data))) 25 | vocab_size = len(chars) 26 | print("all the unique characters:", ''.join(chars)) 27 | print(f"vocab size: {vocab_size:,}") 28 | 29 | # create a mapping from characters to integers 30 | stoi = { ch:i for i,ch in enumerate(chars) } 31 | itos = { i:ch for i,ch in enumerate(chars) } 32 | def encode(s): 33 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers 34 | def decode(l): 35 | return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 36 | 37 | # create the train and test splits 38 | n = len(data) 39 | train_data = data[:int(n*0.9)] 40 | val_data = data[int(n*0.9):] 41 | 42 | # encode both to integers 43 | train_ids = encode(train_data) 44 | val_ids = encode(val_data) 45 | print(f"train has {len(train_ids):,} tokens") 46 | print(f"val has {len(val_ids):,} tokens") 47 | 48 | # export to bin files 49 | train_ids = np.array(train_ids, dtype=np.uint16) 50 | val_ids = np.array(val_ids, dtype=np.uint16) 51 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 52 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 53 | 54 | # save the meta information as well, to help us encode/decode later 55 | meta = { 56 | 'vocab_size': vocab_size, 57 | 'itos': itos, 58 | 'stoi': stoi, 59 | } 60 | with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f: 61 | pickle.dump(meta, f) 62 | 63 | # length of dataset in characters: 1115394 64 | # all the unique characters: 65 | # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 66 | # vocab size: 65 67 | # train has 1003854 tokens 68 | # val has 111540 tokens 69 | -------------------------------------------------------------------------------- /imagenet/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class RASampler(torch.utils.data.Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset for distributed, 9 | with repeated augmentation. 10 | It ensures that different each augmented version of a sample will be visible to a 11 | different process (GPU). 12 | Heavily based on 'torch.utils.data.DistributedSampler'. 13 | 14 | This is borrowed from the DeiT Repo: 15 | https://github.com/facebookresearch/deit/blob/main/samplers.py 16 | """ 17 | 18 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): 19 | if num_replicas is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available!") 22 | num_replicas = dist.get_world_size() 23 | if rank is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available!") 26 | rank = dist.get_rank() 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.epoch = 0 31 | self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) 32 | self.total_size = self.num_samples * self.num_replicas 33 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 34 | self.shuffle = shuffle 35 | self.seed = seed 36 | self.repetitions = repetitions 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # Deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.seed + self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 44 | else: 45 | indices = list(range(len(self.dataset))) 46 | 47 | # Add extra samples to make it evenly divisible 48 | indices = [ele for ele in indices for i in range(self.repetitions)] 49 | indices += indices[: (self.total_size - len(indices))] 50 | assert len(indices) == self.total_size 51 | 52 | # Subsample 53 | indices = indices[self.rank : self.total_size : self.num_replicas] 54 | assert len(indices) == self.num_samples 55 | 56 | return iter(indices[: self.num_selected_samples]) 57 | 58 | def __len__(self): 59 | return self.num_selected_samples 60 | 61 | def set_epoch(self, epoch): 62 | self.epoch = epoch 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADOPT: Modified Adam Can Converge with Any $β_2$ with the Optimal Rate 2 | Official Implementation of "[ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853)", which is presented at NeurIPS 2024. 3 | 4 | ## Update on Nov 22, 2024 5 | 6 | Based on feedbacks from some practitioners, we have updated the implementation to improve the stability of our ADOPT algorithm. 7 | In the original version, ADOPT sometimes gets unstable especially in the early stage of training. 8 | This seems to be because the near-zero division by the second memont estimate occurs when some elements of the parameter gradient are near zero at initialization. 9 | For example, when some parameters (e.g., the last layer of a neural net) are initialized with zero, which is often-used technique in deep learing, a near-zero gradient is observed at the first parameter update. 10 | To avoid such near-zero divisions, we have decided to add a clipping operation in the momentum update. 11 | 12 | ![clipped_adopt](https://github.com/user-attachments/assets/244cb934-8c73-4f89-b9c4-7b94f292af5f) 13 | 14 | Even when clipping is applied, the convergence guarantee in theory is maintained by properly scheduling the clipping value (see the updated arXiv paper). 15 | In our implementation, the clipping value is controlled by the argument `clip_lambda`, which is a callable function that determines the schedule of the clipping value depending on the number of gradient steps. 16 | By default, the clipping value is set to `step**0.25`, which aligns with the theory to ensure the convergence. 17 | We observe that the clipped ADOPT works much more stably than the original one, so we recommend to use it over the unclipped version. 18 | If you want to reproduce the behaivior of the original version, you should set `clip_lambda = None`. 19 | 20 | ## Requirements 21 | 22 | ADOPT requires PyTorch 2.5.0 or later. 23 | 24 | ## Installation 25 | 26 | ```bash 27 | pip install torch-adopt 28 | ``` 29 | 30 | ## Usage 31 | 32 | You can use ADOPT just like any other PyTorch optimizers by importing the `ADOPT` class. 33 | 34 | When you replace the `Adam` optimizer to our `ADOPT`, you should just replace the optimizer as follows: 35 | 36 | ```python3 37 | from adopt import ADOPT 38 | # optimizer = Adam(model.parameters(), lr=1e-3) 39 | optimizer = ADOPT(model.parameters(), lr=1e-3) 40 | ``` 41 | 42 | When you are using `AdamW` as a default optimizer, you should set `decouple=True` for our `ADOPT`: 43 | 44 | ```python3 45 | # optimizer = AdamW(model.parameters(), lr=1e-3) 46 | optimizer = ADOPT(model.parameters(), lr=1e-3, decouple=True) 47 | ``` 48 | 49 | ## Citation 50 | If you use ADOPT in your research, please cite the paper. 51 | ```text 52 | @inproceedings{taniguchi2024adopt, 53 | author={Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka}, 54 | booktitle = {Advances in Neural Information Processing Systems}, 55 | title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate}, 56 | year = {2024} 57 | } 58 | ``` 59 | 60 | ## License 61 | [Apache 2.0](./LICENSE) 62 | -------------------------------------------------------------------------------- /gpt/data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # number of workers in load_dataset() call 15 | # best number might be different from num_proc above as it also depends on NW speed. 16 | # it is better than 1 usually though 17 | num_proc_load_dataset = num_proc 18 | 19 | enc = tiktoken.get_encoding("gpt2") 20 | 21 | if __name__ == '__main__': 22 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 23 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) 24 | 25 | # owt by default only contains the 'train' split, so create a test split 26 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 27 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 28 | 29 | # this results in: 30 | # >>> split_dataset 31 | # DatasetDict({ 32 | # train: Dataset({ 33 | # features: ['text'], 34 | # num_rows: 8009762 35 | # }) 36 | # val: Dataset({ 37 | # features: ['text'], 38 | # num_rows: 4007 39 | # }) 40 | # }) 41 | 42 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 43 | def process(example): 44 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 45 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 46 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 47 | out = {'ids': ids, 'len': len(ids)} 48 | return out 49 | 50 | # tokenize the dataset 51 | tokenized = split_dataset.map( 52 | process, 53 | remove_columns=['text'], 54 | desc="tokenizing the splits", 55 | num_proc=num_proc, 56 | ) 57 | 58 | # concatenate all the ids in each dataset into one large file we can use for training 59 | for split, dset in tokenized.items(): 60 | arr_len = np.sum(dset['len'], dtype=np.uint64) 61 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 62 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 63 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 64 | total_batches = 1024 65 | 66 | idx = 0 67 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 68 | # Batch together samples for faster write 69 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 70 | arr_batch = np.concatenate(batch['ids']) 71 | # Write into mmap 72 | arr[idx : idx + len(arr_batch)] = arr_batch 73 | idx += len(arr_batch) 74 | arr.flush() 75 | 76 | # train.bin is ~17GB, val.bin ~8.5MB 77 | # train has ~9B tokens (9,035,582,198) 78 | # val has ~4M tokens (4,434,897) 79 | 80 | # to read the bin files later, e.g. with numpy: 81 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /gpt/sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from a trained model 3 | """ 4 | import os 5 | import pickle 6 | from contextlib import nullcontext 7 | import torch 8 | import tiktoken 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') 13 | out_dir = 'out' # ignored if init_from is not 'resume' 14 | start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" 15 | num_samples = 10 # number of samples to draw 16 | max_new_tokens = 500 # number of tokens generated in each sample 17 | temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 18 | top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability 19 | seed = 1337 20 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 21 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 22 | compile = False # use PyTorch 2.0 to compile the model to be faster 23 | exec(open('configurator.py').read()) # overrides from command line or config file 24 | # ----------------------------------------------------------------------------- 25 | 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 29 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 30 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 31 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 32 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 33 | 34 | # model 35 | if init_from == 'resume': 36 | # init from a model saved in a specific directory 37 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 38 | checkpoint = torch.load(ckpt_path, map_location=device) 39 | gptconf = GPTConfig(**checkpoint['model_args']) 40 | model = GPT(gptconf) 41 | state_dict = checkpoint['model'] 42 | unwanted_prefix = '_orig_mod.' 43 | for k,v in list(state_dict.items()): 44 | if k.startswith(unwanted_prefix): 45 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 46 | model.load_state_dict(state_dict) 47 | elif init_from.startswith('gpt2'): 48 | # init from a given GPT-2 model 49 | model = GPT.from_pretrained(init_from, dict(dropout=0.0)) 50 | 51 | model.eval() 52 | model.to(device) 53 | if compile: 54 | model = torch.compile(model) # requires PyTorch 2.0 (optional) 55 | 56 | # look for the meta pickle in case it is available in the dataset folder 57 | load_meta = False 58 | if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these... 59 | meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') 60 | load_meta = os.path.exists(meta_path) 61 | if load_meta: 62 | print(f"Loading meta from {meta_path}...") 63 | with open(meta_path, 'rb') as f: 64 | meta = pickle.load(f) 65 | # TODO want to make this more general to arbitrary encoder/decoder schemes 66 | stoi, itos = meta['stoi'], meta['itos'] 67 | encode = lambda s: [stoi[c] for c in s] 68 | decode = lambda l: ''.join([itos[i] for i in l]) 69 | else: 70 | # ok let's assume gpt-2 encodings by default 71 | print("No meta.pkl found, assuming GPT-2 encodings...") 72 | enc = tiktoken.get_encoding("gpt2") 73 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) 74 | decode = lambda l: enc.decode(l) 75 | 76 | # encode the beginning of the prompt 77 | if start.startswith('FILE:'): 78 | with open(start[5:], 'r', encoding='utf-8') as f: 79 | start = f.read() 80 | start_ids = encode(start) 81 | x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) 82 | 83 | # run generation 84 | with torch.no_grad(): 85 | with ctx: 86 | for k in range(num_samples): 87 | y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) 88 | print(decode(y[0].tolist())) 89 | print('---------------') 90 | -------------------------------------------------------------------------------- /imagenet/presets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms.functional import InterpolationMode 3 | 4 | 5 | def get_module(use_v2): 6 | # We need a protected import to avoid the V2 warning in case just V1 is used 7 | if use_v2: 8 | import torchvision.transforms.v2 9 | 10 | return torchvision.transforms.v2 11 | else: 12 | import torchvision.transforms 13 | 14 | return torchvision.transforms 15 | 16 | 17 | class ClassificationPresetTrain: 18 | # Note: this transform assumes that the input to forward() are always PIL 19 | # images, regardless of the backend parameter. We may change that in the 20 | # future though, if we change the output type from the dataset. 21 | def __init__( 22 | self, 23 | *, 24 | crop_size, 25 | mean=(0.485, 0.456, 0.406), 26 | std=(0.229, 0.224, 0.225), 27 | interpolation=InterpolationMode.BILINEAR, 28 | hflip_prob=0.5, 29 | auto_augment_policy=None, 30 | ra_magnitude=9, 31 | augmix_severity=3, 32 | random_erase_prob=0.0, 33 | backend="pil", 34 | use_v2=False, 35 | ): 36 | T = get_module(use_v2) 37 | 38 | transforms = [] 39 | backend = backend.lower() 40 | if backend == "tensor": 41 | transforms.append(T.PILToTensor()) 42 | elif backend != "pil": 43 | raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") 44 | 45 | transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) 46 | if hflip_prob > 0: 47 | transforms.append(T.RandomHorizontalFlip(hflip_prob)) 48 | if auto_augment_policy is not None: 49 | if auto_augment_policy == "ra": 50 | transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) 51 | elif auto_augment_policy == "ta_wide": 52 | transforms.append(T.TrivialAugmentWide(interpolation=interpolation)) 53 | elif auto_augment_policy == "augmix": 54 | transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity)) 55 | else: 56 | aa_policy = T.AutoAugmentPolicy(auto_augment_policy) 57 | transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation)) 58 | 59 | if backend == "pil": 60 | transforms.append(T.PILToTensor()) 61 | 62 | transforms.extend( 63 | [ 64 | T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float), 65 | T.Normalize(mean=mean, std=std), 66 | ] 67 | ) 68 | if random_erase_prob > 0: 69 | transforms.append(T.RandomErasing(p=random_erase_prob)) 70 | 71 | if use_v2: 72 | transforms.append(T.ToPureTensor()) 73 | 74 | self.transforms = T.Compose(transforms) 75 | 76 | def __call__(self, img): 77 | return self.transforms(img) 78 | 79 | 80 | class ClassificationPresetEval: 81 | def __init__( 82 | self, 83 | *, 84 | crop_size, 85 | resize_size=256, 86 | mean=(0.485, 0.456, 0.406), 87 | std=(0.229, 0.224, 0.225), 88 | interpolation=InterpolationMode.BILINEAR, 89 | backend="pil", 90 | use_v2=False, 91 | ): 92 | T = get_module(use_v2) 93 | transforms = [] 94 | backend = backend.lower() 95 | if backend == "tensor": 96 | transforms.append(T.PILToTensor()) 97 | elif backend != "pil": 98 | raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") 99 | 100 | transforms += [ 101 | T.Resize(resize_size, interpolation=interpolation, antialias=True), 102 | T.CenterCrop(crop_size), 103 | ] 104 | 105 | if backend == "pil": 106 | transforms.append(T.PILToTensor()) 107 | 108 | transforms += [ 109 | T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float), 110 | T.Normalize(mean=mean, std=std), 111 | ] 112 | 113 | if use_v2: 114 | transforms.append(T.ToPureTensor()) 115 | 116 | self.transforms = T.Compose(transforms) 117 | 118 | def __call__(self, img): 119 | return self.transforms(img) 120 | -------------------------------------------------------------------------------- /gpt/bench.py: -------------------------------------------------------------------------------- 1 | """ 2 | A much shorter version of train.py for benchmarking 3 | """ 4 | import os 5 | from contextlib import nullcontext 6 | import numpy as np 7 | import time 8 | import torch 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | batch_size = 12 13 | block_size = 1024 14 | bias = False 15 | real_data = True 16 | seed = 1337 17 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 18 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 19 | compile = True # use PyTorch 2.0 to compile the model to be faster 20 | profile = False # use pytorch profiler, or just simple benchmarking? 21 | exec(open('configurator.py').read()) # overrides from command line or config file 22 | # ----------------------------------------------------------------------------- 23 | 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 27 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 28 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 29 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 30 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 31 | 32 | # data loading init 33 | if real_data: 34 | dataset = 'openwebtext' 35 | data_dir = os.path.join('data', dataset) 36 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 37 | def get_batch(split): 38 | data = train_data # note ignore split in benchmarking script 39 | ix = torch.randint(len(data) - block_size, (batch_size,)) 40 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 41 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 42 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 43 | return x, y 44 | else: 45 | # alternatively, if fixed data is desired to not care about data loading 46 | x = torch.randint(50304, (batch_size, block_size), device=device) 47 | y = torch.randint(50304, (batch_size, block_size), device=device) 48 | get_batch = lambda split: (x, y) 49 | 50 | # model init 51 | gptconf = GPTConfig( 52 | block_size = block_size, # how far back does the model look? i.e. context size 53 | n_layer = 12, n_head = 12, n_embd = 768, # size of the model 54 | dropout = 0, # for determinism 55 | bias = bias, 56 | ) 57 | model = GPT(gptconf) 58 | model.to(device) 59 | 60 | optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) 61 | 62 | if compile: 63 | print("Compiling model...") 64 | model = torch.compile(model) # pytorch 2.0 65 | 66 | if profile: 67 | # useful docs on pytorch profiler: 68 | # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html 69 | # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile 70 | wait, warmup, active = 5, 5, 5 71 | num_steps = wait + warmup + active 72 | with torch.profiler.profile( 73 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 74 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), 75 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), 76 | record_shapes=False, 77 | profile_memory=False, 78 | with_stack=False, # incurs an additional overhead, disable if not needed 79 | with_flops=True, 80 | with_modules=False, # only for torchscript models atm 81 | ) as prof: 82 | 83 | X, Y = get_batch('train') 84 | for k in range(num_steps): 85 | with ctx: 86 | logits, loss = model(X, Y) 87 | X, Y = get_batch('train') 88 | optimizer.zero_grad(set_to_none=True) 89 | loss.backward() 90 | optimizer.step() 91 | lossf = loss.item() 92 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 93 | 94 | prof.step() # notify the profiler at end of each step 95 | 96 | else: 97 | 98 | # simple benchmarking 99 | torch.cuda.synchronize() 100 | for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark 101 | t0 = time.time() 102 | X, Y = get_batch('train') 103 | for k in range(num_steps): 104 | with ctx: 105 | logits, loss = model(X, Y) 106 | X, Y = get_batch('train') 107 | optimizer.zero_grad(set_to_none=True) 108 | loss.backward() 109 | optimizer.step() 110 | lossf = loss.item() 111 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 112 | torch.cuda.synchronize() 113 | t1 = time.time() 114 | dt = t1-t0 115 | mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt) 116 | if stage == 1: 117 | print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%") 118 | -------------------------------------------------------------------------------- /imagenet/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from presets import get_module 6 | from torch import Tensor 7 | from torchvision.transforms import functional as F 8 | 9 | 10 | def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2): 11 | transforms_module = get_module(use_v2) 12 | 13 | mixup_cutmix = [] 14 | if mixup_alpha > 0: 15 | mixup_cutmix.append( 16 | transforms_module.MixUp(alpha=mixup_alpha, num_classes=num_classes) 17 | if use_v2 18 | else RandomMixUp(num_classes=num_classes, p=1.0, alpha=mixup_alpha) 19 | ) 20 | if cutmix_alpha > 0: 21 | mixup_cutmix.append( 22 | transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes) 23 | if use_v2 24 | else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha) 25 | ) 26 | if not mixup_cutmix: 27 | return None 28 | 29 | return transforms_module.RandomChoice(mixup_cutmix) 30 | 31 | 32 | class RandomMixUp(torch.nn.Module): 33 | """Randomly apply MixUp to the provided batch and targets. 34 | The class implements the data augmentations as described in the paper 35 | `"mixup: Beyond Empirical Risk Minimization" `_. 36 | 37 | Args: 38 | num_classes (int): number of classes used for one-hot encoding. 39 | p (float): probability of the batch being transformed. Default value is 0.5. 40 | alpha (float): hyperparameter of the Beta distribution used for mixup. 41 | Default value is 1.0. 42 | inplace (bool): boolean to make this transform inplace. Default set to False. 43 | """ 44 | 45 | def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: 46 | super().__init__() 47 | 48 | if num_classes < 1: 49 | raise ValueError( 50 | f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" 51 | ) 52 | 53 | if alpha <= 0: 54 | raise ValueError("Alpha param can't be zero.") 55 | 56 | self.num_classes = num_classes 57 | self.p = p 58 | self.alpha = alpha 59 | self.inplace = inplace 60 | 61 | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: 62 | """ 63 | Args: 64 | batch (Tensor): Float tensor of size (B, C, H, W) 65 | target (Tensor): Integer tensor of size (B, ) 66 | 67 | Returns: 68 | Tensor: Randomly transformed batch. 69 | """ 70 | if batch.ndim != 4: 71 | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") 72 | if target.ndim != 1: 73 | raise ValueError(f"Target ndim should be 1. Got {target.ndim}") 74 | if not batch.is_floating_point(): 75 | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") 76 | if target.dtype != torch.int64: 77 | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") 78 | 79 | if not self.inplace: 80 | batch = batch.clone() 81 | target = target.clone() 82 | 83 | if target.ndim == 1: 84 | target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) 85 | 86 | if torch.rand(1).item() >= self.p: 87 | return batch, target 88 | 89 | # It's faster to roll the batch by one instead of shuffling it to create image pairs 90 | batch_rolled = batch.roll(1, 0) 91 | target_rolled = target.roll(1, 0) 92 | 93 | # Implemented as on mixup paper, page 3. 94 | lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) 95 | batch_rolled.mul_(1.0 - lambda_param) 96 | batch.mul_(lambda_param).add_(batch_rolled) 97 | 98 | target_rolled.mul_(1.0 - lambda_param) 99 | target.mul_(lambda_param).add_(target_rolled) 100 | 101 | return batch, target 102 | 103 | def __repr__(self) -> str: 104 | s = ( 105 | f"{self.__class__.__name__}(" 106 | f"num_classes={self.num_classes}" 107 | f", p={self.p}" 108 | f", alpha={self.alpha}" 109 | f", inplace={self.inplace}" 110 | f")" 111 | ) 112 | return s 113 | 114 | 115 | class RandomCutMix(torch.nn.Module): 116 | """Randomly apply CutMix to the provided batch and targets. 117 | The class implements the data augmentations as described in the paper 118 | `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" 119 | `_. 120 | 121 | Args: 122 | num_classes (int): number of classes used for one-hot encoding. 123 | p (float): probability of the batch being transformed. Default value is 0.5. 124 | alpha (float): hyperparameter of the Beta distribution used for cutmix. 125 | Default value is 1.0. 126 | inplace (bool): boolean to make this transform inplace. Default set to False. 127 | """ 128 | 129 | def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: 130 | super().__init__() 131 | if num_classes < 1: 132 | raise ValueError("Please provide a valid positive value for the num_classes.") 133 | if alpha <= 0: 134 | raise ValueError("Alpha param can't be zero.") 135 | 136 | self.num_classes = num_classes 137 | self.p = p 138 | self.alpha = alpha 139 | self.inplace = inplace 140 | 141 | def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: 142 | """ 143 | Args: 144 | batch (Tensor): Float tensor of size (B, C, H, W) 145 | target (Tensor): Integer tensor of size (B, ) 146 | 147 | Returns: 148 | Tensor: Randomly transformed batch. 149 | """ 150 | if batch.ndim != 4: 151 | raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}") 152 | if target.ndim != 1: 153 | raise ValueError(f"Target ndim should be 1. Got {target.ndim}") 154 | if not batch.is_floating_point(): 155 | raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") 156 | if target.dtype != torch.int64: 157 | raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") 158 | 159 | if not self.inplace: 160 | batch = batch.clone() 161 | target = target.clone() 162 | 163 | if target.ndim == 1: 164 | target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) 165 | 166 | if torch.rand(1).item() >= self.p: 167 | return batch, target 168 | 169 | # It's faster to roll the batch by one instead of shuffling it to create image pairs 170 | batch_rolled = batch.roll(1, 0) 171 | target_rolled = target.roll(1, 0) 172 | 173 | # Implemented as on cutmix paper, page 12 (with minor corrections on typos). 174 | lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) 175 | _, H, W = F.get_dimensions(batch) 176 | 177 | r_x = torch.randint(W, (1,)) 178 | r_y = torch.randint(H, (1,)) 179 | 180 | r = 0.5 * math.sqrt(1.0 - lambda_param) 181 | r_w_half = int(r * W) 182 | r_h_half = int(r * H) 183 | 184 | x1 = int(torch.clamp(r_x - r_w_half, min=0)) 185 | y1 = int(torch.clamp(r_y - r_h_half, min=0)) 186 | x2 = int(torch.clamp(r_x + r_w_half, max=W)) 187 | y2 = int(torch.clamp(r_y + r_h_half, max=H)) 188 | 189 | batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] 190 | lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) 191 | 192 | target_rolled.mul_(1.0 - lambda_param) 193 | target.mul_(lambda_param).add_(target_rolled) 194 | 195 | return batch, target 196 | 197 | def __repr__(self) -> str: 198 | s = ( 199 | f"{self.__class__.__name__}(" 200 | f"num_classes={self.num_classes}" 201 | f", p={self.p}" 202 | f", alpha={self.alpha}" 203 | f", inplace={self.inplace}" 204 | f")" 205 | ) 206 | return s 207 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /imagenet/train_quantization.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import os 4 | import time 5 | 6 | import torch 7 | import torch.ao.quantization 8 | import torch.utils.data 9 | import torchvision 10 | import utils 11 | from torch import nn 12 | from train import evaluate, load_data, train_one_epoch 13 | 14 | 15 | def main(args): 16 | if args.output_dir: 17 | utils.mkdir(args.output_dir) 18 | 19 | utils.init_distributed_mode(args) 20 | print(args) 21 | 22 | if args.post_training_quantize and args.distributed: 23 | raise RuntimeError("Post training quantization example should not be performed on distributed mode") 24 | 25 | # Set backend engine to ensure that quantized model runs on the correct kernels 26 | if args.qbackend not in torch.backends.quantized.supported_engines: 27 | raise RuntimeError("Quantized backend not supported: " + str(args.qbackend)) 28 | torch.backends.quantized.engine = args.qbackend 29 | 30 | device = torch.device(args.device) 31 | torch.backends.cudnn.benchmark = True 32 | 33 | # Data loading code 34 | print("Loading data") 35 | train_dir = os.path.join(args.data_path, "train") 36 | val_dir = os.path.join(args.data_path, "val") 37 | 38 | dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) 39 | data_loader = torch.utils.data.DataLoader( 40 | dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True 41 | ) 42 | 43 | data_loader_test = torch.utils.data.DataLoader( 44 | dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True 45 | ) 46 | 47 | print("Creating model", args.model) 48 | # when training quantized models, we always start from a pre-trained fp32 reference model 49 | prefix = "quantized_" 50 | model_name = args.model 51 | if not model_name.startswith(prefix): 52 | model_name = prefix + model_name 53 | model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only) 54 | model.to(device) 55 | 56 | if not (args.test_only or args.post_training_quantize): 57 | model.fuse_model(is_qat=True) 58 | model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend) 59 | torch.ao.quantization.prepare_qat(model, inplace=True) 60 | 61 | if args.distributed and args.sync_bn: 62 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 63 | 64 | optimizer = torch.optim.SGD( 65 | model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay 66 | ) 67 | 68 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 69 | 70 | criterion = nn.CrossEntropyLoss() 71 | model_without_ddp = model 72 | if args.distributed: 73 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 74 | model_without_ddp = model.module 75 | 76 | if args.resume: 77 | checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True) 78 | model_without_ddp.load_state_dict(checkpoint["model"]) 79 | optimizer.load_state_dict(checkpoint["optimizer"]) 80 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 81 | args.start_epoch = checkpoint["epoch"] + 1 82 | 83 | if args.post_training_quantize: 84 | # perform calibration on a subset of the training dataset 85 | # for that, create a subset of the training dataset 86 | ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches))) 87 | data_loader_calibration = torch.utils.data.DataLoader( 88 | ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True 89 | ) 90 | model.eval() 91 | model.fuse_model(is_qat=False) 92 | model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend) 93 | torch.ao.quantization.prepare(model, inplace=True) 94 | # Calibrate first 95 | print("Calibrating") 96 | evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) 97 | torch.ao.quantization.convert(model, inplace=True) 98 | if args.output_dir: 99 | print("Saving quantized model") 100 | if utils.is_main_process(): 101 | torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth")) 102 | print("Evaluating post-training quantized model") 103 | evaluate(model, criterion, data_loader_test, device=device) 104 | return 105 | 106 | if args.test_only: 107 | evaluate(model, criterion, data_loader_test, device=device) 108 | return 109 | 110 | model.apply(torch.ao.quantization.enable_observer) 111 | model.apply(torch.ao.quantization.enable_fake_quant) 112 | start_time = time.time() 113 | for epoch in range(args.start_epoch, args.epochs): 114 | if args.distributed: 115 | train_sampler.set_epoch(epoch) 116 | print("Starting training for epoch", epoch) 117 | train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args) 118 | lr_scheduler.step() 119 | with torch.inference_mode(): 120 | if epoch >= args.num_observer_update_epochs: 121 | print("Disabling observer for subseq epochs, epoch = ", epoch) 122 | model.apply(torch.ao.quantization.disable_observer) 123 | if epoch >= args.num_batch_norm_update_epochs: 124 | print("Freezing BN for subseq epochs, epoch = ", epoch) 125 | model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) 126 | print("Evaluate QAT model") 127 | 128 | evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT") 129 | quantized_eval_model = copy.deepcopy(model_without_ddp) 130 | quantized_eval_model.eval() 131 | quantized_eval_model.to(torch.device("cpu")) 132 | torch.ao.quantization.convert(quantized_eval_model, inplace=True) 133 | 134 | print("Evaluate Quantized model") 135 | evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu")) 136 | 137 | model.train() 138 | 139 | if args.output_dir: 140 | checkpoint = { 141 | "model": model_without_ddp.state_dict(), 142 | "eval_model": quantized_eval_model.state_dict(), 143 | "optimizer": optimizer.state_dict(), 144 | "lr_scheduler": lr_scheduler.state_dict(), 145 | "epoch": epoch, 146 | "args": args, 147 | } 148 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) 149 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) 150 | print("Saving models after epoch ", epoch) 151 | 152 | total_time = time.time() - start_time 153 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 154 | print(f"Training time {total_time_str}") 155 | 156 | 157 | def get_args_parser(add_help=True): 158 | import argparse 159 | 160 | parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help) 161 | 162 | parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path") 163 | parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name") 164 | parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack") 165 | parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") 166 | 167 | parser.add_argument( 168 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" 169 | ) 170 | parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation") 171 | parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") 172 | parser.add_argument( 173 | "--num-observer-update-epochs", 174 | default=4, 175 | type=int, 176 | metavar="N", 177 | help="number of total epochs to update observers", 178 | ) 179 | parser.add_argument( 180 | "--num-batch-norm-update-epochs", 181 | default=3, 182 | type=int, 183 | metavar="N", 184 | help="number of total epochs to update batch norm stats", 185 | ) 186 | parser.add_argument( 187 | "--num-calibration-batches", 188 | default=32, 189 | type=int, 190 | metavar="N", 191 | help="number of batches of training set for \ 192 | observer calibration ", 193 | ) 194 | 195 | parser.add_argument( 196 | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" 197 | ) 198 | parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate") 199 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 200 | parser.add_argument( 201 | "--wd", 202 | "--weight-decay", 203 | default=1e-4, 204 | type=float, 205 | metavar="W", 206 | help="weight decay (default: 1e-4)", 207 | dest="weight_decay", 208 | ) 209 | parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") 210 | parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") 211 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency") 212 | parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") 213 | parser.add_argument("--resume", default="", type=str, help="path of checkpoint") 214 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") 215 | parser.add_argument( 216 | "--cache-dataset", 217 | dest="cache_dataset", 218 | help="Cache the datasets for quicker initialization. \ 219 | It also serializes the transforms", 220 | action="store_true", 221 | ) 222 | parser.add_argument( 223 | "--sync-bn", 224 | dest="sync_bn", 225 | help="Use sync batch norm", 226 | action="store_true", 227 | ) 228 | parser.add_argument( 229 | "--test-only", 230 | dest="test_only", 231 | help="Only test the model", 232 | action="store_true", 233 | ) 234 | parser.add_argument( 235 | "--post-training-quantize", 236 | dest="post_training_quantize", 237 | help="Post training quantize the model", 238 | action="store_true", 239 | ) 240 | 241 | # distributed training parameters 242 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") 243 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") 244 | 245 | parser.add_argument( 246 | "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" 247 | ) 248 | parser.add_argument( 249 | "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" 250 | ) 251 | parser.add_argument( 252 | "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" 253 | ) 254 | parser.add_argument( 255 | "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" 256 | ) 257 | parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") 258 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") 259 | 260 | parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") 261 | parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") 262 | 263 | return parser 264 | 265 | 266 | if __name__ == "__main__": 267 | args = get_args_parser().parse_args() 268 | if args.backend in ("fbgemm", "qnnpack"): 269 | raise ValueError( 270 | "The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) " 271 | "instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend." 272 | ) 273 | main(args) 274 | -------------------------------------------------------------------------------- /gpt/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | To run with DDP on 4 gpus across 2 nodes, example: 12 | - Run on the first (master) node with example IP 123.456.123.456: 13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 | - Run on the worker node: 15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 | """ 18 | 19 | import os 20 | import time 21 | import math 22 | import pickle 23 | from contextlib import nullcontext 24 | 25 | import numpy as np 26 | import torch 27 | from torch.nn.parallel import DistributedDataParallel as DDP 28 | from torch.distributed import init_process_group, destroy_process_group 29 | 30 | from model import GPTConfig, GPT 31 | 32 | 33 | # ----------------------------------------------------------------------------- 34 | # default config values designed to train a gpt2 (124M) on OpenWebText 35 | # I/O 36 | out_dir = 'out' 37 | eval_interval = 2000 38 | log_interval = 1 39 | eval_iters = 200 40 | eval_only = False # if True, script exits right after the first eval 41 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 42 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 43 | # wandb logging 44 | wandb_log = False # disabled by default 45 | wandb_project = 'owt' 46 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 47 | # data 48 | dataset = 'openwebtext' 49 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 50 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 51 | block_size = 1024 52 | # model 53 | n_layer = 12 54 | n_head = 12 55 | n_embd = 768 56 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 57 | bias = False # do we use bias inside LayerNorm and Linear layers? 58 | # optimizer 59 | optim_type = 'adamw' 60 | learning_rate = 6e-4 # max learning rate 61 | max_iters = 600000 # total number of training iterations 62 | weight_decay = 1e-1 63 | beta1 = 0.9 64 | beta2 = 0.95 65 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 66 | # learning rate decay settings 67 | decay_lr = True # whether to decay the learning rate 68 | warmup_iters = 2000 # how many steps to warm up for 69 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 70 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 71 | # DDP settings 72 | backend = 'nccl' # 'nccl', 'gloo', etc. 73 | # system 74 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 75 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 76 | compile = True # use PyTorch 2.0 to compile the model to be faster 77 | # ----------------------------------------------------------------------------- 78 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 79 | exec(open('configurator.py').read()) # overrides from command line or config file 80 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 81 | # ----------------------------------------------------------------------------- 82 | 83 | # various inits, derived attributes, I/O setup 84 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 85 | if ddp: 86 | init_process_group(backend=backend) 87 | ddp_rank = int(os.environ['RANK']) 88 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 89 | ddp_world_size = int(os.environ['WORLD_SIZE']) 90 | device = f'cuda:{ddp_local_rank}' 91 | torch.cuda.set_device(device) 92 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 93 | seed_offset = ddp_rank # each process gets a different seed 94 | # world_size number of processes will be training simultaneously, so we can scale 95 | # down the desired gradient accumulation iterations per process proportionally 96 | assert gradient_accumulation_steps % ddp_world_size == 0 97 | gradient_accumulation_steps //= ddp_world_size 98 | else: 99 | # if not ddp, we are running on a single gpu, and one process 100 | master_process = True 101 | seed_offset = 0 102 | ddp_world_size = 1 103 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 104 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 105 | 106 | if master_process: 107 | os.makedirs(out_dir, exist_ok=True) 108 | torch.manual_seed(1337 + seed_offset) 109 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 110 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 111 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 112 | # note: float16 data type will automatically use a GradScaler 113 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 114 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 115 | 116 | # poor man's data loader 117 | data_dir = os.path.join('data', dataset) 118 | def get_batch(split): 119 | # We recreate np.memmap every batch to avoid a memory leak, as per 120 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 121 | if split == 'train': 122 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 123 | else: 124 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 125 | ix = torch.randint(len(data) - block_size, (batch_size,)) 126 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 127 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 128 | if device_type == 'cuda': 129 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 130 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 131 | else: 132 | x, y = x.to(device), y.to(device) 133 | return x, y 134 | 135 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 136 | iter_num = 0 137 | best_val_loss = 1e9 138 | 139 | # attempt to derive vocab_size from the dataset 140 | meta_path = os.path.join(data_dir, 'meta.pkl') 141 | meta_vocab_size = None 142 | if os.path.exists(meta_path): 143 | with open(meta_path, 'rb') as f: 144 | meta = pickle.load(f) 145 | meta_vocab_size = meta['vocab_size'] 146 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 147 | 148 | # model init 149 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 150 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line 151 | if init_from == 'scratch': 152 | # init a new model from scratch 153 | print("Initializing a new model from scratch") 154 | # determine the vocab size we'll use for from-scratch training 155 | if meta_vocab_size is None: 156 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 157 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 158 | gptconf = GPTConfig(**model_args) 159 | model = GPT(gptconf) 160 | elif init_from == 'resume': 161 | print(f"Resuming training from {out_dir}") 162 | # resume training from a checkpoint. 163 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 164 | checkpoint = torch.load(ckpt_path, map_location=device) 165 | checkpoint_model_args = checkpoint['model_args'] 166 | # force these config attributes to be equal otherwise we can't even resume training 167 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 168 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 169 | model_args[k] = checkpoint_model_args[k] 170 | # create the model 171 | gptconf = GPTConfig(**model_args) 172 | model = GPT(gptconf) 173 | state_dict = checkpoint['model'] 174 | # fix the keys of the state dictionary :( 175 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 176 | unwanted_prefix = '_orig_mod.' 177 | for k,v in list(state_dict.items()): 178 | if k.startswith(unwanted_prefix): 179 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 180 | model.load_state_dict(state_dict) 181 | iter_num = checkpoint['iter_num'] 182 | best_val_loss = checkpoint['best_val_loss'] 183 | elif init_from.startswith('gpt2'): 184 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 185 | # initialize from OpenAI GPT-2 weights 186 | override_args = dict(dropout=dropout) 187 | model = GPT.from_pretrained(init_from, override_args) 188 | # read off the created config params, so we can store them into checkpoint correctly 189 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 190 | model_args[k] = getattr(model.config, k) 191 | # crop down the model block size if desired, using model surgery 192 | if block_size < model.config.block_size: 193 | model.crop_block_size(block_size) 194 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 195 | model.to(device) 196 | 197 | # initialize a GradScaler. If enabled=False scaler is a no-op 198 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 199 | 200 | # optimizer 201 | optimizer = model.configure_optimizers(optim_type, weight_decay, learning_rate, (beta1, beta2), device_type) 202 | if init_from == 'resume': 203 | optimizer.load_state_dict(checkpoint['optimizer']) 204 | checkpoint = None # free up memory 205 | 206 | # compile the model 207 | if compile: 208 | print("compiling the model... (takes a ~minute)") 209 | unoptimized_model = model 210 | model = torch.compile(model) # requires PyTorch 2.0 211 | 212 | # wrap model into DDP container 213 | if ddp: 214 | model = DDP(model, device_ids=[ddp_local_rank]) 215 | 216 | # helps estimate an arbitrarily accurate loss over either split using many batches 217 | @torch.no_grad() 218 | def estimate_loss(): 219 | out = {} 220 | model.eval() 221 | for split in ['train', 'val']: 222 | losses = torch.zeros(eval_iters) 223 | for k in range(eval_iters): 224 | X, Y = get_batch(split) 225 | with ctx: 226 | logits, loss = model(X, Y) 227 | losses[k] = loss.item() 228 | out[split] = losses.mean() 229 | model.train() 230 | return out 231 | 232 | # learning rate decay scheduler (cosine with warmup) 233 | def get_lr(it): 234 | # 1) linear warmup for warmup_iters steps 235 | if it < warmup_iters: 236 | return learning_rate * it / warmup_iters 237 | # 2) if it > lr_decay_iters, return min learning rate 238 | if it > lr_decay_iters: 239 | return min_lr 240 | # 3) in between, use cosine decay down to min learning rate 241 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 242 | assert 0 <= decay_ratio <= 1 243 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 244 | return min_lr + coeff * (learning_rate - min_lr) 245 | 246 | # logging 247 | if wandb_log and master_process: 248 | import wandb 249 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 250 | 251 | # training loop 252 | X, Y = get_batch('train') # fetch the very first batch 253 | t0 = time.time() 254 | local_iter_num = 0 # number of iterations in the lifetime of this process 255 | raw_model = model.module if ddp else model # unwrap DDP container if needed 256 | running_mfu = -1.0 257 | while True: 258 | 259 | # determine and set the learning rate for this iteration 260 | lr = get_lr(iter_num) if decay_lr else learning_rate 261 | for param_group in optimizer.param_groups: 262 | param_group['lr'] = lr 263 | 264 | # evaluate the loss on train/val sets and write checkpoints 265 | if iter_num % eval_interval == 0 and master_process: 266 | losses = estimate_loss() 267 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 268 | if wandb_log: 269 | wandb.log({ 270 | "iter": iter_num, 271 | "train/loss": losses['train'], 272 | "val/loss": losses['val'], 273 | "lr": lr, 274 | "mfu": running_mfu*100, # convert to percentage 275 | }) 276 | if losses['val'] < best_val_loss or always_save_checkpoint: 277 | best_val_loss = losses['val'] 278 | if iter_num > 0: 279 | checkpoint = { 280 | 'model': raw_model.state_dict(), 281 | 'optimizer': optimizer.state_dict(), 282 | 'model_args': model_args, 283 | 'iter_num': iter_num, 284 | 'best_val_loss': best_val_loss, 285 | 'config': config, 286 | } 287 | print(f"saving checkpoint to {out_dir}") 288 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 289 | if iter_num == 0 and eval_only: 290 | break 291 | 292 | # forward backward update, with optional gradient accumulation to simulate larger batch size 293 | # and using the GradScaler if data type is float16 294 | for micro_step in range(gradient_accumulation_steps): 295 | if ddp: 296 | # in DDP training we only need to sync gradients at the last micro step. 297 | # the official way to do this is with model.no_sync() context manager, but 298 | # I really dislike that this bloats the code and forces us to repeat code 299 | # looking at the source of that context manager, it just toggles this variable 300 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 301 | with ctx: 302 | logits, loss = model(X, Y) 303 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 304 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 305 | X, Y = get_batch('train') 306 | # backward pass, with gradient scaling if training in fp16 307 | scaler.scale(loss).backward() 308 | # clip the gradient 309 | if grad_clip != 0.0: 310 | scaler.unscale_(optimizer) 311 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 312 | # step the optimizer and scaler if training in fp16 313 | scaler.step(optimizer) 314 | scaler.update() 315 | # flush the gradients as soon as we can, no need for this memory anymore 316 | optimizer.zero_grad(set_to_none=True) 317 | 318 | # timing and logging 319 | t1 = time.time() 320 | dt = t1 - t0 321 | t0 = t1 322 | if iter_num % log_interval == 0 and master_process: 323 | # get loss as float. note: this is a CPU-GPU sync point 324 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 325 | lossf = loss.item() * gradient_accumulation_steps 326 | if local_iter_num >= 5: # let the training loop settle a bit 327 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 328 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 329 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 330 | iter_num += 1 331 | local_iter_num += 1 332 | 333 | # termination conditions 334 | if iter_num > max_iters: 335 | break 336 | 337 | if ddp: 338 | destroy_process_group() 339 | -------------------------------------------------------------------------------- /gpt/transformer_sizing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "### Transformer Theoretical Model\n", 9 | "\n", 10 | "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from collections import OrderedDict" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# config_args = {\n", 29 | "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n", 30 | "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n", 31 | "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n", 32 | "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n", 33 | "# }[model_type]\n", 34 | "\n", 35 | "block_size = 1024\n", 36 | "vocab_size = 50257\n", 37 | "n_layer = 12\n", 38 | "n_head = 12\n", 39 | "n_embd = 768\n", 40 | "bias = False\n", 41 | "assert not bias, \"this notebook assumes bias=False just for simplicity\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "we see: 124337664, expected: 124337664, match: True\n", 54 | "name params ratio (%) \n", 55 | "emebedding/position 786432 0.6325\n", 56 | "embedding/token 38597376 31.0424\n", 57 | "embedding 39383808 31.6749\n", 58 | "attention/ln 768 0.0006\n", 59 | "attention/kqv 1769472 1.4231\n", 60 | "attention/proj 589824 0.4744\n", 61 | "attention 2360064 1.8981\n", 62 | "mlp/ln 768 0.0006\n", 63 | "mlp/ffw 2359296 1.8975\n", 64 | "mlp/proj 2359296 1.8975\n", 65 | "mlp 4719360 3.7956\n", 66 | "block 7079424 5.6937\n", 67 | "transformer 84953088 68.3245\n", 68 | "ln_f 768 0.0006\n", 69 | "dense 0 0.0000\n", 70 | "total 124337664 100.0000\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "def params():\n", 76 | " \"\"\" estimates the number of parameters in the model\"\"\"\n", 77 | " out = OrderedDict()\n", 78 | "\n", 79 | " # token and position embeddings\n", 80 | " out['emebedding/position'] = n_embd * block_size\n", 81 | " out['embedding/token'] = n_embd * vocab_size\n", 82 | " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n", 83 | "\n", 84 | " # attention blocks\n", 85 | " out['attention/ln'] = n_embd # note, bias=False in our LN\n", 86 | " out['attention/kqv'] = n_embd * 3*n_embd\n", 87 | " out['attention/proj'] = n_embd**2\n", 88 | " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n", 89 | "\n", 90 | " # MLP blocks\n", 91 | " ffw_size = 4*n_embd # feed forward size\n", 92 | " out['mlp/ln'] = n_embd\n", 93 | " out['mlp/ffw'] = n_embd * ffw_size\n", 94 | " out['mlp/proj'] = ffw_size * n_embd\n", 95 | " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n", 96 | " \n", 97 | " # the transformer and the rest of it\n", 98 | " out['block'] = out['attention'] + out['mlp']\n", 99 | " out['transformer'] = n_layer * out['block']\n", 100 | " out['ln_f'] = n_embd # final layernorm\n", 101 | " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n", 102 | "\n", 103 | " # total\n", 104 | " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n", 105 | "\n", 106 | " return out\n", 107 | "\n", 108 | "# compare our param count to that reported by PyTorch\n", 109 | "p = params()\n", 110 | "params_total = p['total']\n", 111 | "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n", 112 | "# create a header\n", 113 | "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n", 114 | "for k,v in p.items():\n", 115 | " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n", 116 | " " 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "est checkpoint size: 1.49 GB\n", 129 | "measured with wc -c ckpt.pt: 1542470366\n", 130 | "fluff ratio: 103.38%\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "# we can now calculate the size of each checkpoint\n", 136 | "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n", 137 | "params_bytes = params_total*4\n", 138 | "params_and_buffers_bytes = params_bytes + 2*params_bytes\n", 139 | "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n", 140 | "measured_bytes = 1542470366 # from wc -c ckpt.pt\n", 141 | "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n", 142 | "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")" 143 | ] 144 | }, 145 | { 146 | "attachments": {}, 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 5, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "memory ratio taken up just for parameters: 3.73%\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n", 168 | "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")" 169 | ] 170 | }, 171 | { 172 | "attachments": {}, 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models." 177 | ] 178 | }, 179 | { 180 | "attachments": {}, 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Let's estimate FLOPs for a single forward pass." 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "name flops ratio (%) \n", 197 | "attention/kqv 3623878656 1.2426\n", 198 | "attention/scores 1610612736 0.5522\n", 199 | "attention/reduce 1610612736 0.5522\n", 200 | "attention/proj 1207959552 0.4142\n", 201 | "attention 8053063680 2.7612\n", 202 | "mlp/ffw1 4831838208 1.6567\n", 203 | "mlp/ffw2 4831838208 1.6567\n", 204 | "mlp 9663676416 3.3135\n", 205 | "block 17716740096 6.0747\n", 206 | "transformer 212600881152 72.8963\n", 207 | "dense 79047426048 27.1037\n", 208 | "forward_total 291648307200 100.0000\n", 209 | "backward_total 583296614400 200.0000\n", 210 | "total 874944921600 300.0000\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "def flops():\n", 216 | " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n", 217 | " # we count actual FLOPs, not MACs. Hence 2* all over the place\n", 218 | " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n", 219 | "\n", 220 | " out = OrderedDict()\n", 221 | " head_size = n_embd // n_head\n", 222 | "\n", 223 | " # attention blocks\n", 224 | " # 1) the projection to key, query, values\n", 225 | " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n", 226 | " # 2) calculating the attention scores\n", 227 | " out['attention/scores'] = 2 * block_size * block_size * n_embd\n", 228 | " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n", 229 | " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n", 230 | " # 4) the final linear projection\n", 231 | " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n", 232 | " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n", 233 | "\n", 234 | " # MLP blocks\n", 235 | " ffw_size = 4*n_embd # feed forward size\n", 236 | " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n", 237 | " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n", 238 | " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n", 239 | "\n", 240 | " # the transformer and the rest of it\n", 241 | " out['block'] = out['attention'] + out['mlp']\n", 242 | " out['transformer'] = n_layer * out['block']\n", 243 | " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n", 244 | "\n", 245 | " # forward,backward,total\n", 246 | " out['forward_total'] = out['transformer'] + out['dense']\n", 247 | " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n", 248 | " out['total'] = out['forward_total'] + out['backward_total']\n", 249 | "\n", 250 | " return out\n", 251 | " \n", 252 | "# compare our param count to that reported by PyTorch\n", 253 | "f = flops()\n", 254 | "flops_total = f['forward_total']\n", 255 | "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n", 256 | "for k,v in f.items():\n", 257 | " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n", 258 | " " 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 7, 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n" 271 | ] 272 | } 273 | ], 274 | "source": [ 275 | "# now here is an estimate copy pasted from the PaLM paper\n", 276 | "# this formula is often used to calculate MFU (model flops utilization)\n", 277 | "def palm_flops():\n", 278 | " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n", 279 | " # non-embedding model parameters. note that we do not subtract the\n", 280 | " # embedding/token params because those are tied and get used in the last layer.\n", 281 | " N = params()['total'] - params()['emebedding/position']\n", 282 | " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n", 283 | " mf_per_token = 6*N + 12*L*H*Q*T\n", 284 | " mf = mf_per_token * block_size\n", 285 | " return mf\n", 286 | "\n", 287 | "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")" 288 | ] 289 | }, 290 | { 291 | "attachments": {}, 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 8, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "fraction of A100 used: 37.14%\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "# here is what we currently roughly measure\n", 313 | "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n", 314 | "measured_time = 0.755 # in seconds per iteration\n", 315 | "measured_throughput = batch_size / measured_time\n", 316 | "flops_achieved = f['total'] * measured_throughput\n", 317 | "\n", 318 | "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n", 319 | "a100_flops_promised = 312e12\n", 320 | "\n", 321 | "# the fraction of the A100 that we are using:\n", 322 | "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")" 323 | ] 324 | }, 325 | { 326 | "attachments": {}, 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 9, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "time needed to train the model: 3.46 days\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n", 348 | "model_size = params()['total'] # this is number of parameters, N\n", 349 | "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n", 350 | "a100_flops = 312e12 # 312 TFLOPS\n", 351 | "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n", 352 | "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n", 353 | "flops_needed = 6 * model_size * tokens_num # 6ND\n", 354 | "time_needed_s = flops_needed / flops_throughput # in seconds\n", 355 | "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")" 356 | ] 357 | }, 358 | { 359 | "attachments": {}, 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)." 364 | ] 365 | }, 366 | { 367 | "attachments": {}, 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later." 372 | ] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "pytorch2", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.10.8" 392 | }, 393 | "orig_nbformat": 4, 394 | "vscode": { 395 | "interpreter": { 396 | "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222" 397 | } 398 | } 399 | }, 400 | "nbformat": 4, 401 | "nbformat_minor": 2 402 | } 403 | -------------------------------------------------------------------------------- /imagenet/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import errno 4 | import hashlib 5 | import os 6 | import time 7 | from collections import defaultdict, deque, OrderedDict 8 | from typing import List, Optional, Tuple 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | 14 | class SmoothedValue: 15 | """Track a series of values and provide access to smoothed values over a 16 | window or the global series average. 17 | """ 18 | 19 | def __init__(self, window_size=20, fmt=None): 20 | if fmt is None: 21 | fmt = "{median:.4f} ({global_avg:.4f})" 22 | self.deque = deque(maxlen=window_size) 23 | self.total = 0.0 24 | self.count = 0 25 | self.fmt = fmt 26 | 27 | def update(self, value, n=1): 28 | self.deque.append(value) 29 | self.count += n 30 | self.total += value * n 31 | 32 | def synchronize_between_processes(self): 33 | """ 34 | Warning: does not synchronize the deque! 35 | """ 36 | t = reduce_across_processes([self.count, self.total]) 37 | t = t.tolist() 38 | self.count = int(t[0]) 39 | self.total = t[1] 40 | 41 | @property 42 | def median(self): 43 | d = torch.tensor(list(self.deque)) 44 | return d.median().item() 45 | 46 | @property 47 | def avg(self): 48 | d = torch.tensor(list(self.deque), dtype=torch.float32) 49 | return d.mean().item() 50 | 51 | @property 52 | def global_avg(self): 53 | return self.total / self.count 54 | 55 | @property 56 | def max(self): 57 | return max(self.deque) 58 | 59 | @property 60 | def value(self): 61 | return self.deque[-1] 62 | 63 | def __str__(self): 64 | return self.fmt.format( 65 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value 66 | ) 67 | 68 | 69 | class MetricLogger: 70 | def __init__(self, delimiter="\t"): 71 | self.meters = defaultdict(SmoothedValue) 72 | self.delimiter = delimiter 73 | 74 | def update(self, **kwargs): 75 | for k, v in kwargs.items(): 76 | if isinstance(v, torch.Tensor): 77 | v = v.item() 78 | assert isinstance(v, (float, int)) 79 | self.meters[k].update(v) 80 | 81 | def __getattr__(self, attr): 82 | if attr in self.meters: 83 | return self.meters[attr] 84 | if attr in self.__dict__: 85 | return self.__dict__[attr] 86 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") 87 | 88 | def __str__(self): 89 | loss_str = [] 90 | for name, meter in self.meters.items(): 91 | loss_str.append(f"{name}: {str(meter)}") 92 | return self.delimiter.join(loss_str) 93 | 94 | def synchronize_between_processes(self): 95 | for meter in self.meters.values(): 96 | meter.synchronize_between_processes() 97 | 98 | def add_meter(self, name, meter): 99 | self.meters[name] = meter 100 | 101 | def log_every(self, iterable, print_freq, header=None): 102 | i = 0 103 | if not header: 104 | header = "" 105 | start_time = time.time() 106 | end = time.time() 107 | iter_time = SmoothedValue(fmt="{avg:.4f}") 108 | data_time = SmoothedValue(fmt="{avg:.4f}") 109 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 110 | if torch.cuda.is_available(): 111 | log_msg = self.delimiter.join( 112 | [ 113 | header, 114 | "[{0" + space_fmt + "}/{1}]", 115 | "eta: {eta}", 116 | "{meters}", 117 | "time: {time}", 118 | "data: {data}", 119 | "max mem: {memory:.0f}", 120 | ] 121 | ) 122 | else: 123 | log_msg = self.delimiter.join( 124 | [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] 125 | ) 126 | MB = 1024.0 * 1024.0 127 | for obj in iterable: 128 | data_time.update(time.time() - end) 129 | yield obj 130 | iter_time.update(time.time() - end) 131 | if i % print_freq == 0: 132 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 133 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 134 | if torch.cuda.is_available(): 135 | print( 136 | log_msg.format( 137 | i, 138 | len(iterable), 139 | eta=eta_string, 140 | meters=str(self), 141 | time=str(iter_time), 142 | data=str(data_time), 143 | memory=torch.cuda.max_memory_allocated() / MB, 144 | ) 145 | ) 146 | else: 147 | print( 148 | log_msg.format( 149 | i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) 150 | ) 151 | ) 152 | i += 1 153 | end = time.time() 154 | total_time = time.time() - start_time 155 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 156 | print(f"{header} Total time: {total_time_str}") 157 | 158 | 159 | class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): 160 | """Maintains moving averages of model parameters using an exponential decay. 161 | ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` 162 | `torch.optim.swa_utils.AveragedModel `_ 163 | is used to compute the EMA. 164 | """ 165 | 166 | def __init__(self, model, decay, device="cpu"): 167 | def ema_avg(avg_model_param, model_param, num_averaged): 168 | return decay * avg_model_param + (1 - decay) * model_param 169 | 170 | super().__init__(model, device, ema_avg, use_buffers=True) 171 | 172 | 173 | def accuracy(output, target, topk=(1,)): 174 | """Computes the accuracy over the k top predictions for the specified values of k""" 175 | with torch.inference_mode(): 176 | maxk = max(topk) 177 | batch_size = target.size(0) 178 | if target.ndim == 2: 179 | target = target.max(dim=1)[1] 180 | 181 | _, pred = output.topk(maxk, 1, True, True) 182 | pred = pred.t() 183 | correct = pred.eq(target[None]) 184 | 185 | res = [] 186 | for k in topk: 187 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 188 | res.append(correct_k * (100.0 / batch_size)) 189 | return res 190 | 191 | 192 | def mkdir(path): 193 | try: 194 | os.makedirs(path) 195 | except OSError as e: 196 | if e.errno != errno.EEXIST: 197 | raise 198 | 199 | 200 | def setup_for_distributed(is_master): 201 | """ 202 | This function disables printing when not in master process 203 | """ 204 | import builtins as __builtin__ 205 | 206 | builtin_print = __builtin__.print 207 | 208 | def print(*args, **kwargs): 209 | force = kwargs.pop("force", False) 210 | if is_master or force: 211 | builtin_print(*args, **kwargs) 212 | 213 | __builtin__.print = print 214 | 215 | 216 | def is_dist_avail_and_initialized(): 217 | if not dist.is_available(): 218 | return False 219 | if not dist.is_initialized(): 220 | return False 221 | return True 222 | 223 | 224 | def get_world_size(): 225 | if not is_dist_avail_and_initialized(): 226 | return 1 227 | return dist.get_world_size() 228 | 229 | 230 | def get_rank(): 231 | if not is_dist_avail_and_initialized(): 232 | return 0 233 | return dist.get_rank() 234 | 235 | 236 | def is_main_process(): 237 | return get_rank() == 0 238 | 239 | 240 | def save_on_master(*args, **kwargs): 241 | if is_main_process(): 242 | torch.save(*args, **kwargs) 243 | 244 | 245 | def init_distributed_mode(args): 246 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 247 | args.rank = int(os.environ["RANK"]) 248 | args.world_size = int(os.environ["WORLD_SIZE"]) 249 | args.gpu = int(os.environ["LOCAL_RANK"]) 250 | elif "SLURM_PROCID" in os.environ: 251 | args.rank = int(os.environ["SLURM_PROCID"]) 252 | args.gpu = args.rank % torch.cuda.device_count() 253 | elif hasattr(args, "rank"): 254 | pass 255 | else: 256 | print("Not using distributed mode") 257 | args.distributed = False 258 | return 259 | 260 | args.distributed = True 261 | 262 | torch.cuda.set_device(args.gpu) 263 | args.dist_backend = "nccl" 264 | print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) 265 | torch.distributed.init_process_group( 266 | backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank 267 | ) 268 | torch.distributed.barrier() 269 | setup_for_distributed(args.rank == 0) 270 | 271 | 272 | def average_checkpoints(inputs): 273 | """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: 274 | https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 275 | 276 | Args: 277 | inputs (List[str]): An iterable of string paths of checkpoints to load from. 278 | Returns: 279 | A dict of string keys mapping to various values. The 'model' key 280 | from the returned dict should correspond to an OrderedDict mapping 281 | string parameter names to torch Tensors. 282 | """ 283 | params_dict = OrderedDict() 284 | params_keys = None 285 | new_state = None 286 | num_models = len(inputs) 287 | for fpath in inputs: 288 | with open(fpath, "rb") as f: 289 | state = torch.load( 290 | f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True 291 | ) 292 | # Copies over the settings from the first checkpoint 293 | if new_state is None: 294 | new_state = state 295 | model_params = state["model"] 296 | model_params_keys = list(model_params.keys()) 297 | if params_keys is None: 298 | params_keys = model_params_keys 299 | elif params_keys != model_params_keys: 300 | raise KeyError( 301 | f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}" 302 | ) 303 | for k in params_keys: 304 | p = model_params[k] 305 | if isinstance(p, torch.HalfTensor): 306 | p = p.float() 307 | if k not in params_dict: 308 | params_dict[k] = p.clone() 309 | # NOTE: clone() is needed in case of p is a shared parameter 310 | else: 311 | params_dict[k] += p 312 | averaged_params = OrderedDict() 313 | for k, v in params_dict.items(): 314 | averaged_params[k] = v 315 | if averaged_params[k].is_floating_point(): 316 | averaged_params[k].div_(num_models) 317 | else: 318 | averaged_params[k] //= num_models 319 | new_state["model"] = averaged_params 320 | return new_state 321 | 322 | 323 | def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True): 324 | """ 325 | This method can be used to prepare weights files for new models. It receives as 326 | input a model architecture and a checkpoint from the training script and produces 327 | a file with the weights ready for release. 328 | 329 | Examples: 330 | from torchvision import models as M 331 | 332 | # Classification 333 | model = M.mobilenet_v3_large(weights=None) 334 | print(store_model_weights(model, './class.pth')) 335 | 336 | # Quantized Classification 337 | model = M.quantization.mobilenet_v3_large(weights=None, quantize=False) 338 | model.fuse_model(is_qat=True) 339 | model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') 340 | _ = torch.ao.quantization.prepare_qat(model, inplace=True) 341 | print(store_model_weights(model, './qat.pth')) 342 | 343 | # Object Detection 344 | model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None) 345 | print(store_model_weights(model, './obj.pth')) 346 | 347 | # Segmentation 348 | model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True) 349 | print(store_model_weights(model, './segm.pth', strict=False)) 350 | 351 | Args: 352 | model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes. 353 | checkpoint_path (str): The path of the checkpoint we will load. 354 | checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored. 355 | Default: "model". 356 | strict (bool): whether to strictly enforce that the keys 357 | in :attr:`state_dict` match the keys returned by this module's 358 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 359 | 360 | Returns: 361 | output_path (str): The location where the weights are saved. 362 | """ 363 | # Store the new model next to the checkpoint_path 364 | checkpoint_path = os.path.abspath(checkpoint_path) 365 | output_dir = os.path.dirname(checkpoint_path) 366 | 367 | # Deep copy to avoid side effects on the model object. 368 | model = copy.deepcopy(model) 369 | checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) 370 | 371 | # Load the weights to the model to validate that everything works 372 | # and remove unnecessary weights (such as auxiliaries, etc.) 373 | if checkpoint_key == "model_ema": 374 | del checkpoint[checkpoint_key]["n_averaged"] 375 | torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.") 376 | model.load_state_dict(checkpoint[checkpoint_key], strict=strict) 377 | 378 | tmp_path = os.path.join(output_dir, str(model.__hash__())) 379 | torch.save(model.state_dict(), tmp_path) 380 | 381 | sha256_hash = hashlib.sha256() 382 | with open(tmp_path, "rb") as f: 383 | # Read and update hash string value in blocks of 4K 384 | for byte_block in iter(lambda: f.read(4096), b""): 385 | sha256_hash.update(byte_block) 386 | hh = sha256_hash.hexdigest() 387 | 388 | output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth") 389 | os.replace(tmp_path, output_path) 390 | 391 | return output_path 392 | 393 | 394 | def reduce_across_processes(val): 395 | if not is_dist_avail_and_initialized(): 396 | # nothing to sync, but we still convert to tensor for consistency with the distributed case. 397 | return torch.tensor(val) 398 | 399 | t = torch.tensor(val, device="cuda") 400 | dist.barrier() 401 | dist.all_reduce(t) 402 | return t 403 | 404 | 405 | def set_weight_decay( 406 | model: torch.nn.Module, 407 | weight_decay: float, 408 | norm_weight_decay: Optional[float] = None, 409 | norm_classes: Optional[List[type]] = None, 410 | custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, 411 | ): 412 | if not norm_classes: 413 | norm_classes = [ 414 | torch.nn.modules.batchnorm._BatchNorm, 415 | torch.nn.LayerNorm, 416 | torch.nn.GroupNorm, 417 | torch.nn.modules.instancenorm._InstanceNorm, 418 | torch.nn.LocalResponseNorm, 419 | ] 420 | norm_classes = tuple(norm_classes) 421 | 422 | params = { 423 | "other": [], 424 | "norm": [], 425 | } 426 | params_weight_decay = { 427 | "other": weight_decay, 428 | "norm": norm_weight_decay, 429 | } 430 | custom_keys = [] 431 | if custom_keys_weight_decay is not None: 432 | for key, weight_decay in custom_keys_weight_decay: 433 | params[key] = [] 434 | params_weight_decay[key] = weight_decay 435 | custom_keys.append(key) 436 | 437 | def _add_params(module, prefix=""): 438 | for name, p in module.named_parameters(recurse=False): 439 | if not p.requires_grad: 440 | continue 441 | is_custom_key = False 442 | for key in custom_keys: 443 | target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name 444 | if key == target_name: 445 | params[key].append(p) 446 | is_custom_key = True 447 | break 448 | if not is_custom_key: 449 | if norm_weight_decay is not None and isinstance(module, norm_classes): 450 | params["norm"].append(p) 451 | else: 452 | params["other"].append(p) 453 | 454 | for child_name, child_module in module.named_children(): 455 | child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name 456 | _add_params(child_module, prefix=child_prefix) 457 | 458 | _add_params(model) 459 | 460 | param_groups = [] 461 | for key in params: 462 | if len(params[key]) > 0: 463 | param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) 464 | return param_groups 465 | -------------------------------------------------------------------------------- /gpt/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | import sys 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch.nn import functional as F 18 | 19 | from adopt import ADOPT 20 | 21 | 22 | class LayerNorm(nn.Module): 23 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 24 | 25 | def __init__(self, ndim, bias): 26 | super().__init__() 27 | self.weight = nn.Parameter(torch.ones(ndim)) 28 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 29 | 30 | def forward(self, input): 31 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 32 | 33 | class CausalSelfAttention(nn.Module): 34 | 35 | def __init__(self, config): 36 | super().__init__() 37 | assert config.n_embd % config.n_head == 0 38 | # key, query, value projections for all heads, but in a batch 39 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 40 | # output projection 41 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 42 | # regularization 43 | self.attn_dropout = nn.Dropout(config.dropout) 44 | self.resid_dropout = nn.Dropout(config.dropout) 45 | self.n_head = config.n_head 46 | self.n_embd = config.n_embd 47 | self.dropout = config.dropout 48 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 49 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 50 | if not self.flash: 51 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 52 | # causal mask to ensure that attention is only applied to the left in the input sequence 53 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 54 | .view(1, 1, config.block_size, config.block_size)) 55 | 56 | def forward(self, x): 57 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 58 | 59 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 60 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 61 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 62 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 63 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 64 | 65 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 66 | if self.flash: 67 | # efficient attention using Flash Attention CUDA kernels 68 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 69 | else: 70 | # manual implementation of attention 71 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 72 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 73 | att = F.softmax(att, dim=-1) 74 | att = self.attn_dropout(att) 75 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 76 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 77 | 78 | # output projection 79 | y = self.resid_dropout(self.c_proj(y)) 80 | return y 81 | 82 | class MLP(nn.Module): 83 | 84 | def __init__(self, config): 85 | super().__init__() 86 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 87 | self.gelu = nn.GELU() 88 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 89 | self.dropout = nn.Dropout(config.dropout) 90 | 91 | def forward(self, x): 92 | x = self.c_fc(x) 93 | x = self.gelu(x) 94 | x = self.c_proj(x) 95 | x = self.dropout(x) 96 | return x 97 | 98 | class Block(nn.Module): 99 | 100 | def __init__(self, config): 101 | super().__init__() 102 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 103 | self.attn = CausalSelfAttention(config) 104 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 105 | self.mlp = MLP(config) 106 | 107 | def forward(self, x): 108 | x = x + self.attn(self.ln_1(x)) 109 | x = x + self.mlp(self.ln_2(x)) 110 | return x 111 | 112 | @dataclass 113 | class GPTConfig: 114 | block_size: int = 1024 115 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 116 | n_layer: int = 12 117 | n_head: int = 12 118 | n_embd: int = 768 119 | dropout: float = 0.0 120 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 121 | 122 | class GPT(nn.Module): 123 | 124 | def __init__(self, config): 125 | super().__init__() 126 | assert config.vocab_size is not None 127 | assert config.block_size is not None 128 | self.config = config 129 | 130 | self.transformer = nn.ModuleDict(dict( 131 | wte = nn.Embedding(config.vocab_size, config.n_embd), 132 | wpe = nn.Embedding(config.block_size, config.n_embd), 133 | drop = nn.Dropout(config.dropout), 134 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 135 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 136 | )) 137 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 138 | # with weight tying when using torch.compile() some warnings get generated: 139 | # "UserWarning: functional_call was passed multiple values for tied weights. 140 | # This behavior is deprecated and will be an error in future versions" 141 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 142 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 143 | 144 | # init all weights 145 | self.apply(self._init_weights) 146 | # apply special scaled init to the residual projections, per GPT-2 paper 147 | for pn, p in self.named_parameters(): 148 | if pn.endswith('c_proj.weight'): 149 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 150 | 151 | # report number of parameters 152 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 153 | 154 | def get_num_params(self, non_embedding=True): 155 | """ 156 | Return the number of parameters in the model. 157 | For non-embedding count (default), the position embeddings get subtracted. 158 | The token embeddings would too, except due to the parameter sharing these 159 | params are actually used as weights in the final layer, so we include them. 160 | """ 161 | n_params = sum(p.numel() for p in self.parameters()) 162 | if non_embedding: 163 | n_params -= self.transformer.wpe.weight.numel() 164 | return n_params 165 | 166 | def _init_weights(self, module): 167 | if isinstance(module, nn.Linear): 168 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 169 | if module.bias is not None: 170 | torch.nn.init.zeros_(module.bias) 171 | elif isinstance(module, nn.Embedding): 172 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 173 | 174 | def forward(self, idx, targets=None): 175 | device = idx.device 176 | b, t = idx.size() 177 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 178 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 179 | 180 | # forward the GPT model itself 181 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 182 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 183 | x = self.transformer.drop(tok_emb + pos_emb) 184 | for block in self.transformer.h: 185 | x = block(x) 186 | x = self.transformer.ln_f(x) 187 | 188 | if targets is not None: 189 | # if we are given some desired targets also calculate the loss 190 | logits = self.lm_head(x) 191 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 192 | else: 193 | # inference-time mini-optimization: only forward the lm_head on the very last position 194 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 195 | loss = None 196 | 197 | return logits, loss 198 | 199 | def crop_block_size(self, block_size): 200 | # model surgery to decrease the block size if necessary 201 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 202 | # but want to use a smaller block size for some smaller, simpler model 203 | assert block_size <= self.config.block_size 204 | self.config.block_size = block_size 205 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 206 | for block in self.transformer.h: 207 | if hasattr(block.attn, 'bias'): 208 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 209 | 210 | @classmethod 211 | def from_pretrained(cls, model_type, override_args=None): 212 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 213 | override_args = override_args or {} # default to empty dict 214 | # only dropout can be overridden see more notes below 215 | assert all(k == 'dropout' for k in override_args) 216 | from transformers import GPT2LMHeadModel 217 | print("loading weights from pretrained gpt: %s" % model_type) 218 | 219 | # n_layer, n_head and n_embd are determined from model_type 220 | config_args = { 221 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 222 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 223 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 224 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 225 | }[model_type] 226 | print("forcing vocab_size=50257, block_size=1024, bias=True") 227 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 228 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 229 | config_args['bias'] = True # always True for GPT model checkpoints 230 | # we can override the dropout rate, if desired 231 | if 'dropout' in override_args: 232 | print(f"overriding dropout rate to {override_args['dropout']}") 233 | config_args['dropout'] = override_args['dropout'] 234 | # create a from-scratch initialized minGPT model 235 | config = GPTConfig(**config_args) 236 | model = GPT(config) 237 | sd = model.state_dict() 238 | sd_keys = sd.keys() 239 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 240 | 241 | # init a huggingface/transformers model 242 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 243 | sd_hf = model_hf.state_dict() 244 | 245 | # copy while ensuring all of the parameters are aligned and match in names and shapes 246 | sd_keys_hf = sd_hf.keys() 247 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 248 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 249 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 250 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 251 | # this means that we have to transpose these weights when we import them 252 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 253 | for k in sd_keys_hf: 254 | if any(k.endswith(w) for w in transposed): 255 | # special treatment for the Conv1D weights we need to transpose 256 | assert sd_hf[k].shape[::-1] == sd[k].shape 257 | with torch.no_grad(): 258 | sd[k].copy_(sd_hf[k].t()) 259 | else: 260 | # vanilla copy over the other parameters 261 | assert sd_hf[k].shape == sd[k].shape 262 | with torch.no_grad(): 263 | sd[k].copy_(sd_hf[k]) 264 | 265 | return model 266 | 267 | def configure_optimizers(self, optim_type, weight_decay, learning_rate, betas, device_type): 268 | # start with all of the candidate parameters 269 | param_dict = {pn: p for pn, p in self.named_parameters()} 270 | # filter out those that do not require grad 271 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 272 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 273 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 274 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 275 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 276 | optim_groups = [ 277 | {'params': decay_params, 'weight_decay': weight_decay}, 278 | {'params': nodecay_params, 'weight_decay': 0.0} 279 | ] 280 | num_decay_params = sum(p.numel() for p in decay_params) 281 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 282 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 283 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 284 | if optim_type == 'adopt': 285 | optimizer = ADOPT(optim_groups, lr=learning_rate, decoupled=True) 286 | print(f"using ADOPT") 287 | else: 288 | # Create AdamW optimizer and use the fused version if it is available 289 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 290 | use_fused = fused_available and device_type == 'cuda' 291 | extra_args = dict(fused=True) if use_fused else dict() 292 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 293 | print(f"using fused AdamW: {use_fused}") 294 | 295 | return optimizer 296 | 297 | def estimate_mfu(self, fwdbwd_per_iter, dt): 298 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 299 | # first estimate the number of flops we do per iteration. 300 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 301 | N = self.get_num_params() 302 | cfg = self.config 303 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 304 | flops_per_token = 6*N + 12*L*H*Q*T 305 | flops_per_fwdbwd = flops_per_token * T 306 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 307 | # express our flops throughput as ratio of A100 bfloat16 peak flops 308 | flops_achieved = flops_per_iter * (1.0/dt) # per second 309 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 310 | mfu = flops_achieved / flops_promised 311 | return mfu 312 | 313 | @torch.no_grad() 314 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 315 | """ 316 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 317 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 318 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 319 | """ 320 | for _ in range(max_new_tokens): 321 | # if the sequence context is growing too long we must crop it at block_size 322 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 323 | # forward the model to get the logits for the index in the sequence 324 | logits, _ = self(idx_cond) 325 | # pluck the logits at the final step and scale by desired temperature 326 | logits = logits[:, -1, :] / temperature 327 | # optionally crop the logits to only the top k options 328 | if top_k is not None: 329 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 330 | logits[logits < v[:, [-1]]] = -float('Inf') 331 | # apply softmax to convert logits to (normalized) probabilities 332 | probs = F.softmax(logits, dim=-1) 333 | # sample from the distribution 334 | idx_next = torch.multinomial(probs, num_samples=1) 335 | # append sampled index to the running sequence and continue 336 | idx = torch.cat((idx, idx_next), dim=1) 337 | 338 | return idx 339 | -------------------------------------------------------------------------------- /src/adopt/adopt.py: -------------------------------------------------------------------------------- 1 | # mypy: allow-untyped-decorators 2 | # mypy: allow-untyped-defs 3 | from typing import cast, Callable, List, Optional, Tuple, Union 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from torch.optim.optimizer import ( 9 | _capturable_doc, 10 | _default_to_fused_or_foreach, 11 | _device_dtype_check_for_fused, 12 | _differentiable_doc, 13 | _disable_dynamo_if_unsupported, 14 | _foreach_doc, 15 | _fused_doc, 16 | _get_capturable_supported_devices, 17 | _get_scalar_dtype, 18 | _get_value, 19 | _maximize_doc, 20 | _stack_if_compiling, 21 | _use_grad_for_differentiable, 22 | _view_as_real, 23 | DeviceDict, 24 | Optimizer, 25 | ParamsT, 26 | ) 27 | 28 | 29 | __all__ = ["ADOPT", "adopt"] 30 | 31 | 32 | class ADOPT(Optimizer): 33 | def __init__( 34 | self, 35 | params: ParamsT, 36 | lr: Union[float, Tensor] = 1e-3, 37 | betas: Tuple[float, float] = (0.9, 0.9999), 38 | eps: float = 1e-6, 39 | clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25, 40 | weight_decay: float = 0.0, 41 | decouple: bool = False, 42 | *, 43 | foreach: Optional[bool] = None, 44 | maximize: bool = False, 45 | capturable: bool = False, 46 | differentiable: bool = False, 47 | fused: Optional[bool] = None, 48 | ): 49 | if isinstance(lr, Tensor): 50 | if foreach and not capturable: 51 | raise ValueError( 52 | "lr as a Tensor is not supported for capturable=False and foreach=True" 53 | ) 54 | if lr.numel() != 1: 55 | raise ValueError("Tensor lr must be 1-element") 56 | if not 0.0 <= lr: 57 | raise ValueError(f"Invalid learning rate: {lr}") 58 | if not 0.0 <= eps: 59 | raise ValueError(f"Invalid epsilon value: {eps}") 60 | if not 0.0 <= betas[0] < 1.0: 61 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 62 | if not 0.0 <= betas[1] < 1.0: 63 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 64 | if not 0.0 <= weight_decay: 65 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 66 | 67 | self.clip_lambda = clip_lambda 68 | 69 | defaults = dict( 70 | lr=lr, 71 | betas=betas, 72 | eps=eps, 73 | weight_decay=weight_decay, 74 | decouple=decouple, 75 | maximize=maximize, 76 | foreach=foreach, 77 | capturable=capturable, 78 | differentiable=differentiable, 79 | fused=fused, 80 | ) 81 | super().__init__(params, defaults) 82 | 83 | if fused: 84 | # TODO: support fused 85 | raise RuntimeError("`fused` is not currently supported") 86 | 87 | if differentiable: 88 | raise RuntimeError("`fused` does not support `differentiable`") 89 | self._step_supports_amp_scaling = True 90 | # TODO(crcrpar): [low prec params & their higher prec copy] 91 | # Support AMP with FP16/BF16 model params which would need 92 | # higher prec copy of params to do update math in higher prec to 93 | # alleviate the loss of information. 94 | if foreach: 95 | raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 96 | 97 | def __setstate__(self, state): 98 | super().__setstate__(state) 99 | for group in self.param_groups: 100 | group.setdefault("maximize", False) 101 | group.setdefault("foreach", None) 102 | group.setdefault("capturable", False) 103 | group.setdefault("differentiable", False) 104 | fused = group.setdefault("fused", None) 105 | for p in group["params"]: 106 | p_state = self.state.get(p, []) 107 | if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 108 | step_val = float(p_state["step"]) 109 | p_state["step"] = ( 110 | torch.tensor( 111 | step_val, 112 | dtype=_get_scalar_dtype(is_fused=fused), 113 | device=p.device, 114 | ) 115 | if group["capturable"] or group["fused"] 116 | else torch.tensor(step_val, dtype=_get_scalar_dtype()) 117 | ) 118 | 119 | def _init_group( 120 | self, 121 | group, 122 | params_with_grad, 123 | grads, 124 | exp_avgs, 125 | exp_avg_sqs, 126 | state_steps, 127 | ): 128 | has_complex = False 129 | for p in group["params"]: 130 | if p.grad is not None: 131 | has_complex |= torch.is_complex(p) 132 | params_with_grad.append(p) 133 | if p.grad.is_sparse: 134 | raise RuntimeError( 135 | "ADOPT does not support sparse gradients" 136 | ) 137 | grads.append(p.grad) 138 | 139 | state = self.state[p] 140 | # Lazy state initialization 141 | if len(state) == 0: 142 | if group["fused"]: 143 | _device_dtype_check_for_fused(p) 144 | # note(crcrpar): [special device hosting for step] 145 | # Deliberately host `step` on CPU if both capturable and fused are off. 146 | # This is because kernel launches are costly on CUDA and XLA. 147 | state["step"] = ( 148 | torch.zeros( 149 | (), 150 | dtype=_get_scalar_dtype(is_fused=group["fused"]), 151 | device=p.device, 152 | ) 153 | if group["capturable"] or group["fused"] 154 | else torch.tensor(0.0, dtype=_get_scalar_dtype()) 155 | ) 156 | # Exponential moving average of gradient values 157 | state["exp_avg"] = torch.zeros_like( 158 | p, memory_format=torch.preserve_format 159 | ) 160 | # Exponential moving average of squared gradient values 161 | state["exp_avg_sq"] = torch.zeros_like( 162 | p, memory_format=torch.preserve_format 163 | ) 164 | 165 | exp_avgs.append(state["exp_avg"]) 166 | exp_avg_sqs.append(state["exp_avg_sq"]) 167 | 168 | if group["differentiable"] and state["step"].requires_grad: 169 | raise RuntimeError( 170 | "`requires_grad` is not supported for `step` in differentiable mode" 171 | ) 172 | 173 | # Foreach without capturable does not support a tensor lr 174 | if ( 175 | group["foreach"] 176 | and torch.is_tensor(group["lr"]) 177 | and not group["capturable"] 178 | ): 179 | raise RuntimeError( 180 | "lr as a Tensor is not supported for capturable=False and foreach=True" 181 | ) 182 | 183 | state_steps.append(state["step"]) 184 | return has_complex 185 | 186 | @_use_grad_for_differentiable 187 | def step(self, closure=None): 188 | """Perform a single optimization step. 189 | 190 | Args: 191 | closure (Callable, optional): A closure that reevaluates the model 192 | and returns the loss. 193 | """ 194 | self._cuda_graph_capture_health_check() 195 | 196 | loss = None 197 | if closure is not None: 198 | with torch.enable_grad(): 199 | loss = closure() 200 | 201 | for group in self.param_groups: 202 | params_with_grad: List[Tensor] = [] 203 | grads: List[Tensor] = [] 204 | exp_avgs: List[Tensor] = [] 205 | exp_avg_sqs: List[Tensor] = [] 206 | state_steps: List[Tensor] = [] 207 | beta1, beta2 = group["betas"] 208 | 209 | has_complex = self._init_group( 210 | group, 211 | params_with_grad, 212 | grads, 213 | exp_avgs, 214 | exp_avg_sqs, 215 | state_steps, 216 | ) 217 | 218 | adopt( 219 | params_with_grad, 220 | grads, 221 | exp_avgs, 222 | exp_avg_sqs, 223 | state_steps, 224 | has_complex=has_complex, 225 | beta1=beta1, 226 | beta2=beta2, 227 | lr=group["lr"], 228 | clip_lambda=self.clip_lambda, 229 | weight_decay=group["weight_decay"], 230 | decouple=group["decouple"], 231 | eps=group["eps"], 232 | maximize=group["maximize"], 233 | foreach=group["foreach"], 234 | capturable=group["capturable"], 235 | differentiable=group["differentiable"], 236 | fused=group["fused"], 237 | grad_scale=getattr(self, "grad_scale", None), 238 | found_inf=getattr(self, "found_inf", None), 239 | ) 240 | 241 | return loss 242 | 243 | 244 | def _single_tensor_adopt( 245 | params: List[Tensor], 246 | grads: List[Tensor], 247 | exp_avgs: List[Tensor], 248 | exp_avg_sqs: List[Tensor], 249 | state_steps: List[Tensor], 250 | grad_scale: Optional[Tensor], 251 | found_inf: Optional[Tensor], 252 | *, 253 | has_complex: bool, 254 | beta1: float, 255 | beta2: float, 256 | lr: Union[float, Tensor], 257 | clip_lambda: Optional[Callable[[int], float]], 258 | weight_decay: float, 259 | decouple: bool, 260 | eps: float, 261 | maximize: bool, 262 | capturable: bool, 263 | differentiable: bool, 264 | ): 265 | assert grad_scale is None and found_inf is None 266 | 267 | if torch.jit.is_scripting(): 268 | # this assert is due to JIT being dumb and not realizing that the ops below 269 | # have overloads to handle both float and Tensor lrs, so we just assert it's 270 | # a float since most people using JIT are using floats 271 | assert isinstance(lr, float) 272 | 273 | for i, param in enumerate(params): 274 | grad = grads[i] if not maximize else -grads[i] 275 | exp_avg = exp_avgs[i] 276 | exp_avg_sq = exp_avg_sqs[i] 277 | step_t = state_steps[i] 278 | 279 | # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 280 | if not torch._utils.is_compiling() and capturable: 281 | capturable_supported_devices = _get_capturable_supported_devices() 282 | assert ( 283 | param.device.type == step_t.device.type 284 | and param.device.type in capturable_supported_devices 285 | ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 286 | 287 | step = step_t if capturable or differentiable else _get_value(step_t) 288 | 289 | if weight_decay != 0 and not decouple: 290 | grad = grad.add(param, alpha=weight_decay) 291 | 292 | if torch.is_complex(param): 293 | grad = torch.view_as_real(grad) 294 | if exp_avg is not None: 295 | exp_avg = torch.view_as_real(exp_avg) 296 | if exp_avg_sq is not None: 297 | exp_avg_sq = torch.view_as_real(exp_avg_sq) 298 | param = torch.view_as_real(param) 299 | 300 | if step == 0: 301 | exp_avg_sq.addcmul_(grad, grad.conj()) 302 | # update step 303 | step_t += 1 304 | continue 305 | 306 | if weight_decay != 0 and decouple: 307 | param.add_(param, alpha=-lr*weight_decay) 308 | 309 | denom = torch.clamp(exp_avg_sq.sqrt(), eps) 310 | normed_grad = grad.div(denom) 311 | if clip_lambda is not None: 312 | clip = clip_lambda(step) 313 | normed_grad.clamp_(-clip, clip) 314 | 315 | exp_avg.lerp_(normed_grad, 1 - beta1) 316 | 317 | param.add_(exp_avg, alpha=-lr) 318 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 319 | 320 | # update step 321 | step_t += 1 322 | 323 | 324 | def _multi_tensor_adopt( 325 | params: List[Tensor], 326 | grads: List[Tensor], 327 | exp_avgs: List[Tensor], 328 | exp_avg_sqs: List[Tensor], 329 | state_steps: List[Tensor], 330 | grad_scale: Optional[Tensor], 331 | found_inf: Optional[Tensor], 332 | *, 333 | has_complex: bool, 334 | beta1: float, 335 | beta2: float, 336 | lr: Union[float, Tensor], 337 | clip_lambda: Optional[Callable[[int], float]], 338 | weight_decay: float, 339 | decouple: bool, 340 | eps: float, 341 | maximize: bool, 342 | capturable: bool, 343 | differentiable: bool, 344 | ): 345 | if len(params) == 0: 346 | return 347 | 348 | if isinstance(lr, Tensor) and not capturable: 349 | raise RuntimeError( 350 | "lr as a Tensor is not supported for capturable=False and foreach=True" 351 | ) 352 | 353 | # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 354 | if not torch._utils.is_compiling() and capturable: 355 | capturable_supported_devices = _get_capturable_supported_devices( 356 | supports_xla=False 357 | ) 358 | assert all( 359 | p.device.type == step.device.type 360 | and p.device.type in capturable_supported_devices 361 | for p, step in zip(params, state_steps) 362 | ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 363 | 364 | assert grad_scale is None and found_inf is None 365 | 366 | assert not differentiable, "_foreach ops don't support autograd" 367 | 368 | grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 369 | [params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item] 370 | ) 371 | for ( 372 | device_params_, 373 | device_grads_, 374 | device_exp_avgs_, 375 | device_exp_avg_sqs_, 376 | device_state_steps_, 377 | ), _ in grouped_tensors.values(): 378 | device_params = cast(List[Tensor], device_params_) 379 | device_grads = cast(List[Tensor], device_grads_) 380 | device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 381 | device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 382 | device_state_steps = cast(List[Tensor], device_state_steps_) 383 | 384 | # Handle complex parameters 385 | if has_complex: 386 | _view_as_real( 387 | device_params, device_grads, device_exp_avgs, device_exp_avg_sqs 388 | ) 389 | 390 | if maximize: 391 | device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 392 | 393 | if weight_decay != 0 and not decouple: 394 | # Re-use the intermediate memory (device_grads) already allocated for maximize 395 | if maximize: 396 | torch._foreach_add_(device_grads, device_params, alpha=weight_decay) 397 | else: 398 | device_grads = torch._foreach_add( # type: ignore[assignment] 399 | device_grads, device_params, alpha=weight_decay 400 | ) 401 | 402 | if device_state_steps[0] == 0: 403 | torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads) 404 | 405 | # Update steps 406 | # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 407 | # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 408 | # wrapped it once now. The alpha is required to assure we go to the right overload. 409 | if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 410 | torch._foreach_add_( 411 | device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 412 | ) 413 | else: 414 | torch._foreach_add_(device_state_steps, 1) 415 | 416 | continue 417 | 418 | if weight_decay != 0 and decouple: 419 | torch._foreach_add_(device_params, device_params, alpha=-lr*weight_decay) 420 | 421 | exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 422 | torch._foreach_maximum_(exp_avg_sq_sqrt, eps) 423 | 424 | normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt) 425 | if clip_lambda is not None: 426 | clip = clip_lambda(device_state_steps[0]) 427 | torch._foreach_maximum_(normed_grad, -clip) 428 | torch._foreach_minimum_(normed_grad, clip) 429 | 430 | torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1) 431 | 432 | torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) 433 | torch._foreach_mul_(device_exp_avg_sqs, beta2) 434 | torch._foreach_addcmul_( 435 | device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2 436 | ) 437 | 438 | # Update steps 439 | # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 440 | # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 441 | # wrapped it once now. The alpha is required to assure we go to the right overload. 442 | if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 443 | torch._foreach_add_( 444 | device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 445 | ) 446 | else: 447 | torch._foreach_add_(device_state_steps, 1) 448 | 449 | 450 | def _fused_adopt( 451 | params: List[Tensor], 452 | grads: List[Tensor], 453 | exp_avgs: List[Tensor], 454 | exp_avg_sqs: List[Tensor], 455 | state_steps: List[Tensor], 456 | grad_scale: Optional[Tensor], 457 | found_inf: Optional[Tensor], 458 | *, 459 | has_complex: bool, # Needed for consistency. 460 | beta1: float, 461 | beta2: float, 462 | lr: Union[float, Tensor], 463 | clip_lambda: Optional[Callable[[int], float]], 464 | weight_decay: float, 465 | decouple: bool, 466 | eps: float, 467 | maximize: bool, 468 | capturable: bool, # Needed for consistency. 469 | differentiable: bool, 470 | ) -> None: 471 | raise NotImplementedError 472 | 473 | 474 | @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) 475 | def adopt( 476 | params: List[Tensor], 477 | grads: List[Tensor], 478 | exp_avgs: List[Tensor], 479 | exp_avg_sqs: List[Tensor], 480 | state_steps: List[Tensor], 481 | # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 482 | # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 483 | foreach: Optional[bool] = None, 484 | capturable: bool = False, 485 | differentiable: bool = False, 486 | fused: Optional[bool] = None, 487 | grad_scale: Optional[Tensor] = None, 488 | found_inf: Optional[Tensor] = None, 489 | has_complex: bool = False, 490 | *, 491 | beta1: float, 492 | beta2: float, 493 | lr: Union[float, Tensor], 494 | clip_lambda: Optional[Callable[[int], float]], 495 | weight_decay: float, 496 | decouple: bool, 497 | eps: float, 498 | maximize: bool, 499 | ): 500 | r"""Functional API that performs ADOPT algorithm computation. 501 | 502 | """ 503 | # Respect when the user inputs False/True for foreach or fused. We only want to change 504 | # the default when neither have been user-specified. Note that we default to foreach 505 | # and pass False to use_fused. This is not a mistake--we want to give the fused impl 506 | # bake-in time before making it the default, even if it is typically faster. 507 | if fused is None and foreach is None: 508 | _, foreach = _default_to_fused_or_foreach( 509 | params, differentiable, use_fused=False 510 | ) 511 | # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. 512 | if foreach and isinstance(lr, Tensor) and not capturable: 513 | foreach = False 514 | if fused is None: 515 | fused = False 516 | if foreach is None: 517 | foreach = False 518 | 519 | # this check is slow during compilation, so we skip it 520 | # if it's strictly needed we can add this check back in dynamo 521 | if not torch._utils.is_compiling() and not all( 522 | isinstance(t, torch.Tensor) for t in state_steps 523 | ): 524 | raise RuntimeError( 525 | "API has changed, `state_steps` argument must contain a list of singleton tensors" 526 | ) 527 | 528 | if foreach and torch.jit.is_scripting(): 529 | raise RuntimeError("torch.jit.script not supported with foreach optimizers") 530 | if fused and torch.jit.is_scripting(): 531 | raise RuntimeError("torch.jit.script not supported with fused optimizers") 532 | 533 | if fused and not torch.jit.is_scripting(): 534 | func = _fused_adopt 535 | elif foreach and not torch.jit.is_scripting(): 536 | func = _multi_tensor_adopt 537 | else: 538 | func = _single_tensor_adopt 539 | 540 | func( 541 | params, 542 | grads, 543 | exp_avgs, 544 | exp_avg_sqs, 545 | state_steps, 546 | has_complex=has_complex, 547 | beta1=beta1, 548 | beta2=beta2, 549 | lr=lr, 550 | clip_lambda=clip_lambda, 551 | weight_decay=weight_decay, 552 | decouple=decouple, 553 | eps=eps, 554 | maximize=maximize, 555 | capturable=capturable, 556 | differentiable=differentiable, 557 | grad_scale=grad_scale, 558 | found_inf=found_inf, 559 | ) 560 | -------------------------------------------------------------------------------- /imagenet/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import sys 4 | import time 5 | import warnings 6 | 7 | import presets 8 | import torch 9 | import torch.utils.data 10 | import torchvision 11 | import torchvision.transforms 12 | import utils 13 | from sampler import RASampler 14 | from torch import nn 15 | from torch.utils.data.dataloader import default_collate 16 | from torchvision.transforms.functional import InterpolationMode 17 | from transforms import get_mixup_cutmix 18 | 19 | from adopt import ADOPT 20 | 21 | 22 | def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): 23 | model.train() 24 | metric_logger = utils.MetricLogger(delimiter=" ") 25 | metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) 26 | metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) 27 | 28 | header = f"Epoch: [{epoch}]" 29 | for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): 30 | start_time = time.time() 31 | image, target = image.to(device), target.to(device) 32 | with torch.cuda.amp.autocast(enabled=scaler is not None): 33 | output = model(image) 34 | loss = criterion(output, target) 35 | 36 | optimizer.zero_grad() 37 | if scaler is not None: 38 | scaler.scale(loss).backward() 39 | if args.clip_grad_norm is not None: 40 | # we should unscale the gradients of optimizer's assigned params if do gradient clipping 41 | scaler.unscale_(optimizer) 42 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 43 | scaler.step(optimizer) 44 | scaler.update() 45 | else: 46 | loss.backward() 47 | if args.clip_grad_norm is not None: 48 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 49 | optimizer.step() 50 | 51 | if model_ema and i % args.model_ema_steps == 0: 52 | model_ema.update_parameters(model) 53 | if epoch < args.lr_warmup_epochs: 54 | # Reset ema buffer to keep copying weights during warmup period 55 | model_ema.n_averaged.fill_(0) 56 | 57 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 58 | batch_size = image.shape[0] 59 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 60 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 61 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 62 | metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) 63 | 64 | 65 | def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): 66 | model.eval() 67 | metric_logger = utils.MetricLogger(delimiter=" ") 68 | header = f"Test: {log_suffix}" 69 | 70 | num_processed_samples = 0 71 | with torch.inference_mode(): 72 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 73 | image = image.to(device, non_blocking=True) 74 | target = target.to(device, non_blocking=True) 75 | output = model(image) 76 | loss = criterion(output, target) 77 | 78 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 79 | # FIXME need to take into account that the datasets 80 | # could have been padded in distributed setup 81 | batch_size = image.shape[0] 82 | metric_logger.update(loss=loss.item()) 83 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 84 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 85 | num_processed_samples += batch_size 86 | # gather the stats from all processes 87 | 88 | num_processed_samples = utils.reduce_across_processes(num_processed_samples) 89 | if ( 90 | hasattr(data_loader.dataset, "__len__") 91 | and len(data_loader.dataset) != num_processed_samples 92 | and torch.distributed.get_rank() == 0 93 | ): 94 | # See FIXME above 95 | warnings.warn( 96 | f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " 97 | "samples were used for the validation, which might bias the results. " 98 | "Try adjusting the batch size and / or the world size. " 99 | "Setting the world size to 1 is always a safe bet." 100 | ) 101 | 102 | metric_logger.synchronize_between_processes() 103 | 104 | print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") 105 | return metric_logger.acc1.global_avg 106 | 107 | 108 | def _get_cache_path(filepath): 109 | import hashlib 110 | 111 | h = hashlib.sha1(filepath.encode()).hexdigest() 112 | cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") 113 | cache_path = os.path.expanduser(cache_path) 114 | return cache_path 115 | 116 | 117 | def load_data(traindir, valdir, args): 118 | # Data loading code 119 | print("Loading data") 120 | val_resize_size, val_crop_size, train_crop_size = ( 121 | args.val_resize_size, 122 | args.val_crop_size, 123 | args.train_crop_size, 124 | ) 125 | interpolation = InterpolationMode(args.interpolation) 126 | 127 | print("Loading training data") 128 | st = time.time() 129 | cache_path = _get_cache_path(traindir) 130 | if args.cache_dataset and os.path.exists(cache_path): 131 | # Attention, as the transforms are also cached! 132 | print(f"Loading dataset_train from {cache_path}") 133 | # TODO: this could probably be weights_only=True 134 | dataset, _ = torch.load(cache_path, weights_only=False) 135 | else: 136 | # We need a default value for the variables below because args may come 137 | # from train_quantization.py which doesn't define them. 138 | auto_augment_policy = getattr(args, "auto_augment", None) 139 | random_erase_prob = getattr(args, "random_erase", 0.0) 140 | ra_magnitude = getattr(args, "ra_magnitude", None) 141 | augmix_severity = getattr(args, "augmix_severity", None) 142 | dataset = torchvision.datasets.ImageFolder( 143 | traindir, 144 | presets.ClassificationPresetTrain( 145 | crop_size=train_crop_size, 146 | interpolation=interpolation, 147 | auto_augment_policy=auto_augment_policy, 148 | random_erase_prob=random_erase_prob, 149 | ra_magnitude=ra_magnitude, 150 | augmix_severity=augmix_severity, 151 | backend=args.backend, 152 | use_v2=args.use_v2, 153 | ), 154 | ) 155 | if args.cache_dataset: 156 | print(f"Saving dataset_train to {cache_path}") 157 | utils.mkdir(os.path.dirname(cache_path)) 158 | utils.save_on_master((dataset, traindir), cache_path) 159 | print("Took", time.time() - st) 160 | 161 | print("Loading validation data") 162 | cache_path = _get_cache_path(valdir) 163 | if args.cache_dataset and os.path.exists(cache_path): 164 | # Attention, as the transforms are also cached! 165 | print(f"Loading dataset_test from {cache_path}") 166 | # TODO: this could probably be weights_only=True 167 | dataset_test, _ = torch.load(cache_path, weights_only=False) 168 | else: 169 | if args.weights and args.test_only: 170 | weights = torchvision.models.get_weight(args.weights) 171 | preprocessing = weights.transforms(antialias=True) 172 | if args.backend == "tensor": 173 | preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing]) 174 | 175 | else: 176 | preprocessing = presets.ClassificationPresetEval( 177 | crop_size=val_crop_size, 178 | resize_size=val_resize_size, 179 | interpolation=interpolation, 180 | backend=args.backend, 181 | use_v2=args.use_v2, 182 | ) 183 | 184 | dataset_test = torchvision.datasets.ImageFolder( 185 | valdir, 186 | preprocessing, 187 | ) 188 | if args.cache_dataset: 189 | print(f"Saving dataset_test to {cache_path}") 190 | utils.mkdir(os.path.dirname(cache_path)) 191 | utils.save_on_master((dataset_test, valdir), cache_path) 192 | 193 | print("Creating data loaders") 194 | if args.distributed: 195 | if hasattr(args, "ra_sampler") and args.ra_sampler: 196 | train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) 197 | else: 198 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 199 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) 200 | else: 201 | train_sampler = torch.utils.data.RandomSampler(dataset) 202 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 203 | 204 | return dataset, dataset_test, train_sampler, test_sampler 205 | 206 | 207 | def main(args): 208 | if args.output_dir: 209 | utils.mkdir(args.output_dir) 210 | 211 | utils.init_distributed_mode(args) 212 | print(args) 213 | 214 | device = torch.device(args.device) 215 | 216 | if args.use_deterministic_algorithms: 217 | torch.backends.cudnn.benchmark = False 218 | torch.use_deterministic_algorithms(True) 219 | else: 220 | torch.backends.cudnn.benchmark = True 221 | 222 | train_dir = os.path.join(args.data_path, "train") 223 | val_dir = os.path.join(args.data_path, "val") 224 | dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) 225 | 226 | num_classes = len(dataset.classes) 227 | mixup_cutmix = get_mixup_cutmix( 228 | mixup_alpha=args.mixup_alpha, cutmix_alpha=args.cutmix_alpha, num_classes=num_classes, use_v2=args.use_v2 229 | ) 230 | if mixup_cutmix is not None: 231 | 232 | def collate_fn(batch): 233 | return mixup_cutmix(*default_collate(batch)) 234 | 235 | else: 236 | collate_fn = default_collate 237 | 238 | data_loader = torch.utils.data.DataLoader( 239 | dataset, 240 | batch_size=args.batch_size, 241 | sampler=train_sampler, 242 | num_workers=args.workers, 243 | pin_memory=True, 244 | collate_fn=collate_fn, 245 | ) 246 | data_loader_test = torch.utils.data.DataLoader( 247 | dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True 248 | ) 249 | 250 | print("Creating model") 251 | model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes) 252 | model.to(device) 253 | 254 | if args.distributed and args.sync_bn: 255 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 256 | 257 | criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) 258 | 259 | custom_keys_weight_decay = [] 260 | if args.bias_weight_decay is not None: 261 | custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) 262 | if args.transformer_embedding_decay is not None: 263 | for key in ["class_token", "position_embedding", "relative_position_bias_table"]: 264 | custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) 265 | parameters = utils.set_weight_decay( 266 | model, 267 | args.weight_decay, 268 | norm_weight_decay=args.norm_weight_decay, 269 | custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None, 270 | ) 271 | 272 | opt_name = args.opt.lower() 273 | if opt_name.startswith("sgd"): 274 | optimizer = torch.optim.SGD( 275 | parameters, 276 | lr=args.lr, 277 | momentum=args.momentum, 278 | weight_decay=args.weight_decay, 279 | nesterov="nesterov" in opt_name, 280 | ) 281 | elif opt_name == "rmsprop": 282 | optimizer = torch.optim.RMSprop( 283 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 284 | ) 285 | elif opt_name == "adamw": 286 | optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) 287 | elif opt_name == "adopt": 288 | optimizer = ADOPT(parameters, lr=args.lr, weight_decay=args.weight_decay) 289 | elif opt_name == "adoptw": 290 | optimizer = ADOPT(parameters, lr=args.lr, weight_decay=args.weight_decay, decoupled=True) 291 | else: 292 | raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop, AdamW and ADOPT are supported.") 293 | 294 | scaler = torch.cuda.amp.GradScaler() if args.amp else None 295 | 296 | args.lr_scheduler = args.lr_scheduler.lower() 297 | if args.lr_scheduler == "steplr": 298 | main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 299 | elif args.lr_scheduler == "cosineannealinglr": 300 | main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 301 | optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min 302 | ) 303 | elif args.lr_scheduler == "exponentiallr": 304 | main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) 305 | else: 306 | raise RuntimeError( 307 | f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR " 308 | "are supported." 309 | ) 310 | 311 | if args.lr_warmup_epochs > 0: 312 | if args.lr_warmup_method == "linear": 313 | warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( 314 | optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs 315 | ) 316 | elif args.lr_warmup_method == "constant": 317 | warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( 318 | optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs 319 | ) 320 | else: 321 | raise RuntimeError( 322 | f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported." 323 | ) 324 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 325 | optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] 326 | ) 327 | else: 328 | lr_scheduler = main_lr_scheduler 329 | 330 | model_without_ddp = model 331 | if args.distributed: 332 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 333 | model_without_ddp = model.module 334 | 335 | model_ema = None 336 | if args.model_ema: 337 | # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at: 338 | # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 339 | # 340 | # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) 341 | # We consider constant = Dataset_size for a given dataset/setup and omit it. Thus: 342 | # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs 343 | adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs 344 | alpha = 1.0 - args.model_ema_decay 345 | alpha = min(1.0, alpha * adjust) 346 | model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) 347 | 348 | if args.resume: 349 | checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True) 350 | model_without_ddp.load_state_dict(checkpoint["model"]) 351 | if not args.test_only: 352 | optimizer.load_state_dict(checkpoint["optimizer"]) 353 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 354 | args.start_epoch = checkpoint["epoch"] + 1 355 | if model_ema: 356 | model_ema.load_state_dict(checkpoint["model_ema"]) 357 | if scaler: 358 | scaler.load_state_dict(checkpoint["scaler"]) 359 | 360 | if args.test_only: 361 | # We disable the cudnn benchmarking because it can noticeably affect the accuracy 362 | torch.backends.cudnn.benchmark = False 363 | torch.backends.cudnn.deterministic = True 364 | if model_ema: 365 | evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") 366 | else: 367 | evaluate(model, criterion, data_loader_test, device=device) 368 | return 369 | 370 | print("Start training") 371 | start_time = time.time() 372 | for epoch in range(args.start_epoch, args.epochs): 373 | if args.distributed: 374 | train_sampler.set_epoch(epoch) 375 | train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) 376 | lr_scheduler.step() 377 | evaluate(model, criterion, data_loader_test, device=device) 378 | if model_ema: 379 | evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") 380 | if args.output_dir: 381 | checkpoint = { 382 | "model": model_without_ddp.state_dict(), 383 | "optimizer": optimizer.state_dict(), 384 | "lr_scheduler": lr_scheduler.state_dict(), 385 | "epoch": epoch, 386 | "args": args, 387 | } 388 | if model_ema: 389 | checkpoint["model_ema"] = model_ema.state_dict() 390 | if scaler: 391 | checkpoint["scaler"] = scaler.state_dict() 392 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) 393 | utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) 394 | 395 | total_time = time.time() - start_time 396 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 397 | print(f"Training time {total_time_str}") 398 | 399 | 400 | def get_args_parser(add_help=True): 401 | import argparse 402 | 403 | parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) 404 | 405 | parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path") 406 | parser.add_argument("--model", default="resnet18", type=str, help="model name") 407 | parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") 408 | parser.add_argument( 409 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" 410 | ) 411 | parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") 412 | parser.add_argument( 413 | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" 414 | ) 415 | parser.add_argument("--opt", default="sgd", type=str, help="optimizer") 416 | parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") 417 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 418 | parser.add_argument( 419 | "--wd", 420 | "--weight-decay", 421 | default=1e-4, 422 | type=float, 423 | metavar="W", 424 | help="weight decay (default: 1e-4)", 425 | dest="weight_decay", 426 | ) 427 | parser.add_argument( 428 | "--norm-weight-decay", 429 | default=None, 430 | type=float, 431 | help="weight decay for Normalization layers (default: None, same value as --wd)", 432 | ) 433 | parser.add_argument( 434 | "--bias-weight-decay", 435 | default=None, 436 | type=float, 437 | help="weight decay for bias parameters of all layers (default: None, same value as --wd)", 438 | ) 439 | parser.add_argument( 440 | "--transformer-embedding-decay", 441 | default=None, 442 | type=float, 443 | help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)", 444 | ) 445 | parser.add_argument( 446 | "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" 447 | ) 448 | parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") 449 | parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") 450 | parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)") 451 | parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") 452 | parser.add_argument( 453 | "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)" 454 | ) 455 | parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") 456 | parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") 457 | parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") 458 | parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") 459 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency") 460 | parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") 461 | parser.add_argument("--resume", default="", type=str, help="path of checkpoint") 462 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") 463 | parser.add_argument( 464 | "--cache-dataset", 465 | dest="cache_dataset", 466 | help="Cache the datasets for quicker initialization. It also serializes the transforms", 467 | action="store_true", 468 | ) 469 | parser.add_argument( 470 | "--sync-bn", 471 | dest="sync_bn", 472 | help="Use sync batch norm", 473 | action="store_true", 474 | ) 475 | parser.add_argument( 476 | "--test-only", 477 | dest="test_only", 478 | help="Only test the model", 479 | action="store_true", 480 | ) 481 | parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") 482 | parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") 483 | parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") 484 | parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") 485 | 486 | # Mixed precision training parameters 487 | parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") 488 | 489 | # distributed training parameters 490 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") 491 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") 492 | parser.add_argument( 493 | "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" 494 | ) 495 | parser.add_argument( 496 | "--model-ema-steps", 497 | type=int, 498 | default=32, 499 | help="the number of iterations that controls how often to update the EMA model (default: 32)", 500 | ) 501 | parser.add_argument( 502 | "--model-ema-decay", 503 | type=float, 504 | default=0.99998, 505 | help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", 506 | ) 507 | parser.add_argument( 508 | "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." 509 | ) 510 | parser.add_argument( 511 | "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" 512 | ) 513 | parser.add_argument( 514 | "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)" 515 | ) 516 | parser.add_argument( 517 | "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" 518 | ) 519 | parser.add_argument( 520 | "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" 521 | ) 522 | parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") 523 | parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training") 524 | parser.add_argument( 525 | "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" 526 | ) 527 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") 528 | parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") 529 | parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") 530 | return parser 531 | 532 | 533 | if __name__ == "__main__": 534 | args = get_args_parser().parse_args() 535 | main(args) 536 | --------------------------------------------------------------------------------