├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── datasets.py ├── dynamic_tanh.py ├── engine.py ├── main.py ├── optim_factory.py ├── other_tasks ├── DINO │ ├── README.md │ ├── dynamic-tanh.patch │ └── dynamic_tanh.py ├── DNA │ ├── README.md │ ├── dynamic-tanh.patch │ └── dynamic_tanh.py ├── DiT │ ├── README.md │ ├── dynamic-tanh.patch │ ├── dynamic_tanh.py │ └── learning-rate-fix.patch ├── Efficiency │ ├── README.md │ ├── benchmark.py │ └── dynamic_tanh.py ├── LLaMA │ ├── LICENSE │ ├── README.md │ ├── fms_fsdp │ │ ├── __init__.py │ │ ├── config │ │ │ ├── __init__.py │ │ │ └── training.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── llama.py │ │ ├── policies │ │ │ ├── __init__.py │ │ │ ├── ac_handler.py │ │ │ ├── mixed_precision.py │ │ │ └── wrapping.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── checkpointing_utils.py │ │ │ ├── config_utils.py │ │ │ ├── dataset_utils.py │ │ │ └── train_utils.py │ ├── main_training_llama.py │ └── prepare_data.py ├── MAE │ ├── README.md │ ├── compatibility-fix.patch │ ├── dynamic-tanh.patch │ └── dynamic_tanh.py └── wav2vec2 │ ├── README.md │ ├── dynamic-tanh.patch │ └── wav2vec2_large_librispeech.yaml └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to this repo, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Jiachen Zhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Transformers without Normalization](https://arxiv.org/abs/2503.10622) 2 | 3 | Official PyTorch implementation of **DynamicTanh (DyT)**, from the following paper: 4 | 5 | [Transformers without Normalization](https://arxiv.org/abs/2503.10622). CVPR 2025. \ 6 | [Jiachen Zhu](https://jiachenzhu.github.io), [Xinlei Chen](https://xinleic.xyz/), [Kaiming He](https://people.csail.mit.edu/kaiming/), [Yann LeCun](http://yann.lecun.com) and [Zhuang Liu](https://liuzhuang13.github.io) \ 7 | FAIR, NYU, MIT, Princeton \ 8 | [[`arXiv`](https://arxiv.org/abs/2503.10622)][[`project page`](https://jiachenzhu.github.io/DyT/)] 9 | 10 | --- 11 | 12 |

13 | 15 |

16 | 17 | We propose **DynamicTanh(DyT)**, an element-wise operation defined as: DyT(***x***) = tanh($\alpha$***x***), where $\alpha$ is a learnable scaler. 18 | DyT is designed to replace normalization layers in Transformers. Models with DyT achieves similar or better performance than their normalized counterparts. 19 | 20 | 21 | 22 | ## Installation 23 | To reproduce our results, run the following commands to set up the Python environment: 24 | ``` 25 | conda create -n DyT python=3.12 26 | conda activate DyT 27 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia 28 | pip install timm==1.0.15 tensorboard 29 | ``` 30 | 31 | ## Training 32 | 33 | To reproduce our results on ImageNet-1K with ViT and ConvNeXt, run the following commands: \ 34 | (For results with LN, set `--dynamic_tanh` to `false`.) 35 | 36 |
37 | 38 | ViT-B 39 | 40 | 41 | ``` 42 | torchrun --nnodes=8 --nproc_per_node=8 main.py \ 43 | --model vit_base_patch16_224 \ 44 | --drop_path 0.1 \ 45 | --batch_size 64 \ 46 | --lr 4e-3 \ 47 | --update_freq 1 \ 48 | --model_ema true \ 49 | --model_ema_eval true \ 50 | --data_path /path/to/imagenet \ 51 | --output_dir /path/to/saving_dir \ 52 | --dynamic_tanh true 53 | ``` 54 |
55 |
56 | 57 | ViT-L 58 | 59 | 60 | ``` 61 | torchrun --nnodes=8 --nproc_per_node=8 main.py \ 62 | --model vit_large_patch16_224 \ 63 | --drop_path 0.4 \ 64 | --batch_size 64 \ 65 | --lr 4e-3 \ 66 | --update_freq 1 \ 67 | --model_ema true \ 68 | --model_ema_eval true \ 69 | --opt_betas 0.9 0.95 \ 70 | --data_path /path/to/imagenet \ 71 | --output_dir /path/to/saving_dir \ 72 | --dynamic_tanh true 73 | ``` 74 |
75 |
76 | 77 | ConvNeXt-B 78 | 79 | 80 | ``` 81 | torchrun --nnodes=8 --nproc_per_node=8 main.py \ 82 | --model convnext_base \ 83 | --drop_path 0.5 \ 84 | --batch_size 64 \ 85 | --lr 4e-3 \ 86 | --update_freq 1 \ 87 | --model_ema true \ 88 | --model_ema_eval true \ 89 | --data_path /path/to/imagenet \ 90 | --output_dir /path/to/saving_dir \ 91 | --dynamic_tanh true 92 | ``` 93 |
94 |
95 | 96 | ConvNeXt-L 97 | 98 | 99 | ``` 100 | torchrun --nnodes=8 --nproc_per_node=8 main.py \ 101 | --model convnext_large \ 102 | --drop_path 0.5 \ 103 | --batch_size 64 \ 104 | --lr 4e-3 \ 105 | --update_freq 1 \ 106 | --model_ema true \ 107 | --model_ema_eval true \ 108 | --data_path /path/to/imagenet \ 109 | --output_dir /path/to/saving_dir \ 110 | --dynamic_tanh true 111 | ``` 112 |
113 | 114 | ## ImageNet-1K Results 115 | 116 | | name | acc@1 (LN) | acc@1 (DyT) | 117 | |:---:|:---:|:---:| 118 | | ViT-B | 82.3% | 82.5% | 119 | | ViT-L | 83.1% | 83.6% | 120 | | ConvNeXt-B | 83.7% | 83.7% | 121 | | ConvNeXt-L | 84.3% | 84.4% | 122 | 123 | ## Other Tasks 124 | To reproduce results for other tasks, follow the instructions in the respective folders: 125 | - [MAE](other_tasks/MAE) 126 | - [DINO](other_tasks/DINO) 127 | - [DiT](other_tasks/DiT) 128 | - [LLaMA](other_tasks/LLaMA) 129 | - [wav2vec 2.0](other_tasks/wav2vec2) 130 | - [DNA](other_tasks/DNA) 131 | 132 | To apply DyT to your own models, see the [HowTo](other_tasks/HowTo) guide. 133 | 134 | ## Efficiency 135 | To reproduce the computational efficiency results in *Section 6.1*, follow the instructions in the [Efficiency](other_tasks/Efficiency) folder. 136 | 137 | ## Acknowledgement 138 | This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library and [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repository. 139 | 140 | ## License 141 | This project is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. 142 | 143 | ## Citation 144 | If you find this repository helpful, please consider citing: 145 | ``` 146 | @inproceedings{Zhu2025DyT, 147 | title={Transformers without Normalization}, 148 | author={Zhu, Jiachen and Chen, Xinlei and He, Kaiming and LeCun, Yann and Liu, Zhuang}, 149 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 150 | year={2025} 151 | } 152 | ``` 153 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | from torchvision import datasets, transforms 11 | 12 | from timm.data.constants import \ 13 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 14 | from timm.data import create_transform 15 | 16 | def build_dataset(is_train, args): 17 | transform = build_transform(is_train, args) 18 | 19 | print("Transform = ") 20 | if isinstance(transform, tuple): 21 | for trans in transform: 22 | print(" - - - - - - - - - - ") 23 | for t in trans.transforms: 24 | print(t) 25 | else: 26 | for t in transform.transforms: 27 | print(t) 28 | print("---------------------------") 29 | 30 | if args.data_set == 'CIFAR': 31 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 32 | nb_classes = 100 33 | elif args.data_set == 'IMNET': 34 | print("reading from datapath", args.data_path) 35 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 36 | dataset = datasets.ImageFolder(root, transform=transform) 37 | nb_classes = 1000 38 | elif args.data_set == "image_folder": 39 | root = args.data_path if is_train else args.eval_data_path 40 | dataset = datasets.ImageFolder(root, transform=transform) 41 | nb_classes = args.nb_classes 42 | assert len(dataset.class_to_idx) == nb_classes 43 | else: 44 | raise NotImplementedError() 45 | print("Number of the class = %d" % nb_classes) 46 | 47 | return dataset, nb_classes 48 | 49 | 50 | def build_transform(is_train, args): 51 | resize_im = args.input_size > 32 52 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 53 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 54 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 55 | 56 | if is_train: 57 | # this should always dispatch to transforms_imagenet_train 58 | transform = create_transform( 59 | input_size=args.input_size, 60 | is_training=True, 61 | color_jitter=args.color_jitter, 62 | auto_augment=args.aa, 63 | interpolation=args.train_interpolation, 64 | re_prob=args.reprob, 65 | re_mode=args.remode, 66 | re_count=args.recount, 67 | mean=mean, 68 | std=std, 69 | ) 70 | if not resize_im: 71 | transform.transforms[0] = transforms.RandomCrop( 72 | args.input_size, padding=4) 73 | return transform 74 | 75 | t = [] 76 | if resize_im: 77 | # warping (no cropping) when evaluated at 384 or larger 78 | if args.input_size >= 384: 79 | t.append( 80 | transforms.Resize((args.input_size, args.input_size), 81 | interpolation=transforms.InterpolationMode.BICUBIC), 82 | ) 83 | print(f"Warping {args.input_size} size input images...") 84 | else: 85 | if args.crop_pct is None: 86 | args.crop_pct = 224 / 256 87 | size = int(args.input_size / args.crop_pct) 88 | t.append( 89 | # to maintain same ratio w.r.t. 224 images 90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 91 | ) 92 | t.append(transforms.CenterCrop(args.input_size)) 93 | 94 | t.append(transforms.ToTensor()) 95 | t.append(transforms.Normalize(mean, std)) 96 | return transforms.Compose(t) 97 | -------------------------------------------------------------------------------- /dynamic_tanh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.layers import LayerNorm2d 4 | 5 | 6 | class DynamicTanh(nn.Module): 7 | def __init__(self, normalized_shape, channels_last, alpha_init_value=0.5): 8 | super().__init__() 9 | self.normalized_shape = normalized_shape 10 | self.alpha_init_value = alpha_init_value 11 | self.channels_last = channels_last 12 | 13 | self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) 14 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 15 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 16 | 17 | def forward(self, x): 18 | x = torch.tanh(self.alpha * x) 19 | if self.channels_last: 20 | x = x * self.weight + self.bias 21 | else: 22 | x = x * self.weight[:, None, None] + self.bias[:, None, None] 23 | return x 24 | 25 | def extra_repr(self): 26 | return f"normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}, channels_last={self.channels_last}" 27 | 28 | 29 | def convert_ln_to_dyt(module): 30 | module_output = module 31 | if isinstance(module, nn.LayerNorm): 32 | module_output = DynamicTanh(module.normalized_shape, not isinstance(module, LayerNorm2d)) 33 | for name, child in module.named_children(): 34 | module_output.add_module(name, convert_ln_to_dyt(child)) 35 | del module 36 | return module_output 37 | 38 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import math 10 | from typing import Iterable, Optional 11 | import torch 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | 15 | import utils 16 | 17 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 18 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 19 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 20 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 21 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 22 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 23 | model.train(True) 24 | metric_logger = utils.MetricLogger(delimiter=" ") 25 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 27 | header = 'Epoch: [{}]'.format(epoch) 28 | print_freq = 10 29 | 30 | optimizer.zero_grad() 31 | 32 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 33 | step = data_iter_step // update_freq 34 | if step >= num_training_steps_per_epoch: 35 | continue 36 | it = start_steps + step # global training iteration 37 | # Update LR & WD for the first acc 38 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 39 | for i, param_group in enumerate(optimizer.param_groups): 40 | if lr_schedule_values is not None: 41 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 42 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 43 | param_group["weight_decay"] = wd_schedule_values[it] 44 | 45 | samples = samples.to(device, non_blocking=True) 46 | targets = targets.to(device, non_blocking=True) 47 | 48 | if mixup_fn is not None: 49 | samples, targets = mixup_fn(samples, targets) 50 | 51 | if use_amp: 52 | with torch.cuda.amp.autocast(): 53 | output = model(samples) 54 | loss = criterion(output, targets) 55 | else: # full precision 56 | output = model(samples) 57 | loss = criterion(output, targets) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): # this could trigger if using AMP 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | assert math.isfinite(loss_value) 64 | 65 | if use_amp: 66 | # this attribute is added by timm on one optimizer (adahessian) 67 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 68 | loss /= update_freq 69 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 70 | parameters=model.parameters(), create_graph=is_second_order, 71 | update_grad=(data_iter_step + 1) % update_freq == 0) 72 | if (data_iter_step + 1) % update_freq == 0: 73 | optimizer.zero_grad() 74 | if model_ema is not None: 75 | model_ema.update(model) 76 | else: # full precision 77 | loss /= update_freq 78 | loss.backward() 79 | if (data_iter_step + 1) % update_freq == 0: 80 | optimizer.step() 81 | optimizer.zero_grad() 82 | if model_ema is not None: 83 | model_ema.update(model) 84 | 85 | torch.cuda.synchronize() 86 | 87 | if mixup_fn is None: 88 | class_acc = (output.max(-1)[-1] == targets).float().mean() 89 | else: 90 | class_acc = None 91 | metric_logger.update(loss=loss_value) 92 | metric_logger.update(class_acc=class_acc) 93 | min_lr = 10. 94 | max_lr = 0. 95 | for group in optimizer.param_groups: 96 | min_lr = min(min_lr, group["lr"]) 97 | max_lr = max(max_lr, group["lr"]) 98 | 99 | metric_logger.update(lr=max_lr) 100 | metric_logger.update(min_lr=min_lr) 101 | weight_decay_value = None 102 | for group in optimizer.param_groups: 103 | if group["weight_decay"] > 0: 104 | weight_decay_value = group["weight_decay"] 105 | metric_logger.update(weight_decay=weight_decay_value) 106 | if use_amp: 107 | metric_logger.update(grad_norm=grad_norm) 108 | 109 | if log_writer is not None: 110 | log_writer.update(loss=loss_value, head="loss") 111 | log_writer.update(class_acc=class_acc, head="loss") 112 | log_writer.update(lr=max_lr, head="opt") 113 | log_writer.update(min_lr=min_lr, head="opt") 114 | log_writer.update(weight_decay=weight_decay_value, head="opt") 115 | if use_amp: 116 | log_writer.update(grad_norm=grad_norm, head="opt") 117 | log_writer.set_step() 118 | 119 | if wandb_logger: 120 | wandb_logger._wandb.log({ 121 | 'Rank-0 Batch Wise/train_loss': loss_value, 122 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 123 | 'Rank-0 Batch Wise/train_min_lr': min_lr 124 | }, commit=False) 125 | if class_acc: 126 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 127 | if use_amp: 128 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 129 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 130 | 131 | 132 | # gather the stats from all processes 133 | metric_logger.synchronize_between_processes() 134 | print("Averaged stats:", metric_logger) 135 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 136 | 137 | @torch.no_grad() 138 | def evaluate(data_loader, model, device, use_amp=False): 139 | criterion = torch.nn.CrossEntropyLoss() 140 | 141 | metric_logger = utils.MetricLogger(delimiter=" ") 142 | header = 'Test:' 143 | 144 | # switch to evaluation mode 145 | model.eval() 146 | for batch in metric_logger.log_every(data_loader, 10, header): 147 | images = batch[0] 148 | target = batch[-1] 149 | 150 | images = images.to(device, non_blocking=True) 151 | target = target.to(device, non_blocking=True) 152 | 153 | # compute output 154 | if use_amp: 155 | with torch.cuda.amp.autocast(): 156 | output = model(images) 157 | loss = criterion(output, target) 158 | else: 159 | output = model(images) 160 | loss = criterion(output, target) 161 | 162 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 163 | 164 | batch_size = images.shape[0] 165 | metric_logger.update(loss=loss.item()) 166 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 167 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 168 | # gather the stats from all processes 169 | metric_logger.synchronize_between_processes() 170 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 171 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 172 | 173 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 174 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import argparse 10 | import datetime 11 | import numpy as np 12 | import time 13 | import torch 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | torch.backends.cuda.matmul.allow_tf32 = True 17 | torch.backends.cudnn.allow_tf32 = True 18 | import json 19 | import os 20 | 21 | from pathlib import Path 22 | 23 | from timm.data.mixup import Mixup 24 | from timm.models import create_model 25 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 26 | from timm.utils import ModelEma 27 | from optim_factory import create_optimizer, LayerDecayValueAssigner 28 | 29 | from datasets import build_dataset 30 | from engine import train_one_epoch, evaluate 31 | 32 | from utils import NativeScalerWithGradNormCount as NativeScaler 33 | import utils 34 | 35 | from dynamic_tanh import convert_ln_to_dyt 36 | 37 | 38 | def str2bool(v): 39 | """ 40 | Converts string to bool type; enables command line 41 | arguments in the format of '--arg1 true --arg2 false' 42 | """ 43 | if isinstance(v, bool): 44 | return v 45 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 46 | return True 47 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 48 | return False 49 | else: 50 | raise argparse.ArgumentTypeError('Boolean value expected.') 51 | 52 | def get_args_parser(): 53 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 54 | parser.add_argument('--batch_size', default=64, type=int, 55 | help='Per GPU batch size') 56 | parser.add_argument('--epochs', default=300, type=int) 57 | parser.add_argument('--update_freq', default=1, type=int, 58 | help='gradient accumulation steps') 59 | 60 | # Model parameters 61 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL', 62 | help='Name of model to train') 63 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', 64 | help='Drop path rate (default: 0.0)') 65 | parser.add_argument('--input_size', default=224, type=int, 66 | help='image input size') 67 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 68 | help="Layer scale initial values") 69 | 70 | # EMA related parameters 71 | parser.add_argument('--model_ema', type=str2bool, default=False) 72 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 73 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 74 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 75 | 76 | # Optimization parameters 77 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 78 | help='Optimizer (default: "adamw"') 79 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 80 | help='Optimizer Epsilon (default: 1e-8)') 81 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 82 | help='Optimizer Betas (default: None, use opt default)') 83 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 84 | help='Clip gradient norm (default: None, no clipping)') 85 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 86 | help='SGD momentum (default: 0.9)') 87 | parser.add_argument('--weight_decay', type=float, default=0.05, 88 | help='weight decay (default: 0.05)') 89 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 90 | weight decay. We use a cosine schedule for WD and using a larger decay by 91 | the end of training improves performance for ViTs.""") 92 | 93 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR', 94 | help='learning rate (default: 4e-3), with total batch size 4096') 95 | parser.add_argument('--layer_decay', type=float, default=1.0) 96 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 97 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 98 | parser.add_argument('--warmup_epochs', type=int, default=20, metavar='N', 99 | help='epochs to warmup LR, if scheduler supports') 100 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 101 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 102 | 103 | # Augmentation parameters 104 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 105 | help='Color jitter factor (default: 0.4)') 106 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 107 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 108 | parser.add_argument('--smoothing', type=float, default=0.1, 109 | help='Label smoothing (default: 0.1)') 110 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 111 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 112 | 113 | # Evaluation parameters 114 | parser.add_argument('--crop_pct', type=float, default=None) 115 | 116 | # * Random Erase params 117 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 118 | help='Random erase prob (default: 0.25)') 119 | parser.add_argument('--remode', type=str, default='pixel', 120 | help='Random erase mode (default: "pixel")') 121 | parser.add_argument('--recount', type=int, default=1, 122 | help='Random erase count (default: 1)') 123 | parser.add_argument('--resplit', type=str2bool, default=False, 124 | help='Do not random erase first (clean) augmentation split') 125 | 126 | # * Mixup params 127 | parser.add_argument('--mixup', type=float, default=0.8, 128 | help='mixup alpha, mixup enabled if > 0.') 129 | parser.add_argument('--cutmix', type=float, default=1.0, 130 | help='cutmix alpha, cutmix enabled if > 0.') 131 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 132 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 133 | parser.add_argument('--mixup_prob', type=float, default=1.0, 134 | help='Probability of performing mixup or cutmix when either/both is enabled') 135 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 136 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 137 | parser.add_argument('--mixup_mode', type=str, default='batch', 138 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 139 | 140 | # * Finetuning params 141 | parser.add_argument('--finetune', default='', 142 | help='finetune from checkpoint') 143 | parser.add_argument('--head_init_scale', default=1.0, type=float, 144 | help='classifier head initial scale, typically adjusted in fine-tuning') 145 | parser.add_argument('--model_key', default='model|module', type=str, 146 | help='which key to load from saved state dict, usually model or model_ema') 147 | parser.add_argument('--model_prefix', default='', type=str) 148 | 149 | # Dataset parameters 150 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 151 | help='dataset path') 152 | parser.add_argument('--eval_data_path', default=None, type=str, 153 | help='dataset path for evaluation') 154 | parser.add_argument('--nb_classes', default=1000, type=int, 155 | help='number of the classification types') 156 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 157 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'], 158 | type=str, help='ImageNet dataset path') 159 | parser.add_argument('--output_dir', default='', 160 | help='path where to save, empty for no saving') 161 | parser.add_argument('--log_dir', default=None, 162 | help='path where to tensorboard log') 163 | parser.add_argument('--device', default='cuda', 164 | help='device to use for training / testing') 165 | parser.add_argument('--seed', default=0, type=int) 166 | 167 | parser.add_argument('--resume', default='', 168 | help='resume from checkpoint') 169 | parser.add_argument('--auto_resume', type=str2bool, default=True) 170 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 171 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 172 | parser.add_argument('--save_ckpt_num', default=3, type=int) 173 | 174 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 175 | help='start epoch') 176 | parser.add_argument('--eval', type=str2bool, default=False, 177 | help='Perform evaluation only') 178 | parser.add_argument('--dist_eval', type=str2bool, default=True, 179 | help='Enabling distributed evaluation') 180 | parser.add_argument('--disable_eval', type=str2bool, default=False, 181 | help='Disabling evaluation during training') 182 | parser.add_argument('--num_workers', default=10, type=int) 183 | parser.add_argument('--pin_mem', type=str2bool, default=True, 184 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 185 | 186 | # distributed training parameters 187 | parser.add_argument('--world_size', default=1, type=int, 188 | help='number of distributed processes') 189 | parser.add_argument('--local_rank', default=-1, type=int) 190 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 191 | parser.add_argument('--dist_url', default='env://', 192 | help='url used to set up distributed training') 193 | 194 | parser.add_argument('--use_amp', type=str2bool, default=False, 195 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 196 | 197 | # Weights and Biases arguments 198 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 199 | help="enable logging to Weights and Biases") 200 | parser.add_argument('--project', default='convnext', type=str, 201 | help="The name of the W&B project where you're sending the new run.") 202 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 203 | help="Save model checkpoints as W&B Artifacts.") 204 | 205 | parser.add_argument('--dynamic_tanh', type=str2bool, default=False) 206 | 207 | return parser 208 | 209 | def main(args): 210 | utils.init_distributed_mode(args) 211 | print(args) 212 | device = torch.device(args.device) 213 | 214 | # fix the seed for reproducibility 215 | seed = args.seed + utils.get_rank() 216 | torch.manual_seed(seed) 217 | np.random.seed(seed) 218 | cudnn.benchmark = True 219 | 220 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 221 | if args.disable_eval: 222 | args.dist_eval = False 223 | dataset_val = None 224 | else: 225 | dataset_val, _ = build_dataset(is_train=False, args=args) 226 | 227 | num_tasks = utils.get_world_size() 228 | global_rank = utils.get_rank() 229 | 230 | sampler_train = torch.utils.data.DistributedSampler( 231 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 232 | ) 233 | print("Sampler_train = %s" % str(sampler_train)) 234 | if args.dist_eval: 235 | if len(dataset_val) % num_tasks != 0: 236 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 237 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 238 | 'equal num of samples per-process.') 239 | sampler_val = torch.utils.data.DistributedSampler( 240 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 241 | else: 242 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 243 | 244 | if global_rank == 0 and args.log_dir is not None: 245 | os.makedirs(args.log_dir, exist_ok=True) 246 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 247 | else: 248 | log_writer = None 249 | 250 | if global_rank == 0 and args.enable_wandb: 251 | wandb_logger = utils.WandbLogger(args) 252 | else: 253 | wandb_logger = None 254 | 255 | data_loader_train = torch.utils.data.DataLoader( 256 | dataset_train, sampler=sampler_train, 257 | batch_size=args.batch_size, 258 | num_workers=args.num_workers, 259 | pin_memory=args.pin_mem, 260 | drop_last=True, 261 | ) 262 | 263 | if dataset_val is not None: 264 | data_loader_val = torch.utils.data.DataLoader( 265 | dataset_val, sampler=sampler_val, 266 | batch_size=int(1.5 * args.batch_size), 267 | num_workers=args.num_workers, 268 | pin_memory=args.pin_mem, 269 | drop_last=False 270 | ) 271 | else: 272 | data_loader_val = None 273 | 274 | mixup_fn = None 275 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 276 | if mixup_active: 277 | print("Mixup is activated!") 278 | mixup_fn = Mixup( 279 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 280 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 281 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 282 | 283 | if "convnext" in args.model: 284 | model = create_model( 285 | args.model, 286 | pretrained=False, 287 | num_classes=args.nb_classes, 288 | drop_path_rate=args.drop_path, 289 | ls_init_value=args.layer_scale_init_value, 290 | head_init_scale=args.head_init_scale, 291 | ) 292 | elif "vit" in args.model: 293 | model = create_model( 294 | args.model, 295 | pretrained=False, 296 | num_classes=args.nb_classes, 297 | global_pool='avg', 298 | drop_path_rate=args.drop_path, 299 | ) 300 | else: 301 | raise ValueError(f"Unrecognized model: {args.model}") 302 | 303 | if args.dynamic_tanh: 304 | model = convert_ln_to_dyt(model) 305 | 306 | if args.finetune: 307 | if args.finetune.startswith('https'): 308 | checkpoint = torch.hub.load_state_dict_from_url( 309 | args.finetune, map_location='cpu', check_hash=True) 310 | else: 311 | checkpoint = torch.load(args.finetune, map_location='cpu') 312 | 313 | print("Load ckpt from %s" % args.finetune) 314 | checkpoint_model = None 315 | for model_key in args.model_key.split('|'): 316 | if model_key in checkpoint: 317 | checkpoint_model = checkpoint[model_key] 318 | print("Load state_dict by model_key = %s" % model_key) 319 | break 320 | if checkpoint_model is None: 321 | checkpoint_model = checkpoint 322 | state_dict = model.state_dict() 323 | for k in ['head.weight', 'head.bias']: 324 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 325 | print(f"Removing key {k} from pretrained checkpoint") 326 | del checkpoint_model[k] 327 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 328 | model.to(device) 329 | 330 | model_ema = None 331 | if args.model_ema: 332 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 333 | model_ema = ModelEma( 334 | model, 335 | decay=args.model_ema_decay, 336 | device='cpu' if args.model_ema_force_cpu else '', 337 | resume='') 338 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 339 | 340 | model_without_ddp = model 341 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 342 | 343 | print("Model = %s" % str(model_without_ddp)) 344 | print('number of params:', n_parameters) 345 | 346 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 347 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 348 | print("LR = %.8f" % args.lr) 349 | print("Batch size = %d" % total_batch_size) 350 | print("Update frequent = %d" % args.update_freq) 351 | print("Number of training examples = %d" % len(dataset_train)) 352 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 353 | 354 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 355 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value. 356 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \ 357 | "Layer Decay impl only supports convnext_small/base/large/xlarge" 358 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 359 | else: 360 | assigner = None 361 | 362 | if assigner is not None: 363 | print("Assigned values = %s" % str(assigner.values)) 364 | 365 | if args.distributed: 366 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 367 | model_without_ddp = model.module 368 | 369 | optimizer = create_optimizer( 370 | args, model_without_ddp, skip_list=None, 371 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 372 | get_layer_scale=assigner.get_scale if assigner is not None else None) 373 | 374 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 375 | 376 | print("Use Cosine LR scheduler") 377 | lr_schedule_values = utils.cosine_scheduler( 378 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 379 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 380 | ) 381 | 382 | if args.weight_decay_end is None: 383 | args.weight_decay_end = args.weight_decay 384 | wd_schedule_values = utils.cosine_scheduler( 385 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 386 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 387 | 388 | if mixup_fn is not None: 389 | # smoothing is handled with mixup label transform 390 | criterion = SoftTargetCrossEntropy() 391 | elif args.smoothing > 0.: 392 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 393 | else: 394 | criterion = torch.nn.CrossEntropyLoss() 395 | 396 | print("criterion = %s" % str(criterion)) 397 | 398 | utils.auto_load_model( 399 | args=args, model=model, model_without_ddp=model_without_ddp, 400 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 401 | 402 | if args.eval: 403 | print(f"Eval only mode") 404 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 405 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 406 | return 407 | 408 | max_accuracy = 0.0 409 | if args.model_ema and args.model_ema_eval: 410 | max_accuracy_ema = 0.0 411 | 412 | print("Start training for %d epochs" % args.epochs) 413 | start_time = time.time() 414 | for epoch in range(args.start_epoch, args.epochs): 415 | if args.distributed: 416 | data_loader_train.sampler.set_epoch(epoch) 417 | if log_writer is not None: 418 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 419 | if wandb_logger: 420 | wandb_logger.set_steps() 421 | train_stats = train_one_epoch( 422 | model, criterion, data_loader_train, optimizer, 423 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 424 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 425 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, 426 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 427 | use_amp=args.use_amp 428 | ) 429 | if args.output_dir and args.save_ckpt: 430 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 431 | utils.save_model( 432 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 433 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 434 | if data_loader_val is not None: 435 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 436 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 437 | if max_accuracy < test_stats["acc1"]: 438 | max_accuracy = test_stats["acc1"] 439 | if args.output_dir and args.save_ckpt: 440 | utils.save_model( 441 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 442 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 443 | print(f'Max accuracy: {max_accuracy:.2f}%') 444 | 445 | if log_writer is not None: 446 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 447 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 448 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 449 | 450 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 451 | **{f'test_{k}': v for k, v in test_stats.items()}, 452 | 'epoch': epoch, 453 | 'n_parameters': n_parameters} 454 | 455 | # repeat testing routines for EMA, if ema eval is turned on 456 | if args.model_ema and args.model_ema_eval: 457 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 458 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 459 | if max_accuracy_ema < test_stats_ema["acc1"]: 460 | max_accuracy_ema = test_stats_ema["acc1"] 461 | if args.output_dir and args.save_ckpt: 462 | utils.save_model( 463 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 464 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 465 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 466 | if log_writer is not None: 467 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 468 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 469 | else: 470 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 471 | 'epoch': epoch, 472 | 'n_parameters': n_parameters} 473 | 474 | if args.output_dir and utils.is_main_process(): 475 | if log_writer is not None: 476 | log_writer.flush() 477 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 478 | f.write(json.dumps(log_stats) + "\n") 479 | 480 | if wandb_logger: 481 | wandb_logger.log_epoch_metrics(log_stats) 482 | 483 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 484 | wandb_logger.log_checkpoints() 485 | 486 | 487 | total_time = time.time() - start_time 488 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 489 | print('Training time {}'.format(total_time_str)) 490 | 491 | if __name__ == '__main__': 492 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()]) 493 | args = parser.parse_args() 494 | if args.output_dir: 495 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 496 | main(args) 497 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | from torch import optim as optim 11 | 12 | from timm.optim.adafactor import Adafactor 13 | from timm.optim.adahessian import Adahessian 14 | from timm.optim.adamp import AdamP 15 | from timm.optim.lookahead import Lookahead 16 | from timm.optim.nvnovograd import NvNovoGrad 17 | from timm.optim.rmsprop_tf import RMSpropTF 18 | from timm.optim.sgdp import SGDP 19 | 20 | import json 21 | 22 | try: 23 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 24 | has_apex = True 25 | except ImportError: 26 | has_apex = False 27 | 28 | 29 | def get_num_layer_for_convnext(var_name): 30 | """ 31 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 32 | consecutive blocks, including possible neighboring downsample layers; 33 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 34 | """ 35 | num_max_layer = 12 36 | if var_name.startswith("downsample_layers"): 37 | stage_id = int(var_name.split('.')[1]) 38 | if stage_id == 0: 39 | layer_id = 0 40 | elif stage_id == 1 or stage_id == 2: 41 | layer_id = stage_id + 1 42 | elif stage_id == 3: 43 | layer_id = 12 44 | return layer_id 45 | 46 | elif var_name.startswith("stages"): 47 | stage_id = int(var_name.split('.')[1]) 48 | block_id = int(var_name.split('.')[2]) 49 | if stage_id == 0 or stage_id == 1: 50 | layer_id = stage_id + 1 51 | elif stage_id == 2: 52 | layer_id = 3 + block_id // 3 53 | elif stage_id == 3: 54 | layer_id = 12 55 | return layer_id 56 | else: 57 | return num_max_layer + 1 58 | 59 | class LayerDecayValueAssigner(object): 60 | def __init__(self, values): 61 | self.values = values 62 | 63 | def get_scale(self, layer_id): 64 | return self.values[layer_id] 65 | 66 | def get_layer_id(self, var_name): 67 | return get_num_layer_for_convnext(var_name) 68 | 69 | 70 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 71 | parameter_group_names = {} 72 | parameter_group_vars = {} 73 | 74 | for name, param in model.named_parameters(): 75 | if not param.requires_grad: 76 | continue # frozen weights 77 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 78 | group_name = "no_decay" 79 | this_weight_decay = 0. 80 | else: 81 | group_name = "decay" 82 | this_weight_decay = weight_decay 83 | if get_num_layer is not None: 84 | layer_id = get_num_layer(name) 85 | group_name = "layer_%d_%s" % (layer_id, group_name) 86 | else: 87 | layer_id = None 88 | 89 | if group_name not in parameter_group_names: 90 | if get_layer_scale is not None: 91 | scale = get_layer_scale(layer_id) 92 | else: 93 | scale = 1. 94 | 95 | parameter_group_names[group_name] = { 96 | "weight_decay": this_weight_decay, 97 | "params": [], 98 | "lr_scale": scale 99 | } 100 | parameter_group_vars[group_name] = { 101 | "weight_decay": this_weight_decay, 102 | "params": [], 103 | "lr_scale": scale 104 | } 105 | 106 | parameter_group_vars[group_name]["params"].append(param) 107 | parameter_group_names[group_name]["params"].append(name) 108 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 109 | return list(parameter_group_vars.values()) 110 | 111 | 112 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 113 | opt_lower = args.opt.lower() 114 | weight_decay = args.weight_decay 115 | # if weight_decay and filter_bias_and_bn: 116 | if filter_bias_and_bn: 117 | skip = {} 118 | if skip_list is not None: 119 | skip = skip_list 120 | elif hasattr(model, 'no_weight_decay'): 121 | skip = model.no_weight_decay() 122 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 123 | weight_decay = 0. 124 | else: 125 | parameters = model.parameters() 126 | 127 | if 'fused' in opt_lower: 128 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 129 | 130 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 131 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 132 | opt_args['eps'] = args.opt_eps 133 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 134 | opt_args['betas'] = args.opt_betas 135 | 136 | opt_split = opt_lower.split('_') 137 | opt_lower = opt_split[-1] 138 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 139 | opt_args.pop('eps', None) 140 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 141 | elif opt_lower == 'momentum': 142 | opt_args.pop('eps', None) 143 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 144 | elif opt_lower == 'adam': 145 | optimizer = optim.Adam(parameters, **opt_args) 146 | elif opt_lower == 'adamw': 147 | optimizer = optim.AdamW(parameters, **opt_args) 148 | elif opt_lower == 'nadam': 149 | optimizer = Nadam(parameters, **opt_args) 150 | elif opt_lower == 'radam': 151 | optimizer = RAdam(parameters, **opt_args) 152 | elif opt_lower == 'adamp': 153 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 154 | elif opt_lower == 'sgdp': 155 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 156 | elif opt_lower == 'adadelta': 157 | optimizer = optim.Adadelta(parameters, **opt_args) 158 | elif opt_lower == 'adafactor': 159 | if not args.lr: 160 | opt_args['lr'] = None 161 | optimizer = Adafactor(parameters, **opt_args) 162 | elif opt_lower == 'adahessian': 163 | optimizer = Adahessian(parameters, **opt_args) 164 | elif opt_lower == 'rmsprop': 165 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 166 | elif opt_lower == 'rmsproptf': 167 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 168 | elif opt_lower == 'novograd': 169 | optimizer = NovoGrad(parameters, **opt_args) 170 | elif opt_lower == 'nvnovograd': 171 | optimizer = NvNovoGrad(parameters, **opt_args) 172 | elif opt_lower == 'fusedsgd': 173 | opt_args.pop('eps', None) 174 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 175 | elif opt_lower == 'fusedmomentum': 176 | opt_args.pop('eps', None) 177 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 178 | elif opt_lower == 'fusedadam': 179 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 180 | elif opt_lower == 'fusedadamw': 181 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 182 | elif opt_lower == 'fusedlamb': 183 | optimizer = FusedLAMB(parameters, **opt_args) 184 | elif opt_lower == 'fusednovograd': 185 | opt_args.setdefault('betas', (0.95, 0.98)) 186 | optimizer = FusedNovoGrad(parameters, **opt_args) 187 | else: 188 | assert False and "Invalid optimizer" 189 | 190 | if len(opt_split) > 1: 191 | if opt_split[0] == 'lookahead': 192 | optimizer = Lookahead(optimizer) 193 | 194 | return optimizer 195 | -------------------------------------------------------------------------------- /other_tasks/DINO/README.md: -------------------------------------------------------------------------------- 1 | # DINO with DyT 2 | 3 | This guide provides instructions for reproducing the DINO results with our proposed modifications, as presented in our paper. Follow the steps below to set up the environment, apply the patches, and run the experiments. 4 | 5 | ## 1. Clone the DINO Repository 6 | 7 | Clone the official DINO repository from GitHub: 8 | ``` 9 | git clone https://github.com/facebookresearch/dino.git 10 | ``` 11 | 12 | ## 2. Set Up the Python Environment 13 | 14 | Create and activate a Conda environment with the required dependencies: 15 | ``` 16 | conda create -n DINO python=3.12 17 | conda activate DINO 18 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia 19 | ``` 20 | 21 | ## 3. Apply DynamicTanh Patch (Optional) 22 | *(Skip this step if you want to reproduce the baseline results.)* \ 23 | To reproduce the results using Dynamic Tanh (DyT), apply the following patches: 24 | ``` 25 | cp dynamic_tanh.py dino 26 | cp dynamic-tanh.patch dino 27 | cd dino 28 | git apply dynamic-tanh.patch 29 | ``` 30 | 31 | ## 3. Run Experiments 32 | 33 | You can reproduce the DINO pretraining results using the following command: 34 | 35 | ### ViT-B with Patch Size 16 36 | 37 | This configuration follows the arguments from the original DINO documentation at [ViT-B-16](https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/args.txt). 38 | ``` 39 | torchrun --nnodes=4 --nproc_per_node=8 main_dino.py \ 40 | --arch vit_base \ 41 | --patch_size 16 \ 42 | --out_dim 65536 \ 43 | --norm_last_layer true \ 44 | --warmup_teacher_temp 0.04 \ 45 | --teacher_temp 0.07 \ 46 | --warmup_teacher_temp_epochs 50 \ 47 | --use_fp16 false \ 48 | --weight_decay 0.04 \ 49 | --weight_decay_end 0.4 \ 50 | --clip_grad 0.3 \ 51 | --batch_size_per_gpu 32 \ 52 | --epochs 400 \ 53 | --freeze_last_layer 3 \ 54 | --lr 0.00075 \ 55 | --warmup_epochs 10 \ 56 | --min_lr 2e-06 \ 57 | --global_crops_scale 0.25 1.0 \ 58 | --local_crops_scale 0.05 0.25 \ 59 | --local_crops_number 10 \ 60 | --seed 0 \ 61 | --num_workers 10 \ 62 | --optimizer adamw \ 63 | --momentum_teacher 0.996 \ 64 | --use_bn_in_head false \ 65 | --drop_path_rate 0.1 \ 66 | --data_path /path/to/imagenet/train \ 67 | --output_dir /path/to/saving_dir 68 | ``` 69 | 70 | ### ViT-B with Patch Size 8 71 | 72 | This configuration follows the arguments from the original DINO documentation at [ViT-B-8](https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/args.txt). 73 | ``` 74 | torchrun --nnodes=22 --nproc_per_node=8 main_dino.py \ 75 | --arch vit_base \ 76 | --patch_size 8 \ 77 | --out_dim 65536 \ 78 | --norm_last_layer true \ 79 | --warmup_teacher_temp 0.03 \ 80 | --teacher_temp 0.07 \ 81 | --warmup_teacher_temp_epochs 50 \ 82 | --use_fp16 false \ 83 | --weight_decay 0.04 \ 84 | --weight_decay_end 0.4 \ 85 | --clip_grad 3.0 \ 86 | --batch_size_per_gpu 6 \ 87 | --epochs 300 \ 88 | --freeze_last_layer 3 \ 89 | --lr 0.0005 \ 90 | --warmup_epochs 10 \ 91 | --min_lr 2e-06 \ 92 | --global_crops_scale 0.25 1.0 \ 93 | --local_crops_scale 0.05 0.25 \ 94 | --local_crops_number 10 \ 95 | --seed 0 \ 96 | --num_workers 10 \ 97 | --optimizer adamw \ 98 | --momentum_teacher 0.996 \ 99 | --use_bn_in_head false \ 100 | --drop_path_rate 0.1 \ 101 | --data_path /path/to/imagenet/train \ 102 | --output_dir /path/to/saving_dir 103 | ``` 104 | 105 | 106 | ## 5. Evaluation 107 | *(Since DINO does not provide fine-tuning code, we use the MAE code for fine-tuning.)* \ 108 | For fine-tuning and evaluation of pretrained models, refer to the original MAE documentation: [FINETUNE](https://github.com/facebookresearch/mae/blob/main/FINETUNE.md). 109 | -------------------------------------------------------------------------------- /other_tasks/DINO/dynamic-tanh.patch: -------------------------------------------------------------------------------- 1 | From c3b4316f39449bcf386d126ad5710b3d130994cc Mon Sep 17 00:00:00 2001 2 | From: Jiachen Zhu 3 | Date: Wed, 12 Mar 2025 16:28:04 +0000 4 | Subject: [PATCH] dynamic-tanh 5 | 6 | --- 7 | main_dino.py | 3 +++ 8 | 1 file changed, 3 insertions(+) 9 | 10 | diff --git a/main_dino.py b/main_dino.py 11 | index cade987..33239f0 100644 12 | --- a/main_dino.py 13 | +++ b/main_dino.py 14 | @@ -33,6 +33,7 @@ from torchvision import models as torchvision_models 15 | import utils 16 | import vision_transformer as vits 17 | from vision_transformer import DINOHead 18 | +from dynamic_tanh import convert_ln_to_dyt 19 | 20 | torchvision_archs = sorted(name for name in torchvision_models.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | @@ -164,6 +165,8 @@ def train_dino(args): 23 | drop_path_rate=args.drop_path_rate, # stochastic depth 24 | ) 25 | teacher = vits.__dict__[args.arch](patch_size=args.patch_size) 26 | + student = convert_ln_to_dyt(student) 27 | + teacher = convert_ln_to_dyt(teacher) 28 | embed_dim = student.embed_dim 29 | # if the network is a XCiT 30 | elif args.arch in torch.hub.list("facebookresearch/xcit:main"): 31 | -- 32 | 2.34.1 33 | 34 | -------------------------------------------------------------------------------- /other_tasks/DINO/dynamic_tanh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DynamicTanh(nn.Module): 6 | def __init__(self, normalized_shape, alpha_init_value=0.5): 7 | super().__init__() 8 | self.normalized_shape = normalized_shape 9 | self.alpha_init_value = alpha_init_value 10 | self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) 11 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 12 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 13 | 14 | def forward(self, x): 15 | return self.weight * torch.tanh(self.alpha * x) + self.bias 16 | 17 | def extra_repr(self): 18 | return f"normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}" 19 | 20 | 21 | def convert_ln_to_dyt(module): 22 | module_output = module 23 | if isinstance(module, nn.LayerNorm): 24 | module_output = DynamicTanh(module.normalized_shape) 25 | for name, child in module.named_children(): 26 | module_output.add_module(name, convert_ln_to_dyt(child)) 27 | del module 28 | return module_output -------------------------------------------------------------------------------- /other_tasks/DNA/README.md: -------------------------------------------------------------------------------- 1 | # DNA Sequence Modeling with DyT 2 | 3 | This guide provides instructions for reproducing the DNA sequence modeling results with our proposed DynamicTanh (DyT) modifications, as presented in our paper. 4 | 5 | ## 1. Clone the Caduceus Repository 6 | 7 | Clone the official Caduceus repository from GitHub: 8 | 9 | ```bash 10 | git clone https://github.com/kuleshov-group/caduceus.git 11 | ``` 12 | 13 | ## 2. Set Up the Python Environment and Datasets 14 | 15 | Follow the instructions in the original [Caduceus README](https://github.com/kuleshov-group/caduceus/blob/main/README.md) to: 16 | - Set up the Python environment with required dependencies 17 | - Download and prepare the necessary datasets for DNA sequence modeling 18 | 19 | ## 3. Apply DynamicTanh (DyT) Patch 20 | 21 | *(Skip this step if you want to reproduce the baseline results without DyT modifications.)* 22 | 23 | To reproduce the results using Dynamic Tanh (DyT), apply the following patches: 24 | 25 | ```bash 26 | cp dynamic_tanh.py caduceus/ 27 | cp dynamic-tanh.patch caduceus/ 28 | cd caduceus 29 | git apply dynamic-tanh.patch 30 | ``` 31 | 32 | ## 4. Run Experiments 33 | 34 | You can reproduce our DNA Sequence Modeling results using the provided SLURM scripts. You may need to edit these scripts to adapt them to your computing environment. 35 | 36 | ### Caduceus Model Training 37 | 38 | ```bash 39 | cd slurm_scripts 40 | sbatch run_pretrain_caduceus.sh 41 | ``` 42 | 43 | ### HyenaDNA Model Training 44 | 45 | ```bash 46 | cd slurm_scripts 47 | sbatch run_pretrain_hyena.sh 48 | ``` 49 | 50 | ### Model Evaluation 51 | 52 | ```bash 53 | cd slurm_scripts 54 | bash wrapper_run_genomics.sh 55 | ``` 56 | 57 | -------------------------------------------------------------------------------- /other_tasks/DNA/dynamic-tanh.patch: -------------------------------------------------------------------------------- 1 | From e8a4fd96f43ca22e953ec5053bf65398ba949a6f Mon Sep 17 00:00:00 2001 2 | From: Jiachen Zhu 3 | Date: Mon, 17 Mar 2025 22:36:03 +0000 4 | Subject: [PATCH] dynamic-tanh 5 | 6 | --- 7 | train.py | 4 +++- 8 | 1 file changed, 3 insertions(+), 1 deletion(-) 9 | 10 | diff --git a/train.py b/train.py 11 | index c49b878..29bf26d 100644 12 | --- a/train.py 13 | +++ b/train.py 14 | @@ -33,7 +33,7 @@ import torch.backends 15 | 16 | torch.backends.cuda.matmul.allow_tf32 = True 17 | torch.backends.cudnn.allow_tf32 = True 18 | - 19 | +from dynamic_tanh import convert_ln_to_dyt 20 | OmegaConf.register_new_resolver('eval', eval) 21 | OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y) 22 | OmegaConf.register_new_resolver('min', lambda x, y: min([x, y])) 23 | @@ -202,6 +202,8 @@ class SequenceLightningModule(pl.LightningModule): 24 | self.model = utils.instantiate(registry.model, model_hparams) 25 | else: 26 | self.model = utils.instantiate(registry.model, self.hparams.model) 27 | + self.model = convert_ln_to_dyt(self.model, alpha_init_value) 28 | + print(self.model) 29 | if (name := self.hparams.train.post_init_hook['_name_']) is not None: 30 | kwargs = self.hparams.train.post_init_hook.copy() 31 | del kwargs['_name_'] 32 | -- 33 | 2.34.1 34 | 35 | -------------------------------------------------------------------------------- /other_tasks/DNA/dynamic_tanh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DynamicTanh(nn.Module): 6 | def __init__(self, normalized_shape, alpha_init_value=0.5): 7 | super().__init__() 8 | self.normalized_shape = normalized_shape 9 | self.alpha_init_value = alpha_init_value 10 | self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) 11 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 12 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 13 | 14 | def forward(self, x): 15 | return self.weight * torch.tanh(self.alpha * x) + self.bias 16 | 17 | def extra_repr(self): 18 | return f'normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}' 19 | 20 | 21 | def convert_ln_to_dyt(module): 22 | module_output = module 23 | if isinstance(module, nn.LayerNorm): 24 | module_output = DynamicTanh(module.normalized_shape) 25 | for name, child in module.named_children(): 26 | module_output.add_module(name, convert_ln_to_dyt(child)) 27 | del module 28 | return module_output 29 | -------------------------------------------------------------------------------- /other_tasks/DiT/README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Transformers (DiTs) with DyT 2 | 3 | This guide provides instructions for reproducing the DiT results with our proposed modifications, as presented in our paper. Follow the steps below to set up the environment, apply the patches, and run the experiments. 4 | 5 | ## 1. Clone the DiT Repository 6 | 7 | Clone the official DiT repository from GitHub: 8 | ``` 9 | git clone https://github.com/facebookresearch/DiT.git 10 | ``` 11 | 12 | ## 2. Set Up the Python Environment 13 | 14 | Set up the Python environment with the following commands: 15 | ``` 16 | conda create -n DiT python=3.12 17 | conda activate DiT 18 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia 19 | pip install timm==1.0.15 diffusers==0.32.2 accelerate==1.4.0 20 | ``` 21 | 22 | ## 3. Apply Learning Rate Fix 23 | 24 | Update the original DiT code to accept learning rate argument by applying the provided patch: 25 | ``` 26 | cp learning-rate-fix.patch DiT 27 | cd DiT 28 | git apply learning-rate-fix.patch 29 | ``` 30 | 31 | ## 4. Apply DynamicTanh Patch (Optional) 32 | *(Skip this step if you wish to reproduce the baseline results.)* \ 33 | To reproduce the results using Dynamic Tanh (DyT), apply the following patches: 34 | ``` 35 | cp dynamic_tanh.py DiT 36 | cp dynamic-tanh.patch DiT 37 | cd DiT 38 | git apply dynamic-tanh.patch 39 | ``` 40 | 41 | ## 5. Run Experiments 42 | 43 | After applying the patches, run the DiT pretraining with the following command: 44 | ``` 45 | srun torchrun --nnodes=1 --nproc_per_node=8 train.py \ 46 | --model $MODEL \ 47 | --lr $LEARNING_RATE \ 48 | --data-path /path/to/imagenet/train \ 49 | --results-dir /path/to/saving_dir 50 | ``` 51 | Replace `$MODEL` with one of the following options: `DiT-B/4`, `DiT-L/4`, or `DiT-XL/2`. 52 | Repace `$LEARNING_RATE` with one of the following options: `1e-4`, `2e-4`, or `4e-4`. 53 | 54 | 55 | ## 6. Evaluation 56 | 57 | Follow the [DiT evaluation guide](https://github.com/facebookresearch/DiT) to: 58 | - Generate samples 59 | - Compute evaluation metrics using the TensorFlow evaluation suite provided in the repository. 60 | 61 | -------------------------------------------------------------------------------- /other_tasks/DiT/dynamic-tanh.patch: -------------------------------------------------------------------------------- 1 | From 01dc036d356c11ef0cd298de550e2802f928c5f8 Mon Sep 17 00:00:00 2001 2 | From: Jiachen Zhu 3 | Date: Mon, 17 Mar 2025 19:42:55 +0000 4 | Subject: [PATCH] dynamic-tanh 5 | 6 | --- 7 | models.py | 10 ---------- 8 | train.py | 2 ++ 9 | 2 files changed, 2 insertions(+), 10 deletions(-) 10 | 11 | diff --git a/models.py b/models.py 12 | index c90eeba..5b0750a 100644 13 | --- a/models.py 14 | +++ b/models.py 15 | @@ -204,16 +204,6 @@ class DiT(nn.Module): 16 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 17 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 18 | 19 | - # Zero-out adaLN modulation layers in DiT blocks: 20 | - for block in self.blocks: 21 | - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 22 | - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 23 | - 24 | - # Zero-out output layers: 25 | - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 26 | - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 27 | - nn.init.constant_(self.final_layer.linear.weight, 0) 28 | - nn.init.constant_(self.final_layer.linear.bias, 0) 29 | 30 | def unpatchify(self, x): 31 | """ 32 | diff --git a/train.py b/train.py 33 | index 3bc8c87..c9e5a24 100644 34 | --- a/train.py 35 | +++ b/train.py 36 | @@ -30,6 +30,7 @@ import os 37 | from models import DiT_models 38 | from diffusion import create_diffusion 39 | from diffusers.models import AutoencoderKL 40 | +from dynamic_tanh import convert_ln_to_dyt 41 | 42 | 43 | ################################################################################# 44 | @@ -143,6 +144,7 @@ def main(args): 45 | input_size=latent_size, 46 | num_classes=args.num_classes 47 | ) 48 | + model = convert_ln_to_dyt(model) 49 | # Note that parameter initialization is done within the DiT constructor 50 | ema = deepcopy(model).to(device) # Create an EMA of the model for use after training 51 | requires_grad(ema, False) 52 | -- 53 | 2.34.1 54 | 55 | -------------------------------------------------------------------------------- /other_tasks/DiT/dynamic_tanh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DynamicTanh(nn.Module): 6 | def __init__(self, normalized_shape, elementwise_affine, alpha_init_value=0.5): 7 | super().__init__() 8 | self.normalized_shape = normalized_shape 9 | self.elementwise_affine = elementwise_affine 10 | self.alpha_init_value = alpha_init_value 11 | self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) 12 | if elementwise_affine: 13 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 14 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 15 | 16 | def forward(self, x): 17 | if self.elementwise_affine: 18 | return self.weight * torch.tanh(self.alpha * x) + self.bias 19 | else: 20 | return torch.tanh(self.alpha * x) 21 | 22 | def extra_repr(self): 23 | return f"normalized_shape={self.normalized_shape}, elementwise_affine={self.elementwise_affine}, alpha_init_value={self.alpha_init_value}" 24 | 25 | 26 | def convert_ln_to_dyt(module): 27 | module_output = module 28 | if isinstance(module, nn.LayerNorm): 29 | module_output = DynamicTanh(module.normalized_shape, module.elementwise_affine) 30 | for name, child in module.named_children(): 31 | module_output.add_module(name, convert_ln_to_dyt(child)) 32 | del module 33 | return module_output -------------------------------------------------------------------------------- /other_tasks/DiT/learning-rate-fix.patch: -------------------------------------------------------------------------------- 1 | From a78277316a4c58e0e40e5506dbacdc346090597e Mon Sep 17 00:00:00 2001 2 | From: Jiachen Zhu 3 | Date: Mon, 17 Mar 2025 19:33:17 +0000 4 | Subject: [PATCH] learning-rate-fix 5 | 6 | --- 7 | train.py | 3 ++- 8 | 1 file changed, 2 insertions(+), 1 deletion(-) 9 | 10 | diff --git a/train.py b/train.py 11 | index 7cfee80..3bc8c87 100644 12 | --- a/train.py 13 | +++ b/train.py 14 | @@ -152,7 +152,7 @@ def main(args): 15 | logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}") 16 | 17 | # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper): 18 | - opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) 19 | + opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0) 20 | 21 | # Setup data: 22 | transform = transforms.Compose([ 23 | @@ -265,5 +265,6 @@ if __name__ == "__main__": 24 | parser.add_argument("--num-workers", type=int, default=4) 25 | parser.add_argument("--log-every", type=int, default=100) 26 | parser.add_argument("--ckpt-every", type=int, default=50_000) 27 | + parser.add_argument("--lr", type=float, default=4e-4) 28 | args = parser.parse_args() 29 | main(args) 30 | -- 31 | 2.34.1 32 | 33 | -------------------------------------------------------------------------------- /other_tasks/Efficiency/README.md: -------------------------------------------------------------------------------- 1 | # Efficiency of DyT 2 | 3 | This guide provides instructions for reproducing the latency benchmarks reported in Section 6.1 of the original paper. 4 | 5 | ## 1. Set Up the Python Environment 6 | Set up the Python environment using the following commands: 7 | ``` 8 | conda create -n DyT python=3.12 9 | conda activate DyT 10 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia 11 | pip install transformers 12 | ``` 13 | 14 | ## 2. Benchmark Latency 15 | 16 | After setting up the environment, run the benchmark script using the following command: 17 | ``` 18 | python benchmark.py --layer $LAYER --training 19 | ``` 20 | Replace `$LAYER` with one of the following options: 21 | - `DyT` - DynamicTanh 22 | - `RMSNorm` - RMSNorm 23 | - `Identity` - Identity 24 | 25 | To benchmark latency for the forward pass only, omit the `--training` flag. 26 | 27 | 28 | ## 3. Notes and Limitations 29 | 30 | This benchmark is preliminary and does not include any optimization tricks for the forward or backward pass. Therefore, the results should be interpreted as indicative rather than conclusive. 31 | 32 | -------------------------------------------------------------------------------- /other_tasks/Efficiency/benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | torch.set_float32_matmul_precision('high') 8 | import transformers 9 | from dynamic_tanh import convert_rms_to_dyt, convert_rms_to_identity 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser(description="Benchmark the latency of a LLaMA-2 7B.") 14 | parser.add_argument("--layer", default="DyT", help="The layer to benchmark.") 15 | parser.add_argument("--training", action="store_true", help="Whether to benchmark training.") 16 | args = parser.parse_args() 17 | 18 | assert args.layer.lower() in ["dyt", "identity", "rmsnorm"] 19 | 20 | model = transformers.AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") 21 | if args.layer.lower() == "dyt": 22 | model = convert_rms_to_dyt(model) 23 | elif args.layer.lower() == "identity": 24 | model = convert_rms_to_identity(model) 25 | elif args.layer.lower() == "rmsnorm": 26 | pass 27 | else: 28 | raise ValueError("Invalid layer. Must be dyt, identity, or rmsnorm.") 29 | print(model) 30 | 31 | model.to(device=0, dtype=torch.bfloat16) 32 | 33 | samples = [] 34 | for _ in range(200): 35 | samples.append(torch.randint(0, 32000, (1, 4096), dtype=torch.long, device=0)) 36 | 37 | torch.cuda.synchronize() 38 | if args.training: 39 | for sample in samples[:100]: 40 | out = model(sample) 41 | loss = F.cross_entropy(out.logits.view(-1, out.logits.size(-1)), sample.view(-1)) 42 | loss.backward() 43 | else: 44 | for sample in samples[:100]: 45 | with torch.no_grad(): 46 | out = model(sample) 47 | loss = F.cross_entropy(out.logits.view(-1, out.logits.size(-1)), sample.view(-1)) 48 | torch.cuda.synchronize() 49 | 50 | torch.cuda.synchronize() 51 | time_1 = time.time() 52 | if args.training: 53 | for sample in samples[100:]: 54 | out = model(sample) 55 | loss = F.cross_entropy(out.logits.view(-1, out.logits.size(-1)), sample.view(-1)) 56 | loss.backward() 57 | else: 58 | for sample in samples[100:]: 59 | with torch.no_grad(): 60 | out = model(sample) 61 | loss = F.cross_entropy(out.logits.view(-1, out.logits.size(-1)), sample.view(-1)) 62 | torch.cuda.synchronize() 63 | time_2 = time.time() 64 | 65 | print(f"{args.layer}, {'inference' if not args.training else 'training'} Time: {time_2 - time_1:.2f} seconds") 66 | 67 | 68 | -------------------------------------------------------------------------------- /other_tasks/Efficiency/dynamic_tanh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers.models.llama.modeling_llama import LlamaRMSNorm 5 | 6 | 7 | class DynamicTanh(nn.Module): 8 | def __init__(self, normalized_shape): 9 | super().__init__() 10 | self.alpha = nn.Parameter(torch.ones(1) * 0.5) 11 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 12 | 13 | def forward(self, x): 14 | return self.weight * torch.tanh(self.alpha * x) 15 | 16 | 17 | def convert_rms_to_dyt(module): 18 | module_output = module 19 | if isinstance(module, LlamaRMSNorm): 20 | module_output = DynamicTanh(normalized_shape=module.weight.shape[0]) 21 | for name, child in module.named_children(): 22 | module_output.add_module(name, convert_rms_to_dyt(child)) 23 | del module 24 | return module_output 25 | 26 | 27 | def convert_rms_to_identity(module): 28 | module_output = module 29 | if isinstance(module, LlamaRMSNorm): 30 | module_output = nn.Identity() 31 | for name, child in module.named_children(): 32 | module_output.add_module(name, convert_rms_to_identity(child)) 33 | del module 34 | return module_output 35 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /other_tasks/LLaMA/README.md: -------------------------------------------------------------------------------- 1 | # LLaMA with DyT 2 | 3 | This guide provides instructions for reproducing the LLaMA results with our proposed DynamicTanh (DyT) modifications, as presented in our paper. We use the [fms-fsdp](https://github.com/foundation-model-stack/fms-fsdp/tree/main) framework to train our LLaMA models. Follow the steps below to reproduce the results. 4 | 5 | ## 1. Set Up the Python Environment 6 | 7 | Follow the original [fms-fsdp](https://github.com/foundation-model-stack/fms-fsdp/tree/main) documentation to set up the required Python environment for the project. 8 | 9 | You'll need one additional Python library to save dataloader checkpoints in case you need to resume training without repeating the training data: 10 | 11 | ```bash 12 | pip install torchdata 13 | ``` 14 | 15 | ## 2. Generate Tokenized Data 16 | Since the original fms-fsdp repository does not provide a standard dataset, we used the Pile dataset. To simplify the process, we've included a script that generates tokenized Arrow files from the original dataset. 17 | First, determine the world size you want to use, as this will dictate the total number of Arrow files generated. In general, you should create the same number of Arrow files as the maximum total number of GPUs you plan to use for model training. 18 | After deciding on your world size, run the following command: 19 | ```bash 20 | python prepare_data.py \ 21 | --rank $RANK \ 22 | --world_size $WORLD_SIZE \ 23 | --data_path /path/to/data \ 24 | --output_path /path/to/output_dir \ 25 | --max_num_tokens $MAX_NUM_TOKENS \ 26 | --tokenizer $TOKENIZER 27 | ``` 28 | 29 | 30 | 31 | **Important**: You need to run this command `$WORLD_SIZE` times, each with a different `$RANK` ranging from `0` to `$WORLD_SIZE - 1`. 32 | 33 | For large datasets where you only want to tokenize a subset, set `$MAX_NUM_TOKENS` appropriately. The total tokens in your tokenized data will be `$MAX_NUM_TOKENS × $WORLD_SIZE`. For the tokenizer, we use "meta-llama/Llama-2-7b-chat-hf". 34 | 35 | ## 3. Run the Experiments 36 | 37 | Below are the commands for training various sizes of LLaMA models with DyT. 38 | 39 | 40 |
41 | LLaMA2 7B Training Command 42 | 43 | ```bash 44 | MODEL_ARGS="\ 45 | --model_variant=llama2_7b \ 46 | --ckpt_load_path=/checkpoint/path \ 47 | --ckpt_save_path=/checkpoint/path \ 48 | --data_path=/dataset/path \ 49 | --sharding_strategy=hsdp \ 50 | --fsdp_activation_checkpointing=False \ 51 | --selective_checkpointing=1 \ 52 | --mixed_precision=True \ 53 | --low_cpu_fsdp=True \ 54 | --batch_size=2 \ 55 | --learning_rate=0.0003 \ 56 | --checkpoint_interval=5000 \ 57 | --tracker=wandb \ 58 | --tracker_dir=/tracker/path \ 59 | --tracker_project_name=tracker_project_name \ 60 | --tracker_run_name=llama2_dyt_7b \ 61 | --attn_alpha_init_value=0.8 \ 62 | --ffn_alpha_init_value=0.2 \ 63 | --dec_alpha_init_value=0.2 64 | " 65 | srun torchrun --nnodes=64 --nproc_per_node=8 main_training_llama.py ${MODEL_ARGS} 66 | ``` 67 | 68 |
69 | 70 | 71 |
72 | LLaMA2 13B Training Command 73 | 74 | ```bash 75 | MODEL_ARGS="\ 76 | --model_variant=llama2_13b \ 77 | --ckpt_load_path=/checkpoint/path \ 78 | --ckpt_save_path=/checkpoint/path \ 79 | --data_path=/dataset/path \ 80 | --sharding_strategy=hsdp \ 81 | --fsdp_activation_checkpointing=True \ 82 | --selective_checkpointing=0.5 \ 83 | --mixed_precision=True \ 84 | --low_cpu_fsdp=True \ 85 | --batch_size=2 \ 86 | --learning_rate=0.0003 \ 87 | --checkpoint_interval=2000 \ 88 | --tracker=wandb \ 89 | --tracker_dir=/tracker/path \ 90 | --tracker_project_name=tracker_project_name \ 91 | --tracker_run_name=llama2_dyt_13b \ 92 | --attn_alpha_init_value=0.6 \ 93 | --ffn_alpha_init_value=0.15 \ 94 | --dec_alpha_init_value=0.15 95 | " 96 | srun torchrun --nnodes=64 --nproc_per_node=8 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $head_node_ip:54224 main_training_llama.py ${MODEL_ARGS} 97 | ``` 98 | 99 |
100 | 101 | 102 | 103 |
104 | LLaMA2 34B Training Command 105 | 106 | ```bash 107 | MODEL_ARGS="\ 108 | --model_variant=llama2_34b \ 109 | --ckpt_load_path=/checkpoint/path \ 110 | --ckpt_save_path=/checkpoint/path \ 111 | --data_path=/dataset/path \ 112 | --sharding_strategy=fsdp \ 113 | --fsdp_activation_checkpointing=True \ 114 | --selective_checkpointing=0.5 \ 115 | --mixed_precision=True \ 116 | --low_cpu_fsdp=True \ 117 | --batch_size=1 \ 118 | --learning_rate=0.00015 \ 119 | --checkpoint_interval=2000 \ 120 | --tracker=wandb \ 121 | --tracker_dir=/tracker/path \ 122 | --tracker_project_name=tracker_project_name \ 123 | --tracker_run_name=llama2_dyt_34b \ 124 | --attn_alpha_init_value=0.2 \ 125 | --ffn_alpha_init_value=0.05 \ 126 | --dec_alpha_init_value=0.05 127 | " 128 | srun torchrun --nnodes=128 --nproc_per_node=8 main_training_llama.py ${MODEL_ARGS} 129 | ``` 130 | 131 |
132 | 133 | 134 | 135 |
136 | LLaMA2 70B Training Command 137 | 138 | ```bash 139 | MODEL_ARGS="\ 140 | --model_variant=llama2_70b \ 141 | --ckpt_load_path=/checkpoint/path \ 142 | --ckpt_save_path=/checkpoint/path \ 143 | --data_path=/dataset/path \ 144 | --sharding_strategy=fsdp \ 145 | --fsdp_activation_checkpointing=True \ 146 | --selective_checkpointing=1 \ 147 | --mixed_precision=True \ 148 | --low_cpu_fsdp=True \ 149 | --batch_size=1 \ 150 | --learning_rate=0.00015 \ 151 | --checkpoint_interval=2000 \ 152 | --tracker=wandb \ 153 | --tracker_dir=/tracker/path \ 154 | --tracker_project_name=tracker_project_name \ 155 | --tracker_run_name=llama2_dyt_70b \ 156 | --attn_alpha_init_value=0.2 \ 157 | --ffn_alpha_init_value=0.05 \ 158 | --dec_alpha_init_value=0.05 159 | " 160 | srun torchrun --nnodes=128 --nproc_per_node=8 main_training_llama.py ${MODEL_ARGS} 161 | ``` 162 | 163 |
164 | 165 | To reproduce the baseline results, follow the original [fms-fsdp](https://github.com/foundation-model-stack/fms-fsdp) repository using the same command, excluding the last three arguments, which are specific to DyT. 166 | 167 | 168 | ## Acknowledgement 169 | This repository is built using the [Foundation Model Stack](https://github.com/foundation-model-stack/foundation-model-stack) library and [fms-fsdp](https://github.com/foundation-model-stack/fms-fsdp) repository. 170 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiachenzhu/DyT/aab5dde0bc1bdd4410f687dad404c87b31808f90/other_tasks/LLaMA/fms_fsdp/__init__.py -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .training import train_config 2 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/config/training.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union 3 | 4 | 5 | @dataclass 6 | class train_config: 7 | # model 8 | model_variant: str = "7b" 9 | ckpt_load_path: str = "/fsx/output/ckpt" 10 | ckpt_save_path: str = "/fsx/output/ckpt" 11 | 12 | # dataset and dataloader 13 | data_path: str = "/fsx/data" 14 | seq_length: int = 4096 15 | vocab_size: int = 32000 16 | bos_token: int = 1 17 | eos_token: int = 2 18 | 19 | # fsdp policies 20 | sharding_strategy: str = "hsdp" 21 | fsdp_activation_checkpointing: bool = False 22 | selective_checkpointing: Union[float, str] = 1 # percentage of blocks to apply ac 23 | mixed_precision: bool = True 24 | low_cpu_fsdp: bool = False 25 | 26 | # training spec 27 | batch_size: int = 2 28 | num_steps: int = 1000000 29 | training_stage: str = "initial" 30 | learning_rate: float = 3e-4 31 | grad_clip_thresh: float = 1.0 32 | seed: int = 2023 33 | 34 | # profiling 35 | use_profiler: bool = False 36 | profiler_rank0_only: bool = True 37 | 38 | # logging 39 | report_interval: int = 10 40 | checkpoint_interval: int = 10000 41 | tracker: Optional[str] = None # None, "wandb", "aim" 42 | tracker_dir: str = "/fsx/aim_logs/llama" 43 | tracker_project_name: str = "llama" # project name for a group of runs 44 | tracker_run_name: Optional[str] = None 45 | tracker_run_id: Optional[str] = None # run id, for job resume purpose 46 | 47 | # compile 48 | use_torch_compile: bool = True 49 | 50 | # dyt set up 51 | attn_alpha_init_value: float = 1.0 52 | ffn_alpha_init_value: float = 1.0 53 | dec_alpha_init_value: float = 1.0 54 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiachenzhu/DyT/aab5dde0bc1bdd4410f687dad404c87b31808f90/other_tasks/LLaMA/fms_fsdp/models/__init__.py -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/models/llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from dataclasses import dataclass 4 | from typing import Any, List, Mapping, Optional, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from fms.distributed.strategy import ( 10 | DistributedStrategy, 11 | NoOpStrategy, 12 | TensorParallelStrategy, 13 | ) 14 | from fms.modules.attention import MultiHeadAttention 15 | from fms.modules.embedding import WordEmbedding 16 | from fms.modules.feedforward import GatedLinearUnit 17 | from fms.modules.positions import RotaryEmbedding 18 | from fms.utils.activation import str_to_activation 19 | from fms.utils.config import ModelConfig 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | # params emb_dim heads layers lr 26 | # 7B 4096 32 32 3.0E-04 27 | # 13B 5120 40 40 3.0E-04 28 | # 33B 6656 52 60 1.5.E-04 29 | # 65B 8192 64 80 1.5.E-04 30 | 31 | 32 | class LayerNormParameterized(nn.Module): 33 | def __init__(self, normalized_shape, alpha_init_value): 34 | super(LayerNormParameterized, self).__init__() 35 | self.normalized_shape = normalized_shape 36 | self.alpha_init_value = alpha_init_value 37 | 38 | self.alpha = nn.Parameter(torch.empty(1)) 39 | self.weight = nn.Parameter(torch.empty(normalized_shape)) 40 | 41 | def reset_parameters(self): 42 | self.alpha.data.fill_(self.alpha_init_value) 43 | self.weight.data.fill_(1) 44 | 45 | def forward(self, x): 46 | return self.weight * torch.tanh(self.alpha * x) 47 | 48 | def extra_repr(self): 49 | return f"normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}" 50 | 51 | 52 | @dataclass 53 | class LLaMAConfig(ModelConfig): 54 | src_vocab_size: int = 32_000 # can be set by tokenizer 55 | emb_dim: int = 4096 56 | norm_eps: float = 1e-5 57 | nheads: int = 32 58 | kvheads: int = 0 59 | nlayers: int = 32 60 | pad_id: int = -1 61 | hidden_grow_factor: float = 8 / 3 62 | multiple_of: int = 256 63 | activation_fn: str = "swish" 64 | p_dropout: float = 0.0 65 | max_expected_seq_len: int = 4096 66 | ntk_scaling: bool = False 67 | attn_bias: bool = False 68 | mlp_bias: bool = False 69 | tie_heads: bool = False 70 | rope_theta: float = 10_000.0 71 | linear_config: Optional[Mapping[str, Any]] = None 72 | fused_weights: bool = True 73 | attn_alpha_init_value: float = 1.0 74 | ffn_alpha_init_value: float = 1.0 75 | dec_alpha_init_value: float = 1.0 76 | 77 | 78 | class LLaMABlock(nn.Module): 79 | def __init__(self, config: LLaMAConfig, rotary_emb: RotaryEmbedding): 80 | super(LLaMABlock, self).__init__() 81 | self.config = config 82 | emb_kq = self.config.emb_dim // self.config.nheads 83 | emb_v = self.config.emb_dim // self.config.nheads 84 | 85 | self.ln = LayerNormParameterized( 86 | self.config.emb_dim, 87 | self.config.attn_alpha_init_value, 88 | ) 89 | self.ff_ln = LayerNormParameterized( 90 | self.config.emb_dim, 91 | self.config.ffn_alpha_init_value, 92 | ) 93 | 94 | if self.config.kvheads == 0: 95 | kvheads = self.config.nheads 96 | else: 97 | kvheads = self.config.kvheads 98 | assert self.config.nheads % self.config.kvheads == 0 99 | 100 | self.attn = MultiHeadAttention( 101 | self.config.emb_dim, 102 | emb_kq, 103 | emb_v, 104 | self.config.nheads, 105 | kvheads, 106 | p_dropout=self.config.p_dropout, 107 | use_bias=self.config.attn_bias, 108 | position_encoder=rotary_emb, 109 | fused=self.config.fused_weights, 110 | linear_config=self.config.linear_config, 111 | ) 112 | self.ff_sub_layer = GatedLinearUnit( 113 | self.config.emb_dim, 114 | hidden_grow_factor=self.config.hidden_grow_factor, 115 | multiple_of=self.config.multiple_of, 116 | activation_fn=str_to_activation(self.config.activation_fn), 117 | p_dropout=self.config.p_dropout, 118 | use_bias=self.config.mlp_bias, 119 | fused=self.config.fused_weights, 120 | linear_config=self.config.linear_config, 121 | ) 122 | 123 | if self.config.p_dropout != 0: 124 | self.dropout = nn.Dropout(self.config.p_dropout) 125 | 126 | def forward( 127 | self, 128 | x, 129 | *, 130 | mask=None, 131 | position_ids=None, 132 | past_key_value_state=None, 133 | use_cache=False, 134 | is_causal_mask=False, 135 | attn_algorithm=None, 136 | ): 137 | # if the cache is not empty, we need to get the kv cache for self and cross attention 138 | self_attn_past_key_value = past_key_value_state 139 | # if past_key_value_state is not None: 140 | # self_attn_past_key_value = past_key_value_state[:2] 141 | # else: 142 | # self_attn_past_key_value = None 143 | 144 | # first we do MHA and Add&Norm 145 | residual = x 146 | x = self.ln(x) 147 | x = self.attn( 148 | q=x, 149 | mask=mask, 150 | position_ids=position_ids, 151 | attn_algorithm=attn_algorithm, 152 | past_key_value_state=self_attn_past_key_value, 153 | use_cache=use_cache, 154 | is_self=True, 155 | is_causal_mask=is_causal_mask, 156 | ) 157 | cache = None 158 | if use_cache: 159 | x, cache = x 160 | if self.config.p_dropout != 0: 161 | x = self.dropout(x) 162 | # residual connection 163 | x = x + residual 164 | 165 | # then we do FF and Add&Norm 166 | residual = x 167 | x = self.ff_ln(x) 168 | x = self.ff_sub_layer(x) 169 | if self.config.p_dropout != 0: 170 | x = self.dropout(x) 171 | # another residual 172 | x = x + residual 173 | 174 | if use_cache: 175 | return (x, cache) 176 | else: 177 | return x 178 | 179 | 180 | class LLaMA(nn.Module): 181 | def __init__( 182 | self, 183 | config: Optional[LLaMAConfig] = None, 184 | distributed_strategy: DistributedStrategy = NoOpStrategy, 185 | **kwargs, 186 | ): 187 | super(LLaMA, self).__init__() 188 | if config is not None: 189 | self.config = config 190 | else: 191 | self.config = LLaMAConfig() 192 | self.config = self.config.updated(**kwargs) 193 | self.distributed_strategy = distributed_strategy 194 | 195 | self.width = self.config.emb_dim 196 | self.pad_id = self.config.pad_id 197 | self.max_expected_seq_len = self.config.max_expected_seq_len 198 | 199 | shared = WordEmbedding( 200 | self.config.src_vocab_size, 201 | self.config.emb_dim, 202 | padding_idx=self.config.pad_id, 203 | abs_pos=False, 204 | reversible=True, 205 | tie_weights=self.config.tie_heads, 206 | bias=False, 207 | ) 208 | shared_scale = nn.Parameter(torch.empty(1)) 209 | self.shared_scale_init_value = math.sqrt(self.config.emb_dim) 210 | 211 | # TP does not work with tied weights 212 | if ( 213 | not isinstance(self.distributed_strategy, TensorParallelStrategy) 214 | or not self.config.tie_heads 215 | ): 216 | self.shared = self.distributed_strategy.distribute_module(shared) 217 | self.shared_scale = self.distributed_strategy.distribute_module(shared_scale) 218 | else: 219 | logger.warning( 220 | "You're using TP on a model with tied weights between head and embedding. " 221 | "The tied weights won't be sharded, which can result in unexpected OOMs." 222 | ) 223 | self.shared = shared 224 | 225 | self.rot_emb = RotaryEmbedding( 226 | dim=self.config.emb_dim // self.config.nheads, 227 | ntk_scaling=self.config.ntk_scaling, 228 | max_seq_len=self.config.max_expected_seq_len, 229 | ratio=self.config.rope_theta, 230 | ) 231 | # RoPE init 232 | for device in set( 233 | [param.device for param in self.parameters()] 234 | + [buffer.device for buffer in self.buffers()] 235 | ): 236 | self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) 237 | 238 | layers = [] 239 | for i in range(self.config.nlayers): 240 | block: nn.Module = LLaMABlock(self.config, self.rot_emb) 241 | block = self.distributed_strategy.distribute_layer(block, i) 242 | layers.append(block) 243 | self.layers = nn.ModuleList(layers) 244 | 245 | dec_norm = LayerNormParameterized( 246 | self.config.emb_dim, 247 | self.config.dec_alpha_init_value, 248 | ) 249 | self.dec_norm = self.distributed_strategy.distribute_module( 250 | dec_norm, final_layers=True 251 | ) 252 | 253 | if self.config.p_dropout: 254 | self.dropout = nn.Dropout(self.config.p_dropout) 255 | 256 | def get_config(self) -> LLaMAConfig: 257 | return self.config 258 | 259 | @classmethod 260 | def from_config(cls, config: LLaMAConfig) -> "LLaMA": 261 | return cls(config) 262 | 263 | def reset_parameters(self): 264 | # Call reset_parameters for relevant sub-layers 265 | for m in self.modules(): 266 | if ( 267 | isinstance(m, MultiHeadAttention) 268 | or isinstance(m, WordEmbedding) 269 | or isinstance(m, GatedLinearUnit) 270 | or isinstance(m, LayerNormParameterized) 271 | ): 272 | m.reset_parameters() 273 | 274 | if isinstance(m, LLaMA): 275 | m.shared_scale.data.fill_(self.shared_scale_init_value) 276 | 277 | def validate_reset_parameters(self): 278 | # Verifies that the above self.reset_parameters() executed correctly. 279 | # This may not always be the case for distributed settings with sharded tensors, 280 | # such as FSDP or TP. Note that performing this check may require unsharding / 281 | # re-materializing the full model on a single rank to access the underlying tensors. 282 | tolerance = 1e-3 283 | 284 | def check_close(x): 285 | assert x.mean().abs() < tolerance 286 | assert x.std().sub(0.02).abs() < tolerance 287 | 288 | with torch.no_grad(): 289 | for p in self.parameters(): 290 | assert p.isnan().int().sum() == 0 291 | assert p.isinf().int().sum() == 0 292 | for m in self.modules(): 293 | if isinstance(LayerNormParameterized): 294 | if m.elementwise_scale: 295 | assert m.weight.sum() == m.weight.numel() 296 | if m.elementwise_shift: 297 | assert m.bias.add(1).sum() == m.bias.numel() 298 | elif isinstance(WordEmbedding): 299 | check_close(m.emb.weight) 300 | check_close(m.head.weight) 301 | elif isinstance(GatedLinearUnit): 302 | check_close(m.w1.weight) 303 | check_close(m.w2.weight) 304 | check_close(m.wg.weight) 305 | elif isinstance(MultiHeadAttention): 306 | check_close(m.query.weight) 307 | check_close(m.key.weight) 308 | check_close(m.value.weight) 309 | check_close(m.dense.weight) 310 | 311 | def _clean_up_rot_emb_cache( 312 | self, 313 | cached_freqs: dict[Optional[torch.device], dict[int, torch.Tensor]], 314 | max_seq_len_cached: dict[Optional[torch.device], int], 315 | ): 316 | # remove meta tensors from cached_freqs 317 | for dev in list(cached_freqs.keys()): 318 | for alp in list(cached_freqs[dev].keys()): 319 | if cached_freqs[dev][alp].device == torch.device("meta"): 320 | del cached_freqs[dev][alp] 321 | if len(cached_freqs[dev]) == 0: 322 | del cached_freqs[dev] 323 | del max_seq_len_cached[dev] 324 | 325 | def post_init(self): 326 | # This function is called in `get_model` after the model is 327 | # fully initalized on the correct device 328 | 329 | # if this model ties weights, they are tied here 330 | if self.config.tie_heads: 331 | # handle assignment of non-meta weights to meta parameters 332 | if self.shared.head.weight.device == torch.device("meta"): 333 | self.shared.head.weight = self.shared.emb.weight 334 | else: 335 | self.shared.emb.weight = self.shared.head.weight 336 | 337 | self._clean_up_rot_emb_cache( 338 | self.rot_emb.cached_freqs, 339 | self.rot_emb.max_seq_len_cached, 340 | ) 341 | 342 | # init RoPE on the right device(s) 343 | for device in set( 344 | [param.device for param in self.parameters()] 345 | + [buffer.device for buffer in self.buffers()] 346 | ): 347 | self.rot_emb.compute_freqs_cis(device, self.config.max_expected_seq_len) 348 | 349 | def _helper( 350 | self, 351 | x_in, 352 | mask=None, 353 | position_ids=None, 354 | past_key_value_states=None, 355 | use_cache=False, 356 | attn_algorithm=None, 357 | ): 358 | # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified 359 | # x_in: batch_size x seq_len 360 | # mask: batch_size x seq_len x seq_len 361 | # bias: nheads x seq_len x seq_len 362 | if past_key_value_states is None or len(past_key_value_states) == 0: 363 | past_key_value_states = [None for _ in range(len(self.layers))] 364 | 365 | qlen = x_in.size(1) 366 | klen = x_in.size(1) 367 | 368 | # if we are using the cache, the key length needs to be extended with the past keys length 369 | if use_cache and past_key_value_states[0] is not None: 370 | klen += past_key_value_states[0][0].size(-2) 371 | 372 | # if mask is none, we need to specify causal mask 373 | if mask is None: 374 | # we are caching and can assume all 1s in the mask 375 | if use_cache and klen != 1 and qlen == 1: 376 | # b x h x qlen x kvlen 377 | is_causal_mask = False 378 | else: 379 | is_causal_mask = True 380 | else: 381 | is_causal_mask = False 382 | 383 | x_in = self.shared(x_in) * self.shared_scale 384 | 385 | # this is the output cache for all the decoder layers 386 | present_key_value_states = [] 387 | 388 | for i, layer in enumerate(self.layers): 389 | output = layer( 390 | x=x_in, 391 | mask=mask, 392 | position_ids=position_ids, 393 | past_key_value_state=past_key_value_states[i], 394 | use_cache=use_cache, 395 | is_causal_mask=is_causal_mask, 396 | attn_algorithm=attn_algorithm, 397 | ) 398 | 399 | if use_cache: 400 | x_in, present_key_value_state = output 401 | present_key_value_states.append(present_key_value_state) 402 | 403 | else: 404 | x_in = output 405 | 406 | dec_out = x_in 407 | dec_out = self.dec_norm(dec_out) 408 | if self.config.p_dropout: 409 | dec_out = self.dropout(dec_out) 410 | 411 | return dec_out, present_key_value_states 412 | 413 | def forward( 414 | self, 415 | x: torch.Tensor, 416 | mask: Optional[torch.Tensor] = None, 417 | position_ids: Optional[torch.Tensor] = None, 418 | past_key_value_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, 419 | use_cache: bool = False, 420 | only_last_token: bool = False, 421 | attn_algorithm: Optional[str] = None, 422 | ): 423 | output, cache = self._helper( 424 | x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm 425 | ) 426 | 427 | if only_last_token: 428 | output = output[:, -1, :] 429 | preds = self.shared(output, reverse=True) 430 | 431 | if use_cache: 432 | return preds, cache 433 | else: 434 | return preds 435 | 436 | 437 | def param_init_function(module): 438 | if ( 439 | isinstance(module, MultiHeadAttention) 440 | or isinstance(module, WordEmbedding) 441 | or isinstance(module, GatedLinearUnit) 442 | or isinstance(module, LayerNormParameterized) 443 | ): 444 | module.to_empty(device=torch.cuda.current_device()) 445 | with torch.no_grad(): 446 | module.reset_parameters() 447 | 448 | if isinstance(module, LLaMA): 449 | module.to_empty(device=torch.cuda.current_device(), recurse=False) 450 | with torch.no_grad(): 451 | module.shared_scale.data.fill_(module.shared_scale_init_value) 452 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from .ac_handler import apply_fsdp_checkpointing 2 | from .mixed_precision import * 3 | from .wrapping import get_wrapper 4 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/policies/ac_handler.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 4 | CheckpointImpl, 5 | apply_activation_checkpointing, 6 | checkpoint_wrapper, 7 | ) 8 | 9 | 10 | non_reentrant_wrapper = partial( 11 | checkpoint_wrapper, 12 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 13 | ) 14 | 15 | 16 | def apply_fsdp_checkpointing(model, block, p): 17 | """ 18 | Apply selective activation checkpointing. 19 | 20 | Selectivity is defined as a percentage p, which means we apply ac 21 | on p of the total blocks. p is a floating number in the range of 22 | [0, 1]. 23 | 24 | Some examples: 25 | p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False` 26 | p = 1: apply ac on every block. i.e. "full ac". 27 | p = 1/2: [ac, no-ac, ac, no-ac, ...] 28 | p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...] 29 | p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...] 30 | Since blocks are homogeneous, we make ac blocks evenly spaced among 31 | all blocks. 32 | 33 | Implementation: 34 | For a given ac ratio p, we should essentially apply ac on every "1/p" 35 | blocks. The first ac block can be as early as the 0th block, or as 36 | late as the "1/p"th block, and we pick the middle one: (0.5p)th block. 37 | Therefore, we are essentially to apply ac on: 38 | (0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course, 39 | with these values rounding to integers. 40 | Since ac is applied recursively, we can simply use the following math 41 | in the code to apply ac on corresponding blocks. 42 | """ 43 | block_idx = 0 44 | cut_off = 1 / 2 45 | # when passing p as a fraction number (e.g. 1/3), it will be interpreted 46 | # as a string in argv, thus we need eval("1/3") here for fractions. 47 | p = eval(p) if isinstance(p, str) else p 48 | 49 | def selective_checkpointing(submodule): 50 | nonlocal block_idx 51 | nonlocal cut_off 52 | 53 | if isinstance(submodule, block): 54 | block_idx += 1 55 | if block_idx * p >= cut_off: 56 | cut_off += 1 57 | return True 58 | return False 59 | 60 | apply_activation_checkpointing( 61 | model, 62 | checkpoint_wrapper_fn=non_reentrant_wrapper, 63 | check_fn=selective_checkpointing, 64 | ) 65 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/policies/mixed_precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributed.fsdp import MixedPrecision 3 | 4 | 5 | fpSixteen = MixedPrecision( 6 | param_dtype=torch.float16, 7 | reduce_dtype=torch.float16, 8 | buffer_dtype=torch.float16, 9 | ) 10 | 11 | bfSixteen = MixedPrecision( 12 | param_dtype=torch.bfloat16, 13 | reduce_dtype=torch.bfloat16, 14 | buffer_dtype=torch.bfloat16, 15 | ) 16 | 17 | bfSixteen_working = MixedPrecision( 18 | param_dtype=torch.float32, 19 | reduce_dtype=torch.bfloat16, 20 | buffer_dtype=torch.bfloat16, 21 | ) 22 | 23 | fp32_policy = MixedPrecision( 24 | param_dtype=torch.float32, 25 | reduce_dtype=torch.float32, 26 | buffer_dtype=torch.float32, 27 | ) 28 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/policies/wrapping.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 4 | 5 | 6 | def get_wrapper(block): 7 | auto_wrap_policy = functools.partial( 8 | transformer_auto_wrap_policy, 9 | transformer_layer_cls={ 10 | block, 11 | }, 12 | ) 13 | 14 | return auto_wrap_policy 15 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiachenzhu/DyT/aab5dde0bc1bdd4410f687dad404c87b31808f90/other_tasks/LLaMA/fms_fsdp/utils/__init__.py -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/utils/checkpointing_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from pathlib import Path 5 | 6 | import torch 7 | from torch.distributed._shard.checkpoint import ( 8 | FileSystemReader, 9 | FileSystemWriter, 10 | load as load_state_dict, 11 | save as save_state_dict, 12 | ) 13 | from torch.distributed.checkpoint.default_planner import ( 14 | DefaultLoadPlanner, 15 | DefaultSavePlanner, 16 | ) 17 | from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict 18 | from torch.distributed.fsdp import FullStateDictConfig 19 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 20 | from torch.distributed.fsdp import StateDictType 21 | 22 | 23 | def get_latest(targdir, qualifier=lambda x: True, key=os.path.getctime): 24 | """ 25 | Fetch the full path of the latest file or folder written to target directory, 26 | subject to name passing the qualifier fn. 27 | Optional key fn can be used for custom sorting. 28 | Both functions take full path arguments. 29 | If directory is empty or nonexistent or no items qualify, return None. 30 | """ 31 | if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: 32 | latest = max( 33 | [ 34 | os.path.join(targdir, x) 35 | for x in os.listdir(targdir) 36 | if qualifier(os.path.join(targdir, x)) 37 | ], 38 | key=key, 39 | ) 40 | return latest 41 | return None 42 | 43 | 44 | def get_oldest(targdir, qualifier=lambda x: True, key=os.path.getctime): 45 | """ 46 | Fetch the full path of the oldest file or folder written to target directory, 47 | subject to name passing the qualifier fn. 48 | Optional key fn can be used for custom sorting. 49 | Both functions take full path arguments. 50 | If directory is empty or nonexistent or no items qualify, return None. 51 | """ 52 | if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: 53 | oldest = min( 54 | [ 55 | os.path.join(targdir, x) 56 | for x in os.listdir(targdir) 57 | if qualifier(os.path.join(targdir, x)) 58 | ], 59 | key=key, 60 | ) 61 | return oldest 62 | return None 63 | 64 | 65 | class Checkpointer: 66 | """ 67 | Manages the checkpoint directory. Saves new checkpoints and deletes old ones after the specified number are written. 68 | Also handles loading and saving of checkpoints in sharded and unsharded formats. 69 | Assumes model and optimizer inputs are in FSDP. 70 | ... 71 | Args 72 | ---- 73 | ckpdir : str 74 | Absolute path to desired save location. Creates a new 'checkpoints/' subfolder at that location. 75 | n_to_save : int 76 | Number of volatile checkpoints to maintain at any given time. 77 | parallel_mode : str 78 | Write sharded folder ckps (when sharded: 'fsdp' or 'hsdp') or unsharded file ckps (when sharded: 'ddp') 79 | report_fn : Callable or None 80 | Optional function for reporting or logging status updates. Expected to handle arbitrary *args, **kwargs. 81 | Defaults to self._selective_print(). 82 | model_auto_placement : bool 83 | Optional; If True, auto detect GPU device to move model to, as set in device mesh init 84 | 85 | Methods 86 | ------- 87 | save : keyword args -> str | None 88 | Saves dictionary of keyword arg key/value pairs to specified checkpoint directory, deleting old checkpoints 89 | as necessary. If a checkpoint is deleted, returns the filename of that checkpoint. 90 | load : 91 | See docstring for individual function below 92 | """ 93 | 94 | def __init__( 95 | self, 96 | ckpdir, 97 | n_to_save, 98 | parallel_mode, 99 | rank, 100 | local_rank, 101 | report_fn=None, 102 | model_auto_placement=False, 103 | ): 104 | self.max_ckps = n_to_save 105 | self.rank = rank 106 | self.local_rank = local_rank 107 | self.ckp_path = os.path.join(ckpdir, "checkpoints/") 108 | os.makedirs(self.ckp_path, exist_ok=True) 109 | self.p_mode = parallel_mode 110 | assert parallel_mode in ["fsdp", "hsdp", "ddp"] 111 | self.report = self._selective_print if report_fn is None else report_fn 112 | self.model_auto_placement = model_auto_placement 113 | 114 | def _selective_print(self, *args, **kwargs): 115 | if self.rank == 0: 116 | print(*args) 117 | for k, v in kwargs.items(): 118 | print(k, "=", v) 119 | 120 | def _cleanup(self): 121 | # Clean old checkpoints. Barrier to keep synchronization correct. 122 | file_to_remove = None 123 | if ( 124 | self.rank == 0 125 | and len([x for x in os.listdir(self.ckp_path) if "tmp" in x]) 126 | > self.max_ckps 127 | ): 128 | ckp_to_remove = Path( 129 | get_oldest(self.ckp_path, qualifier=lambda x: "tmp" in x) 130 | ) 131 | if os.path.isfile(ckp_to_remove): 132 | ckp_to_remove.unlink() 133 | else: 134 | shutil.rmtree(ckp_to_remove) 135 | return file_to_remove 136 | 137 | def _do_save(self, rank, local_rank): # , shard_group, replicate_group): 138 | if self.p_mode == "hsdp": 139 | return rank == local_rank 140 | else: 141 | return True 142 | # TODO: Distributed writing contingent upon the following fix: https://github.com/pytorch/pytorch/issues/104081 143 | # if not is_dist: 144 | # return (rank == local_rank) 145 | # else: 146 | # a = rank % shard_group.size() 147 | # b = rank // shard_group.size() 148 | # return True if a % replicate_group.size() == b else False 149 | # shard_group = model.process_group 150 | # replicate_group = model.__inter_node_state.process_group 151 | 152 | def _write(self, state_dict, loader_state, process_group, save_name, rank): 153 | os.makedirs(save_name, exist_ok=True) 154 | writer = FileSystemWriter(save_name, single_file_per_rank=True) 155 | if state_dict is not None: 156 | save_state_dict( 157 | state_dict=state_dict, 158 | storage_writer=writer, 159 | process_group=process_group, 160 | planner=DefaultSavePlanner(), 161 | ) 162 | if loader_state is not None: 163 | torch.save(loader_state, os.path.join(save_name, f"loader_rank_{rank}.pth")) 164 | 165 | def _validate_ckp_path(self, path): 166 | """Interpret path to appropriate checkpoint. If found, return modified path. If not found, return None.""" 167 | # Does path exist and is it non-empty? 168 | if os.path.exists(path): 169 | # Is this a file? 170 | if os.path.isfile(path): 171 | return path 172 | # Is this a sharded directory? 173 | elif "metadata.pth" in os.listdir(path): 174 | return path 175 | # Is this a path to a set of checkpoints? 176 | elif len(os.listdir(path)) > 0: 177 | latest = get_latest(path) 178 | if os.path.isfile(latest): 179 | return latest 180 | elif "metadata.pth" in os.listdir(latest): 181 | return latest 182 | return None 183 | 184 | def load( 185 | self, 186 | model, 187 | optimizer, 188 | dataloader, 189 | path="", 190 | reset_stepcount=False, 191 | strict=True, 192 | is_compiled=False, 193 | ): 194 | """ 195 | Handle checkpoint loading for model/optimizer/dataloader from given path, according to arguments. 196 | Defaults to save path for locating an appropriate checkpoint. If a path is provided, will use 197 | it only if no appropriate checkpoint is found in the save path (in which case it's a job restart). 198 | Reset_stepcount manually resets optimizer and dataloader states, and stat tracking. 199 | Strict determines whether to use strict loading or not FOR SINGLEFILE LOADING ONLY. 200 | Returns model, optimizer, dataloader, current step, and current tokens seen. 201 | """ 202 | is_resuming = False 203 | if self._validate_ckp_path(self.ckp_path) is not None: 204 | path = self.ckp_path 205 | is_resuming = True 206 | load_path = self._validate_ckp_path(path) 207 | if load_path is None: 208 | self.report( 209 | f"No valid checkpoint detected at {path}, starting from scratch." 210 | ) 211 | return model, optimizer, dataloader, 0, 0, False 212 | else: 213 | self.report(f"Prior checkpoint {load_path} detected.") 214 | model_load_time = time.time() 215 | if os.path.isfile(load_path): 216 | checkpoint_data = torch.load(load_path, map_location="cpu") 217 | if is_compiled: 218 | model._orig_mod.load_state_dict( 219 | checkpoint_data.get("model_state"), strict=strict 220 | ) 221 | else: 222 | model.load_state_dict( 223 | checkpoint_data.get("model_state"), strict=strict 224 | ) 225 | if self.model_auto_placement: 226 | model.to("cuda") 227 | else: 228 | model.to(self.local_rank) 229 | self.report( 230 | f"Checkpoint {load_path} is a single-file checkpoint containing only a model. Optimizer and dataloader are from scratch.", 231 | model_load_time=time.time() - model_load_time, 232 | ) 233 | return model, optimizer, dataloader, 0, 0, is_resuming 234 | else: 235 | # Load model 236 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 237 | state_dict = model.state_dict() 238 | model_ckp = {"model_state": state_dict} 239 | load_state_dict( 240 | state_dict=model_ckp, 241 | storage_reader=FileSystemReader(load_path), 242 | planner=DefaultLoadPlanner(), 243 | ) 244 | model.load_state_dict(model_ckp["model_state"]) 245 | if self.model_auto_placement: 246 | model.to("cuda") 247 | else: 248 | model.to(self.local_rank) 249 | self.report(model_load_time=time.time() - model_load_time) 250 | step = 0 251 | ntok = 0 252 | # Load metadata 253 | if is_resuming: 254 | metadata = torch.load(os.path.join(load_path, "metadata.pth")) 255 | step = metadata.get("step", 0) 256 | ntok = metadata.get("tokens_seen", 0) 257 | self.report("Metadata loaded", start_step=step, n_tokens_seen=ntok) 258 | # Load optimizer 259 | if optimizer is not None: 260 | optim_load_time = time.time() 261 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 262 | optim_state = load_sharded_optimizer_state_dict( 263 | model_state_dict=model.state_dict(), 264 | optimizer_key="optimizer_state", 265 | storage_reader=FileSystemReader(load_path), 266 | ) 267 | flattened_osd = FSDP.optim_state_dict_to_load( 268 | model, optimizer, optim_state["optimizer_state"] 269 | ) 270 | optimizer.load_state_dict(flattened_osd) 271 | self.report(optimizer_load_time=time.time() - optim_load_time) 272 | else: 273 | self.report("Skipping optimizer load, no optimizer provided.") 274 | # Load dataset 275 | if dataloader is not None: 276 | data_load_time = time.time() 277 | dataloader.load_state_dict( 278 | torch.load(os.path.join(load_path, f"loader_rank_{self.rank}.pth"), weights_only=True) 279 | ) 280 | self.report(dataset_load_time=time.time() - data_load_time) 281 | else: 282 | self.report("Skipping dataset load, no dataloader provided.") 283 | return model, optimizer, dataloader, step, ntok, is_resuming 284 | 285 | def save( 286 | self, 287 | step, 288 | model, 289 | optimizer, 290 | dataloader, 291 | **kwargs, 292 | ): 293 | # Note: metadata kwargs cannot contain any of: 294 | # (step, model, optimizer, dataloader) 295 | rank = self.rank 296 | save_time = time.time() 297 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 298 | model_state = model.state_dict() 299 | optim_state = FSDP.sharded_optim_state_dict(model, optimizer) 300 | dataloader_state = None if dataloader is None else dataloader.state_dict() 301 | 302 | save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp") 303 | state_dict = {"model_state": model_state, "optimizer_state": optim_state} 304 | if self._do_save(rank, self.local_rank): 305 | self._write( 306 | state_dict, dataloader_state, model.process_group, save_name, rank 307 | ) 308 | else: 309 | self._write(None, dataloader_state, None, save_name, rank) 310 | if rank == 0: 311 | metadata = kwargs 312 | metadata["step"] = step 313 | torch.save(metadata, os.path.join(save_name, "metadata.pth")) 314 | self.report( 315 | f"Checkpoint saved in {save_name}", model_save_time=time.time() - save_time 316 | ) 317 | 318 | return self._cleanup() 319 | 320 | def save_single_file( 321 | self, 322 | step, 323 | model, 324 | is_compiled=False, 325 | **kwargs, 326 | ): 327 | # Note: metadata kwargs cannot contain any of: 328 | # (step, model) 329 | save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp.pth") 330 | save_time = time.time() 331 | with FSDP.state_dict_type( 332 | model, 333 | StateDictType.FULL_STATE_DICT, 334 | FullStateDictConfig(offload_to_cpu=True, rank0_only=True), 335 | ): 336 | if is_compiled: 337 | model_state = model._orig_mod.state_dict() 338 | else: 339 | model_state = model.state_dict() 340 | if self.rank == 0: 341 | metadata = kwargs 342 | metadata["step"] = step 343 | metadata["model_state"] = model_state 344 | torch.save(metadata, save_name) 345 | self.report("Checkpoint written", model_save_time=time.time() - save_time) 346 | 347 | return self._cleanup() 348 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | from fms_fsdp.models.llama import LLaMAConfig 2 | 3 | from fms_fsdp.config import train_config 4 | 5 | 6 | def update_config(config, **kwargs): 7 | if isinstance(config, (tuple, list)): 8 | for c in config: 9 | update_config(c, **kwargs) 10 | else: 11 | for k, v in kwargs.items(): 12 | if hasattr(config, k): 13 | setattr(config, k, v) 14 | elif "." in k: 15 | config_name, param_name = k.split(".") 16 | if type(config).__name__ == config_name: 17 | if hasattr(config, param_name): 18 | setattr(config, param_name, v) 19 | else: 20 | print(f"Warning: {config_name} does not accept parameter: {k}") 21 | elif isinstance(config, train_config): 22 | print(f"Warning: unknown parameter {k}") 23 | 24 | 25 | def get_model_config(model_variant): 26 | if model_variant == "llama2_70b": 27 | model_config = LLaMAConfig( 28 | emb_dim=8192, 29 | nheads=64, 30 | kvheads=8, 31 | nlayers=80, 32 | multiple_of=4096, 33 | hidden_grow_factor=28672 / 8192, 34 | ) 35 | elif model_variant == "llama2_34b": 36 | model_config = LLaMAConfig( 37 | emb_dim=8192, 38 | nheads=64, 39 | kvheads=8, 40 | nlayers=48, 41 | hidden_grow_factor=22016 / 8192, 42 | ) 43 | elif model_variant == "llama2_13b": 44 | model_config = LLaMAConfig( 45 | emb_dim=5120, 46 | nheads=40, 47 | nlayers=40, 48 | hidden_grow_factor=13824 / 5120, 49 | ) 50 | elif model_variant == "llama2_7b": 51 | model_config = LLaMAConfig( 52 | emb_dim=4096, 53 | nheads=32, 54 | nlayers=32, 55 | hidden_grow_factor=11008 / 4096, 56 | ) 57 | elif model_variant == "llama2_1.4b": 58 | model_config = LLaMAConfig( 59 | emb_dim=2048, 60 | nheads=16, 61 | nlayers=24, 62 | hidden_grow_factor=3, 63 | kvheads=4, 64 | ) 65 | elif model_variant == "llama3_8b": 66 | model_config = LLaMAConfig( 67 | src_vocab_size=128256, 68 | emb_dim=4096, 69 | nheads=32, 70 | kvheads=8, 71 | nlayers=32, 72 | hidden_grow_factor=3.5, 73 | max_expected_seq_len=8192, 74 | rope_theta=500000.0, 75 | ) 76 | elif model_variant == "llama3_8b_4k": 77 | model_config = LLaMAConfig( 78 | src_vocab_size=128256, 79 | emb_dim=4096, 80 | nheads=32, 81 | kvheads=8, 82 | nlayers=32, 83 | hidden_grow_factor=3.5, 84 | max_expected_seq_len=4096, 85 | rope_theta=500000.0, 86 | ) 87 | elif model_variant == "llama3_1.8b": 88 | model_config = LLaMAConfig( 89 | src_vocab_size=128256, 90 | emb_dim=2048, 91 | nheads=16, 92 | kvheads=8, 93 | nlayers=24, 94 | hidden_grow_factor=3.5, 95 | max_expected_seq_len=8192, 96 | rope_theta=500000.0, 97 | ) 98 | elif model_variant == "llama3_1.8b_4k": 99 | model_config = LLaMAConfig( 100 | src_vocab_size=128256, 101 | emb_dim=2048, 102 | nheads=16, 103 | kvheads=8, 104 | nlayers=24, 105 | hidden_grow_factor=3.5, 106 | max_expected_seq_len=4096, 107 | rope_theta=500000.0, 108 | ) 109 | elif model_variant == "llama3_3.2b": 110 | model_config = LLaMAConfig( 111 | src_vocab_size=128256, 112 | emb_dim=3072, 113 | nheads=24, 114 | kvheads=8, 115 | nlayers=24, 116 | hidden_grow_factor=8 / 3, 117 | max_expected_seq_len=8192, 118 | rope_theta=500000.0, 119 | ) 120 | elif model_variant == "llama3_3.2b_4k": 121 | model_config = LLaMAConfig( 122 | src_vocab_size=128256, 123 | emb_dim=3072, 124 | nheads=24, 125 | kvheads=8, 126 | nlayers=24, 127 | hidden_grow_factor=8 / 3, 128 | max_expected_seq_len=4096, 129 | rope_theta=500000.0, 130 | ) 131 | elif model_variant == "llama3_70b": 132 | model_config = LLaMAConfig( 133 | src_vocab_size=128256, 134 | emb_dim=8192, 135 | nheads=64, 136 | kvheads=8, 137 | nlayers=80, 138 | hidden_grow_factor=3.5, 139 | max_expected_seq_len=8192, 140 | rope_theta=500000.0, 141 | ) 142 | elif model_variant == "llama3_70b_4k": 143 | model_config = LLaMAConfig( 144 | src_vocab_size=128256, 145 | emb_dim=8192, 146 | nheads=64, 147 | kvheads=8, 148 | nlayers=80, 149 | hidden_grow_factor=3.5, 150 | max_expected_seq_len=4096, 151 | rope_theta=500000.0, 152 | ) 153 | elif model_variant == "llama3_194m_4k": 154 | model_config = LLaMAConfig( 155 | src_vocab_size=128256, 156 | emb_dim=1024, 157 | nheads=8, 158 | nlayers=10, 159 | max_expected_seq_len=4096, 160 | rope_theta=500000.0, 161 | ) 162 | elif model_variant == "mamba_9.8b": 163 | model_config = { 164 | "d_model": 4096, 165 | "d_intermediate": 14336, 166 | "n_layer": 32, 167 | "vocab_size": 128256, 168 | "ssm_cfg": {"layer": "Mamba2"}, 169 | "attn_layer_idx": [9, 18, 27], 170 | "attn_cfg": { 171 | "causal": True, 172 | "d_conv": 0, 173 | "head_dim": 128, 174 | "num_heads": 32, 175 | "num_heads_kv": 8, 176 | "out_proj_bias": False, 177 | "qkv_proj_bias": False, 178 | "rotary_emb_dim": 64, 179 | }, 180 | "rms_norm": True, 181 | "residual_in_fp32": True, 182 | "fused_add_norm": True, 183 | "pad_vocab_size_multiple": 16, 184 | "tie_embeddings": False, 185 | } 186 | else: 187 | raise ValueError(f"model variant {model_variant} not supported.") 188 | 189 | return model_config 190 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import pyarrow as pa 5 | 6 | import torch 7 | 8 | class DistributedDataset(torch.utils.data.IterableDataset): 9 | def __init__(self, root_dir, rank, world_size, batch_size, seq_length, bos_token, eos_token): 10 | super().__init__() 11 | self.root_dir = root_dir 12 | self.rank = rank 13 | self.world_size = world_size 14 | self.batch_size = batch_size 15 | self.seq_length = seq_length 16 | self.bos_token = bos_token 17 | self.eos_token = eos_token 18 | 19 | # the root directory should contain files named rank_0.arrow, rank_1.arrow, ... 20 | # the total number of files should be divisible by the world size 21 | # each rank will read num_files_per_rank files such that rank 0 reads rank_0.arrow, rank_{0 + world_size}.arrow, ... 22 | num_files = len(glob.glob(os.path.join(self.root_dir, "rank_*.arrow"))) 23 | assert num_files % self.world_size == 0 24 | self.num_files_per_rank = num_files // self.world_size 25 | self.readers = [ 26 | pa.ipc.open_file(pa.memory_map( 27 | os.path.join(self.root_dir, f"rank_{self.rank + i * self.world_size}.arrow") 28 | )) 29 | for i in range(self.num_files_per_rank) 30 | ] 31 | 32 | # state variables 33 | self.buffer = [] 34 | self.current_reader_idx = 0 35 | self.current_batch_idx = 0 36 | 37 | def __iter__(self): 38 | for reader_idx in range(self.current_reader_idx, len(self.readers)): 39 | self.current_reader_idx = reader_idx 40 | 41 | reader = self.readers[reader_idx] 42 | for batch_idx in range(self.current_batch_idx, reader.num_record_batches): 43 | self.current_batch_idx = batch_idx 44 | 45 | sample = reader.get_batch(batch_idx)['input_ids'].to_pylist() 46 | self.buffer += [self.bos_token] + sample + [self.eos_token] 47 | 48 | while len(self.buffer) >= self.batch_size * self.seq_length + 1: 49 | yield torch.LongTensor(self.buffer[:self.batch_size * self.seq_length]).reshape(self.batch_size, self.seq_length), \ 50 | torch.LongTensor(self.buffer[1:self.batch_size * self.seq_length + 1]).reshape(self.batch_size, self.seq_length) 51 | self.buffer = self.buffer[self.batch_size * self.seq_length:] 52 | 53 | self.current_batch_idx = 0 54 | 55 | def state_dict(self): 56 | """Return a dictionary containing the state of the dataset.""" 57 | return { 58 | 'buffer': self.buffer, 59 | 'current_reader_idx': self.current_reader_idx, 60 | 'current_batch_idx': self.current_batch_idx, 61 | } 62 | 63 | def load_state_dict(self, state_dict): 64 | """Load the state of the dataset.""" 65 | self.buffer = state_dict['buffer'] 66 | self.current_reader_idx = state_dict['current_reader_idx'] 67 | self.current_batch_idx = state_dict['current_batch_idx'] 68 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/fms_fsdp/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict 3 | from functools import partial 4 | 5 | 6 | try: 7 | import packaging.version 8 | except ImportError: 9 | from pkg_resources import packaging # type: ignore 10 | 11 | import time 12 | from datetime import timedelta 13 | 14 | import torch.cuda.nccl as nccl 15 | import torch.distributed as dist 16 | from torch.distributed.fsdp import ShardingStrategy 17 | 18 | from fms_fsdp.policies import * 19 | 20 | 21 | def train( 22 | cfg, 23 | model, 24 | local_rank, 25 | rank, 26 | train_loader, 27 | optimizer, 28 | scheduler, 29 | profiler, 30 | checkpointer, 31 | start_step, 32 | tokens_seen, 33 | ): 34 | if cfg.tracker: 35 | if cfg.tracker not in ["wandb", "aim"]: 36 | raise ValueError(f"tracker {cfg.tracker} not supported.") 37 | tracker_dir = cfg.tracker_dir 38 | project_name = cfg.tracker_project_name 39 | run_name = cfg.tracker_run_name 40 | run_id = cfg.tracker_run_id 41 | 42 | if cfg.tracker == "wandb": 43 | try: 44 | import wandb # type: ignore 45 | except ImportError: 46 | raise ImportError("tracker is set to wandb but wandb is not installed.") 47 | if rank == 0: 48 | print(f"--> wandb is enabled!") 49 | try: 50 | wandb.init( 51 | project=project_name, 52 | name=run_name, 53 | dir=tracker_dir, 54 | resume="allow", 55 | id=run_id, 56 | ) 57 | except wandb.errors.UsageError: 58 | raise ValueError( 59 | "wandb failed to init, did you pass your wandb api key via WANDB_API_KEY?" 60 | ) 61 | wandb.config = asdict(cfg) 62 | 63 | if cfg.tracker == "aim": 64 | try: 65 | from aim import Run # type: ignore 66 | except ImportError: 67 | raise ImportError("tracker is set to aim but aim is not installed.") 68 | if rank == 0: 69 | print(f"--> aim is enabled!") 70 | run = Run( 71 | experiment=project_name, 72 | repo=tracker_dir, 73 | run_hash=run_id, 74 | ) 75 | run["hparams"] = asdict(cfg) 76 | 77 | model.train() 78 | ddp_stats = torch.zeros(3).to(local_rank) 79 | 80 | start = time.time() 81 | loop_start = time.time() 82 | train_loss = -1 83 | for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1): 84 | if batch_idx > cfg.num_steps: 85 | break 86 | input = input.to(local_rank) 87 | label = label.to(local_rank) 88 | 89 | optimizer.zero_grad() 90 | output = model(input) 91 | output = output.logits if hasattr(output, "logits") else output 92 | ce_loss = torch.nn.CrossEntropyLoss() 93 | loss = ce_loss(output.view(-1, output.size(-1)), label.view(-1).long()) 94 | 95 | loss.backward() 96 | ddp_stats[1] += model.clip_grad_norm_(cfg.grad_clip_thresh).item() 97 | optimizer.step() 98 | scheduler.step() 99 | 100 | ddp_stats[0] += loss.item() 101 | ddp_stats[2] += 1 102 | 103 | if profiler: 104 | profiler.step() 105 | 106 | if batch_idx % cfg.report_interval == 0: 107 | dist.all_reduce(ddp_stats, op=dist.ReduceOp.SUM) 108 | train_loss = ddp_stats[0] / ddp_stats[2] 109 | g_norm = ddp_stats[1] / ddp_stats[2] 110 | elapsed_time = time.time() - loop_start 111 | world_size = int(os.environ["WORLD_SIZE"]) 112 | new_tokens_seen = ( 113 | (batch_idx - start_step) * world_size * cfg.batch_size * cfg.seq_length 114 | ) 115 | if rank == 0: 116 | total_tokens_seen = tokens_seen + new_tokens_seen 117 | current_loss = train_loss.item() 118 | current_lr = scheduler.get_last_lr()[0] 119 | current_gnorm = g_norm.item() 120 | current_step_time = (time.time() - start) / cfg.report_interval 121 | overall_step_time = elapsed_time / (batch_idx - start_step) 122 | current_throughput = int( 123 | cfg.batch_size * cfg.seq_length / current_step_time 124 | ) 125 | overall_throughput = int( 126 | cfg.batch_size * cfg.seq_length / overall_step_time 127 | ) 128 | reserved_mem = torch.cuda.max_memory_reserved( 129 | device=torch.cuda.current_device() 130 | ) 131 | allocated_mem = torch.cuda.max_memory_allocated( 132 | device=torch.cuda.current_device() 133 | ) 134 | 135 | print("step:", batch_idx) 136 | print("loss:", current_loss) 137 | print("LR:", current_lr) 138 | print("tokens seen:", total_tokens_seen) 139 | print("gradient norm:", current_gnorm) 140 | print("reserved memory:", reserved_mem) 141 | print("allocated memory:", allocated_mem) 142 | print("current step time:", current_step_time) 143 | print("overall step time:", overall_step_time) 144 | print("current token per gpu per sec:", current_throughput) 145 | print("overall token per gpu per sec:", overall_throughput) 146 | print( 147 | "overall token per day:", 148 | int(new_tokens_seen / elapsed_time * 3600 * 24), 149 | ) 150 | if cfg.tracker: 151 | vals_to_track = { 152 | "learning rate": current_lr, 153 | "loss": current_loss, 154 | "gradient norm": current_gnorm, 155 | "token seen": total_tokens_seen, 156 | "current throughput (token per gpu per sec)": current_throughput, 157 | "overall throughput (token per gpu per sec)": overall_throughput, 158 | "gpu reserved memory": reserved_mem, 159 | "gpu allocated memory": allocated_mem, 160 | } 161 | if cfg.tracker == "wandb": 162 | tracker_fn = wandb.log 163 | elif cfg.tracker == "aim": 164 | tracker_fn = run.track 165 | tracker_fn(vals_to_track, step=batch_idx) 166 | 167 | start = time.time() 168 | ddp_stats.zero_() 169 | torch.cuda.reset_peak_memory_stats(device=torch.cuda.current_device()) 170 | 171 | if batch_idx % cfg.checkpoint_interval == 0: 172 | checkpointer.save( 173 | batch_idx, 174 | model, 175 | optimizer, 176 | train_loader, 177 | tokens_seen=tokens_seen + new_tokens_seen, 178 | ) 179 | 180 | return train_loss 181 | 182 | 183 | def setup(): 184 | dist.init_process_group("nccl", timeout=timedelta(seconds=60 * 60)) 185 | 186 | 187 | def setup_environ_flags(): 188 | os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) 189 | os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = str(1) 190 | 191 | 192 | def get_mixed_precision_policy(cfg, rank): 193 | verify_bfloat_support = ( 194 | torch.version.cuda 195 | and torch.cuda.is_bf16_supported() 196 | and packaging.version.parse(torch.version.cuda).release >= (11, 0) 197 | and dist.is_nccl_available() 198 | and nccl.version() >= (2, 10) 199 | ) 200 | 201 | if cfg.mixed_precision: 202 | bf16_ready = verify_bfloat_support 203 | if bf16_ready: 204 | mixed_precision_policy = bfSixteen 205 | if rank == 0: 206 | print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") 207 | else: 208 | mixed_precision_policy = fpSixteen 209 | if rank == 0: 210 | print(f"FP16 enabled") 211 | else: 212 | mixed_precision_policy = None 213 | 214 | return mixed_precision_policy 215 | 216 | 217 | def get_policies(cfg, rank, block): 218 | """Get policies for mixed precision, wrapping, sharding, ac and param init function.""" 219 | 220 | # mixed precision 221 | mixed_precision_policy = get_mixed_precision_policy(cfg, rank) 222 | 223 | # wrapping policy 224 | wrapping_policy = get_wrapper(block) 225 | 226 | # sharding strategy 227 | if cfg.sharding_strategy == "fsdp": 228 | sharding_strategy = ShardingStrategy.FULL_SHARD 229 | elif cfg.sharding_strategy == "hsdp": 230 | sharding_strategy = ShardingStrategy.HYBRID_SHARD 231 | elif cfg.sharding_strategy == "ddp": 232 | sharding_strategy = ShardingStrategy.NO_SHARD 233 | else: 234 | sharding_strategy = ShardingStrategy.FULL_SHARD 235 | if rank == 0: 236 | print(f"Sharding strategy = {cfg.sharding_strategy}") 237 | 238 | # ac handler 239 | apply_selective_ac = partial(apply_fsdp_checkpointing, block=block) 240 | 241 | return ( 242 | mixed_precision_policy, 243 | wrapping_policy, 244 | sharding_strategy, 245 | apply_selective_ac, 246 | ) 247 | 248 | 249 | def get_profiler(cfg, rank): 250 | if not cfg.use_profiler: 251 | return 252 | if cfg.profiler_rank0_only and rank != 0: 253 | return 254 | return torch.profiler.profile( 255 | activities=[ 256 | torch.profiler.ProfilerActivity.CPU, 257 | torch.profiler.ProfilerActivity.CUDA, 258 | ], 259 | schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1), 260 | on_trace_ready=torch.profiler.tensorboard_trace_handler("profile_traces"), 261 | profile_memory=True, 262 | with_stack=False, 263 | record_shapes=True, 264 | ) 265 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/main_training_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import fire 5 | import torch 6 | import torch.optim as optim 7 | from torch import distributed as dist 8 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 9 | from torch.optim.lr_scheduler import LambdaLR 10 | 11 | from fms_fsdp import config 12 | from fms_fsdp.utils.checkpointing_utils import Checkpointer 13 | from fms_fsdp.models.llama import LLaMA, LLaMABlock, param_init_function 14 | from fms_fsdp.utils.config_utils import get_model_config, update_config 15 | from fms_fsdp.utils.dataset_utils import DistributedDataset 16 | from torchdata.stateful_dataloader import StatefulDataLoader 17 | from fms_fsdp.utils.train_utils import ( 18 | get_policies, 19 | get_profiler, 20 | setup, 21 | setup_environ_flags, 22 | train, 23 | ) 24 | 25 | 26 | def main(**kwargs): 27 | # get configs 28 | cfg = config.train_config() 29 | update_config(cfg, **kwargs) 30 | 31 | # ensure reproducibility 32 | torch.cuda.manual_seed(cfg.seed) 33 | torch.manual_seed(cfg.seed) 34 | 35 | # torchrun specific 36 | local_rank = int(os.environ["LOCAL_RANK"]) 37 | rank = int(os.environ["RANK"]) 38 | world_size = int(os.environ["WORLD_SIZE"]) 39 | 40 | if rank == 0: 41 | print(f"--> running with these configs {cfg}") 42 | 43 | # some setups 44 | setup() 45 | torch.cuda.set_device(local_rank) 46 | torch.cuda.empty_cache() 47 | setup_environ_flags() 48 | 49 | # get policy 50 | block = LLaMABlock 51 | ( 52 | mixed_precision_policy, 53 | wrapping_policy, 54 | sharding_strategy_policy, 55 | apply_selective_ac, 56 | ) = get_policies(cfg, rank, block) 57 | 58 | # get fms model 59 | llama_config = get_model_config(cfg.model_variant) 60 | llama_config.attn_alpha_init_value = cfg.attn_alpha_init_value 61 | llama_config.ffn_alpha_init_value = cfg.ffn_alpha_init_value 62 | llama_config.dec_alpha_init_value = cfg.dec_alpha_init_value 63 | if rank == 0: 64 | print(f"--> llama config: {llama_config}") 65 | if cfg.low_cpu_fsdp: 66 | with torch.device("meta"): 67 | model = LLaMA(llama_config) 68 | else: 69 | model = LLaMA(llama_config) 70 | model.reset_parameters() 71 | 72 | if rank == 0: 73 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 74 | print(f"\n--> model has {total_params / 1e6} Million params\n") 75 | 76 | # get data loader 77 | if rank == 0: 78 | print("Constructing datasets...") 79 | dataset = DistributedDataset(cfg.data_path, rank, world_size, cfg.batch_size, cfg.seq_length, cfg.bos_token, cfg.eos_token) 80 | train_loader = StatefulDataLoader(dataset, batch_size=None, num_workers=0) 81 | if rank == 0: 82 | print("Datasets constructed!") 83 | 84 | # FSDP 85 | model = FSDP( 86 | model, 87 | auto_wrap_policy=wrapping_policy, 88 | mixed_precision=mixed_precision_policy, 89 | sharding_strategy=sharding_strategy_policy, 90 | use_orig_params=cfg.use_torch_compile, 91 | device_id=torch.cuda.current_device(), 92 | limit_all_gathers=True, 93 | param_init_fn=param_init_function if cfg.low_cpu_fsdp else None, 94 | ) 95 | if rank == 0: 96 | print(model) 97 | 98 | # we need this post-fsdp call to avoid graph break with torch.compile, until we figure out a better solution. 99 | model.rot_emb.compute_freqs_cis( 100 | torch.device("cuda", torch.cuda.current_device()), 101 | model.config.max_expected_seq_len, 102 | ) 103 | 104 | # fsdp activation checkpointing 105 | if cfg.fsdp_activation_checkpointing: 106 | if rank == 0: 107 | print(f"--> applying FSDP activation checkpointing...") 108 | apply_selective_ac(model, p=cfg.selective_checkpointing) 109 | 110 | # torch compile 111 | if cfg.use_torch_compile: 112 | if rank == 0: 113 | print(f"--> enabling torch compile...") 114 | # the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here 115 | torch._dynamo.config.accumulated_cache_size_limit = 128 116 | model = torch.compile(model) 117 | 118 | # Optimizer 119 | optimizer = optim.AdamW( 120 | model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 121 | ) 122 | 123 | # optionally load from checkpoint (when continue pretraining) 124 | checkpointer = Checkpointer( 125 | cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank 126 | ) 127 | model, optimizer, train_loader, start_step, tokens_seen, is_resuming = checkpointer.load( 128 | model, 129 | optimizer, 130 | train_loader, 131 | path=os.path.join(cfg.ckpt_load_path, "checkpoints/") 132 | if not os.path.isfile(cfg.ckpt_load_path) 133 | else cfg.ckpt_load_path, 134 | strict=False, 135 | ) 136 | if not is_resuming: 137 | start_step = 0 138 | # Override loaded optim hyperparams with the current values 139 | for g in optimizer.param_groups: 140 | g["initial_lr"] = cfg.learning_rate 141 | 142 | # LR schedule 143 | if cfg.training_stage == "annealing": 144 | schedule = lambda x: 1 - x / cfg.num_steps 145 | else: 146 | warmup_interval = min(2000, cfg.num_steps // 20) 147 | schedule = lambda x: min( 148 | 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, 149 | 0.1 150 | + 0.5 151 | * (1 - 0.1) 152 | * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), 153 | ) 154 | scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) 155 | 156 | # profiler 157 | profiler = get_profiler(cfg, rank) 158 | 159 | # Train 160 | if rank == 0: 161 | print(f"Training for {cfg.num_steps} steps") 162 | train( 163 | cfg, 164 | model, 165 | local_rank, 166 | rank, 167 | train_loader, 168 | optimizer, 169 | scheduler, 170 | profiler, 171 | checkpointer, 172 | start_step, 173 | tokens_seen, 174 | ) 175 | 176 | checkpointer.save_single_file(cfg.num_steps, model) 177 | 178 | dist.barrier() 179 | dist.destroy_process_group() 180 | 181 | 182 | if __name__ == "__main__": 183 | fire.Fire(main) 184 | -------------------------------------------------------------------------------- /other_tasks/LLaMA/prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import argparse 5 | import pyarrow as pa 6 | import transformers 7 | 8 | 9 | def _format_text(text, bos_token, eos_token): 10 | try: 11 | while text[0] == bos_token: text = text[1:] 12 | while text[-1] == eos_token: text = text[:-1] 13 | return text 14 | except: 15 | print(f"Format error: {text}") 16 | return [] 17 | 18 | 19 | def main(args): 20 | file_path_list = glob.glob(os.path.join(args.data_path, '*.chunk.*.jsonl')) 21 | num_files = len(file_path_list) 22 | file_idx = args.rank % num_files 23 | file_path = file_path_list[file_idx] 24 | 25 | assert args.world_size % num_files == 0, 'world_size must be divisible by the number of files' 26 | offset = args.rank // num_files 27 | num_ranks_per_file = args.world_size // num_files 28 | 29 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer) 30 | 31 | schema = pa.schema([pa.field('input_ids', pa.uint32())]) 32 | 33 | with pa.ipc.new_file(os.path.join(args.output_path, f"rank_{args.rank}.arrow"), schema) as writer: 34 | with open(file_path, 'r') as file: 35 | current_line = 0 36 | num_tokens = 0 37 | while line := file.readline(): 38 | if current_line % num_ranks_per_file == offset: 39 | text = json.loads(line)['text'] 40 | tokens = tokenizer(text)['input_ids'] 41 | tokens = _format_text(tokens, tokenizer.bos_token_id, tokenizer.eos_token_id) 42 | if tokens: 43 | writer.write(pa.record_batch([tokens], schema=schema)) 44 | num_tokens += len(tokens) 45 | current_line += 1 46 | if num_tokens >= args.max_num_tokens: 47 | break 48 | 49 | if __name__ == "__main__": 50 | parser = argparse.ArgumentParser(description='Prepare data for training.') 51 | parser.add_argument('--rank', type=int, required=True) 52 | parser.add_argument('--world_size', type=int, default=2048) 53 | 54 | parser.add_argument('--data_path', type=str, required=True) 55 | parser.add_argument('--output_path', type=str, required=True) 56 | 57 | parser.add_argument('--max_num_tokens', type=int, default=204_800_000) 58 | 59 | parser.add_argument('--tokenizer', type=str, default='meta-llama/Llama-2-7b-hf') 60 | args = parser.parse_args() 61 | main(args) -------------------------------------------------------------------------------- /other_tasks/MAE/README.md: -------------------------------------------------------------------------------- 1 | # Masked Autoencoders (MAEs) with DyT 2 | 3 | This guide provides instructions for reproducing the MAE results with our proposed modifications, as presented in our paper. Follow the steps below to set up the environment, apply the patches, and run the experiments. 4 | 5 | ## 1. Clone the MAE Repository 6 | 7 | Clone the official MAE repository from GitHub: 8 | ``` 9 | git clone https://github.com/facebookresearch/mae.git 10 | ``` 11 | 12 | ## 2. Set Up the Python Environment 13 | 14 | The original repository relies on outdated dependencies that may be incompatible with newer GPUs. We have updated the dependencies to ensure compatibility while preserving the integrity of the original implementation. 15 | 16 | Set up the Python environment with the following commands: 17 | ``` 18 | conda create -n MAE python=3.12 19 | conda activate MAE 20 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia 21 | pip install timm==1.0.15 tensorboard 22 | ``` 23 | 24 | ## 3. Apply Compatibility Fix 25 | 26 | Update the original MAE code for compatibility by applying the provided patch: 27 | ``` 28 | cp compatibility-fix.patch mae 29 | cd mae 30 | git apply compatibility-fix.patch 31 | ``` 32 | 33 | ## 4. Apply DynamicTanh Patch (Optional) 34 | *(Skip this step if you wish to reproduce the baseline results.)* \ 35 | To reproduce the results using Dynamic Tanh (DyT), apply the following patches: 36 | ``` 37 | cp dynamic_tanh.py mae 38 | cp dynamic-tanh.patch mae 39 | cd mae 40 | git apply dynamic-tanh.patch 41 | ``` 42 | 43 | ## 4. Run Experiments 44 | 45 | After applying the patch, run the MAE pretraining with the following command: 46 | ``` 47 | torchrun --nnodes=8 --nproc_per_node=8 main_pretrain.py \ 48 | --output_dir /path/to/saving_dir \ 49 | --batch_size 64 \ 50 | --model $MODEL \ 51 | --norm_pix_loss \ 52 | --mask_ratio 0.75 \ 53 | --epochs 800 \ 54 | --warmup_epochs 40 \ 55 | --blr 1.5e-4 \ 56 | --weight_decay 0.05 \ 57 | --data_path /path/to/imagenet 58 | ``` 59 | Replace `$MODEL` with one of the following options: 60 | - `mae_vit_base_patch16` - base model. 61 | - `mae_vit_large_patch16` - large model. 62 | 63 | 64 | 65 | ## 5. Evaluation 66 | 67 | For fine-tuning and evaluation of pretrained models, refer to the original MAE documentation: [FINETUNE](https://github.com/facebookresearch/mae/blob/main/FINETUNE.md). 68 | -------------------------------------------------------------------------------- /other_tasks/MAE/compatibility-fix.patch: -------------------------------------------------------------------------------- 1 | From add7c6fc1515d68702d5236026522a22671484cb Mon Sep 17 00:00:00 2001 2 | From: Jiachen Zhu 3 | Date: Thu, 6 Mar 2025 06:55:20 +0000 4 | Subject: [PATCH] compatibility fix 5 | 6 | --- 7 | main_pretrain.py | 7 ++++--- 8 | models_mae.py | 4 ++-- 9 | util/misc.py | 3 +-- 10 | util/pos_embed.py | 6 +++--- 11 | 4 files changed, 10 insertions(+), 10 deletions(-) 12 | 13 | diff --git a/main_pretrain.py b/main_pretrain.py 14 | index 58a18c5..beb7810 100644 15 | --- a/main_pretrain.py 16 | +++ b/main_pretrain.py 17 | @@ -17,6 +17,8 @@ import time 18 | from pathlib import Path 19 | 20 | import torch 21 | +torch.backends.cuda.matmul.allow_tf32 = True 22 | +torch.backends.cudnn.allow_tf32 = True 23 | import torch.backends.cudnn as cudnn 24 | from torch.utils.tensorboard import SummaryWriter 25 | import torchvision.transforms as transforms 26 | @@ -24,8 +26,7 @@ import torchvision.datasets as datasets 27 | 28 | import timm 29 | 30 | -assert timm.__version__ == "0.3.2" # version check 31 | -import timm.optim.optim_factory as optim_factory 32 | +import timm.optim 33 | 34 | import util.misc as misc 35 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 36 | @@ -176,7 +177,7 @@ def main(args): 37 | model_without_ddp = model.module 38 | 39 | # following timm: set wd as 0 for bias and norm layers 40 | - param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 41 | + param_groups = timm.optim.param_groups_weight_decay(model_without_ddp, args.weight_decay) 42 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 43 | print(optimizer) 44 | loss_scaler = NativeScaler() 45 | diff --git a/models_mae.py b/models_mae.py 46 | index 880e28f..4a13e22 100644 47 | --- a/models_mae.py 48 | +++ b/models_mae.py 49 | @@ -37,7 +37,7 @@ class MaskedAutoencoderViT(nn.Module): 50 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 51 | 52 | self.blocks = nn.ModuleList([ 53 | - Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 54 | + Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 55 | for i in range(depth)]) 56 | self.norm = norm_layer(embed_dim) 57 | # -------------------------------------------------------------------------- 58 | @@ -51,7 +51,7 @@ class MaskedAutoencoderViT(nn.Module): 59 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 60 | 61 | self.decoder_blocks = nn.ModuleList([ 62 | - Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 63 | + Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 64 | for i in range(decoder_depth)]) 65 | 66 | self.decoder_norm = norm_layer(decoder_embed_dim) 67 | diff --git a/util/misc.py b/util/misc.py 68 | index ad9a786..2963da8 100644 69 | --- a/util/misc.py 70 | +++ b/util/misc.py 71 | @@ -18,7 +18,6 @@ from pathlib import Path 72 | 73 | import torch 74 | import torch.distributed as dist 75 | -from torch._six import inf 76 | 77 | 78 | class SmoothedValue(object): 79 | @@ -285,7 +284,7 @@ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 80 | if len(parameters) == 0: 81 | return torch.tensor(0.) 82 | device = parameters[0].grad.device 83 | - if norm_type == inf: 84 | + if norm_type == float('inf'): 85 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 86 | else: 87 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 88 | diff --git a/util/pos_embed.py b/util/pos_embed.py 89 | index 6acf8bd..ff86f28 100644 90 | --- a/util/pos_embed.py 91 | +++ b/util/pos_embed.py 92 | @@ -23,8 +23,8 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 93 | return: 94 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 95 | """ 96 | - grid_h = np.arange(grid_size, dtype=np.float32) 97 | - grid_w = np.arange(grid_size, dtype=np.float32) 98 | + grid_h = np.arange(grid_size, dtype=float) 99 | + grid_w = np.arange(grid_size, dtype=float) 100 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 101 | grid = np.stack(grid, axis=0) 102 | 103 | @@ -53,7 +53,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 104 | out: (M, D) 105 | """ 106 | assert embed_dim % 2 == 0 107 | - omega = np.arange(embed_dim // 2, dtype=np.float) 108 | + omega = np.arange(embed_dim // 2, dtype=float) 109 | omega /= embed_dim / 2. 110 | omega = 1. / 10000**omega # (D/2,) 111 | 112 | -- 113 | 2.34.1 114 | 115 | -------------------------------------------------------------------------------- /other_tasks/MAE/dynamic-tanh.patch: -------------------------------------------------------------------------------- 1 | From ddf89dfd9e9a5e36a16405d18b2f3b3f1c669901 Mon Sep 17 00:00:00 2001 2 | From: Jiachen Zhu 3 | Date: Thu, 6 Mar 2025 07:26:29 +0000 4 | Subject: [PATCH] dynamic-tanh 5 | 6 | --- 7 | main_pretrain.py | 2 ++ 8 | 1 file changed, 2 insertions(+) 9 | 10 | diff --git a/main_pretrain.py b/main_pretrain.py 11 | index beb7810..f434962 100644 12 | --- a/main_pretrain.py 13 | +++ b/main_pretrain.py 14 | @@ -34,6 +34,7 @@ from util.misc import NativeScalerWithGradNormCount as NativeScaler 15 | import models_mae 16 | 17 | from engine_pretrain import train_one_epoch 18 | +from dynamic_tanh import convert_ln_to_dyt 19 | 20 | 21 | def get_args_parser(): 22 | @@ -155,6 +156,7 @@ def main(args): 23 | 24 | # define the model 25 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) 26 | + model = convert_ln_to_dyt(model) 27 | 28 | model.to(device) 29 | 30 | -- 31 | 2.34.1 32 | 33 | -------------------------------------------------------------------------------- /other_tasks/MAE/dynamic_tanh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DynamicTanh(nn.Module): 6 | def __init__(self, normalized_shape, alpha_init_value=0.5): 7 | super().__init__() 8 | self.normalized_shape = normalized_shape 9 | self.alpha_init_value = alpha_init_value 10 | self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) 11 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 12 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 13 | 14 | def forward(self, x): 15 | return self.weight * torch.tanh(self.alpha * x) + self.bias 16 | 17 | def extra_repr(self): 18 | return f"normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}" 19 | 20 | 21 | def convert_ln_to_dyt(module): 22 | module_output = module 23 | if isinstance(module, nn.LayerNorm): 24 | module_output = DynamicTanh(module.normalized_shape) 25 | for name, child in module.named_children(): 26 | module_output.add_module(name, convert_ln_to_dyt(child)) 27 | del module 28 | return module_output -------------------------------------------------------------------------------- /other_tasks/wav2vec2/README.md: -------------------------------------------------------------------------------- 1 | # wav2vec 2.0 with DyT 2 | 3 | This guide provides instructions for reproducing the wav2vec 2.0 results with our proposed modifications, as presented in our paper. Follow the steps below to set up the environment, apply the patches, and run the experiments. 4 | 5 | ## 1. Clone the fairseq Repository 6 | 7 | Clone the official fairseq repository from GitHub: 8 | ``` 9 | git clone https://github.com/facebookresearch/fairseq.git 10 | ``` 11 | 12 | ## 2. Set Up the Python Environment 13 | 14 | Create and activate a Conda environment with the required dependencies: 15 | ``` 16 | conda create -n w2v python=3.10 17 | conda activate w2v 18 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia 19 | pip install soundfile 20 | 21 | cd fairseq 22 | pip install --editable ./ 23 | ``` 24 | 25 | *(Fairseq does not provide a config for wav2vec 2.0 Large with LibriSpeech. We created our own by following the instructions from the original paper.)* 26 | Copy the configuration file for wav2vec 2.0 Large with LibriSpeech: 27 | ``` 28 | cp wav2vec2_large_librispeech.yaml ./fairseq/examples/wav2vec/config/pretraining/ 29 | ``` 30 | 31 | ## 3. Apply DynamicTanh Patch (Optional) 32 | *(Skip this step if you want to reproduce the baseline results.)* \ 33 | To reproduce the results using Dynamic Tanh (DyT), apply the following patch: 34 | ``` 35 | cp dynamic-tanh.patch fairseq 36 | cd fairseq 37 | git apply dynamic-tanh.patch 38 | ``` 39 | 40 | ## 4. Run Experiments 41 | 42 | You can reproduce the dynamic-tanh pretraining results using the following command: 43 | 44 | ### wav2vec 2.0 Base 45 | 46 | ``` 47 | srun torchrun --nnodes=8 --nproc_per_node=8 fairseq-hydra-train \ 48 | task.data=/path/to/manifest \ 49 | --config-dir ./examples/wav2vec/config/pretraining \ 50 | --config-name wav2vec2_base_librispeech 51 | ``` 52 | 53 | ### wav2vec 2.0 Large 54 | 55 | ``` 56 | srun torchrun --nnodes=16 --nproc_per_node=8 fairseq-hydra-train \ 57 | task.data=/path/to/manifest \ 58 | --config-dir ./examples/wav2vec/config/pretraining \ 59 | --config-name wav2vec2_large_librispeech 60 | ``` 61 | 62 | -------------------------------------------------------------------------------- /other_tasks/wav2vec2/dynamic-tanh.patch: -------------------------------------------------------------------------------- 1 | From e65952277190e602649e687317a3f1df15974992 Mon Sep 17 00:00:00 2001 2 | From: Jiachen Zhu 3 | Date: Mon, 17 Mar 2025 21:59:40 +0000 4 | Subject: [PATCH] dynamic-tanh 5 | 6 | --- 7 | fairseq/models/wav2vec/wav2vec2.py | 29 ++++++++++++++++++++++++----- 8 | 1 file changed, 24 insertions(+), 5 deletions(-) 9 | 10 | diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py 11 | index 0faba77f..7d880f54 100644 12 | --- a/fairseq/models/wav2vec/wav2vec2.py 13 | +++ b/fairseq/models/wav2vec/wav2vec2.py 14 | @@ -305,6 +305,25 @@ class Wav2Vec2Config(FairseqDataclass): 15 | ) 16 | 17 | 18 | +class DynamicTanh(nn.Module): 19 | + def __init__(self, normalized_shape, alpha_init_value=0.5): 20 | + super().__init__() 21 | + self.normalized_shape = normalized_shape 22 | + self.alpha_init_value = alpha_init_value 23 | + 24 | + self.alpha = nn.Parameter(torch.ones(1) * alpha_init_value) 25 | + self.weight = nn.Parameter(torch.ones(normalized_shape)) 26 | + self.bias = nn.Parameter(torch.zeros(normalized_shape)) 27 | + 28 | + def forward(self, x): 29 | + x = torch.tanh(self.alpha * x) 30 | + x = x * self.weight + self.bias 31 | + return x 32 | + 33 | + def extra_repr(self): 34 | + return f"normalized_shape={self.normalized_shape}, alpha_init_value={self.alpha_init_value}" 35 | + 36 | + 37 | @register_model("wav2vec2", dataclass=Wav2Vec2Config) 38 | class Wav2Vec2Model(BaseFairseqModel): 39 | def __init__(self, cfg: Wav2Vec2Config): 40 | @@ -617,7 +636,7 @@ class Wav2Vec2Model(BaseFairseqModel): 41 | features_pen = features.float().pow(2).mean() 42 | 43 | features = features.transpose(1, 2) 44 | - features = self.layer_norm(features) 45 | + features = self.dynamic_tanh(features) 46 | unmasked_features = features.clone() 47 | 48 | if padding_mask is not None and padding_mask.any(): 49 | @@ -1066,7 +1085,7 @@ class TransformerEncoder(nn.Module): 50 | [self.build_encoder_layer(args, layer_idx=ii) for ii in range(encoder_layers)] 51 | ) 52 | self.layer_norm_first = args.layer_norm_first 53 | - self.layer_norm = LayerNorm(self.embedding_dim) 54 | + self.layer_norm = DynamicTanh(self.embedding_dim) 55 | self.layerdrop = args.encoder_layerdrop 56 | 57 | self.apply(init_bert_params) 58 | @@ -1217,7 +1236,7 @@ class ConformerEncoder(TransformerEncoder): 59 | [self.build_encoder_layer(args) for _ in range(args.encoder_layers)] 60 | ) 61 | self.layer_norm_first = args.layer_norm_first 62 | - self.layer_norm = LayerNorm(self.embedding_dim) 63 | + self.layer_norm = DynamicTanh(self.embedding_dim) 64 | self.layerdrop = args.encoder_layerdrop 65 | 66 | self.apply(init_bert_params) 67 | @@ -1305,12 +1324,12 @@ class TransformerSentenceEncoderLayer(nn.Module): 68 | self.layer_norm_first = layer_norm_first 69 | 70 | # layer norm associated with the self attention layer 71 | - self.self_attn_layer_norm = LayerNorm(self.embedding_dim) 72 | + self.self_attn_layer_norm = DynamicTanh(self.embedding_dim) 73 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) 74 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) 75 | 76 | # layer norm associated with the position wise feed-forward NN 77 | - self.final_layer_norm = LayerNorm(self.embedding_dim) 78 | + self.final_layer_norm = DynamicTanh(self.embedding_dim) 79 | 80 | def forward( 81 | self, 82 | -- 83 | 2.34.1 84 | 85 | -------------------------------------------------------------------------------- /other_tasks/wav2vec2/wav2vec2_large_librispeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | 8 | checkpoint: 9 | save_interval_updates: 25000 10 | keep_interval_updates: 1 11 | no_epoch_checkpoints: true 12 | 13 | task: 14 | _name: audio_pretraining 15 | data: ??? 16 | max_sample_size: 320000 17 | min_sample_size: 32000 18 | normalize: false 19 | 20 | dataset: 21 | num_workers: 6 22 | max_tokens: 1200000 23 | skip_invalid_size_inputs_valid_test: true 24 | 25 | distributed_training: 26 | distributed_world_size: 128 27 | ddp_backend: legacy_ddp 28 | 29 | criterion: 30 | _name: wav2vec 31 | infonce: true 32 | log_keys: ["prob_perplexity","code_perplexity","temp"] 33 | loss_weights: [0.1, 10] 34 | 35 | optimization: 36 | max_update: 250000 37 | lr: [0.0003] 38 | 39 | optimizer: 40 | _name: adam 41 | adam_betas: (0.9,0.98) 42 | adam_eps: 1e-06 43 | weight_decay: 0.01 44 | 45 | lr_scheduler: 46 | _name: polynomial_decay 47 | warmup_updates: 20000 48 | 49 | model: 50 | _name: wav2vec2 51 | quantize_targets: true 52 | final_dim: 768 53 | latent_temp: [2.0,0.1,0.999995] 54 | 55 | encoder_layerdrop: 0.2 56 | dropout_input: 0.1 57 | dropout_features: 0.1 58 | dropout: 0.1 59 | attention_dropout: 0.0 60 | activation_dropout: 0.0 61 | 62 | encoder_layers: 24 63 | encoder_embed_dim: 1024 64 | encoder_ffn_embed_dim: 4096 65 | encoder_attention_heads: 16 66 | 67 | feature_grad_mult: 0.1 68 | 69 | layer_norm_first: true -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | import math 11 | import time 12 | from collections import defaultdict, deque 13 | import datetime 14 | import numpy as np 15 | from timm.utils import get_state_dict 16 | 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | class TensorboardLogger(object): 171 | def __init__(self, log_dir): 172 | self.writer = SummaryWriter(logdir=log_dir) 173 | self.step = 0 174 | 175 | def set_step(self, step=None): 176 | if step is not None: 177 | self.step = step 178 | else: 179 | self.step += 1 180 | 181 | def update(self, head='scalar', step=None, **kwargs): 182 | for k, v in kwargs.items(): 183 | if v is None: 184 | continue 185 | if isinstance(v, torch.Tensor): 186 | v = v.item() 187 | assert isinstance(v, (float, int)) 188 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 189 | 190 | def flush(self): 191 | self.writer.flush() 192 | 193 | 194 | class WandbLogger(object): 195 | def __init__(self, args): 196 | self.args = args 197 | 198 | try: 199 | import wandb 200 | self._wandb = wandb 201 | except ImportError: 202 | raise ImportError( 203 | "To use the Weights and Biases Logger please install wandb." 204 | "Run `pip install wandb` to install it." 205 | ) 206 | 207 | # Initialize a W&B run 208 | if self._wandb.run is None: 209 | self._wandb.init( 210 | project=args.project, 211 | config=args 212 | ) 213 | 214 | def log_epoch_metrics(self, metrics, commit=True): 215 | """ 216 | Log train/test metrics onto W&B. 217 | """ 218 | # Log number of model parameters as W&B summary 219 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 220 | metrics.pop('n_parameters', None) 221 | 222 | # Log current epoch 223 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 224 | metrics.pop('epoch') 225 | 226 | for k, v in metrics.items(): 227 | if 'train' in k: 228 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 229 | elif 'test' in k: 230 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 231 | 232 | self._wandb.log({}) 233 | 234 | def log_checkpoints(self): 235 | output_dir = self.args.output_dir 236 | model_artifact = self._wandb.Artifact( 237 | self._wandb.run.id + "_model", type="model" 238 | ) 239 | 240 | model_artifact.add_dir(output_dir) 241 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 242 | 243 | def set_steps(self): 244 | # Set global training step 245 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 246 | # Set epoch-wise step 247 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 248 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 249 | 250 | 251 | def setup_for_distributed(is_master): 252 | """ 253 | This function disables printing when not in master process 254 | """ 255 | import builtins as __builtin__ 256 | builtin_print = __builtin__.print 257 | 258 | def print(*args, **kwargs): 259 | force = kwargs.pop('force', False) 260 | if is_master or force: 261 | builtin_print(*args, **kwargs) 262 | 263 | __builtin__.print = print 264 | 265 | 266 | def is_dist_avail_and_initialized(): 267 | if not dist.is_available(): 268 | return False 269 | if not dist.is_initialized(): 270 | return False 271 | return True 272 | 273 | 274 | def get_world_size(): 275 | if not is_dist_avail_and_initialized(): 276 | return 1 277 | return dist.get_world_size() 278 | 279 | 280 | def get_rank(): 281 | if not is_dist_avail_and_initialized(): 282 | return 0 283 | return dist.get_rank() 284 | 285 | 286 | def is_main_process(): 287 | return get_rank() == 0 288 | 289 | 290 | def save_on_master(*args, **kwargs): 291 | if is_main_process(): 292 | torch.save(*args, **kwargs) 293 | 294 | 295 | def init_distributed_mode(args): 296 | 297 | if args.dist_on_itp: 298 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 299 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 300 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 301 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 302 | os.environ['LOCAL_RANK'] = str(args.gpu) 303 | os.environ['RANK'] = str(args.rank) 304 | os.environ['WORLD_SIZE'] = str(args.world_size) 305 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 306 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 307 | args.rank = int(os.environ["RANK"]) 308 | args.world_size = int(os.environ['WORLD_SIZE']) 309 | args.gpu = int(os.environ['LOCAL_RANK']) 310 | elif 'SLURM_PROCID' in os.environ: 311 | args.rank = int(os.environ['SLURM_PROCID']) 312 | args.gpu = args.rank % torch.cuda.device_count() 313 | 314 | os.environ['RANK'] = str(args.rank) 315 | os.environ['LOCAL_RANK'] = str(args.gpu) 316 | os.environ['WORLD_SIZE'] = str(args.world_size) 317 | else: 318 | print('Not using distributed mode') 319 | args.distributed = False 320 | return 321 | 322 | args.distributed = True 323 | 324 | torch.cuda.set_device(args.gpu) 325 | args.dist_backend = 'nccl' 326 | print('| distributed init (rank {}): {}, gpu {}'.format( 327 | args.rank, args.dist_url, args.gpu), flush=True) 328 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 329 | world_size=args.world_size, rank=args.rank) 330 | torch.distributed.barrier() 331 | setup_for_distributed(args.rank == 0) 332 | 333 | 334 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 335 | missing_keys = [] 336 | unexpected_keys = [] 337 | error_msgs = [] 338 | # copy state_dict so _load_from_state_dict can modify it 339 | metadata = getattr(state_dict, '_metadata', None) 340 | state_dict = state_dict.copy() 341 | if metadata is not None: 342 | state_dict._metadata = metadata 343 | 344 | def load(module, prefix=''): 345 | local_metadata = {} if metadata is None else metadata.get( 346 | prefix[:-1], {}) 347 | module._load_from_state_dict( 348 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 349 | for name, child in module._modules.items(): 350 | if child is not None: 351 | load(child, prefix + name + '.') 352 | 353 | load(model, prefix=prefix) 354 | 355 | warn_missing_keys = [] 356 | ignore_missing_keys = [] 357 | for key in missing_keys: 358 | keep_flag = True 359 | for ignore_key in ignore_missing.split('|'): 360 | if ignore_key in key: 361 | keep_flag = False 362 | break 363 | if keep_flag: 364 | warn_missing_keys.append(key) 365 | else: 366 | ignore_missing_keys.append(key) 367 | 368 | missing_keys = warn_missing_keys 369 | 370 | if len(missing_keys) > 0: 371 | print("Weights of {} not initialized from pretrained model: {}".format( 372 | model.__class__.__name__, missing_keys)) 373 | if len(unexpected_keys) > 0: 374 | print("Weights from pretrained model not used in {}: {}".format( 375 | model.__class__.__name__, unexpected_keys)) 376 | if len(ignore_missing_keys) > 0: 377 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 378 | model.__class__.__name__, ignore_missing_keys)) 379 | if len(error_msgs) > 0: 380 | print('\n'.join(error_msgs)) 381 | 382 | 383 | class NativeScalerWithGradNormCount: 384 | state_dict_key = "amp_scaler" 385 | 386 | def __init__(self): 387 | self._scaler = torch.cuda.amp.GradScaler() 388 | 389 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 390 | self._scaler.scale(loss).backward(create_graph=create_graph) 391 | if update_grad: 392 | if clip_grad is not None: 393 | assert parameters is not None 394 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 395 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 396 | else: 397 | self._scaler.unscale_(optimizer) 398 | norm = get_grad_norm_(parameters) 399 | self._scaler.step(optimizer) 400 | self._scaler.update() 401 | else: 402 | norm = None 403 | return norm 404 | 405 | def state_dict(self): 406 | return self._scaler.state_dict() 407 | 408 | def load_state_dict(self, state_dict): 409 | self._scaler.load_state_dict(state_dict) 410 | 411 | 412 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 413 | if isinstance(parameters, torch.Tensor): 414 | parameters = [parameters] 415 | parameters = [p for p in parameters if p.grad is not None] 416 | norm_type = float(norm_type) 417 | if len(parameters) == 0: 418 | return torch.tensor(0.) 419 | device = parameters[0].grad.device 420 | if norm_type == float('inf'): 421 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 422 | else: 423 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 424 | return total_norm 425 | 426 | 427 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 428 | start_warmup_value=0, warmup_steps=-1): 429 | warmup_schedule = np.array([]) 430 | warmup_iters = warmup_epochs * niter_per_ep 431 | if warmup_steps > 0: 432 | warmup_iters = warmup_steps 433 | print("Set warmup steps = %d" % warmup_iters) 434 | if warmup_epochs > 0: 435 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 436 | 437 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 438 | schedule = np.array( 439 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 440 | 441 | schedule = np.concatenate((warmup_schedule, schedule)) 442 | 443 | assert len(schedule) == epochs * niter_per_ep 444 | return schedule 445 | 446 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 447 | output_dir = Path(args.output_dir) 448 | epoch_name = str(epoch) 449 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 450 | for checkpoint_path in checkpoint_paths: 451 | to_save = { 452 | 'model': model_without_ddp.state_dict(), 453 | 'optimizer': optimizer.state_dict(), 454 | 'epoch': epoch, 455 | 'scaler': loss_scaler.state_dict(), 456 | 'args': args, 457 | } 458 | 459 | if model_ema is not None: 460 | to_save['model_ema'] = get_state_dict(model_ema) 461 | 462 | save_on_master(to_save, checkpoint_path) 463 | 464 | if is_main_process() and isinstance(epoch, int): 465 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 466 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 467 | if os.path.exists(old_ckpt): 468 | os.remove(old_ckpt) 469 | 470 | 471 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 472 | output_dir = Path(args.output_dir) 473 | if args.auto_resume and len(args.resume) == 0: 474 | import glob 475 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 476 | latest_ckpt = -1 477 | for ckpt in all_checkpoints: 478 | t = ckpt.split('-')[-1].split('.')[0] 479 | if t.isdigit(): 480 | latest_ckpt = max(int(t), latest_ckpt) 481 | if latest_ckpt >= 0: 482 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 483 | print("Auto resume checkpoint: %s" % args.resume) 484 | 485 | if args.resume: 486 | if args.resume.startswith('https'): 487 | checkpoint = torch.hub.load_state_dict_from_url( 488 | args.resume, map_location='cpu', check_hash=True) 489 | else: 490 | checkpoint = torch.load(args.resume, map_location='cpu') 491 | model_without_ddp.load_state_dict(checkpoint['model']) 492 | print("Resume checkpoint %s" % args.resume) 493 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 494 | optimizer.load_state_dict(checkpoint['optimizer']) 495 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 496 | args.start_epoch = checkpoint['epoch'] + 1 497 | else: 498 | assert args.eval, 'Does not support resuming with checkpoint-best' 499 | if hasattr(args, 'model_ema') and args.model_ema: 500 | if 'model_ema' in checkpoint.keys(): 501 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 502 | else: 503 | model_ema.ema.load_state_dict(checkpoint['model']) 504 | if 'scaler' in checkpoint: 505 | loss_scaler.load_state_dict(checkpoint['scaler']) 506 | print("With optim & sched!") 507 | --------------------------------------------------------------------------------