├── .gitignore ├── README.md ├── datasets.py ├── drop_scheduler.py ├── engine.py ├── engine_soft.py ├── llama2 └── weight_selection.py ├── logs ├── illama-base-in1k-384-83.0.txt ├── illama-base-in1k-81.6.txt ├── illama-base-in21k-checkpoint-89.txt ├── illama-base-in21kin1k-224-83.6.txt ├── illama-base-in21kin1k-384-85.0.txt ├── illama-large-in21k-checkpoint-89.txt ├── illama-large-in21kin1k-224-84.8.txt ├── illama-large-in21kin1k-384-86.0.txt ├── illama-small-in1k-79.9.txt └── illama-tiny-in1k-75.0.txt ├── main.py ├── main_soft.py ├── main_soft_fthr.py ├── mask_scheduler.py ├── models ├── __init__.py └── illama.py ├── optim_factory.py ├── scripts ├── eval_illama_in1k_224.sh ├── eval_illama_in1k_384.sh ├── train_illama_base_from_llama2.sh ├── train_illama_base_in1k.sh ├── train_illama_small_from_llama2.sh ├── train_illama_small_in1k.sh ├── train_illama_tiny_from_llama2.sh └── train_illama_tiny_in1k.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # pyc 2 | __pycache__/ 3 | 4 | # iLLaMA 5 | output/ 6 | scripts/debug.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Adapting LLaMA Decoder to Vision Transformer](https://arxiv.org/pdf/2404.06773) 2 | 3 |

4 | 5 | 6 |

7 | 8 | 9 |

10 |
11 | Image credit: DALL·E 12 |

13 | 14 | 15 | This is a PyTorch implementation of iLLaMA proposed by our paper "[Adapting LLaMA Decoder to Vision Transformer](https://arxiv.org/abs/2404.06773)". 16 | 17 | 18 | ![iLLaMA first figure](https://github.com/hpcaitech/Open-Sora/assets/48375204/59f7af9a-679c-46ea-a428-c7bf27c0ecea) 19 | Figure 1: Left: iLLaMA architecture. Right: our design roadmap. Colored and gray bars 20 | represent the results of the tiny and base regimes, with the red line depicting the training loss of the 21 | tiny regime. iLLaMA strives to process visual tokens using standard LLaMa components, e.g., causal 22 | self-attention. The proposed PS [cls] and soft mask strategy help overcome training challenges. 23 | 24 |
25 | 26 | ![iLLaMA second figure](https://github.com/hpcaitech/Open-Sora/assets/48375204/6dffefaa-cb27-49ba-a258-1953bdaa7330) 27 | Figure 2: (a) mask in causal self-attention. (b) mask in causal self-attention with our post-sequence 28 | class token (PS [cls]) method. (c) modified causal mask. 29 | 30 |
31 | 32 | ![iLLaMA third figure](https://github.com/hpcaitech/Open-Sora/assets/48375204/f3b46c50-c807-4997-81d4-257b6168e5f7) 33 | Figure 3: (a) Soft mask gradually transitions from a bi-directional mask into a causal mask during 34 | training through a constant or linear schedule. (b) Ablation results of training loss and test accuracy. 35 | 36 | 37 | 38 | ## Requirements 39 | PyTorch and timm 0.5.4 (`pip install timm==0.5.4`). 40 | 41 | Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4). 42 | 43 | ``` 44 | │imagenet/ 45 | ├──train/ 46 | │ ├── n01440764 47 | │ │ ├── n01440764_10026.JPEG 48 | │ │ ├── n01440764_10027.JPEG 49 | │ │ ├── ...... 50 | │ ├── ...... 51 | ├──val/ 52 | │ ├── n01440764 53 | │ │ ├── ILSVRC2012_val_00000293.JPEG 54 | │ │ ├── ILSVRC2012_val_00002138.JPEG 55 | │ │ ├── ...... 56 | │ ├── ...... 57 | ``` 58 | 59 | 60 | ## Models 61 | ### iLLaMA on ImageNet-1K 62 | | Model | Pre-trained dataset | Resolution | Params | MACs | Top1 Acc | 63 | | :--- | :--- | :---: | :---: | :---: | :---: | 64 | | [illama_tiny](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-tiny-in1k-75.0.pth?download=true) | - | 224 | 5.7M | 1.3G | 75.0 | 65 | | [illama_small](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-small-in1k-79.9.pth?download=true) | - | 224 | 21.9M | 4.6G | 79.9 | 66 | | [illama_base](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-base-in1k-81.6.pth?download=true) | - | 224 | 86.3M | 17.6G | 81.6 | 67 | | [illama_base](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-base-in1k-384-83.0.pth?download=true) | - | 384 | 86.3M | 55.5G | 83.0 | 68 | | [illama_base](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-base-in21kin1k-224-83.6.pth?download=true) | ImageNet-21K | 224 | 86.3M | 17.6G | 83.6 | 69 | | [illama_base](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-base-in21kin1k-384-85.0.pth?download=true) | ImageNet-21K | 384 | 86.3M | 55.5G | 85.0 | 70 | | [illama_large](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-large-in21kin1k-224-84.8.pth?download=true) | ImageNet-21K | 224 | 310.2M | 62.8G | 84.8 | 71 | | [illama_large](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-large-in21kin1k-384-86.0.pth?download=true) | ImageNet-21K | 384 | 310.2M | 194.7G | 86.0 | 72 | 73 | 74 | 75 | 76 | ## Evaluate 77 | 78 | To evaluate models on 224 resolution, run: 79 | 80 | ```bash 81 | MODEL=illama_tiny 82 | RESUME='/your/path/to/model.pth' 83 | 84 | python -m torch.distributed.launch --nproc_per_node=2 main.py \ 85 | --model $MODEL --eval true \ 86 | --data_path $root_imagenet \ 87 | --resume $RESUME 88 | ``` 89 | 90 | To evaluate models on 384 resolution, run: 91 | 92 | ```bash 93 | MODEL=illama_base 94 | RESUME='/your/path/to/model.pth' 95 | 96 | python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \ 97 | --model $MODEL --input_size 384 --eval true \ 98 | --data_path $root_imagenet \ 99 | --resume $RESUME 100 | ``` 101 | 102 | ## Train 103 | We use batch size of 4096 by default with 8 GPUs. 104 | 105 | 106 | ```bash 107 | bash scripts/train_illama_tiny_in1k.sh 108 | ``` 109 | Training scripts of other models are shown in [scripts](/scripts/). 110 | 111 | 112 | ## Initialization Using LLaMA2-7B (Optional) 113 | We use weight selection method to select weights from LLaMA2-7B. 114 | 115 | ```bash 116 | python llama2/weight_selection.py 117 | ``` 118 | 119 | Then we use the selected weights to initialize our iLLaMA-T/S/B. 120 | 121 | ```bash 122 | bash scripts/train_illama_tiny_from_llama2.sh 123 | ``` 124 | Training scripts of other models are shown in [scripts](/scripts/). 125 | 126 | 127 | ## Bibtex 128 | ``` 129 | @article{wang2024adapting, 130 | title={Adapting LLaMA Decoder to Vision Transformer}, 131 | author={Wang, Jiahao and Shao, Wenqi and Chen, Mengzhao and Wu, Chengyue and Liu, Yong and Zhang, Kaipeng and Zhang, Songyang and Chen, Kai and Luo, Ping}, 132 | journal={arXiv preprint arXiv:2404.06773}, 133 | year={2024} 134 | } 135 | ``` 136 | 137 | ## Acknowledgment 138 | 139 | Our implementation is based on [pytorch-image-models](https://github.com/huggingface/pytorch-image-models), [llama](https://github.com/meta-llama/llama), [dropout](https://github.com/facebookresearch/dropout), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt), [weight-selection](https://github.com/OscarXZQ/weight-selection), and [MambaOut](https://github.com/yuweihao/MambaOut). 140 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import datasets, transforms 3 | 4 | from timm.data.constants import \ 5 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 6 | from timm.data import create_transform 7 | 8 | def build_dataset(is_train, args): 9 | transform = build_transform(is_train, args) 10 | 11 | print("Transform = ") 12 | if isinstance(transform, tuple): 13 | for trans in transform: 14 | print(" - - - - - - - - - - ") 15 | for t in trans.transforms: 16 | print(t) 17 | else: 18 | for t in transform.transforms: 19 | print(t) 20 | print("---------------------------") 21 | 22 | if args.data_set == 'CIFAR10': 23 | dataset = datasets.CIFAR10(args.data_path, train=is_train, download=True, transform=transform) 24 | nb_classes = 10 25 | elif args.data_set == 'CIFAR100': 26 | dataset = datasets.CIFAR100(args.data_path, train=is_train, download=True, transform=transform) 27 | nb_classes = 100 28 | elif args.data_set == 'IMNET': 29 | print("reading from datapath", args.data_path) 30 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 31 | dataset = datasets.ImageFolder(root, transform=transform) 32 | nb_classes = 1000 33 | elif args.data_set == 'IMNET21K': 34 | print("reading from datapath", args.data_path) 35 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 36 | dataset = datasets.ImageFolder(root, transform=transform) 37 | nb_classes = args.nb_classes 38 | assert len(dataset.class_to_idx) == nb_classes 39 | elif args.data_set == "image_folder": 40 | root = args.data_path if is_train else args.eval_data_path 41 | dataset = datasets.ImageFolder(root, transform=transform) 42 | nb_classes = args.nb_classes 43 | assert len(dataset.class_to_idx) == nb_classes 44 | else: 45 | raise NotImplementedError() 46 | print("Number of the class = %d" % nb_classes) 47 | 48 | return dataset, nb_classes 49 | 50 | 51 | def build_transform(is_train, args): 52 | resize_im = args.input_size > 32 53 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 54 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 55 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 56 | 57 | if is_train: 58 | # this should always dispatch to transforms_imagenet_train 59 | transform = create_transform( 60 | input_size=args.input_size, 61 | is_training=True, 62 | color_jitter=args.color_jitter, 63 | auto_augment=args.aa, 64 | interpolation=args.train_interpolation, 65 | re_prob=args.reprob, 66 | re_mode=args.remode, 67 | re_count=args.recount, 68 | mean=mean, 69 | std=std, 70 | ) 71 | if not resize_im: 72 | transform.transforms[0] = transforms.RandomCrop( 73 | args.input_size, padding=4) 74 | return transform 75 | 76 | t = [] 77 | if resize_im: 78 | # warping (no cropping) when evaluated at 384 or larger 79 | if args.input_size >= 384: 80 | t.append( 81 | transforms.Resize((args.input_size, args.input_size), 82 | interpolation=transforms.InterpolationMode.BICUBIC), 83 | ) 84 | print(f"Warping {args.input_size} size input images...") 85 | else: 86 | if args.crop_pct is None: 87 | args.crop_pct = 224 / 256 88 | size = int(args.input_size / args.crop_pct) 89 | t.append( 90 | # to maintain same ratio w.r.t. 224 images 91 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 92 | ) 93 | t.append(transforms.CenterCrop(args.input_size)) 94 | 95 | t.append(transforms.ToTensor()) 96 | t.append(transforms.Normalize(mean, std)) 97 | return transforms.Compose(t) 98 | -------------------------------------------------------------------------------- /drop_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode="standard", schedule="constant"): 4 | assert mode in ["standard", "early", "late"] 5 | if mode == "standard": 6 | return np.full(epochs * niter_per_ep, drop_rate) 7 | 8 | early_iters = cutoff_epoch * niter_per_ep 9 | late_iters = (epochs - cutoff_epoch) * niter_per_ep 10 | 11 | if mode == "early": 12 | assert schedule in ["constant", "linear"] 13 | if schedule == 'constant': 14 | early_schedule = np.full(early_iters, drop_rate) 15 | elif schedule == 'linear': 16 | early_schedule = np.linspace(drop_rate, 0, early_iters) 17 | final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0))) 18 | 19 | elif mode == "late": 20 | assert schedule in ["constant"] 21 | early_schedule = np.full(early_iters, 0) 22 | final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate))) 23 | 24 | assert len(final_schedule) == epochs * niter_per_ep 25 | return final_schedule -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterable, Optional 3 | import torch 4 | from timm.data import Mixup 5 | from timm.utils import accuracy, ModelEma 6 | 7 | import utils 8 | 9 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 10 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 12 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 13 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, schedules={}, 14 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 15 | model.train(True) 16 | metric_logger = utils.MetricLogger(delimiter=" ") 17 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 18 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 19 | header = 'Epoch: [{}]'.format(epoch) 20 | print_freq = 10 21 | 22 | optimizer.zero_grad() 23 | 24 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 25 | step = data_iter_step // update_freq 26 | if step >= num_training_steps_per_epoch: 27 | continue 28 | it = start_steps + step # global training iteration 29 | # Update LR & WD for the first acc 30 | if data_iter_step % update_freq == 0: 31 | if lr_schedule_values is not None or wd_schedule_values is not None: 32 | for i, param_group in enumerate(optimizer.param_groups): 33 | if lr_schedule_values is not None: 34 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 35 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 36 | param_group["weight_decay"] = wd_schedule_values[it] 37 | if 'dp' in schedules: 38 | model.module.update_drop_path(schedules['dp'][it]) 39 | if 'do' in schedules: 40 | model.module.update_dropout(schedules['do'][it]) 41 | 42 | samples = samples.to(device, non_blocking=True) 43 | targets = targets.to(device, non_blocking=True) 44 | 45 | if mixup_fn is not None: 46 | samples, targets = mixup_fn(samples, targets) 47 | 48 | if use_amp: 49 | with torch.cuda.amp.autocast(): 50 | output = model(samples) 51 | loss = criterion(output, targets) 52 | else: # full precision 53 | output = model(samples) 54 | loss = criterion(output, targets) 55 | 56 | loss_value = loss.item() 57 | 58 | if not math.isfinite(loss_value): # this could trigger if using AMP 59 | print("Loss is {}, stopping training".format(loss_value)) 60 | assert math.isfinite(loss_value) 61 | 62 | if use_amp: 63 | # this attribute is added by timm on one optimizer (adahessian) 64 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 65 | loss /= update_freq 66 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 67 | parameters=model.parameters(), create_graph=is_second_order, 68 | update_grad=(data_iter_step + 1) % update_freq == 0) 69 | if (data_iter_step + 1) % update_freq == 0: 70 | optimizer.zero_grad() 71 | if model_ema is not None: 72 | model_ema.update(model) 73 | else: # full precision 74 | loss /= update_freq 75 | loss.backward() 76 | if (data_iter_step + 1) % update_freq == 0: 77 | optimizer.step() 78 | optimizer.zero_grad() 79 | if model_ema is not None: 80 | model_ema.update(model) 81 | 82 | torch.cuda.synchronize() 83 | 84 | if mixup_fn is None: 85 | class_acc = (output.max(-1)[-1] == targets).float().mean() 86 | else: 87 | class_acc = None 88 | metric_logger.update(loss=loss_value) 89 | metric_logger.update(class_acc=class_acc) 90 | min_lr = 10. 91 | max_lr = 0. 92 | for group in optimizer.param_groups: 93 | min_lr = min(min_lr, group["lr"]) 94 | max_lr = max(max_lr, group["lr"]) 95 | 96 | metric_logger.update(lr=max_lr) 97 | metric_logger.update(min_lr=min_lr) 98 | weight_decay_value = None 99 | for group in optimizer.param_groups: 100 | if group["weight_decay"] > 0: 101 | weight_decay_value = group["weight_decay"] 102 | metric_logger.update(weight_decay=weight_decay_value) 103 | 104 | if 'dp' in schedules: 105 | metric_logger.update(drop_path=model.module.drop_path) 106 | 107 | if 'do' in schedules: 108 | metric_logger.update(dropout=model.module.drop_rate) 109 | 110 | if use_amp: 111 | metric_logger.update(grad_norm=grad_norm) 112 | 113 | if log_writer is not None: 114 | log_writer.update(loss=loss_value, head="loss") 115 | log_writer.update(class_acc=class_acc, head="loss") 116 | log_writer.update(lr=max_lr, head="opt") 117 | log_writer.update(min_lr=min_lr, head="opt") 118 | log_writer.update(weight_decay=weight_decay_value, head="opt") 119 | if use_amp: 120 | log_writer.update(grad_norm=grad_norm, head="opt") 121 | log_writer.set_step() 122 | 123 | if wandb_logger: 124 | wandb_logger._wandb.log({ 125 | 'Rank-0 Batch Wise/train_loss': loss_value, 126 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 127 | 'Rank-0 Batch Wise/train_min_lr': min_lr 128 | }, commit=False) 129 | if class_acc: 130 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 131 | if use_amp: 132 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 133 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 134 | 135 | 136 | # gather the stats from all processes 137 | metric_logger.synchronize_between_processes() 138 | print("Averaged stats:", metric_logger) 139 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 140 | 141 | @torch.no_grad() 142 | def evaluate(data_loader, model, device, use_amp=False): 143 | criterion = torch.nn.CrossEntropyLoss() 144 | 145 | metric_logger = utils.MetricLogger(delimiter=" ") 146 | header = 'Test:' 147 | 148 | # switch to evaluation mode 149 | model.eval() 150 | for batch in metric_logger.log_every(data_loader, 10, header): 151 | images = batch[0] 152 | target = batch[-1] 153 | 154 | images = images.to(device, non_blocking=True) 155 | target = target.to(device, non_blocking=True) 156 | 157 | # compute output 158 | if use_amp: 159 | with torch.cuda.amp.autocast(): 160 | output = model(images) 161 | loss = criterion(output, target) 162 | else: 163 | output = model(images) 164 | loss = criterion(output, target) 165 | 166 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 167 | 168 | batch_size = images.shape[0] 169 | metric_logger.update(loss=loss.item()) 170 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 171 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 172 | # gather the stats from all processes 173 | metric_logger.synchronize_between_processes() 174 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 175 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 176 | 177 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 178 | -------------------------------------------------------------------------------- /engine_soft.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterable, Optional 3 | import torch 4 | from timm.data import Mixup 5 | from timm.utils import accuracy, ModelEma 6 | 7 | import utils 8 | 9 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 10 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 12 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 13 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, schedules={}, 14 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 15 | model.train(True) 16 | metric_logger = utils.MetricLogger(delimiter=" ") 17 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 18 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 19 | header = 'Epoch: [{}]'.format(epoch) 20 | print_freq = 10 21 | 22 | optimizer.zero_grad() 23 | 24 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 25 | step = data_iter_step // update_freq 26 | if step >= num_training_steps_per_epoch: 27 | continue 28 | it = start_steps + step # global training iteration 29 | # Update LR & WD for the first acc 30 | if data_iter_step % update_freq == 0: 31 | if lr_schedule_values is not None or wd_schedule_values is not None: 32 | for i, param_group in enumerate(optimizer.param_groups): 33 | if lr_schedule_values is not None: 34 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 35 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 36 | param_group["weight_decay"] = wd_schedule_values[it] 37 | if 'dp' in schedules: 38 | model.module.update_drop_path(schedules['dp'][it]) 39 | if 'do' in schedules: 40 | model.module.update_dropout(schedules['do'][it]) 41 | if 'sm' in schedules: 42 | model.module.update_soft_mask(schedules['sm'][it]) 43 | 44 | samples = samples.to(device, non_blocking=True) 45 | targets = targets.to(device, non_blocking=True) 46 | 47 | if mixup_fn is not None: 48 | samples, targets = mixup_fn(samples, targets) 49 | 50 | if use_amp: 51 | with torch.cuda.amp.autocast(): 52 | output = model(samples) 53 | loss = criterion(output, targets) 54 | else: # full precision 55 | output = model(samples) 56 | loss = criterion(output, targets) 57 | 58 | loss_value = loss.item() 59 | 60 | if not math.isfinite(loss_value): # this could trigger if using AMP 61 | print("Loss is {}, stopping training".format(loss_value)) 62 | assert math.isfinite(loss_value) 63 | 64 | if use_amp: 65 | # this attribute is added by timm on one optimizer (adahessian) 66 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 67 | loss /= update_freq 68 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 69 | parameters=model.parameters(), create_graph=is_second_order, 70 | update_grad=(data_iter_step + 1) % update_freq == 0) 71 | if (data_iter_step + 1) % update_freq == 0: 72 | optimizer.zero_grad() 73 | if model_ema is not None: 74 | model_ema.update(model) 75 | else: # full precision 76 | loss /= update_freq 77 | loss.backward() 78 | if (data_iter_step + 1) % update_freq == 0: 79 | optimizer.step() 80 | optimizer.zero_grad() 81 | if model_ema is not None: 82 | model_ema.update(model) 83 | 84 | torch.cuda.synchronize() 85 | 86 | if mixup_fn is None: 87 | class_acc = (output.max(-1)[-1] == targets).float().mean() 88 | else: 89 | class_acc = None 90 | metric_logger.update(loss=loss_value) 91 | metric_logger.update(class_acc=class_acc) 92 | min_lr = 10. 93 | max_lr = 0. 94 | for group in optimizer.param_groups: 95 | min_lr = min(min_lr, group["lr"]) 96 | max_lr = max(max_lr, group["lr"]) 97 | 98 | metric_logger.update(lr=max_lr) 99 | metric_logger.update(min_lr=min_lr) 100 | weight_decay_value = None 101 | for group in optimizer.param_groups: 102 | if group["weight_decay"] > 0: 103 | weight_decay_value = group["weight_decay"] 104 | metric_logger.update(weight_decay=weight_decay_value) 105 | 106 | if 'dp' in schedules: 107 | metric_logger.update(drop_path=model.module.drop_path) 108 | 109 | if 'do' in schedules: 110 | metric_logger.update(dropout=model.module.drop_rate) 111 | 112 | if 'sm' in schedules: 113 | metric_logger.update(soft_mask_rate=model.module.soft_mask_rate) 114 | 115 | if use_amp: 116 | metric_logger.update(grad_norm=grad_norm) 117 | 118 | if log_writer is not None: 119 | log_writer.update(loss=loss_value, head="loss") 120 | log_writer.update(class_acc=class_acc, head="loss") 121 | log_writer.update(lr=max_lr, head="opt") 122 | log_writer.update(min_lr=min_lr, head="opt") 123 | log_writer.update(weight_decay=weight_decay_value, head="opt") 124 | if use_amp: 125 | log_writer.update(grad_norm=grad_norm, head="opt") 126 | log_writer.set_step() 127 | 128 | if wandb_logger: 129 | wandb_logger._wandb.log({ 130 | 'Rank-0 Batch Wise/train_loss': loss_value, 131 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 132 | 'Rank-0 Batch Wise/train_min_lr': min_lr 133 | }, commit=False) 134 | if class_acc: 135 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 136 | if use_amp: 137 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 138 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 139 | 140 | 141 | # gather the stats from all processes 142 | metric_logger.synchronize_between_processes() 143 | print("Averaged stats:", metric_logger) 144 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 145 | 146 | @torch.no_grad() 147 | def evaluate(data_loader, model, device, use_amp=False): 148 | criterion = torch.nn.CrossEntropyLoss() 149 | 150 | metric_logger = utils.MetricLogger(delimiter=" ") 151 | header = 'Test:' 152 | 153 | # switch to evaluation mode 154 | model.eval() 155 | for batch in metric_logger.log_every(data_loader, 10, header): 156 | images = batch[0] 157 | target = batch[-1] 158 | 159 | images = images.to(device, non_blocking=True) 160 | target = target.to(device, non_blocking=True) 161 | 162 | # compute output 163 | if use_amp: 164 | with torch.cuda.amp.autocast(): 165 | output = model(images) 166 | loss = criterion(output, target) 167 | else: 168 | output = model(images) 169 | loss = criterion(output, target) 170 | 171 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 172 | 173 | batch_size = images.shape[0] 174 | metric_logger.update(loss=loss.item()) 175 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 176 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 177 | # gather the stats from all processes 178 | metric_logger.synchronize_between_processes() 179 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 180 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 181 | 182 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 183 | -------------------------------------------------------------------------------- /llama2/weight_selection.py: -------------------------------------------------------------------------------- 1 | from safetensors import safe_open 2 | import os 3 | import torch 4 | import collections 5 | import sys 6 | from functools import partial 7 | sys.path.insert(0, '/mnt/petrelfs/wangjiahao/DoiT') 8 | from models.illama import VisionTransformer as iLLaMa 9 | from models.illama import RMSNorm 10 | 11 | 12 | def uniform_element_selection(wt, s_shape): 13 | assert wt.dim() == len(s_shape), "Tensors have different number of dimensions" 14 | ws = wt.clone() 15 | for dim in range(wt.dim()): 16 | assert wt.shape[dim] >= s_shape[dim], "Teacher's dimension should not be smaller than student's dimension" # determine whether teacher is larger than student on this dimension 17 | if wt.shape[dim] % s_shape[dim] == 0: 18 | step = wt.shape[dim] // s_shape[dim] 19 | indices = torch.arange(s_shape[dim]) * step 20 | else: 21 | indices = torch.round(torch.linspace(0, wt.shape[dim]-1, s_shape[dim])).to(torch.int64) 22 | ws = torch.index_select(ws, dim, indices) 23 | assert ws.shape == s_shape 24 | return ws 25 | 26 | 27 | # show iLLaMA keys 28 | # illama_t = iLLaMa( 29 | # patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 30 | # norm_layer=partial(RMSNorm, eps=1e-6), num_classes=1000, drop_path_rate=0.0, drop_rate=0.0) 31 | tensors_illama = torch.load('/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-tiny-in1k-75.0.pth', map_location='cpu') 32 | if 'model' in tensors_illama.keys(): 33 | tensors_illama = tensors_illama['model'] 34 | print("illama keys:") 35 | print("\n".join(tensors_illama.keys())) 36 | 37 | 38 | # load llama2-7b-hf 39 | tensors_llama2 = {} 40 | path="llama2-7b-hf" 41 | for n in ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"]: 42 | file_name = os.path.join(path, n) 43 | with safe_open(file_name, framework="pt", device="cpu") as f: 44 | for k in f.keys(): 45 | tensors_llama2[k] = f.get_tensor(k) 46 | print("llama2 keys:") 47 | print("\n".join(tensors_llama2.keys())) 48 | print("load done") 49 | 50 | # (576, 192) 51 | print(tensors_illama["blocks.8.attn.qkv.weight"].shape) 52 | # (4096, 4096) 53 | print(tensors_llama2["model.layers.12.self_attn.q_proj.weight"].shape) 54 | # (11008, 4096) 55 | print(tensors_llama2["model.layers.12.mlp.up_proj.weight"].shape) 56 | # (11008, 4096) 57 | print(tensors_llama2["model.layers.12.mlp.gate_proj.weight"].shape) 58 | # (4096) 59 | print(tensors_llama2["model.layers.12.input_layernorm.weight"].shape) 60 | 61 | 62 | # illama: 63 | # blocks.8.norm1.weight 64 | # blocks.8.attn.qkv.weight 65 | # blocks.8.attn.qkv.bias 66 | # blocks.8.attn.proj.weight 67 | # blocks.8.attn.proj.bias 68 | # blocks.8.norm2.weight 69 | # blocks.8.mlp.fc1.weight 70 | # blocks.8.mlp.fc2.weight 71 | # blocks.8.mlp.fc3.weight 72 | 73 | # llama2: 74 | # model.layers.12.input_layernorm.weight 75 | # model.layers.12.mlp.down_proj.weight 76 | # model.layers.12.mlp.gate_proj.weight 77 | # model.layers.12.mlp.up_proj.weight 78 | # model.layers.12.post_attention_layernorm.weight 79 | # model.layers.12.self_attn.k_proj.weight 80 | # model.layers.12.self_attn.o_proj.weight 81 | # model.layers.12.self_attn.q_proj.weight 82 | # model.layers.12.self_attn.rotary_emb.inv_freq 83 | # model.layers.12.self_attn.v_proj.weight 84 | 85 | tensors_llama2_to_illama = collections.OrderedDict() 86 | # for k,v in tensors_llama2.items(): 87 | for i in range(12): 88 | # norm 89 | tensors_llama2_to_illama["blocks." + str(i) + ".norm1.weight"] = \ 90 | tensors_llama2["model.layers." + str(i) + ".input_layernorm.weight"] 91 | tensors_llama2_to_illama["blocks." + str(i) + ".norm2.weight"] = \ 92 | tensors_llama2["model.layers." + str(i) + ".post_attention_layernorm.weight"] 93 | # attn 94 | tensors_llama2_to_illama["blocks." + str(i) + ".attn.qkv.weight"] = \ 95 | torch.cat((tensors_llama2["model.layers." + str(i) + ".self_attn.q_proj.weight"], 96 | tensors_llama2["model.layers." + str(i) + ".self_attn.k_proj.weight"], 97 | tensors_llama2["model.layers." + str(i) + ".self_attn.v_proj.weight"]), dim=0) 98 | tensors_llama2_to_illama["blocks." + str(i) + ".attn.proj.weight"] = \ 99 | tensors_llama2["model.layers." + str(i) + ".self_attn.o_proj.weight"] 100 | # ffn 101 | tensors_llama2_to_illama["blocks." + str(i) + ".mlp.fc1.weight"] = \ 102 | tensors_llama2["model.layers." + str(i) + ".mlp.gate_proj.weight"] 103 | tensors_llama2_to_illama["blocks." + str(i) + ".mlp.fc2.weight"] = \ 104 | tensors_llama2["model.layers." + str(i) + ".mlp.down_proj.weight"] 105 | tensors_llama2_to_illama["blocks." + str(i) + ".mlp.fc3.weight"] = \ 106 | tensors_llama2["model.layers." + str(i) + ".mlp.up_proj.weight"] 107 | print("llama2_to_illama keys:") 108 | print("\n".join(tensors_llama2_to_illama.keys())) 109 | 110 | # save new 111 | # torch.save(tensors_llama2_to_illama, 'llama2/1.pth') 112 | 113 | 114 | 115 | # weight selection from llama2 to illama 116 | student_illama_tiny = iLLaMa( 117 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 118 | norm_layer=partial(RMSNorm, eps=1e-6), num_classes=1000, drop_path_rate=0.0, drop_rate=0.0) 119 | student_illama_small = iLLaMa( 120 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 121 | norm_layer=partial(RMSNorm, eps=1e-6), num_classes=1000, drop_path_rate=0.0, drop_rate=0.0) 122 | student_illama_base = iLLaMa( 123 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 124 | norm_layer=partial(RMSNorm, eps=1e-6), num_classes=1000, drop_path_rate=0.0, drop_rate=0.0) 125 | 126 | teacher_weights = tensors_llama2_to_illama 127 | student_weights_tiny = student_illama_tiny.state_dict() 128 | student_weights_small = student_illama_small.state_dict() 129 | student_weights_base = student_illama_base.state_dict() 130 | 131 | # illama tiny weight selection 132 | student_weight_selection_tiny = collections.OrderedDict() 133 | for key in student_weights_tiny.keys(): 134 | if "norm1.weight" in key or "norm2.weight" in key or "attn.qkv.weight" in key or "attn.proj.weight" in key or "mlp.fc1.weight" in key or "mlp.fc2.weight" in key or "mlp.fc3.weight" in key: 135 | print(f"initializing {key}") 136 | print("teacher_weights shape:", teacher_weights[key].shape) 137 | print("student_weights_tiny shape:", student_weights_tiny[key].shape) 138 | student_weight_selection_tiny[key] = uniform_element_selection(teacher_weights[key], student_weights_tiny[key].shape) 139 | print("student_weights_tiny initialization done") 140 | 141 | # illama small weight selection 142 | student_weight_selection_small = collections.OrderedDict() 143 | for key in student_weights_small.keys(): 144 | if "norm1.weight" in key or "norm2.weight" in key or "attn.qkv.weight" in key or "attn.proj.weight" in key or "mlp.fc1.weight" in key or "mlp.fc2.weight" in key or "mlp.fc3.weight" in key: 145 | print(f"initializing {key}") 146 | print("teacher_weights shape:", teacher_weights[key].shape) 147 | print("student_weights_small shape:", student_weights_small[key].shape) 148 | student_weight_selection_small[key] = uniform_element_selection(teacher_weights[key], student_weights_small[key].shape) 149 | print("student_weights_small initialization done") 150 | 151 | # illama base weight selection 152 | student_weight_selection_base = collections.OrderedDict() 153 | for key in student_weights_base.keys(): 154 | if "norm1.weight" in key or "norm2.weight" in key or "attn.qkv.weight" in key or "attn.proj.weight" in key or "mlp.fc1.weight" in key or "mlp.fc2.weight" in key or "mlp.fc3.weight" in key: 155 | print(f"initializing {key}") 156 | print("teacher_weights shape:", teacher_weights[key].shape) 157 | print("student_weights_base shape:", student_weights_base[key].shape) 158 | student_weight_selection_base[key] = uniform_element_selection(teacher_weights[key], student_weights_base[key].shape) 159 | print("student_weights_base initialization done") 160 | 161 | 162 | # check keys for selected tiny small base 163 | for key in student_weight_selection_tiny.keys(): 164 | assert key in student_weights_tiny, f"Key {key} not found in Model iLLaMA-T" 165 | assert student_weight_selection_tiny[key].shape == student_weights_tiny[key].shape, f"Shape mismatch for key {key}: {student_weight_selection_tiny[key].shape} != {student_weights_tiny[key].shape}" 166 | for key in student_weight_selection_small.keys(): 167 | assert key in student_weights_small, f"Key {key} not found in Model iLLaMA-S" 168 | assert student_weight_selection_small[key].shape == student_weights_small[key].shape, f"Shape mismatch for key {key}: {student_weight_selection_small[key].shape} != {student_weights_small[key].shape}" 169 | for key in student_weight_selection_base.keys(): 170 | assert key in student_weights_base, f"Key {key} not found in Model iLLaMA-B" 171 | assert student_weight_selection_base[key].shape == student_weights_base[key].shape, f"Shape mismatch for key {key}: {student_weight_selection_base[key].shape} != {student_weights_base[key].shape}" 172 | 173 | print("All keys checked, you can use student_weight_selection_tiny, student_weight_selection_small, student_weight_selection_base.") 174 | 175 | # save weight selection for illama 176 | torch.save(student_weight_selection_tiny, 'llama2/pretrained/illama_ws_tiny.pth') 177 | torch.save(student_weight_selection_small, 'llama2/pretrained/illama_ws_small.pth') 178 | torch.save(student_weight_selection_base, 'llama2/pretrained/illama_ws_base.pth') 179 | -------------------------------------------------------------------------------- /logs/illama-base-in1k-384-83.0.txt: -------------------------------------------------------------------------------- 1 | {"train_lr": 7.992788848282348e-05, "train_min_lr": 7.992788848282348e-05, "train_loss": 2.473320546398441, "train_class_acc": 0.648087779776179, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.733891578917285, "test_acc1": 81.76600249023437, "test_acc5": 95.97600270751953, "epoch": 0, "n_parameters": 86794216} 2 | {"train_lr": 7.949599477065516e-05, "train_min_lr": 7.949599477065516e-05, "train_loss": 2.2826324206866997, "train_class_acc": 0.6871768210431655, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7196540747989566, "test_acc1": 82.07200255126953, "test_acc5": 96.05600268554687, "epoch": 1, "n_parameters": 86794216} 3 | {"train_lr": 7.863685277934232e-05, "train_min_lr": 7.863685277934232e-05, "train_loss": 2.2307381229601697, "train_class_acc": 0.6979994129696243, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7166183656650303, "test_acc1": 82.18800258789062, "test_acc5": 96.01400269775391, "epoch": 2, "n_parameters": 86794216} 4 | {"train_lr": 7.735987544832995e-05, "train_min_lr": 7.735987544832995e-05, "train_loss": 2.199187022056892, "train_class_acc": 0.7046206784572342, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7158498064906315, "test_acc1": 82.26000260009765, "test_acc5": 96.04400275634765, "epoch": 3, "n_parameters": 86794216} 5 | {"train_lr": 7.567905360848068e-05, "train_min_lr": 7.567905360848068e-05, "train_loss": 2.176901005824907, "train_class_acc": 0.7092958445743405, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7096047366308119, "test_acc1": 82.44200252441406, "test_acc5": 96.15400270996093, "epoch": 4, "n_parameters": 86794216} 6 | {"train_lr": 7.361280269560707e-05, "train_min_lr": 7.361280269560707e-05, "train_loss": 2.155886616605482, "train_class_acc": 0.714189585831335, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7127004508711113, "test_acc1": 82.37200253417969, "test_acc5": 96.11600269287109, "epoch": 5, "n_parameters": 86794216} 7 | {"train_lr": 7.118376098710012e-05, "train_min_lr": 7.118376098710012e-05, "train_loss": 2.143780262636052, "train_class_acc": 0.7167828237410072, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.712347365247498, "test_acc1": 82.32000246337891, "test_acc5": 96.13600271484376, "epoch": 6, "n_parameters": 86794216} 8 | {"train_lr": 6.841854157222974e-05, "train_min_lr": 6.841854157222974e-05, "train_loss": 2.1295641221516997, "train_class_acc": 0.7199880720423661, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7088617964413102, "test_acc1": 82.33000249267577, "test_acc5": 96.1380027319336, "epoch": 7, "n_parameters": 86794216} 9 | {"train_lr": 6.534744077356283e-05, "train_min_lr": 6.534744077356283e-05, "train_loss": 2.117899560379229, "train_class_acc": 0.7227358738009593, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7082845338243451, "test_acc1": 82.44600254394531, "test_acc5": 96.18600267333984, "epoch": 8, "n_parameters": 86794216} 10 | {"train_lr": 6.200410621411962e-05, "train_min_lr": 6.200410621411962e-05, "train_loss": 2.105851644964265, "train_class_acc": 0.7254204448940847, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7060970708599869, "test_acc1": 82.48400255371094, "test_acc5": 96.21400270996094, "epoch": 9, "n_parameters": 86794216} 11 | {"train_lr": 5.8425168166971116e-05, "train_min_lr": 5.8425168166971116e-05, "train_loss": 2.0965708284486206, "train_class_acc": 0.7277209482414069, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7031197266738493, "test_acc1": 82.65600259033204, "test_acc5": 96.1560027368164, "epoch": 10, "n_parameters": 86794216} 12 | {"train_lr": 5.464983822630283e-05, "train_min_lr": 5.464983822630283e-05, "train_loss": 2.087975359384081, "train_class_acc": 0.7299106027677857, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7062890046126856, "test_acc1": 82.49000248291016, "test_acc5": 96.24600264404297, "epoch": 11, "n_parameters": 86794216} 13 | {"train_lr": 5.0719479696983124e-05, "train_min_lr": 5.0719479696983124e-05, "train_loss": 2.0755441978746516, "train_class_acc": 0.7325998576139089, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7021110344550195, "test_acc1": 82.65200256591797, "test_acc5": 96.22800266845704, "epoch": 12, "n_parameters": 86794216} 14 | {"train_lr": 4.667715440953944e-05, "train_min_lr": 4.667715440953944e-05, "train_loss": 2.069228977078109, "train_class_acc": 0.7342329261590728, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6990430768266434, "test_acc1": 82.67000250732421, "test_acc5": 96.24800268310547, "epoch": 13, "n_parameters": 86794216} 15 | {"train_lr": 4.256715092573226e-05, "train_min_lr": 4.256715092573226e-05, "train_loss": 2.060961584470493, "train_class_acc": 0.7357473396282974, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6991592321104108, "test_acc1": 82.67000253417969, "test_acc5": 96.28400275634766, "epoch": 14, "n_parameters": 86794216} 16 | {"train_lr": 3.843449930380363e-05, "train_min_lr": 3.843449930380363e-05, "train_loss": 2.0526455630340594, "train_class_acc": 0.7377301283972821, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7010124566189888, "test_acc1": 82.69000256835938, "test_acc5": 96.27200267578125, "epoch": 15, "n_parameters": 86794216} 17 | {"train_lr": 3.432447773973649e-05, "train_min_lr": 3.432447773973649e-05, "train_loss": 2.0455116031558918, "train_class_acc": 0.7400072129796164, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6991047438248309, "test_acc1": 82.73400261962891, "test_acc5": 96.31600272216797, "epoch": 16, "n_parameters": 86794216} 18 | {"train_lr": 3.0282116489863717e-05, "train_min_lr": 3.0282116489863717e-05, "train_loss": 2.036234163123069, "train_class_acc": 0.7421242693345323, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6974382813067268, "test_acc1": 82.86800261962891, "test_acc5": 96.26000268310547, "epoch": 17, "n_parameters": 86794216} 19 | {"train_lr": 2.635170450995756e-05, "train_min_lr": 2.635170450995756e-05, "train_loss": 2.031415472404181, "train_class_acc": 0.7433147232214229, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6950523978355163, "test_acc1": 82.86200258544922, "test_acc5": 96.27200266845703, "epoch": 18, "n_parameters": 86794216} 20 | {"train_lr": 2.2576304216161565e-05, "train_min_lr": 2.2576304216161565e-05, "train_loss": 2.026128534757667, "train_class_acc": 0.7445496727617905, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6956147605212022, "test_acc1": 82.88400256835938, "test_acc5": 96.3520026928711, "epoch": 19, "n_parameters": 86794216} 21 | {"train_lr": 1.8997279684147784e-05, "train_min_lr": 1.8997279684147784e-05, "train_loss": 2.020787144066404, "train_class_acc": 0.7456807991107114, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6950576499118718, "test_acc1": 82.85400256835938, "test_acc5": 96.28000270019531, "epoch": 20, "n_parameters": 86794216} 22 | {"train_lr": 1.5653843455648113e-05, "train_min_lr": 1.5653843455648113e-05, "train_loss": 2.0178974027381384, "train_class_acc": 0.7466074015787371, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6941298640596162, "test_acc1": 82.87400258544922, "test_acc5": 96.32000274169921, "epoch": 21, "n_parameters": 86794216} 23 | {"train_lr": 1.2582626917640712e-05, "train_min_lr": 1.2582626917640712e-05, "train_loss": 2.0134329836723284, "train_class_acc": 0.7475152690347722, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6930451631574458, "test_acc1": 82.90600259521484, "test_acc5": 96.32400272460937, "epoch": 22, "n_parameters": 86794216} 24 | {"train_lr": 9.817278961209553e-06, "train_min_lr": 9.817278961209553e-06, "train_loss": 2.010525077039437, "train_class_acc": 0.7482654501398881, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6927108393967835, "test_acc1": 82.94400260009766, "test_acc5": 96.34200279785156, "epoch": 23, "n_parameters": 86794216} 25 | {"train_lr": 7.388097317251548e-06, "train_min_lr": 7.388097317251548e-06, "train_loss": 2.007418536215568, "train_class_acc": 0.7493333458233413, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6922111400680346, "test_acc1": 82.94600257324218, "test_acc5": 96.33200274169921, "epoch": 24, "n_parameters": 86794216} 26 | {"train_lr": 5.321696608196864e-06, "train_min_lr": 5.321696608196864e-06, "train_loss": 2.002901304453659, "train_class_acc": 0.7501826663669064, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.692061053752501, "test_acc1": 82.95000255615234, "test_acc5": 96.33200280029297, "epoch": 25, "n_parameters": 86794216} 27 | {"train_lr": 3.6407167526360116e-06, "train_min_lr": 3.6407167526360116e-06, "train_loss": 2.001760026268798, "train_class_acc": 0.7506760216826539, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6917145001268341, "test_acc1": 82.95800252929688, "test_acc5": 96.34800278320313, "epoch": 26, "n_parameters": 86794216} 28 | {"train_lr": 2.3635749176341855e-06, "train_min_lr": 2.3635749176341855e-06, "train_loss": 2.002015811621214, "train_class_acc": 0.7507548648581135, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6918871564264516, "test_acc1": 82.95200257324218, "test_acc5": 96.33200277832032, "epoch": 27, "n_parameters": 86794216} 29 | {"train_lr": 1.5042637363947687e-06, "train_min_lr": 1.5042637363947687e-06, "train_loss": 1.9995884048453958, "train_class_acc": 0.7511397132294164, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6912183112443517, "test_acc1": 82.9560025805664, "test_acc5": 96.33800274902343, "epoch": 28, "n_parameters": 86794216} 30 | {"train_lr": 1.0721980020418422e-06, "train_min_lr": 1.0721980020418422e-06, "train_loss": 2.000189851820588, "train_class_acc": 0.7509312862210232, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6910126191272422, "test_acc1": 82.99800246337891, "test_acc5": 96.33600276855469, "epoch": 29, "n_parameters": 86794216} 31 | -------------------------------------------------------------------------------- /logs/illama-base-in21kin1k-224-83.6.txt: -------------------------------------------------------------------------------- 1 | {"train_lr": 7.992788848282348e-05, "train_min_lr": 7.992788848282348e-05, "train_loss": 2.8800309649784026, "train_class_acc": 0.6066779388988809, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.8331461997404126, "test_acc1": 79.12400257568359, "test_acc5": 95.1300026611328, "epoch": 0, "n_parameters": 86502376} 2 | {"train_lr": 7.949599477065516e-05, "train_min_lr": 7.949599477065516e-05, "train_loss": 2.2193321749556074, "train_class_acc": 0.6949448253896883, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.7763625056916521, "test_acc1": 80.35400242919921, "test_acc5": 95.79000268554688, "epoch": 1, "n_parameters": 86502376} 3 | {"train_lr": 7.863685277934232e-05, "train_min_lr": 7.863685277934232e-05, "train_loss": 2.145499920530571, "train_class_acc": 0.7092076338928857, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.7521247481145249, "test_acc1": 81.05800251953124, "test_acc5": 96.04000268554688, "epoch": 2, "n_parameters": 86502376} 4 | {"train_lr": 7.735987544832995e-05, "train_min_lr": 7.735987544832995e-05, "train_loss": 2.0983406849008954, "train_class_acc": 0.718855384442446, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.7348379289077325, "test_acc1": 81.4920025769043, "test_acc5": 96.17200262939453, "epoch": 3, "n_parameters": 86502376} 5 | {"train_lr": 7.567905360848068e-05, "train_min_lr": 7.567905360848068e-05, "train_loss": 2.0673020389720165, "train_class_acc": 0.7251214653277378, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.71815004223915, "test_acc1": 81.904002578125, "test_acc5": 96.32000264404297, "epoch": 4, "n_parameters": 86502376} 6 | {"train_lr": 7.361280269560707e-05, "train_min_lr": 7.361280269560707e-05, "train_loss": 2.039417670573667, "train_class_acc": 0.7317458533173461, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.710230432720926, "test_acc1": 82.15400252319336, "test_acc5": 96.39000264648438, "epoch": 5, "n_parameters": 86502376} 7 | {"train_lr": 7.118376098710012e-05, "train_min_lr": 7.118376098710012e-05, "train_loss": 2.0196545109164705, "train_class_acc": 0.735794957783773, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.7012611285975302, "test_acc1": 82.18800258666992, "test_acc5": 96.47600269042968, "epoch": 6, "n_parameters": 86502376} 8 | {"train_lr": 6.841854157222974e-05, "train_min_lr": 6.841854157222974e-05, "train_loss": 2.0017266717883087, "train_class_acc": 0.7400204836131095, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.697616238310535, "test_acc1": 82.52600250366211, "test_acc5": 96.49200275146484, "epoch": 7, "n_parameters": 86502376} 9 | {"train_lr": 6.534744077356283e-05, "train_min_lr": 6.534744077356283e-05, "train_loss": 1.9866541259699493, "train_class_acc": 0.7436441471822542, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6930679365324269, "test_acc1": 82.61600248535156, "test_acc5": 96.51000275634766, "epoch": 8, "n_parameters": 86502376} 10 | {"train_lr": 6.200410621411962e-05, "train_min_lr": 6.200410621411962e-05, "train_loss": 1.9714939381977876, "train_class_acc": 0.7470945118904876, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6908177238825055, "test_acc1": 82.52600251464844, "test_acc5": 96.59200271240235, "epoch": 9, "n_parameters": 86502376} 11 | {"train_lr": 5.8425168166971116e-05, "train_min_lr": 5.8425168166971116e-05, "train_loss": 1.960133994753055, "train_class_acc": 0.7494847871702638, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6893090736263126, "test_acc1": 82.65200252685547, "test_acc5": 96.58000280029297, "epoch": 10, "n_parameters": 86502376} 12 | {"train_lr": 5.464983822630283e-05, "train_min_lr": 5.464983822630283e-05, "train_loss": 1.946019628276523, "train_class_acc": 0.7532965814848122, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6812137730903071, "test_acc1": 82.84400247802735, "test_acc5": 96.63000275146484, "epoch": 11, "n_parameters": 86502376} 13 | {"train_lr": 5.0719479696983124e-05, "train_min_lr": 5.0719479696983124e-05, "train_loss": 1.9346819410054423, "train_class_acc": 0.7557016886490807, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6739478692668313, "test_acc1": 83.09200239746093, "test_acc5": 96.69000272460937, "epoch": 12, "n_parameters": 86502376} 14 | {"train_lr": 4.667715440953944e-05, "train_min_lr": 4.667715440953944e-05, "train_loss": 1.9267961068881405, "train_class_acc": 0.75771492181255, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6747824225560514, "test_acc1": 83.08600253417968, "test_acc5": 96.75600277099609, "epoch": 13, "n_parameters": 86502376} 15 | {"train_lr": 4.256715092573226e-05, "train_min_lr": 4.256715092573226e-05, "train_loss": 1.916994124752202, "train_class_acc": 0.7603752935151878, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6742744893688737, "test_acc1": 83.1600024609375, "test_acc5": 96.69000269287109, "epoch": 14, "n_parameters": 86502376} 16 | {"train_lr": 3.843449930380363e-05, "train_min_lr": 3.843449930380363e-05, "train_loss": 1.9078451729953574, "train_class_acc": 0.7626750162370104, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6682798480686113, "test_acc1": 83.12800247558594, "test_acc5": 96.66800276855469, "epoch": 15, "n_parameters": 86502376} 17 | {"train_lr": 3.432447773973649e-05, "train_min_lr": 3.432447773973649e-05, "train_loss": 1.9017058968645253, "train_class_acc": 0.7641628884392486, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6697903494741625, "test_acc1": 83.13600244628907, "test_acc5": 96.710002734375, "epoch": 16, "n_parameters": 86502376} 18 | {"train_lr": 3.0282116489863717e-05, "train_min_lr": 3.0282116489863717e-05, "train_loss": 1.892893716779282, "train_class_acc": 0.7661409934552358, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6674242881304435, "test_acc1": 83.25400260986328, "test_acc5": 96.72800265625, "epoch": 17, "n_parameters": 86502376} 19 | {"train_lr": 2.635170450995756e-05, "train_min_lr": 2.635170450995756e-05, "train_loss": 1.8847338850770017, "train_class_acc": 0.7681589103717026, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.664745727070529, "test_acc1": 83.29600263183593, "test_acc5": 96.77000273681641, "epoch": 18, "n_parameters": 86502376} 20 | {"train_lr": 2.2576304216161565e-05, "train_min_lr": 2.2576304216161565e-05, "train_loss": 1.8793799648854992, "train_class_acc": 0.7697154776179057, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6642096964713499, "test_acc1": 83.43800253173828, "test_acc5": 96.77800274169923, "epoch": 19, "n_parameters": 86502376} 21 | {"train_lr": 1.8997279684147784e-05, "train_min_lr": 1.8997279684147784e-05, "train_loss": 1.8746926908661945, "train_class_acc": 0.7709285696442846, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6633108184452503, "test_acc1": 83.33400262695312, "test_acc5": 96.75200265869141, "epoch": 20, "n_parameters": 86502376} 22 | {"train_lr": 1.5653843455648113e-05, "train_min_lr": 1.5653843455648113e-05, "train_loss": 1.8700135239421558, "train_class_acc": 0.7721518098021583, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6620311849956749, "test_acc1": 83.49600257568359, "test_acc5": 96.81000268554688, "epoch": 21, "n_parameters": 86502376} 23 | {"train_lr": 1.2582626917640712e-05, "train_min_lr": 1.2582626917640712e-05, "train_loss": 1.8679893643503709, "train_class_acc": 0.7726155013489209, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6601207956515309, "test_acc1": 83.5120025805664, "test_acc5": 96.7720026977539, "epoch": 22, "n_parameters": 86502376} 24 | {"train_lr": 9.817278961209553e-06, "train_min_lr": 9.817278961209553e-06, "train_loss": 1.8622591408203355, "train_class_acc": 0.7740495103916867, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.660049624840608, "test_acc1": 83.48000255615234, "test_acc5": 96.81000270751953, "epoch": 23, "n_parameters": 86502376} 25 | {"train_lr": 7.388097317251548e-06, "train_min_lr": 7.388097317251548e-06, "train_loss": 1.8585193322585831, "train_class_acc": 0.7752017136290967, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6595075904526784, "test_acc1": 83.55000259033203, "test_acc5": 96.78800273681641, "epoch": 24, "n_parameters": 86502376} 26 | {"train_lr": 5.321696608196864e-06, "train_min_lr": 5.321696608196864e-06, "train_loss": 1.8578995852280769, "train_class_acc": 0.7751564373501199, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.657806609423106, "test_acc1": 83.60400248779297, "test_acc5": 96.79800274414063, "epoch": 25, "n_parameters": 86502376} 27 | {"train_lr": 3.6407167526360116e-06, "train_min_lr": 3.6407167526360116e-06, "train_loss": 1.8572985059429081, "train_class_acc": 0.7757184877098321, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6586373214545937, "test_acc1": 83.59400256347656, "test_acc5": 96.78200276123047, "epoch": 26, "n_parameters": 86502376} 28 | {"train_lr": 2.3635749176341855e-06, "train_min_lr": 2.3635749176341855e-06, "train_loss": 1.8557780712819119, "train_class_acc": 0.7756099807653877, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6584229432330787, "test_acc1": 83.53800257568359, "test_acc5": 96.79200271484375, "epoch": 27, "n_parameters": 86502376} 29 | {"train_lr": 1.5042637363947687e-06, "train_min_lr": 1.5042637363947687e-06, "train_loss": 1.8542869655306724, "train_class_acc": 0.7759206697142286, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6582214436884827, "test_acc1": 83.57600252685548, "test_acc5": 96.79600274169921, "epoch": 28, "n_parameters": 86502376} 30 | {"train_lr": 1.0721980020418422e-06, "train_min_lr": 1.0721980020418422e-06, "train_loss": 1.8537212703302561, "train_class_acc": 0.7762789768185452, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6580269141762539, "test_acc1": 83.56400259277343, "test_acc5": 96.79200272949218, "epoch": 29, "n_parameters": 86502376} 31 | -------------------------------------------------------------------------------- /logs/illama-base-in21kin1k-384-85.0.txt: -------------------------------------------------------------------------------- 1 | {"train_lr": 0.00010990050436237771, "train_min_lr": 0.00010990050436237771, "train_loss": 2.873497698356291, "train_class_acc": 0.5845987272681854, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.7941034439580079, "test_acc1": 80.10400251403809, "test_acc5": 95.64000197753906, "epoch": 0, "n_parameters": 86794216} 2 | {"train_lr": 0.0001093046003797633, "train_min_lr": 0.0001093046003797633, "train_loss": 2.1891729478314104, "train_class_acc": 0.6994349832633893, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.7331855757112019, "test_acc1": 81.6440025, "test_acc5": 96.29000178466796, "epoch": 1, "n_parameters": 86794216} 3 | {"train_lr": 0.00010811920193605486, "train_min_lr": 0.00010811920193605486, "train_loss": 2.0989335492375396, "train_class_acc": 0.7191855890287769, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6977391123586116, "test_acc1": 82.28000253540038, "test_acc5": 96.6640016796875, "epoch": 2, "n_parameters": 86794216} 4 | {"train_lr": 0.00010635729650465759, "train_min_lr": 0.00010635729650465759, "train_loss": 2.044830535443948, "train_class_acc": 0.7307005957733813, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6783538484995849, "test_acc1": 82.9180024987793, "test_acc5": 96.83400157958984, "epoch": 3, "n_parameters": 86794216} 5 | {"train_lr": 0.00010403818789018106, "train_min_lr": 0.00010403818789018106, "train_loss": 2.0058495105370033, "train_class_acc": 0.7397300909272582, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.661780134467872, "test_acc1": 83.26400256591796, "test_acc5": 96.93800155395508, "epoch": 4, "n_parameters": 86794216} 6 | {"train_lr": 0.00010118728473191276, "train_min_lr": 0.00010118728473191276, "train_loss": 1.974607850857478, "train_class_acc": 0.7472795201338929, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6588617378553211, "test_acc1": 83.33800248962402, "test_acc5": 96.9840015612793, "epoch": 5, "n_parameters": 86794216} 7 | {"train_lr": 9.783582212144216e-05, "train_min_lr": 9.783582212144216e-05, "train_loss": 1.9508491760591784, "train_class_acc": 0.7529296875, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6439973158098158, "test_acc1": 83.67800259216308, "test_acc5": 97.14000151123047, "epoch": 6, "n_parameters": 86794216} 8 | {"train_lr": 9.40205193844685e-05, "train_min_lr": 9.40205193844685e-05, "train_loss": 1.927743834557293, "train_class_acc": 0.7581224083233413, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6432044282341483, "test_acc1": 83.71600251037597, "test_acc5": 97.12800151611329, "epoch": 7, "n_parameters": 86794216} 9 | {"train_lr": 8.978317777618145e-05, "train_min_lr": 8.978317777618145e-05, "train_loss": 1.909357096816746, "train_class_acc": 0.7624626861011191, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6321660951921766, "test_acc1": 83.95000252563477, "test_acc5": 97.20400146850587, "epoch": 8, "n_parameters": 86794216} 10 | {"train_lr": 8.517022249796265e-05, "train_min_lr": 8.517022249796265e-05, "train_loss": 1.8913158439174236, "train_class_acc": 0.7667912544964028, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6311975168742211, "test_acc1": 84.0760025769043, "test_acc5": 97.18400147705078, "epoch": 9, "n_parameters": 86794216} 11 | {"train_lr": 8.023219405316332e-05, "train_min_lr": 8.023219405316332e-05, "train_loss": 1.8768250636360366, "train_class_acc": 0.7702845536071143, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6237856106889716, "test_acc1": 84.27800251464843, "test_acc5": 97.26600143920898, "epoch": 10, "n_parameters": 86794216} 12 | {"train_lr": 7.502319451477216e-05, "train_min_lr": 7.502319451477216e-05, "train_loss": 1.8619147903106505, "train_class_acc": 0.7740620003996802, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6272280179463698, "test_acc1": 84.36800256347657, "test_acc5": 97.22600145751953, "epoch": 11, "n_parameters": 86794216} 13 | {"train_lr": 6.960029477178692e-05, "train_min_lr": 6.960029477178692e-05, "train_loss": 1.849825464002937, "train_class_acc": 0.7773375049960032, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6230855795636887, "test_acc1": 84.43800255615234, "test_acc5": 97.23800145874023, "epoch": 12, "n_parameters": 86794216} 14 | {"train_lr": 6.402290924860439e-05, "train_min_lr": 6.402290924860439e-05, "train_loss": 1.8378819944076448, "train_class_acc": 0.7804053632094324, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.622640575275139, "test_acc1": 84.46200261657715, "test_acc5": 97.28800141845703, "epoch": 13, "n_parameters": 86794216} 15 | {"train_lr": 5.8352144948162407e-05, "train_min_lr": 5.8352144948162407e-05, "train_loss": 1.827634315547659, "train_class_acc": 0.783148481215028, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6218125294836887, "test_acc1": 84.50000257263184, "test_acc5": 97.27200144287109, "epoch": 14, "n_parameters": 86794216} 16 | {"train_lr": 5.265013195081766e-05, "train_min_lr": 5.265013195081766e-05, "train_loss": 1.8147839116398616, "train_class_acc": 0.7861304706235012, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6165422496325211, "test_acc1": 84.60000260375976, "test_acc5": 97.31800141357422, "epoch": 15, "n_parameters": 86794216} 17 | {"train_lr": 4.6979342704193364e-05, "train_min_lr": 4.6979342704193364e-05, "train_loss": 1.8073137391844003, "train_class_acc": 0.7881702450539568, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6163868690409138, "test_acc1": 84.58200257385253, "test_acc5": 97.33000140991211, "epoch": 16, "n_parameters": 86794216} 18 | {"train_lr": 4.1401907561964056e-05, "train_min_lr": 4.1401907561964056e-05, "train_loss": 1.7976006015250794, "train_class_acc": 0.7906003322342127, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6167496893897482, "test_acc1": 84.6180025604248, "test_acc5": 97.35400142089844, "epoch": 17, "n_parameters": 86794216} 19 | {"train_lr": 3.597893407070102e-05, "train_min_lr": 3.597893407070102e-05, "train_loss": 1.7888045026217219, "train_class_acc": 0.7930522769284573, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6136899749388012, "test_acc1": 84.5920025415039, "test_acc5": 97.39400140869141, "epoch": 18, "n_parameters": 86794216} 20 | {"train_lr": 3.076983746280512e-05, "train_min_lr": 3.076983746280512e-05, "train_loss": 1.783275410302168, "train_class_acc": 0.7947243767486011, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6135569893937002, "test_acc1": 84.66600256591796, "test_acc5": 97.36400138305665, "epoch": 19, "n_parameters": 86794216} 21 | {"train_lr": 2.5831689690786165e-05, "train_min_lr": 2.5831689690786165e-05, "train_loss": 1.7752880633955332, "train_class_acc": 0.7968812450039968, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6124859313405622, "test_acc1": 84.76600255981445, "test_acc5": 97.40800135131836, "epoch": 20, "n_parameters": 86794216} 22 | {"train_lr": 2.1218594135007993e-05, "train_min_lr": 2.1218594135007993e-05, "train_loss": 1.770902261998561, "train_class_acc": 0.7981021432853717, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6106539139188684, "test_acc1": 84.81200262145997, "test_acc5": 97.38000138061524, "epoch": 21, "n_parameters": 86794216} 23 | {"train_lr": 1.69810928357321e-05, "train_min_lr": 1.69810928357321e-05, "train_loss": 1.7643373457746183, "train_class_acc": 0.7997430180855316, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6096352562558092, "test_acc1": 84.88000256713867, "test_acc5": 97.38400138061523, "epoch": 22, "n_parameters": 86794216} 24 | {"train_lr": 1.3165612743947417e-05, "train_min_lr": 1.3165612743947417e-05, "train_loss": 1.7615033509440536, "train_class_acc": 0.8006266861510791, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6073408679915342, "test_acc1": 84.89200262695313, "test_acc5": 97.43200135620117, "epoch": 23, "n_parameters": 86794216} 25 | {"train_lr": 9.81395705798004e-06, "train_min_lr": 9.81395705798004e-06, "train_loss": 1.757708400523634, "train_class_acc": 0.8013019272082335, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.608022909829129, "test_acc1": 84.94000262390136, "test_acc5": 97.41000136230468, "epoch": 24, "n_parameters": 86794216} 26 | {"train_lr": 6.962847218904466e-06, "train_min_lr": 6.962847218904466e-06, "train_loss": 1.7527120730532206, "train_class_acc": 0.8027710643984812, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6065240236306099, "test_acc1": 84.92000258117676, "test_acc5": 97.44800136108398, "epoch": 25, "n_parameters": 86794216} 27 | {"train_lr": 4.643520582750963e-06, "train_min_lr": 4.643520582750963e-06, "train_loss": 1.750563392784896, "train_class_acc": 0.8033440435151878, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6067701281401617, "test_acc1": 84.99000257385254, "test_acc5": 97.44600135864258, "epoch": 26, "n_parameters": 86794216} 28 | {"train_lr": 2.8813881774952494e-06, "train_min_lr": 2.8813881774952494e-06, "train_loss": 1.7516952332535045, "train_class_acc": 0.8030598958333334, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6067123004089028, "test_acc1": 84.98800256896973, "test_acc5": 97.43000136108398, "epoch": 27, "n_parameters": 86794216} 29 | {"train_lr": 1.6957562945193597e-06, "train_min_lr": 1.6957562945193597e-06, "train_loss": 1.7505262671343738, "train_class_acc": 0.8033760491606715, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6069099562442452, "test_acc1": 84.98800257080079, "test_acc5": 97.43000136108398, "epoch": 28, "n_parameters": 86794216} 30 | {"train_lr": 1.0996149648425372e-06, "train_min_lr": 1.0996149648425372e-06, "train_loss": 1.7485280318288423, "train_class_acc": 0.8038975069944044, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6071232186801646, "test_acc1": 84.98400258239747, "test_acc5": 97.4180013684082, "epoch": 29, "n_parameters": 86794216} 31 | -------------------------------------------------------------------------------- /logs/illama-large-in21kin1k-224-84.8.txt: -------------------------------------------------------------------------------- 1 | {"train_lr": 5.99461445631214e-05, "train_min_lr": 5.99461445631214e-05, "train_loss": 2.7198627313895285, "train_class_acc": 0.6766891174560352, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6804418705816146, "test_acc1": 82.96000253662109, "test_acc5": 96.45800264160157, "epoch": 0, "n_parameters": 310445032} 2 | {"train_lr": 5.9623591031248656e-05, "train_min_lr": 5.9623591031248656e-05, "train_loss": 1.9470356172538823, "train_class_acc": 0.7654407723820943, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6517717496150721, "test_acc1": 83.69000252197266, "test_acc5": 96.84800266845703, "epoch": 1, "n_parameters": 310445032} 3 | {"train_lr": 5.898195334153406e-05, "train_min_lr": 5.898195334153406e-05, "train_loss": 1.8859988836188921, "train_class_acc": 0.7761657861211031, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.644474639695934, "test_acc1": 83.92600262451172, "test_acc5": 96.99000259765624, "epoch": 2, "n_parameters": 310445032} 4 | {"train_lr": 5.802826141077773e-05, "train_min_lr": 5.802826141077773e-05, "train_loss": 1.848835029488178, "train_class_acc": 0.7838791466826539, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6354520462439147, "test_acc1": 84.10400263671875, "test_acc5": 97.09200262695313, "epoch": 3, "n_parameters": 310445032} 5 | {"train_lr": 5.6772964087346344e-05, "train_min_lr": 5.6772964087346344e-05, "train_loss": 1.8186543585239745, "train_class_acc": 0.7905823778477218, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6252702991869158, "test_acc1": 84.28800256591796, "test_acc5": 97.17200262451172, "epoch": 4, "n_parameters": 310445032} 6 | {"train_lr": 5.522981467140261e-05, "train_min_lr": 5.522981467140261e-05, "train_loss": 1.7934198190125343, "train_class_acc": 0.7961583857913669, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6280702026916369, "test_acc1": 84.43800261230469, "test_acc5": 97.20400264160156, "epoch": 5, "n_parameters": 310445032} 7 | {"train_lr": 5.34157202308725e-05, "train_min_lr": 5.34157202308725e-05, "train_loss": 1.7750700143107074, "train_class_acc": 0.8004088916366906, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.626661965431546, "test_acc1": 84.5240024584961, "test_acc5": 97.25000264648438, "epoch": 6, "n_parameters": 310445032} 8 | {"train_lr": 5.135055636407023e-05, "train_min_lr": 5.135055636407023e-05, "train_loss": 1.7532443771009252, "train_class_acc": 0.8054899830135891, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6199031330730397, "test_acc1": 84.57000253662109, "test_acc5": 97.33400265136719, "epoch": 7, "n_parameters": 310445032} 9 | {"train_lr": 4.905694943848367e-05, "train_min_lr": 4.905694943848367e-05, "train_loss": 1.7408968822585402, "train_class_acc": 0.8087194307054356, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.62204097242642, "test_acc1": 84.67200252197266, "test_acc5": 97.26200261230468, "epoch": 8, "n_parameters": 310445032} 10 | {"train_lr": 4.656002869155787e-05, "train_min_lr": 4.656002869155787e-05, "train_loss": 1.7256579877798268, "train_class_acc": 0.8119886902977618, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6242762830620276, "test_acc1": 84.61400246337891, "test_acc5": 97.28200252685546, "epoch": 9, "n_parameters": 310445032} 11 | {"train_lr": 4.38871509095102e-05, "train_min_lr": 4.38871509095102e-05, "train_loss": 1.7129281256064999, "train_class_acc": 0.8155319494404476, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6209116319393497, "test_acc1": 84.7500024584961, "test_acc5": 97.33800256347656, "epoch": 10, "n_parameters": 310445032} 12 | {"train_lr": 4.1067600700656485e-05, "train_min_lr": 4.1067600700656485e-05, "train_loss": 1.6997128057864597, "train_class_acc": 0.8191392198741008, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6250483979620324, "test_acc1": 84.59800246826173, "test_acc5": 97.28200264648437, "epoch": 11, "n_parameters": 310445032} 13 | {"train_lr": 3.813226964711385e-05, "train_min_lr": 3.813226964711385e-05, "train_loss": 1.6904406517053203, "train_class_acc": 0.8210681454836131, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6233285891446676, "test_acc1": 84.73200243408203, "test_acc5": 97.26800258789062, "epoch": 12, "n_parameters": 310445032} 14 | {"train_lr": 3.511331785016224e-05, "train_min_lr": 3.511331785016224e-05, "train_loss": 1.6803397394651227, "train_class_acc": 0.8235848820943246, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6287899412897252, "test_acc1": 84.66200260009765, "test_acc5": 97.27600256835937, "epoch": 13, "n_parameters": 310445032} 15 | {"train_lr": 3.2043821577445613e-05, "train_min_lr": 3.2043821577445613e-05, "train_loss": 1.672017324668207, "train_class_acc": 0.8262663306854516, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6278802271832833, "test_acc1": 84.64600250732421, "test_acc5": 97.3100025805664, "epoch": 14, "n_parameters": 310445032} 16 | {"train_lr": 2.8957410872461055e-05, "train_min_lr": 2.8957410872461055e-05, "train_loss": 1.66248219554349, "train_class_acc": 0.8286121103117506, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6275539799444093, "test_acc1": 84.6720025366211, "test_acc5": 97.27400262207031, "epoch": 15, "n_parameters": 310445032} 17 | {"train_lr": 2.5887901096765108e-05, "train_min_lr": 2.5887901096765108e-05, "train_loss": 1.6523609593552318, "train_class_acc": 0.8313396158073542, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6307366626081234, "test_acc1": 84.72800237548829, "test_acc5": 97.29800254394532, "epoch": 16, "n_parameters": 310445032} 18 | {"train_lr": 2.2868922441796946e-05, "train_min_lr": 2.2868922441796946e-05, "train_loss": 1.6468906652965516, "train_class_acc": 0.8327057104316546, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6287269541951082, "test_acc1": 84.78000256347656, "test_acc5": 97.2840025390625, "epoch": 17, "n_parameters": 310445032} 19 | {"train_lr": 1.9933551469461998e-05, "train_min_lr": 1.9933551469461998e-05, "train_loss": 1.6423381217610207, "train_class_acc": 0.834031212529976, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6298858197619215, "test_acc1": 84.80200260498047, "test_acc5": 97.32000254882813, "epoch": 18, "n_parameters": 310445032} 20 | {"train_lr": 1.7113948718399234e-05, "train_min_lr": 1.7113948718399234e-05, "train_loss": 1.6362359538751778, "train_class_acc": 0.8353895008992805, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6317386629045693, "test_acc1": 84.75400255859375, "test_acc5": 97.27000256835937, "epoch": 19, "n_parameters": 310445032} 21 | {"train_lr": 1.4441006346388868e-05, "train_min_lr": 1.4441006346388868e-05, "train_loss": 1.6317886930557606, "train_class_acc": 0.8367150029976019, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.631641619341121, "test_acc1": 84.77200250244141, "test_acc5": 97.27600258300781, "epoch": 20, "n_parameters": 310445032} 22 | {"train_lr": 1.194400966940797e-05, "train_min_lr": 1.194400966940797e-05, "train_loss": 1.6257266758186975, "train_class_acc": 0.8386322192246203, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6315950325811068, "test_acc1": 84.7240025366211, "test_acc5": 97.31400260498047, "epoch": 21, "n_parameters": 310445032} 23 | {"train_lr": 9.650316305579721e-06, "train_min_lr": 9.650316305579721e-06, "train_loss": 1.6226078195424436, "train_class_acc": 0.8390529763689049, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6312967267896952, "test_acc1": 84.78600256591797, "test_acc5": 97.27000255371094, "epoch": 22, "n_parameters": 310445032} 24 | {"train_lr": 7.585056439384332e-06, "train_min_lr": 7.585056439384332e-06, "train_loss": 1.6201602544495575, "train_class_acc": 0.8399928494704236, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.631438076019799, "test_acc1": 84.79600241455078, "test_acc5": 97.27800255615234, "epoch": 23, "n_parameters": 310445032} 25 | {"train_lr": 5.770857490099259e-06, "train_min_lr": 5.770857490099259e-06, "train_loss": 1.618003329766883, "train_class_acc": 0.8402816809052758, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6306168148649558, "test_acc1": 84.76200250732421, "test_acc5": 97.29400259765625, "epoch": 24, "n_parameters": 310445032} 26 | {"train_lr": 4.227596201058397e-06, "train_min_lr": 4.227596201058397e-06, "train_loss": 1.6154780858676496, "train_class_acc": 0.8410123463729017, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.632067259187803, "test_acc1": 84.79600251708985, "test_acc5": 97.28800254638672, "epoch": 25, "n_parameters": 310445032} 27 | {"train_lr": 2.9721808658927067e-06, "train_min_lr": 2.9721808658927067e-06, "train_loss": 1.6140473674830915, "train_class_acc": 0.8416587042865707, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6312973588449127, "test_acc1": 84.83000244628906, "test_acc5": 97.27400254394531, "epoch": 26, "n_parameters": 310445032} 28 | {"train_lr": 2.0183660777267993e-06, "train_min_lr": 2.0183660777267993e-06, "train_loss": 1.6156177416944104, "train_class_acc": 0.8415580035971223, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6316051020763302, "test_acc1": 84.79000246826172, "test_acc5": 97.27800256347656, "epoch": 27, "n_parameters": 310445032} 29 | {"train_lr": 1.3766020309783612e-06, "train_min_lr": 1.3766020309783612e-06, "train_loss": 1.6134667002307854, "train_class_acc": 0.8415853254896083, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.632169377478195, "test_acc1": 84.79400240234375, "test_acc5": 97.29200255859375, "epoch": 28, "n_parameters": 310445032} 30 | {"train_lr": 1.053920026841383e-06, "train_min_lr": 1.053920026841383e-06, "train_loss": 1.6139339337841594, "train_class_acc": 0.8416836843025579, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6321293162964, "test_acc1": 84.78600248046875, "test_acc5": 97.30600255371094, "epoch": 29, "n_parameters": 310445032} 31 | -------------------------------------------------------------------------------- /logs/illama-large-in21kin1k-384-86.0.txt: -------------------------------------------------------------------------------- 1 | {"train_lr": 3.496896466349359e-05, "train_min_lr": 3.496896466349359e-05, "train_loss": 3.4040759515127452, "train_class_acc": 0.5474323666067147, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.7216167944301476, "test_acc1": 82.15600245239258, "test_acc5": 95.91400194091797, "epoch": 0, "n_parameters": 310834152} 2 | {"train_lr": 3.4783086356990544e-05, "train_min_lr": 3.4783086356990544e-05, "train_loss": 2.052658562197817, "train_class_acc": 0.7441101806055156, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.6449875789212769, "test_acc1": 83.74600245239257, "test_acc5": 96.87000160522462, "epoch": 1, "n_parameters": 310834152} 3 | {"train_lr": 3.441332904427394e-05, "train_min_lr": 3.441332904427394e-05, "train_loss": 1.9515358309367958, "train_class_acc": 0.7633596248001598, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.6230926276809069, "test_acc1": 84.35800254638671, "test_acc5": 97.18800145507812, "epoch": 2, "n_parameters": 310834152} 4 | {"train_lr": 3.386374386383809e-05, "train_min_lr": 3.386374386383809e-05, "train_loss": 1.897901961399759, "train_class_acc": 0.7735717675859313, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.6125124040170362, "test_acc1": 84.66400259338378, "test_acc5": 97.27800142578126, "epoch": 3, "n_parameters": 310834152} 5 | {"train_lr": 3.314035218592863e-05, "train_min_lr": 3.314035218592863e-05, "train_loss": 1.8600160417219194, "train_class_acc": 0.7820524830135891, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5977429341974443, "test_acc1": 85.0540025817871, "test_acc5": 97.46000132568359, "epoch": 4, "n_parameters": 310834152} 6 | {"train_lr": 3.225107964114749e-05, "train_min_lr": 3.225107964114749e-05, "train_loss": 1.830572905389668, "train_class_acc": 0.7881803931854516, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5893618462757132, "test_acc1": 85.29800263305664, "test_acc5": 97.53000135253906, "epoch": 5, "n_parameters": 310834152} 7 | {"train_lr": 3.1205669285587655e-05, "train_min_lr": 3.1205669285587655e-05, "train_loss": 1.8076860456690966, "train_class_acc": 0.7934660084432454, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5866689494850249, "test_acc1": 85.39600262695312, "test_acc5": 97.58600130737305, "epoch": 6, "n_parameters": 310834152} 8 | {"train_lr": 3.0015574853870956e-05, "train_min_lr": 3.0015574853870956e-05, "train_loss": 1.7887820138318313, "train_class_acc": 0.7975385316746603, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5859094357039508, "test_acc1": 85.33000260742187, "test_acc5": 97.65400125488281, "epoch": 7, "n_parameters": 310834152} 9 | {"train_lr": 2.8693835269634953e-05, "train_min_lr": 2.8693835269634953e-05, "train_loss": 1.7704070983923597, "train_class_acc": 0.801730490607514, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5785367794234977, "test_acc1": 85.49200261840821, "test_acc5": 97.74800122314453, "epoch": 8, "n_parameters": 310834152} 10 | {"train_lr": 2.7254931788355584e-05, "train_min_lr": 2.7254931788355584e-05, "train_loss": 1.7566198213152486, "train_class_acc": 0.8048787532474021, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5768480996861769, "test_acc1": 85.70000265258788, "test_acc5": 97.69600127685547, "epoch": 9, "n_parameters": 310834152} 11 | {"train_lr": 2.5714629337683663e-05, "train_min_lr": 2.5714629337683663e-05, "train_loss": 1.7435758125980934, "train_class_acc": 0.8079981327438049, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5750291858268081, "test_acc1": 85.58600266601563, "test_acc5": 97.75000124389648, "epoch": 10, "n_parameters": 310834152} 12 | {"train_lr": 2.4089803793598676e-05, "train_min_lr": 2.4089803793598676e-05, "train_loss": 1.7315405576774996, "train_class_acc": 0.8110628684552358, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5750140335475062, "test_acc1": 85.6760026525879, "test_acc5": 97.78000123291015, "epoch": 11, "n_parameters": 310834152} 13 | {"train_lr": 2.239825708477748e-05, "train_min_lr": 2.239825708477748e-05, "train_loss": 1.7212515797236745, "train_class_acc": 0.8138005220823341, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5737644145800465, "test_acc1": 85.70400269042969, "test_acc5": 97.79000122314453, "epoch": 12, "n_parameters": 310834152} 14 | {"train_lr": 2.0658522150940823e-05, "train_min_lr": 2.0658522150940823e-05, "train_loss": 1.7101750767258979, "train_class_acc": 0.8160253047561951, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5718247301669585, "test_acc1": 85.7860026928711, "test_acc5": 97.8000012097168, "epoch": 13, "n_parameters": 310834152} 15 | {"train_lr": 1.8889659892087196e-05, "train_min_lr": 1.8889659892087196e-05, "train_loss": 1.7035075478094945, "train_class_acc": 0.8176162195243805, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5718602961343195, "test_acc1": 85.9500026171875, "test_acc5": 97.8180012133789, "epoch": 14, "n_parameters": 310834152} 16 | {"train_lr": 1.7111050333282612e-05, "train_min_lr": 1.7111050333282612e-05, "train_loss": 1.6949291081650915, "train_class_acc": 0.8199417028876899, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5712083624390407, "test_acc1": 85.84400260742187, "test_acc5": 97.82400120239258, "epoch": 15, "n_parameters": 310834152} 17 | {"train_lr": 1.534218029305105e-05, "train_min_lr": 1.534218029305105e-05, "train_loss": 1.6870110205409647, "train_class_acc": 0.8220735911270983, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.570508421045439, "test_acc1": 85.84200260864257, "test_acc5": 97.85600119018555, "epoch": 16, "n_parameters": 310834152} 18 | {"train_lr": 1.3602429881713486e-05, "train_min_lr": 1.3602429881713486e-05, "train_loss": 1.6825196964221059, "train_class_acc": 0.82322267186251, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5710876270579138, "test_acc1": 85.90600260253906, "test_acc5": 97.82000119018555, "epoch": 17, "n_parameters": 310834152} 19 | {"train_lr": 1.1910860168842462e-05, "train_min_lr": 1.1910860168842462e-05, "train_loss": 1.6754372675912819, "train_class_acc": 0.8250719736710631, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5708488073499723, "test_acc1": 85.8620026220703, "test_acc5": 97.81800118774414, "epoch": 18, "n_parameters": 310834152} 20 | {"train_lr": 1.0286004346196181e-05, "train_min_lr": 1.0286004346196181e-05, "train_loss": 1.6718671010906914, "train_class_acc": 0.8259642286171063, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5703325133941982, "test_acc1": 85.83000261474609, "test_acc5": 97.82800118896485, "epoch": 19, "n_parameters": 310834152} 21 | {"train_lr": 8.745664674190202e-06, "train_min_lr": 8.745664674190202e-06, "train_loss": 1.667895621664042, "train_class_acc": 0.8266925522082335, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5726590922309293, "test_acc1": 85.74800260009765, "test_acc5": 97.81800119018554, "epoch": 20, "n_parameters": 310834152} 22 | {"train_lr": 7.306717436607989e-06, "train_min_lr": 7.306717436607989e-06, "train_loss": 1.663098725781357, "train_class_acc": 0.8280797237210232, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5705665085086478, "test_acc1": 85.86000262939453, "test_acc5": 97.84200119140625, "epoch": 21, "n_parameters": 310834152} 23 | {"train_lr": 5.984928040503593e-06, "train_min_lr": 5.984928040503593e-06, "train_loss": 1.6612017263682912, "train_class_acc": 0.8286690959732215, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5713938242859311, "test_acc1": 85.8300025793457, "test_acc5": 97.79600120483398, "epoch": 22, "n_parameters": 310834152} 24 | {"train_lr": 4.794778287102872e-06, "train_min_lr": 4.794778287102872e-06, "train_loss": 1.657341641054558, "train_class_acc": 0.8294949977517986, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.570363612395936, "test_acc1": 85.92000260131836, "test_acc5": 97.81800120117188, "epoch": 23, "n_parameters": 310834152} 25 | {"train_lr": 3.7493077061589098e-06, "train_min_lr": 3.7493077061589098e-06, "train_loss": 1.6565849094641962, "train_class_acc": 0.8301234012789768, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5699021836521525, "test_acc1": 85.85400256835938, "test_acc5": 97.83400119506835, "epoch": 24, "n_parameters": 310834152} 26 | {"train_lr": 2.8599706921353547e-06, "train_min_lr": 2.8599706921353547e-06, "train_loss": 1.6534634246502062, "train_class_acc": 0.8304403352318146, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5695552553810265, "test_acc1": 85.8520026208496, "test_acc5": 97.84200118652343, "epoch": 25, "n_parameters": 310834152} 27 | {"train_lr": 2.1365110074636167e-06, "train_min_lr": 2.1365110074636167e-06, "train_loss": 1.652146488183932, "train_class_acc": 0.8312748238908872, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5698677959719894, "test_acc1": 85.84400259277344, "test_acc5": 97.85800118774414, "epoch": 26, "n_parameters": 310834152} 28 | {"train_lr": 1.58685502784256e-06, "train_min_lr": 1.58685502784256e-06, "train_loss": 1.6519211692472013, "train_class_acc": 0.8312717013888888, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5704294863902746, "test_acc1": 85.81200258178711, "test_acc5": 97.83800119262695, "epoch": 27, "n_parameters": 310834152} 29 | {"train_lr": 1.2170248992078712e-06, "train_min_lr": 1.2170248992078712e-06, "train_loss": 1.6494522876918412, "train_class_acc": 0.831925865557554, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5699063584572454, "test_acc1": 85.86200257080078, "test_acc5": 97.84400118896484, "epoch": 28, "n_parameters": 310834152} 30 | {"train_lr": 1.0310725578407872e-06, "train_min_lr": 1.0310725578407872e-06, "train_loss": 1.6503264461305263, "train_class_acc": 0.8314910571542766, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5699177861563376, "test_acc1": 85.86400259155273, "test_acc5": 97.83600119018554, "epoch": 29, "n_parameters": 310834152} 31 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | import json 9 | import os 10 | 11 | from pathlib import Path 12 | 13 | from timm.data.mixup import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.utils import ModelEma 17 | from optim_factory import create_optimizer, LayerDecayValueAssigner 18 | 19 | from datasets import build_dataset 20 | from engine import train_one_epoch, evaluate 21 | 22 | from utils import NativeScalerWithGradNormCount as NativeScaler 23 | import utils 24 | from drop_scheduler import drop_scheduler 25 | 26 | import models 27 | 28 | def str2bool(v): 29 | """ 30 | Converts string to bool type; enables command line 31 | arguments in the format of '--arg1 true --arg2 false' 32 | """ 33 | if isinstance(v, bool): 34 | return v 35 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 36 | return True 37 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 38 | return False 39 | else: 40 | raise argparse.ArgumentTypeError('Boolean value expected.') 41 | 42 | def get_args_parser(): 43 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 44 | parser.add_argument('--batch_size', default=64, type=int, 45 | help='Per GPU batch size') 46 | parser.add_argument('--epochs', default=300, type=int) 47 | parser.add_argument('--update_freq', default=1, type=int, 48 | help='gradient accumulation steps') 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | parser.add_argument('--input_size', default=224, type=int, 54 | help='image input size') 55 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 56 | help="Layer scale initial values") 57 | 58 | ########################## settings specific to this project ########################## 59 | 60 | # dropout and stochastic depth drop rate; set at most one to non-zero 61 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT', 62 | help='Drop path rate (default: 0.0)') 63 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', 64 | help='Drop path rate (default: 0.0)') 65 | 66 | # early / late dropout and stochastic depth settings 67 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode') 68 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'], 69 | help='drop schedule for early dropout / s.d. only') 70 | parser.add_argument('--cutoff_epoch', type=int, default=0, 71 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts') 72 | 73 | ####################################################################################### 74 | 75 | # EMA related parameters 76 | parser.add_argument('--model_ema', type=str2bool, default=False) 77 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 78 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 79 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 80 | 81 | # Optimization parameters 82 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 83 | help='Optimizer (default: "adamw"') 84 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 85 | help='Optimizer Epsilon (default: 1e-8)') 86 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 87 | help='Optimizer Betas (default: None, use opt default)') 88 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 89 | help='Clip gradient norm (default: None, no clipping)') 90 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 91 | help='SGD momentum (default: 0.9)') 92 | parser.add_argument('--weight_decay', type=float, default=0.05, 93 | help='weight decay (default: 0.05)') 94 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 95 | weight decay. We use a cosine schedule for WD and using a larger decay by 96 | the end of training improves performance for ViTs.""") 97 | 98 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR', 99 | help='learning rate (default: 4e-3), with total batch size 4096') 100 | parser.add_argument('--layer_decay', type=float, default=1.0) 101 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 102 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 103 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N', 104 | help='epochs to warmup LR, if scheduler supports') 105 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 106 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 107 | 108 | # Augmentation parameters 109 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 110 | help='Color jitter factor (default: 0.4)') 111 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 112 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 113 | parser.add_argument('--smoothing', type=float, default=0.1, 114 | help='Label smoothing (default: 0.1)') 115 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 116 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 117 | 118 | # Evaluation parameters 119 | parser.add_argument('--crop_pct', type=float, default=None) 120 | 121 | # * Random Erase params 122 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 123 | help='Random erase prob (default: 0.25)') 124 | parser.add_argument('--remode', type=str, default='pixel', 125 | help='Random erase mode (default: "pixel")') 126 | parser.add_argument('--recount', type=int, default=1, 127 | help='Random erase count (default: 1)') 128 | parser.add_argument('--resplit', type=str2bool, default=False, 129 | help='Do not random erase first (clean) augmentation split') 130 | 131 | # * Mixup params 132 | parser.add_argument('--mixup', type=float, default=0.8, 133 | help='mixup alpha, mixup enabled if > 0.') 134 | parser.add_argument('--cutmix', type=float, default=1.0, 135 | help='cutmix alpha, cutmix enabled if > 0.') 136 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 137 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 138 | parser.add_argument('--mixup_prob', type=float, default=1.0, 139 | help='Probability of performing mixup or cutmix when either/both is enabled') 140 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 141 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 142 | parser.add_argument('--mixup_mode', type=str, default='batch', 143 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 144 | 145 | # * Finetuning params 146 | parser.add_argument('--finetune', default='', 147 | help='finetune from checkpoint') 148 | parser.add_argument('--head_init_scale', default=1.0, type=float, 149 | help='classifier head initial scale, typically adjusted in fine-tuning') 150 | parser.add_argument('--model_key', default='model|module', type=str, 151 | help='which key to load from saved state dict, usually model or model_ema') 152 | parser.add_argument('--model_prefix', default='', type=str) 153 | 154 | # Dataset parameters 155 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 156 | help='dataset path') 157 | parser.add_argument('--eval_data_path', default=None, type=str, 158 | help='dataset path for evaluation') 159 | parser.add_argument('--nb_classes', default=1000, type=int, 160 | help='number of the classification types') 161 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 162 | parser.add_argument('--data_set', default='IMNET', 163 | type=str, help='ImageNet dataset path') 164 | parser.add_argument('--output_dir', default='', 165 | help='path where to save, empty for no saving') 166 | parser.add_argument('--log_dir', default=None, 167 | help='path where to tensorboard log') 168 | parser.add_argument('--device', default='cuda', 169 | help='device to use for training / testing') 170 | parser.add_argument('--seed', default=0, type=int) 171 | 172 | parser.add_argument('--resume', default='', 173 | help='resume from checkpoint') 174 | parser.add_argument('--auto_resume', type=str2bool, default=True) 175 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 176 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 177 | parser.add_argument('--save_ckpt_num', default=3, type=int) 178 | 179 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 180 | help='start epoch') 181 | parser.add_argument('--eval', type=str2bool, default=False, 182 | help='Perform evaluation only') 183 | parser.add_argument('--dist_eval', type=str2bool, default=True, 184 | help='Enabling distributed evaluation') 185 | parser.add_argument('--disable_eval', type=str2bool, default=False, 186 | help='Disabling evaluation during training') 187 | parser.add_argument('--num_workers', default=10, type=int) 188 | parser.add_argument('--pin_mem', type=str2bool, default=True, 189 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 190 | 191 | # distributed training parameters 192 | parser.add_argument('--world_size', default=1, type=int, 193 | help='number of distributed processes') 194 | parser.add_argument('--local_rank', default=-1, type=int) 195 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 196 | parser.add_argument('--dist_url', default='env://', 197 | help='url used to set up distributed training') 198 | 199 | parser.add_argument('--use_amp', type=str2bool, default=False, 200 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 201 | 202 | # Weights and Biases arguments 203 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 204 | help="enable logging to Weights and Biases") 205 | parser.add_argument('--project', default='convnext', type=str, 206 | help="The name of the W&B project where you're sending the new run.") 207 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 208 | help="Save model checkpoints as W&B Artifacts.") 209 | 210 | return parser 211 | 212 | def main(args): 213 | utils.init_distributed_mode(args) 214 | print(args) 215 | device = torch.device(args.device) 216 | 217 | # fix the seed for reproducibility 218 | seed = args.seed + utils.get_rank() 219 | torch.manual_seed(seed) 220 | np.random.seed(seed) 221 | cudnn.benchmark = True 222 | 223 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 224 | if args.disable_eval: 225 | args.dist_eval = False 226 | dataset_val = None 227 | else: 228 | dataset_val, _ = build_dataset(is_train=False, args=args) 229 | 230 | num_tasks = utils.get_world_size() 231 | global_rank = utils.get_rank() 232 | 233 | sampler_train = torch.utils.data.DistributedSampler( 234 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 235 | ) 236 | print("Sampler_train = %s" % str(sampler_train)) 237 | if args.dist_eval: 238 | if len(dataset_val) % num_tasks != 0: 239 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 240 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 241 | 'equal num of samples per-process.') 242 | sampler_val = torch.utils.data.DistributedSampler( 243 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 244 | else: 245 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 246 | 247 | if global_rank == 0 and args.log_dir is not None: 248 | os.makedirs(args.log_dir, exist_ok=True) 249 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 250 | else: 251 | log_writer = None 252 | 253 | if global_rank == 0 and args.enable_wandb: 254 | wandb_logger = utils.WandbLogger(args) 255 | else: 256 | wandb_logger = None 257 | 258 | data_loader_train = torch.utils.data.DataLoader( 259 | dataset_train, sampler=sampler_train, 260 | batch_size=args.batch_size, 261 | num_workers=args.num_workers, 262 | pin_memory=args.pin_mem, 263 | drop_last=True, 264 | ) 265 | 266 | if dataset_val is not None: 267 | data_loader_val = torch.utils.data.DataLoader( 268 | dataset_val, sampler=sampler_val, 269 | batch_size=int(1.5 * args.batch_size), 270 | num_workers=args.num_workers, 271 | pin_memory=args.pin_mem, 272 | drop_last=False 273 | ) 274 | else: 275 | data_loader_val = None 276 | 277 | mixup_fn = None 278 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 279 | if mixup_active: 280 | print("Mixup is activated!") 281 | mixup_fn = Mixup( 282 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 283 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 284 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 285 | 286 | model = utils.build_model(args) 287 | if args.finetune: 288 | if args.finetune.startswith('https'): 289 | checkpoint = torch.hub.load_state_dict_from_url( 290 | args.finetune, map_location='cpu', check_hash=True) 291 | else: 292 | checkpoint = torch.load(args.finetune, map_location='cpu') 293 | 294 | print("Load ckpt from %s" % args.finetune) 295 | checkpoint_model = None 296 | for model_key in args.model_key.split('|'): 297 | if model_key in checkpoint: 298 | checkpoint_model = checkpoint[model_key] 299 | print("Load state_dict by model_key = %s" % model_key) 300 | break 301 | if checkpoint_model is None: 302 | checkpoint_model = checkpoint 303 | state_dict = model.state_dict() 304 | for k in ['head.weight', 'head.bias']: 305 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 306 | print(f"Removing key {k} from pretrained checkpoint") 307 | del checkpoint_model[k] 308 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 309 | model.to(device) 310 | 311 | model_ema = None 312 | if args.model_ema: 313 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 314 | model_ema = ModelEma( 315 | model, 316 | decay=args.model_ema_decay, 317 | device='cpu' if args.model_ema_force_cpu else '', 318 | resume='') 319 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 320 | 321 | model_without_ddp = model 322 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 323 | 324 | print("Model = %s" % str(model_without_ddp)) 325 | print('number of params:', n_parameters) 326 | 327 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 328 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 329 | print("LR = %.8f" % args.lr) 330 | print("Batch size = %d" % total_batch_size) 331 | print("Update frequent = %d" % args.update_freq) 332 | print("Number of training examples = %d" % len(dataset_train)) 333 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 334 | 335 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 336 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value. 337 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \ 338 | "Layer Decay impl only supports convnext_small/base/large/xlarge" 339 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 340 | else: 341 | assigner = None 342 | 343 | if assigner is not None: 344 | print("Assigned values = %s" % str(assigner.values)) 345 | 346 | if args.distributed: 347 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 348 | model_without_ddp = model.module 349 | 350 | optimizer = create_optimizer( 351 | args, model_without_ddp, skip_list=None, 352 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 353 | get_layer_scale=assigner.get_scale if assigner is not None else None) 354 | 355 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 356 | 357 | print("Use Cosine LR scheduler") 358 | lr_schedule_values = utils.cosine_scheduler( 359 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 360 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 361 | ) 362 | 363 | if args.weight_decay_end is None: 364 | args.weight_decay_end = args.weight_decay 365 | wd_schedule_values = utils.cosine_scheduler( 366 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 367 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 368 | 369 | schedules = {} 370 | 371 | # At most one of dropout and stochastic depth should be enabled. 372 | assert(args.dropout == 0 or args.drop_path == 0) 373 | # ConvNeXt does not support dropout. 374 | assert(args.dropout == 0 if args.model.startswith("convnext") else True) 375 | 376 | if args.dropout > 0: 377 | schedules['do'] = drop_scheduler( 378 | args.dropout, args.epochs, num_training_steps_per_epoch, 379 | args.cutoff_epoch, args.drop_mode, args.drop_schedule) 380 | print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do']))) 381 | 382 | if args.drop_path > 0: 383 | schedules['dp'] = drop_scheduler( 384 | args.drop_path, args.epochs, num_training_steps_per_epoch, 385 | args.cutoff_epoch, args.drop_mode, args.drop_schedule) 386 | print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp']))) 387 | 388 | if mixup_fn is not None: 389 | # smoothing is handled with mixup label transform 390 | criterion = SoftTargetCrossEntropy() 391 | elif args.smoothing > 0.: 392 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 393 | else: 394 | criterion = torch.nn.CrossEntropyLoss() 395 | 396 | print("criterion = %s" % str(criterion)) 397 | 398 | utils.auto_load_model( 399 | args=args, model=model, model_without_ddp=model_without_ddp, 400 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 401 | 402 | if args.eval: 403 | print(f"Eval only mode") 404 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 405 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 406 | return 407 | 408 | max_accuracy = 0.0 409 | if args.model_ema and args.model_ema_eval: 410 | max_accuracy_ema = 0.0 411 | 412 | print("Start training for %d epochs" % args.epochs) 413 | start_time = time.time() 414 | for epoch in range(args.start_epoch, args.epochs): 415 | if args.distributed: 416 | data_loader_train.sampler.set_epoch(epoch) 417 | if log_writer is not None: 418 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 419 | if wandb_logger: 420 | wandb_logger.set_steps() 421 | train_stats = train_one_epoch( 422 | model, criterion, data_loader_train, optimizer, 423 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 424 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 425 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, schedules=schedules, 426 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 427 | use_amp=args.use_amp 428 | ) 429 | if args.output_dir and args.save_ckpt: 430 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 431 | utils.save_model( 432 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 433 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 434 | if data_loader_val is not None: 435 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 436 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 437 | if max_accuracy < test_stats["acc1"]: 438 | max_accuracy = test_stats["acc1"] 439 | if args.output_dir and args.save_ckpt: 440 | utils.save_model( 441 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 442 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 443 | print(f'Max accuracy: {max_accuracy:.2f}%') 444 | 445 | if log_writer is not None: 446 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 447 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 448 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 449 | 450 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 451 | **{f'test_{k}': v for k, v in test_stats.items()}, 452 | 'epoch': epoch, 453 | 'n_parameters': n_parameters} 454 | 455 | # repeat testing routines for EMA, if ema eval is turned on 456 | if args.model_ema and args.model_ema_eval: 457 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 458 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 459 | if max_accuracy_ema < test_stats_ema["acc1"]: 460 | max_accuracy_ema = test_stats_ema["acc1"] 461 | if args.output_dir and args.save_ckpt: 462 | utils.save_model( 463 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 464 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 465 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 466 | if log_writer is not None: 467 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 468 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 469 | else: 470 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 471 | 'epoch': epoch, 472 | 'n_parameters': n_parameters} 473 | 474 | if args.output_dir and utils.is_main_process(): 475 | if log_writer is not None: 476 | log_writer.flush() 477 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 478 | f.write(json.dumps(log_stats) + "\n") 479 | 480 | if wandb_logger: 481 | wandb_logger.log_epoch_metrics(log_stats) 482 | 483 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 484 | wandb_logger.log_checkpoints() 485 | 486 | 487 | total_time = time.time() - start_time 488 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 489 | print('Training time {}'.format(total_time_str)) 490 | 491 | if __name__ == '__main__': 492 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()]) 493 | args = parser.parse_args() 494 | if args.output_dir: 495 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 496 | main(args) 497 | -------------------------------------------------------------------------------- /main_soft.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | import json 9 | import os 10 | 11 | from pathlib import Path 12 | 13 | from timm.data.mixup import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.utils import ModelEma 17 | from optim_factory import create_optimizer, LayerDecayValueAssigner 18 | 19 | from datasets import build_dataset 20 | from engine_soft import train_one_epoch, evaluate 21 | 22 | from utils import NativeScalerWithGradNormCount as NativeScaler 23 | import utils 24 | from drop_scheduler import drop_scheduler 25 | from mask_scheduler import mask_scheduler 26 | 27 | import models 28 | 29 | def str2bool(v): 30 | """ 31 | Converts string to bool type; enables command line 32 | arguments in the format of '--arg1 true --arg2 false' 33 | """ 34 | if isinstance(v, bool): 35 | return v 36 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 37 | return True 38 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 39 | return False 40 | else: 41 | raise argparse.ArgumentTypeError('Boolean value expected.') 42 | 43 | def get_args_parser(): 44 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 45 | parser.add_argument('--batch_size', default=64, type=int, 46 | help='Per GPU batch size') 47 | parser.add_argument('--epochs', default=300, type=int) 48 | parser.add_argument('--update_freq', default=1, type=int, 49 | help='gradient accumulation steps') 50 | 51 | # Model parameters 52 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL', 53 | help='Name of model to train') 54 | parser.add_argument('--input_size', default=224, type=int, 55 | help='image input size') 56 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 57 | help="Layer scale initial values") 58 | 59 | ########################## settings specific to this project ########################## 60 | 61 | # dropout and stochastic depth drop rate; set at most one to non-zero 62 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT', 63 | help='Drop path rate (default: 0.0)') 64 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', 65 | help='Drop path rate (default: 0.0)') 66 | 67 | # early / late dropout and stochastic depth settings 68 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode') 69 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'], 70 | help='drop schedule for early dropout / s.d. only') 71 | parser.add_argument('--cutoff_epoch', type=int, default=0, 72 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts') 73 | 74 | # soft mask settings 75 | parser.add_argument('--soft_mask', type=float, default=1.0, metavar='PCT', 76 | help='Soft mask rate (default: 1.0)') 77 | parser.add_argument('--mask_mode', type=str, default='standard', choices=['standard', 'soft'], help='mask mode') 78 | parser.add_argument('--mask_schedule', type=str, default='constant', choices=['constant', 'linear'], 79 | help='mask schedule for soft mask') 80 | parser.add_argument('--cutoff_soft', type=int, default=0, 81 | help='if mask_mode is early / late, this is the epoch where soft mask ends / starts') 82 | 83 | ####################################################################################### 84 | 85 | # EMA related parameters 86 | parser.add_argument('--model_ema', type=str2bool, default=False) 87 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 88 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 89 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 90 | 91 | # Optimization parameters 92 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 93 | help='Optimizer (default: "adamw"') 94 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 95 | help='Optimizer Epsilon (default: 1e-8)') 96 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 97 | help='Optimizer Betas (default: None, use opt default)') 98 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 99 | help='Clip gradient norm (default: None, no clipping)') 100 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 101 | help='SGD momentum (default: 0.9)') 102 | parser.add_argument('--weight_decay', type=float, default=0.05, 103 | help='weight decay (default: 0.05)') 104 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 105 | weight decay. We use a cosine schedule for WD and using a larger decay by 106 | the end of training improves performance for ViTs.""") 107 | 108 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR', 109 | help='learning rate (default: 4e-3), with total batch size 4096') 110 | parser.add_argument('--layer_decay', type=float, default=1.0) 111 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 112 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 113 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N', 114 | help='epochs to warmup LR, if scheduler supports') 115 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 116 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 117 | 118 | # Augmentation parameters 119 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 120 | help='Color jitter factor (default: 0.4)') 121 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 122 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 123 | parser.add_argument('--smoothing', type=float, default=0.1, 124 | help='Label smoothing (default: 0.1)') 125 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 126 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 127 | 128 | # Evaluation parameters 129 | parser.add_argument('--crop_pct', type=float, default=None) 130 | 131 | # * Random Erase params 132 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 133 | help='Random erase prob (default: 0.25)') 134 | parser.add_argument('--remode', type=str, default='pixel', 135 | help='Random erase mode (default: "pixel")') 136 | parser.add_argument('--recount', type=int, default=1, 137 | help='Random erase count (default: 1)') 138 | parser.add_argument('--resplit', type=str2bool, default=False, 139 | help='Do not random erase first (clean) augmentation split') 140 | 141 | # * Mixup params 142 | parser.add_argument('--mixup', type=float, default=0.8, 143 | help='mixup alpha, mixup enabled if > 0.') 144 | parser.add_argument('--cutmix', type=float, default=1.0, 145 | help='cutmix alpha, cutmix enabled if > 0.') 146 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 147 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 148 | parser.add_argument('--mixup_prob', type=float, default=1.0, 149 | help='Probability of performing mixup or cutmix when either/both is enabled') 150 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 151 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 152 | parser.add_argument('--mixup_mode', type=str, default='batch', 153 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 154 | 155 | # * Finetuning params 156 | parser.add_argument('--finetune', default='', 157 | help='finetune from checkpoint') 158 | parser.add_argument('--head_init_scale', default=1.0, type=float, 159 | help='classifier head initial scale, typically adjusted in fine-tuning') 160 | parser.add_argument('--model_key', default='model|module', type=str, 161 | help='which key to load from saved state dict, usually model or model_ema') 162 | parser.add_argument('--model_prefix', default='', type=str) 163 | 164 | # Dataset parameters 165 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 166 | help='dataset path') 167 | parser.add_argument('--eval_data_path', default=None, type=str, 168 | help='dataset path for evaluation') 169 | parser.add_argument('--nb_classes', default=1000, type=int, 170 | help='number of the classification types') 171 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 172 | parser.add_argument('--data_set', default='IMNET', 173 | type=str, help='ImageNet dataset path') 174 | parser.add_argument('--output_dir', default='', 175 | help='path where to save, empty for no saving') 176 | parser.add_argument('--log_dir', default=None, 177 | help='path where to tensorboard log') 178 | parser.add_argument('--device', default='cuda', 179 | help='device to use for training / testing') 180 | parser.add_argument('--seed', default=0, type=int) 181 | 182 | parser.add_argument('--resume', default='', 183 | help='resume from checkpoint') 184 | parser.add_argument('--auto_resume', type=str2bool, default=True) 185 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 186 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 187 | parser.add_argument('--save_ckpt_num', default=3, type=int) 188 | 189 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 190 | help='start epoch') 191 | parser.add_argument('--eval', type=str2bool, default=False, 192 | help='Perform evaluation only') 193 | parser.add_argument('--dist_eval', type=str2bool, default=True, 194 | help='Enabling distributed evaluation') 195 | parser.add_argument('--disable_eval', type=str2bool, default=False, 196 | help='Disabling evaluation during training') 197 | parser.add_argument('--num_workers', default=10, type=int) 198 | parser.add_argument('--pin_mem', type=str2bool, default=True, 199 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 200 | 201 | # distributed training parameters 202 | parser.add_argument('--world_size', default=1, type=int, 203 | help='number of distributed processes') 204 | parser.add_argument('--local_rank', default=-1, type=int) 205 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 206 | parser.add_argument('--dist_url', default='env://', 207 | help='url used to set up distributed training') 208 | 209 | parser.add_argument('--use_amp', type=str2bool, default=False, 210 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 211 | 212 | # Weights and Biases arguments 213 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 214 | help="enable logging to Weights and Biases") 215 | parser.add_argument('--project', default='convnext', type=str, 216 | help="The name of the W&B project where you're sending the new run.") 217 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 218 | help="Save model checkpoints as W&B Artifacts.") 219 | 220 | return parser 221 | 222 | def main(args): 223 | utils.init_distributed_mode(args) 224 | print(args) 225 | device = torch.device(args.device) 226 | 227 | # fix the seed for reproducibility 228 | seed = args.seed + utils.get_rank() 229 | torch.manual_seed(seed) 230 | np.random.seed(seed) 231 | cudnn.benchmark = True 232 | 233 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 234 | if args.disable_eval: 235 | args.dist_eval = False 236 | dataset_val = None 237 | else: 238 | dataset_val, _ = build_dataset(is_train=False, args=args) 239 | 240 | num_tasks = utils.get_world_size() 241 | global_rank = utils.get_rank() 242 | 243 | sampler_train = torch.utils.data.DistributedSampler( 244 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 245 | ) 246 | print("Sampler_train = %s" % str(sampler_train)) 247 | if args.dist_eval: 248 | if len(dataset_val) % num_tasks != 0: 249 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 250 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 251 | 'equal num of samples per-process.') 252 | sampler_val = torch.utils.data.DistributedSampler( 253 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 254 | else: 255 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 256 | 257 | if global_rank == 0 and args.log_dir is not None: 258 | os.makedirs(args.log_dir, exist_ok=True) 259 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 260 | else: 261 | log_writer = None 262 | 263 | if global_rank == 0 and args.enable_wandb: 264 | wandb_logger = utils.WandbLogger(args) 265 | else: 266 | wandb_logger = None 267 | 268 | data_loader_train = torch.utils.data.DataLoader( 269 | dataset_train, sampler=sampler_train, 270 | batch_size=args.batch_size, 271 | num_workers=args.num_workers, 272 | pin_memory=args.pin_mem, 273 | drop_last=True, 274 | ) 275 | 276 | if dataset_val is not None: 277 | data_loader_val = torch.utils.data.DataLoader( 278 | dataset_val, sampler=sampler_val, 279 | batch_size=int(1.5 * args.batch_size), 280 | num_workers=args.num_workers, 281 | pin_memory=args.pin_mem, 282 | drop_last=False 283 | ) 284 | else: 285 | data_loader_val = None 286 | 287 | mixup_fn = None 288 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 289 | if mixup_active: 290 | print("Mixup is activated!") 291 | mixup_fn = Mixup( 292 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 293 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 294 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 295 | 296 | model = utils.build_model(args) 297 | if args.finetune: 298 | if args.finetune.startswith('https'): 299 | checkpoint = torch.hub.load_state_dict_from_url( 300 | args.finetune, map_location='cpu', check_hash=True) 301 | else: 302 | checkpoint = torch.load(args.finetune, map_location='cpu') 303 | 304 | print("Load ckpt from %s" % args.finetune) 305 | checkpoint_model = None 306 | for model_key in args.model_key.split('|'): 307 | if model_key in checkpoint: 308 | checkpoint_model = checkpoint[model_key] 309 | print("Load state_dict by model_key = %s" % model_key) 310 | break 311 | if checkpoint_model is None: 312 | checkpoint_model = checkpoint 313 | state_dict = model.state_dict() 314 | for k in ['head.weight', 'head.bias']: 315 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 316 | print(f"Removing key {k} from pretrained checkpoint") 317 | del checkpoint_model[k] 318 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 319 | model.to(device) 320 | 321 | model_ema = None 322 | if args.model_ema: 323 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 324 | model_ema = ModelEma( 325 | model, 326 | decay=args.model_ema_decay, 327 | device='cpu' if args.model_ema_force_cpu else '', 328 | resume='') 329 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 330 | 331 | model_without_ddp = model 332 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 333 | 334 | print("Model = %s" % str(model_without_ddp)) 335 | print('number of params:', n_parameters) 336 | 337 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 338 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 339 | print("LR = %.8f" % args.lr) 340 | print("Batch size = %d" % total_batch_size) 341 | print("Update frequent = %d" % args.update_freq) 342 | print("Number of training examples = %d" % len(dataset_train)) 343 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 344 | 345 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 346 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value. 347 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \ 348 | "Layer Decay impl only supports convnext_small/base/large/xlarge" 349 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 350 | else: 351 | assigner = None 352 | 353 | if assigner is not None: 354 | print("Assigned values = %s" % str(assigner.values)) 355 | 356 | if args.distributed: 357 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 358 | model_without_ddp = model.module 359 | 360 | optimizer = create_optimizer( 361 | args, model_without_ddp, skip_list=None, 362 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 363 | get_layer_scale=assigner.get_scale if assigner is not None else None) 364 | 365 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 366 | 367 | print("Use Cosine LR scheduler") 368 | lr_schedule_values = utils.cosine_scheduler( 369 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 370 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 371 | ) 372 | 373 | if args.weight_decay_end is None: 374 | args.weight_decay_end = args.weight_decay 375 | wd_schedule_values = utils.cosine_scheduler( 376 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 377 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 378 | 379 | schedules = {} 380 | 381 | # At most one of dropout and stochastic depth should be enabled. 382 | assert(args.dropout == 0 or args.drop_path == 0) 383 | # ConvNeXt does not support dropout. 384 | assert(args.dropout == 0 if args.model.startswith("convnext") else True) 385 | 386 | if args.dropout > 0: 387 | schedules['do'] = drop_scheduler( 388 | args.dropout, args.epochs, num_training_steps_per_epoch, 389 | args.cutoff_epoch, args.drop_mode, args.drop_schedule) 390 | print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do']))) 391 | 392 | if args.drop_path > 0: 393 | schedules['dp'] = drop_scheduler( 394 | args.drop_path, args.epochs, num_training_steps_per_epoch, 395 | args.cutoff_epoch, args.drop_mode, args.drop_schedule) 396 | print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp']))) 397 | 398 | # Mask schedule enabled. 399 | schedules['sm'] = mask_scheduler( 400 | args.soft_mask, args.epochs, num_training_steps_per_epoch, 401 | args.cutoff_soft, args.mask_mode, args.mask_schedule) 402 | print("Min MASK = %.7f, Max MASK = %.7f" % (min(schedules['sm']), max(schedules['sm']))) 403 | 404 | if mixup_fn is not None: 405 | # smoothing is handled with mixup label transform 406 | criterion = SoftTargetCrossEntropy() 407 | elif args.smoothing > 0.: 408 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 409 | else: 410 | criterion = torch.nn.CrossEntropyLoss() 411 | 412 | print("criterion = %s" % str(criterion)) 413 | 414 | utils.auto_load_model( 415 | args=args, model=model, model_without_ddp=model_without_ddp, 416 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 417 | 418 | if args.eval: 419 | print(f"Eval only mode") 420 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 421 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 422 | return 423 | 424 | max_accuracy = 0.0 425 | if args.model_ema and args.model_ema_eval: 426 | max_accuracy_ema = 0.0 427 | 428 | print("Start training for %d epochs" % args.epochs) 429 | start_time = time.time() 430 | for epoch in range(args.start_epoch, args.epochs): 431 | if args.distributed: 432 | data_loader_train.sampler.set_epoch(epoch) 433 | if log_writer is not None: 434 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 435 | if wandb_logger: 436 | wandb_logger.set_steps() 437 | train_stats = train_one_epoch( 438 | model, criterion, data_loader_train, optimizer, 439 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 440 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 441 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, schedules=schedules, 442 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 443 | use_amp=args.use_amp 444 | ) 445 | if args.output_dir and args.save_ckpt: 446 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 447 | utils.save_model( 448 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 449 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 450 | if data_loader_val is not None: 451 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 452 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 453 | if max_accuracy < test_stats["acc1"]: 454 | max_accuracy = test_stats["acc1"] 455 | if args.output_dir and args.save_ckpt: 456 | utils.save_model( 457 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 458 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 459 | print(f'Max accuracy: {max_accuracy:.2f}%') 460 | 461 | if log_writer is not None: 462 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 463 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 464 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 465 | 466 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 467 | **{f'test_{k}': v for k, v in test_stats.items()}, 468 | 'epoch': epoch, 469 | 'n_parameters': n_parameters} 470 | 471 | # repeat testing routines for EMA, if ema eval is turned on 472 | if args.model_ema and args.model_ema_eval: 473 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 474 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 475 | if max_accuracy_ema < test_stats_ema["acc1"]: 476 | max_accuracy_ema = test_stats_ema["acc1"] 477 | if args.output_dir and args.save_ckpt: 478 | utils.save_model( 479 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 480 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 481 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 482 | if log_writer is not None: 483 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 484 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 485 | else: 486 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 487 | 'epoch': epoch, 488 | 'n_parameters': n_parameters} 489 | 490 | if args.output_dir and utils.is_main_process(): 491 | if log_writer is not None: 492 | log_writer.flush() 493 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 494 | f.write(json.dumps(log_stats) + "\n") 495 | 496 | if wandb_logger: 497 | wandb_logger.log_epoch_metrics(log_stats) 498 | 499 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 500 | wandb_logger.log_checkpoints() 501 | 502 | 503 | total_time = time.time() - start_time 504 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 505 | print('Training time {}'.format(total_time_str)) 506 | 507 | if __name__ == '__main__': 508 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()]) 509 | args = parser.parse_args() 510 | if args.output_dir: 511 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 512 | main(args) 513 | -------------------------------------------------------------------------------- /main_soft_fthr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | import json 9 | import os 10 | 11 | from pathlib import Path 12 | 13 | from timm.data.mixup import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.utils import ModelEma 17 | from optim_factory import create_optimizer, LayerDecayValueAssigner 18 | 19 | from datasets import build_dataset 20 | from engine_soft import train_one_epoch, evaluate 21 | 22 | from utils import NativeScalerWithGradNormCount as NativeScaler 23 | import utils 24 | from drop_scheduler import drop_scheduler 25 | from mask_scheduler import mask_scheduler 26 | 27 | import models 28 | 29 | def str2bool(v): 30 | """ 31 | Converts string to bool type; enables command line 32 | arguments in the format of '--arg1 true --arg2 false' 33 | """ 34 | if isinstance(v, bool): 35 | return v 36 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 37 | return True 38 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 39 | return False 40 | else: 41 | raise argparse.ArgumentTypeError('Boolean value expected.') 42 | 43 | def get_args_parser(): 44 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 45 | parser.add_argument('--batch_size', default=64, type=int, 46 | help='Per GPU batch size') 47 | parser.add_argument('--epochs', default=300, type=int) 48 | parser.add_argument('--update_freq', default=1, type=int, 49 | help='gradient accumulation steps') 50 | 51 | # Model parameters 52 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL', 53 | help='Name of model to train') 54 | parser.add_argument('--input_size', default=224, type=int, 55 | help='image input size') 56 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 57 | help="Layer scale initial values") 58 | 59 | ########################## settings specific to this project ########################## 60 | 61 | # dropout and stochastic depth drop rate; set at most one to non-zero 62 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT', 63 | help='Drop path rate (default: 0.0)') 64 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', 65 | help='Drop path rate (default: 0.0)') 66 | 67 | # early / late dropout and stochastic depth settings 68 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode') 69 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'], 70 | help='drop schedule for early dropout / s.d. only') 71 | parser.add_argument('--cutoff_epoch', type=int, default=0, 72 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts') 73 | 74 | # soft mask settings 75 | parser.add_argument('--soft_mask', type=float, default=1.0, metavar='PCT', 76 | help='Soft mask rate (default: 1.0)') 77 | parser.add_argument('--mask_mode', type=str, default='standard', choices=['standard', 'soft'], help='mask mode') 78 | parser.add_argument('--mask_schedule', type=str, default='constant', choices=['constant', 'linear'], 79 | help='mask schedule for soft mask') 80 | parser.add_argument('--cutoff_soft', type=int, default=0, 81 | help='if mask_mode is early / late, this is the epoch where soft mask ends / starts') 82 | 83 | ####################################################################################### 84 | 85 | # EMA related parameters 86 | parser.add_argument('--model_ema', type=str2bool, default=False) 87 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 88 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 89 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 90 | 91 | # Optimization parameters 92 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 93 | help='Optimizer (default: "adamw"') 94 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 95 | help='Optimizer Epsilon (default: 1e-8)') 96 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 97 | help='Optimizer Betas (default: None, use opt default)') 98 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 99 | help='Clip gradient norm (default: None, no clipping)') 100 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 101 | help='SGD momentum (default: 0.9)') 102 | parser.add_argument('--weight_decay', type=float, default=0.05, 103 | help='weight decay (default: 0.05)') 104 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 105 | weight decay. We use a cosine schedule for WD and using a larger decay by 106 | the end of training improves performance for ViTs.""") 107 | 108 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR', 109 | help='learning rate (default: 4e-3), with total batch size 4096') 110 | parser.add_argument('--layer_decay', type=float, default=1.0) 111 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 112 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 113 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N', 114 | help='epochs to warmup LR, if scheduler supports') 115 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 116 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 117 | 118 | # Augmentation parameters 119 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 120 | help='Color jitter factor (default: 0.4)') 121 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 122 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 123 | parser.add_argument('--smoothing', type=float, default=0.1, 124 | help='Label smoothing (default: 0.1)') 125 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 126 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 127 | 128 | # Evaluation parameters 129 | parser.add_argument('--crop_pct', type=float, default=None) 130 | 131 | # * Random Erase params 132 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 133 | help='Random erase prob (default: 0.25)') 134 | parser.add_argument('--remode', type=str, default='pixel', 135 | help='Random erase mode (default: "pixel")') 136 | parser.add_argument('--recount', type=int, default=1, 137 | help='Random erase count (default: 1)') 138 | parser.add_argument('--resplit', type=str2bool, default=False, 139 | help='Do not random erase first (clean) augmentation split') 140 | 141 | # * Mixup params 142 | parser.add_argument('--mixup', type=float, default=0.8, 143 | help='mixup alpha, mixup enabled if > 0.') 144 | parser.add_argument('--cutmix', type=float, default=1.0, 145 | help='cutmix alpha, cutmix enabled if > 0.') 146 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 147 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 148 | parser.add_argument('--mixup_prob', type=float, default=1.0, 149 | help='Probability of performing mixup or cutmix when either/both is enabled') 150 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 151 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 152 | parser.add_argument('--mixup_mode', type=str, default='batch', 153 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 154 | 155 | # * Finetuning params 156 | parser.add_argument('--finetune', default='', 157 | help='finetune from checkpoint') 158 | parser.add_argument('--head_init_scale', default=1.0, type=float, 159 | help='classifier head initial scale, typically adjusted in fine-tuning') 160 | parser.add_argument('--model_key', default='model|module', type=str, 161 | help='which key to load from saved state dict, usually model or model_ema') 162 | parser.add_argument('--model_prefix', default='', type=str) 163 | 164 | # Dataset parameters 165 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 166 | help='dataset path') 167 | parser.add_argument('--eval_data_path', default=None, type=str, 168 | help='dataset path for evaluation') 169 | parser.add_argument('--nb_classes', default=1000, type=int, 170 | help='number of the classification types') 171 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 172 | parser.add_argument('--data_set', default='IMNET', 173 | type=str, help='ImageNet dataset path') 174 | parser.add_argument('--output_dir', default='', 175 | help='path where to save, empty for no saving') 176 | parser.add_argument('--log_dir', default=None, 177 | help='path where to tensorboard log') 178 | parser.add_argument('--device', default='cuda', 179 | help='device to use for training / testing') 180 | parser.add_argument('--seed', default=0, type=int) 181 | 182 | parser.add_argument('--resume', default='', 183 | help='resume from checkpoint') 184 | parser.add_argument('--auto_resume', type=str2bool, default=True) 185 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 186 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 187 | parser.add_argument('--save_ckpt_num', default=3, type=int) 188 | 189 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 190 | help='start epoch') 191 | parser.add_argument('--eval', type=str2bool, default=False, 192 | help='Perform evaluation only') 193 | parser.add_argument('--dist_eval', type=str2bool, default=True, 194 | help='Enabling distributed evaluation') 195 | parser.add_argument('--disable_eval', type=str2bool, default=False, 196 | help='Disabling evaluation during training') 197 | parser.add_argument('--num_workers', default=10, type=int) 198 | parser.add_argument('--pin_mem', type=str2bool, default=True, 199 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 200 | 201 | # distributed training parameters 202 | parser.add_argument('--world_size', default=1, type=int, 203 | help='number of distributed processes') 204 | parser.add_argument('--local_rank', default=-1, type=int) 205 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 206 | parser.add_argument('--dist_url', default='env://', 207 | help='url used to set up distributed training') 208 | 209 | parser.add_argument('--use_amp', type=str2bool, default=False, 210 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 211 | 212 | # Weights and Biases arguments 213 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 214 | help="enable logging to Weights and Biases") 215 | parser.add_argument('--project', default='convnext', type=str, 216 | help="The name of the W&B project where you're sending the new run.") 217 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 218 | help="Save model checkpoints as W&B Artifacts.") 219 | 220 | return parser 221 | 222 | def main(args): 223 | utils.init_distributed_mode(args) 224 | print(args) 225 | device = torch.device(args.device) 226 | 227 | # fix the seed for reproducibility 228 | seed = args.seed + utils.get_rank() 229 | torch.manual_seed(seed) 230 | np.random.seed(seed) 231 | cudnn.benchmark = True 232 | 233 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 234 | if args.disable_eval: 235 | args.dist_eval = False 236 | dataset_val = None 237 | else: 238 | dataset_val, _ = build_dataset(is_train=False, args=args) 239 | 240 | num_tasks = utils.get_world_size() 241 | global_rank = utils.get_rank() 242 | 243 | sampler_train = torch.utils.data.DistributedSampler( 244 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 245 | ) 246 | print("Sampler_train = %s" % str(sampler_train)) 247 | if args.dist_eval: 248 | if len(dataset_val) % num_tasks != 0: 249 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 250 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 251 | 'equal num of samples per-process.') 252 | sampler_val = torch.utils.data.DistributedSampler( 253 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 254 | else: 255 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 256 | 257 | if global_rank == 0 and args.log_dir is not None: 258 | os.makedirs(args.log_dir, exist_ok=True) 259 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 260 | else: 261 | log_writer = None 262 | 263 | if global_rank == 0 and args.enable_wandb: 264 | wandb_logger = utils.WandbLogger(args) 265 | else: 266 | wandb_logger = None 267 | 268 | data_loader_train = torch.utils.data.DataLoader( 269 | dataset_train, sampler=sampler_train, 270 | batch_size=args.batch_size, 271 | num_workers=args.num_workers, 272 | pin_memory=args.pin_mem, 273 | drop_last=True, 274 | ) 275 | 276 | if dataset_val is not None: 277 | data_loader_val = torch.utils.data.DataLoader( 278 | dataset_val, sampler=sampler_val, 279 | batch_size=int(1.5 * args.batch_size), 280 | num_workers=args.num_workers, 281 | pin_memory=args.pin_mem, 282 | drop_last=False 283 | ) 284 | else: 285 | data_loader_val = None 286 | 287 | mixup_fn = None 288 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 289 | if mixup_active: 290 | print("Mixup is activated!") 291 | mixup_fn = Mixup( 292 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 293 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 294 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 295 | 296 | model = utils.build_model(args) 297 | if args.finetune: 298 | if args.finetune.startswith('https'): 299 | checkpoint = torch.hub.load_state_dict_from_url( 300 | args.finetune, map_location='cpu', check_hash=True) 301 | else: 302 | checkpoint = torch.load(args.finetune, map_location='cpu') 303 | 304 | print("Load ckpt from %s" % args.finetune) 305 | checkpoint_model = None 306 | for model_key in args.model_key.split('|'): 307 | if model_key in checkpoint: 308 | checkpoint_model = checkpoint[model_key] 309 | print("Load state_dict by model_key = %s" % model_key) 310 | break 311 | if checkpoint_model is None: 312 | checkpoint_model = checkpoint 313 | state_dict = model.state_dict() 314 | for k in ['head.weight', 'head.bias']: 315 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 316 | print(f"Removing key {k} from pretrained checkpoint") 317 | del checkpoint_model[k] 318 | 319 | # I add these code, interpolate position embedding 320 | print('here') 321 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 322 | embedding_size = pos_embed_checkpoint.shape[-1] 323 | num_patches = model.patch_embed.num_patches 324 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 325 | # height (== width) for the checkpoint position embedding 326 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 327 | # height (== width) for the new position embedding 328 | new_size = int(num_patches ** 0.5) 329 | # class_token and dist_token are kept unchanged 330 | extra_tokens = pos_embed_checkpoint[:, -num_extra_tokens:] 331 | # only the position tokens are interpolated 332 | pos_tokens = pos_embed_checkpoint[:, :-num_extra_tokens] 333 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 334 | pos_tokens = torch.nn.functional.interpolate( 335 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 336 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 337 | new_pos_embed = torch.cat((pos_tokens, extra_tokens), dim=1) 338 | checkpoint_model['pos_embed'] = new_pos_embed 339 | # interpolate position embedding done 340 | 341 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 342 | model.to(device) 343 | 344 | model_ema = None 345 | if args.model_ema: 346 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 347 | model_ema = ModelEma( 348 | model, 349 | decay=args.model_ema_decay, 350 | device='cpu' if args.model_ema_force_cpu else '', 351 | resume='') 352 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 353 | 354 | model_without_ddp = model 355 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 356 | 357 | print("Model = %s" % str(model_without_ddp)) 358 | print('number of params:', n_parameters) 359 | 360 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 361 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 362 | print("LR = %.8f" % args.lr) 363 | print("Batch size = %d" % total_batch_size) 364 | print("Update frequent = %d" % args.update_freq) 365 | print("Number of training examples = %d" % len(dataset_train)) 366 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 367 | 368 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 369 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value. 370 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \ 371 | "Layer Decay impl only supports convnext_small/base/large/xlarge" 372 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 373 | else: 374 | assigner = None 375 | 376 | if assigner is not None: 377 | print("Assigned values = %s" % str(assigner.values)) 378 | 379 | if args.distributed: 380 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 381 | model_without_ddp = model.module 382 | 383 | optimizer = create_optimizer( 384 | args, model_without_ddp, skip_list=None, 385 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 386 | get_layer_scale=assigner.get_scale if assigner is not None else None) 387 | 388 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 389 | 390 | print("Use Cosine LR scheduler") 391 | lr_schedule_values = utils.cosine_scheduler( 392 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 393 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 394 | ) 395 | 396 | if args.weight_decay_end is None: 397 | args.weight_decay_end = args.weight_decay 398 | wd_schedule_values = utils.cosine_scheduler( 399 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 400 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 401 | 402 | schedules = {} 403 | 404 | # At most one of dropout and stochastic depth should be enabled. 405 | assert(args.dropout == 0 or args.drop_path == 0) 406 | # ConvNeXt does not support dropout. 407 | assert(args.dropout == 0 if args.model.startswith("convnext") else True) 408 | 409 | if args.dropout > 0: 410 | schedules['do'] = drop_scheduler( 411 | args.dropout, args.epochs, num_training_steps_per_epoch, 412 | args.cutoff_epoch, args.drop_mode, args.drop_schedule) 413 | print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do']))) 414 | 415 | if args.drop_path > 0: 416 | schedules['dp'] = drop_scheduler( 417 | args.drop_path, args.epochs, num_training_steps_per_epoch, 418 | args.cutoff_epoch, args.drop_mode, args.drop_schedule) 419 | print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp']))) 420 | 421 | # Mask schedule enabled. 422 | schedules['sm'] = mask_scheduler( 423 | args.soft_mask, args.epochs, num_training_steps_per_epoch, 424 | args.cutoff_soft, args.mask_mode, args.mask_schedule) 425 | print("Min MASK = %.7f, Max MASK = %.7f" % (min(schedules['sm']), max(schedules['sm']))) 426 | 427 | if mixup_fn is not None: 428 | # smoothing is handled with mixup label transform 429 | criterion = SoftTargetCrossEntropy() 430 | elif args.smoothing > 0.: 431 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 432 | else: 433 | criterion = torch.nn.CrossEntropyLoss() 434 | 435 | print("criterion = %s" % str(criterion)) 436 | 437 | utils.auto_load_model( 438 | args=args, model=model, model_without_ddp=model_without_ddp, 439 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 440 | 441 | if args.eval: 442 | print(f"Eval only mode") 443 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 444 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 445 | return 446 | 447 | max_accuracy = 0.0 448 | if args.model_ema and args.model_ema_eval: 449 | max_accuracy_ema = 0.0 450 | 451 | print("Start training for %d epochs" % args.epochs) 452 | start_time = time.time() 453 | for epoch in range(args.start_epoch, args.epochs): 454 | if args.distributed: 455 | data_loader_train.sampler.set_epoch(epoch) 456 | if log_writer is not None: 457 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 458 | if wandb_logger: 459 | wandb_logger.set_steps() 460 | train_stats = train_one_epoch( 461 | model, criterion, data_loader_train, optimizer, 462 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 463 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 464 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, schedules=schedules, 465 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 466 | use_amp=args.use_amp 467 | ) 468 | if args.output_dir and args.save_ckpt: 469 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 470 | utils.save_model( 471 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 472 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 473 | if data_loader_val is not None: 474 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 475 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 476 | if max_accuracy < test_stats["acc1"]: 477 | max_accuracy = test_stats["acc1"] 478 | if args.output_dir and args.save_ckpt: 479 | utils.save_model( 480 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 481 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 482 | print(f'Max accuracy: {max_accuracy:.2f}%') 483 | 484 | if log_writer is not None: 485 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 486 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 487 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 488 | 489 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 490 | **{f'test_{k}': v for k, v in test_stats.items()}, 491 | 'epoch': epoch, 492 | 'n_parameters': n_parameters} 493 | 494 | # repeat testing routines for EMA, if ema eval is turned on 495 | if args.model_ema and args.model_ema_eval: 496 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 497 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 498 | if max_accuracy_ema < test_stats_ema["acc1"]: 499 | max_accuracy_ema = test_stats_ema["acc1"] 500 | if args.output_dir and args.save_ckpt: 501 | utils.save_model( 502 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 503 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 504 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 505 | if log_writer is not None: 506 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 507 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 508 | else: 509 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 510 | 'epoch': epoch, 511 | 'n_parameters': n_parameters} 512 | 513 | if args.output_dir and utils.is_main_process(): 514 | if log_writer is not None: 515 | log_writer.flush() 516 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 517 | f.write(json.dumps(log_stats) + "\n") 518 | 519 | if wandb_logger: 520 | wandb_logger.log_epoch_metrics(log_stats) 521 | 522 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 523 | wandb_logger.log_checkpoints() 524 | 525 | 526 | total_time = time.time() - start_time 527 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 528 | print('Training time {}'.format(total_time_str)) 529 | 530 | if __name__ == '__main__': 531 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()]) 532 | args = parser.parse_args() 533 | if args.output_dir: 534 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 535 | main(args) 536 | -------------------------------------------------------------------------------- /mask_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def mask_scheduler(soft_mask_rate, epochs, niter_per_ep, cutoff_epoch=0, mode="standard", schedule="constant"): 4 | assert mode in ["standard", "soft"] 5 | if mode == "standard": 6 | return np.full(epochs * niter_per_ep, soft_mask_rate) 7 | 8 | early_iters = cutoff_epoch * niter_per_ep 9 | late_iters = (epochs - cutoff_epoch) * niter_per_ep 10 | 11 | if mode == "soft": 12 | assert schedule in ["constant", "linear"] 13 | if schedule == 'constant': 14 | early_schedule = np.full(early_iters, soft_mask_rate) 15 | elif schedule == 'linear': 16 | early_schedule = np.linspace(soft_mask_rate, 0, early_iters) 17 | final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0))) 18 | 19 | assert len(final_schedule) == epochs * niter_per_ep 20 | return final_schedule -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .illama import * -------------------------------------------------------------------------------- /models/illama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from typing import Optional, Tuple 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | 9 | 10 | class RMSNorm(torch.nn.Module): 11 | def __init__(self, dim: int, eps: float = 1e-6): 12 | super().__init__() 13 | self.eps = eps 14 | self.weight = nn.Parameter(torch.ones(dim)) 15 | 16 | def _norm(self, x): 17 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 18 | 19 | def forward(self, x): 20 | output = self._norm(x.float()).type_as(x) 21 | return output * self.weight 22 | 23 | 24 | class Mlp(nn.Module): 25 | def __init__(self, in_features, hidden_features, multiple_of=256, act_layer=nn.GELU, drop=0.): 26 | super().__init__() 27 | hidden_features = int(2 * hidden_features / 3) 28 | hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) 29 | 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=False) 31 | self.fc2 = nn.Linear(hidden_features, in_features, bias=False) 32 | self.fc3 = nn.Linear(in_features, hidden_features, bias=False) 33 | 34 | def forward(self, x): 35 | x = F.silu(self.fc1(x)) * self.fc3(x) # [B, N+1, 4D*2/3] 36 | # print(x.shape) 37 | x = self.fc2(x) 38 | return x 39 | 40 | 41 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 42 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [32], float32 43 | t = torch.arange(end, device=freqs.device) # [197], int64 44 | freqs = torch.outer(t, freqs).float() # [197, 32], float32 45 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [197, 32], complex64 46 | return freqs_cis # [197, 32], complex64 47 | 48 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 49 | ndim = x.ndim # 4, since [bsz, 197, 3, 32], complex64 50 | assert 0 <= 1 < ndim 51 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 52 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # [1, 197, 1, 32], list 53 | return freqs_cis.view(*shape) # [1, 197, 1, 32], complex64 54 | 55 | def apply_rotary_emb( 56 | xq: torch.Tensor, 57 | xk: torch.Tensor, 58 | freqs_cis: torch.Tensor, 59 | ) -> Tuple[torch.Tensor, torch.Tensor]: 60 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bsz, 197, 3, 64], float32 → [bsz, 197, 3, 32, 2], float32 → [bsz, 197, 3, 32], complex64 61 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [bsz, 197, 3, 64], float32 → [bsz, 197, 3, 32, 2], float32 → [bsz, 197, 3, 32], complex64 62 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, 197, 1, 32], complex64 63 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bsz, 197, 3, 64], float32 64 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) # [bsz, 197, 3, 64], float32 65 | return xq_out.type_as(xq), xk_out.type_as(xk) # [bsz, 197, 3, 64], float32, [bsz, 197, 3, 64], float32 66 | 67 | 68 | class Attention(nn.Module): 69 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 70 | super().__init__() 71 | self.num_heads = num_heads 72 | head_dim = dim // num_heads 73 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 74 | self.scale = qk_scale or head_dim ** -0.5 75 | 76 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 77 | # self.attn_drop = nn.Dropout(attn_drop) 78 | self.proj = nn.Linear(dim, dim) 79 | # self.proj_drop = nn.Dropout(proj_drop) 80 | 81 | def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): 82 | B, N, C = x.shape 83 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) # [3, B, N, self.num_heads, C // self.num_heads] 84 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # [B, N, self.num_heads, C // self.num_heads] 85 | 86 | q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) 87 | 88 | q = q.transpose(1, 2) # [B, self.num_heads, N, C // self.num_heads] 89 | k = k.transpose(1, 2) # [B, self.num_heads, N, C // self.num_heads] 90 | v = v.transpose(1, 2) # [B, self.num_heads, N, C // self.num_heads] 91 | attn = (q @ k.transpose(-2, -1)) * self.scale # [B, self.num_heads, N, N] 92 | attn = attn.softmax(dim=-1) 93 | if mask is not None: 94 | attn = attn * mask # (B, H, N, N) 95 | # attn = self.attn_drop(attn) 96 | 97 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 98 | x = self.proj(x) 99 | # x = self.proj_drop(x) 100 | return x 101 | 102 | 103 | class Block(nn.Module): 104 | 105 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 106 | drop_path=0., act_layer=nn.GELU, norm_layer=RMSNorm): 107 | super().__init__() 108 | self.norm1 = norm_layer(dim) 109 | self.attn = Attention( 110 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 111 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 112 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 113 | self.norm2 = norm_layer(dim) 114 | mlp_hidden_dim = int(dim * mlp_ratio) 115 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 116 | 117 | def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): 118 | x = x + self.drop_path(self.attn(self.norm1(x), freqs_cis, mask)) 119 | x = x + self.drop_path(self.mlp(self.norm2(x))) 120 | return x 121 | 122 | 123 | class PatchEmbed(nn.Module): 124 | """ Image to Patch Embedding 125 | """ 126 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 127 | super().__init__() 128 | img_size = to_2tuple(img_size) 129 | patch_size = to_2tuple(patch_size) 130 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 131 | self.img_size = img_size 132 | self.patch_size = patch_size 133 | self.num_patches = num_patches 134 | 135 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 136 | 137 | def forward(self, x): 138 | B, C, H, W = x.shape 139 | # FIXME look at relaxing size constraints 140 | assert H == self.img_size[0] and W == self.img_size[1], \ 141 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 142 | x = self.proj(x).flatten(2).transpose(1, 2) 143 | return x 144 | 145 | 146 | class HybridEmbed(nn.Module): 147 | """ CNN Feature Map Embedding 148 | Extract feature map from CNN, flatten, project to embedding dim. 149 | """ 150 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 151 | super().__init__() 152 | assert isinstance(backbone, nn.Module) 153 | img_size = to_2tuple(img_size) 154 | self.img_size = img_size 155 | self.backbone = backbone 156 | if feature_size is None: 157 | with torch.no_grad(): 158 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 159 | # map for all networks, the feature metadata has reliable channel and stride info, but using 160 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 161 | training = backbone.training 162 | if training: 163 | backbone.eval() 164 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 165 | feature_size = o.shape[-2:] 166 | feature_dim = o.shape[1] 167 | backbone.train(training) 168 | else: 169 | feature_size = to_2tuple(feature_size) 170 | feature_dim = self.backbone.feature_info.channels()[-1] 171 | self.num_patches = feature_size[0] * feature_size[1] 172 | self.proj = nn.Linear(feature_dim, embed_dim) 173 | 174 | def forward(self, x): 175 | x = self.backbone(x)[-1] 176 | x = x.flatten(2).transpose(1, 2) 177 | x = self.proj(x) 178 | return x 179 | 180 | 181 | class VisionTransformer(nn.Module): 182 | """ Vision Transformer with support for patch or hybrid CNN input stage 183 | """ 184 | 185 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 186 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 187 | drop_path_rate=0., hybrid_backbone=None, norm_layer=RMSNorm): 188 | super().__init__() 189 | self.num_classes = num_classes 190 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 191 | # I add these two lines 192 | self.drop_rate=drop_rate 193 | attn_drop_rate=drop_rate 194 | # I add these one line 195 | self.soft_mask_rate = 0.0 196 | if hybrid_backbone is not None: 197 | self.patch_embed = HybridEmbed( 198 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 199 | else: 200 | self.patch_embed = PatchEmbed( 201 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 202 | num_patches = self.patch_embed.num_patches 203 | 204 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 205 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 206 | self.pos_drop = nn.Dropout(p=drop_rate) 207 | self.depth = depth 208 | 209 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 210 | self.blocks = nn.ModuleList([ 211 | Block( 212 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 213 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 214 | for i in range(depth)]) 215 | self.norm = norm_layer(embed_dim) 216 | 217 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 218 | #self.repr = nn.Linear(embed_dim, representation_size) 219 | #self.repr_act = nn.Tanh() 220 | 221 | # Classifier head 222 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 223 | 224 | trunc_normal_(self.pos_embed, std=.02) 225 | trunc_normal_(self.cls_token, std=.02) 226 | self.apply(self._init_weights) 227 | 228 | self.freqs_cis = precompute_freqs_cis( 229 | self.num_features // num_heads, num_patches + 1 230 | ) 231 | 232 | def _init_weights(self, m): 233 | if isinstance(m, nn.Linear): 234 | trunc_normal_(m.weight, std=.02) 235 | if isinstance(m, nn.Linear) and m.bias is not None: 236 | nn.init.constant_(m.bias, 0) 237 | elif isinstance(m, nn.LayerNorm): 238 | nn.init.constant_(m.bias, 0) 239 | nn.init.constant_(m.weight, 1.0) 240 | 241 | @torch.jit.ignore 242 | def no_weight_decay(self): 243 | return {'pos_embed', 'cls_token'} 244 | 245 | def get_classifier(self): 246 | return self.head 247 | 248 | def reset_classifier(self, num_classes, global_pool=''): 249 | self.num_classes = num_classes 250 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 251 | 252 | def forward_features(self, x): 253 | B = x.shape[0] 254 | x = self.patch_embed(x) 255 | freqs_cis = self.freqs_cis.to(x.device) 256 | 257 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 258 | x = torch.cat((x, cls_tokens), dim=1) 259 | x = x + self.pos_embed 260 | x = self.pos_drop(x) 261 | N = x.shape[1] 262 | 263 | mask = None 264 | if N > 1: 265 | mask_bidirectional = torch.full((1, 1, N, N), 1, device=x.device).type_as(x) 266 | mask_casual = torch.full((1, 1, N, N), 0, device=x.device) 267 | triu_mask = torch.tril(torch.ones((N, N), dtype=torch.bool)) 268 | mask_casual[:, :, triu_mask] = 1 269 | mask_casual = mask_casual.type_as(x) 270 | # soft_mask_rate: from one to zero 271 | mask = self.soft_mask_rate * mask_bidirectional + (1 - self.soft_mask_rate) * mask_casual 272 | 273 | for blk in self.blocks: 274 | x = blk(x, freqs_cis, mask) 275 | 276 | x = self.norm(x) 277 | return x[:, -1] 278 | 279 | def forward(self, x): 280 | x = self.forward_features(x) 281 | x = self.head(x) 282 | return x 283 | 284 | def update_drop_path(self, drop_path_rate): 285 | self.drop_path = drop_path_rate 286 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, self.depth)] 287 | for i in range(self.depth): 288 | self.blocks[i].drop_path.drop_prob = dp_rates[i] 289 | 290 | def update_dropout(self, drop_rate): 291 | self.drop_rate = drop_rate 292 | for module in self.modules(): 293 | if isinstance(module, nn.Dropout): 294 | module.p = drop_rate 295 | 296 | def update_soft_mask(self, soft_mask_rate): 297 | self.soft_mask_rate = soft_mask_rate 298 | 299 | 300 | @register_model 301 | def illama_tiny(pretrained=False, **kwargs): 302 | model = VisionTransformer( 303 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 304 | norm_layer=partial(RMSNorm, eps=1e-6), **kwargs) 305 | return model 306 | 307 | @register_model 308 | def illama_small(pretrained=False, **kwargs): 309 | model = VisionTransformer( 310 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 311 | norm_layer=partial(RMSNorm, eps=1e-6), **kwargs) 312 | return model 313 | 314 | @register_model 315 | def illama_base(pretrained=False, **kwargs): 316 | model = VisionTransformer( 317 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 318 | norm_layer=partial(RMSNorm, eps=1e-6), **kwargs) 319 | return model 320 | 321 | @register_model 322 | def illama_large(pretrained=False, **kwargs): 323 | model = VisionTransformer( 324 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 325 | norm_layer=partial(RMSNorm, eps=1e-6), **kwargs) 326 | return model 327 | 328 | 329 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | 4 | from timm.optim.adafactor import Adafactor 5 | from timm.optim.adahessian import Adahessian 6 | from timm.optim.adamp import AdamP 7 | from timm.optim.lookahead import Lookahead 8 | from timm.optim.nadam import Nadam 9 | # from timm.optim.novograd import NovoGrad 10 | from timm.optim.nvnovograd import NvNovoGrad 11 | from timm.optim.radam import RAdam 12 | from timm.optim.rmsprop_tf import RMSpropTF 13 | from timm.optim.sgdp import SGDP 14 | 15 | import json 16 | 17 | try: 18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_convnext(var_name): 25 | """ 26 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 27 | consecutive blocks, including possible neighboring downsample layers; 28 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 29 | """ 30 | num_max_layer = 12 31 | if var_name.startswith("downsample_layers"): 32 | stage_id = int(var_name.split('.')[1]) 33 | if stage_id == 0: 34 | layer_id = 0 35 | elif stage_id == 1 or stage_id == 2: 36 | layer_id = stage_id + 1 37 | elif stage_id == 3: 38 | layer_id = 12 39 | return layer_id 40 | 41 | elif var_name.startswith("stages"): 42 | stage_id = int(var_name.split('.')[1]) 43 | block_id = int(var_name.split('.')[2]) 44 | if stage_id == 0 or stage_id == 1: 45 | layer_id = stage_id + 1 46 | elif stage_id == 2: 47 | layer_id = 3 + block_id // 3 48 | elif stage_id == 3: 49 | layer_id = 12 50 | return layer_id 51 | else: 52 | return num_max_layer + 1 53 | 54 | class LayerDecayValueAssigner(object): 55 | def __init__(self, values): 56 | self.values = values 57 | 58 | def get_scale(self, layer_id): 59 | return self.values[layer_id] 60 | 61 | def get_layer_id(self, var_name): 62 | return get_num_layer_for_convnext(var_name) 63 | 64 | 65 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 66 | parameter_group_names = {} 67 | parameter_group_vars = {} 68 | 69 | for name, param in model.named_parameters(): 70 | if not param.requires_grad: 71 | continue # frozen weights 72 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 73 | group_name = "no_decay" 74 | this_weight_decay = 0. 75 | else: 76 | group_name = "decay" 77 | this_weight_decay = weight_decay 78 | if get_num_layer is not None: 79 | layer_id = get_num_layer(name) 80 | group_name = "layer_%d_%s" % (layer_id, group_name) 81 | else: 82 | layer_id = None 83 | 84 | if group_name not in parameter_group_names: 85 | if get_layer_scale is not None: 86 | scale = get_layer_scale(layer_id) 87 | else: 88 | scale = 1. 89 | 90 | parameter_group_names[group_name] = { 91 | "weight_decay": this_weight_decay, 92 | "params": [], 93 | "lr_scale": scale 94 | } 95 | parameter_group_vars[group_name] = { 96 | "weight_decay": this_weight_decay, 97 | "params": [], 98 | "lr_scale": scale 99 | } 100 | 101 | parameter_group_vars[group_name]["params"].append(param) 102 | parameter_group_names[group_name]["params"].append(name) 103 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 104 | return list(parameter_group_vars.values()) 105 | 106 | 107 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 108 | opt_lower = args.opt.lower() 109 | weight_decay = args.weight_decay 110 | # if weight_decay and filter_bias_and_bn: 111 | if filter_bias_and_bn: 112 | skip = {} 113 | if skip_list is not None: 114 | skip = skip_list 115 | elif hasattr(model, 'no_weight_decay'): 116 | skip = model.no_weight_decay() 117 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 118 | weight_decay = 0. 119 | else: 120 | parameters = model.parameters() 121 | 122 | if 'fused' in opt_lower: 123 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 124 | 125 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 126 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 127 | opt_args['eps'] = args.opt_eps 128 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 129 | opt_args['betas'] = args.opt_betas 130 | 131 | opt_split = opt_lower.split('_') 132 | opt_lower = opt_split[-1] 133 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 134 | opt_args.pop('eps', None) 135 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 136 | elif opt_lower == 'momentum': 137 | opt_args.pop('eps', None) 138 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 139 | elif opt_lower == 'adam': 140 | optimizer = optim.Adam(parameters, **opt_args) 141 | elif opt_lower == 'adamw': 142 | optimizer = optim.AdamW(parameters, **opt_args) 143 | elif opt_lower == 'nadam': 144 | optimizer = Nadam(parameters, **opt_args) 145 | elif opt_lower == 'radam': 146 | optimizer = RAdam(parameters, **opt_args) 147 | elif opt_lower == 'adamp': 148 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 149 | elif opt_lower == 'sgdp': 150 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 151 | elif opt_lower == 'adadelta': 152 | optimizer = optim.Adadelta(parameters, **opt_args) 153 | elif opt_lower == 'adafactor': 154 | if not args.lr: 155 | opt_args['lr'] = None 156 | optimizer = Adafactor(parameters, **opt_args) 157 | elif opt_lower == 'adahessian': 158 | optimizer = Adahessian(parameters, **opt_args) 159 | elif opt_lower == 'rmsprop': 160 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 161 | elif opt_lower == 'rmsproptf': 162 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 163 | elif opt_lower == 'novograd': 164 | optimizer = NovoGrad(parameters, **opt_args) 165 | elif opt_lower == 'nvnovograd': 166 | optimizer = NvNovoGrad(parameters, **opt_args) 167 | elif opt_lower == 'fusedsgd': 168 | opt_args.pop('eps', None) 169 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 170 | elif opt_lower == 'fusedmomentum': 171 | opt_args.pop('eps', None) 172 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 173 | elif opt_lower == 'fusedadam': 174 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 175 | elif opt_lower == 'fusedadamw': 176 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 177 | elif opt_lower == 'fusedlamb': 178 | optimizer = FusedLAMB(parameters, **opt_args) 179 | elif opt_lower == 'fusednovograd': 180 | opt_args.setdefault('betas', (0.95, 0.98)) 181 | optimizer = FusedNovoGrad(parameters, **opt_args) 182 | else: 183 | assert False and "Invalid optimizer" 184 | 185 | if len(opt_split) > 1: 186 | if opt_split[0] == 'lookahead': 187 | optimizer = Lookahead(optimizer) 188 | 189 | return optimizer 190 | -------------------------------------------------------------------------------- /scripts/eval_illama_in1k_224.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | root_imagenet='/mnt/petrelfs/wangjiahao/datasets/classificaton/imagenet/' 4 | 5 | # illama-tiny: 75.0 6 | MODEL=illama_tiny 7 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-tiny-in1k-75.0.pth' 8 | 9 | srun -p gvembodied \ 10 | --job-name=evaluation_224 \ 11 | --gres=gpu:2 \ 12 | --cpus-per-task=32 \ 13 | --preempt \ 14 | --quotatype=spot \ 15 | python -m torch.distributed.launch --nproc_per_node=2 main.py \ 16 | --model $MODEL --eval true \ 17 | --data_path $root_imagenet \ 18 | --resume $RESUME 19 | 20 | 21 | # illama-small: 79.9 22 | MODEL=illama_small 23 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-small-in1k-79.9.pth' 24 | 25 | srun -p gvembodied \ 26 | --job-name=evaluation_224 \ 27 | --gres=gpu:2 \ 28 | --cpus-per-task=32 \ 29 | --preempt \ 30 | --quotatype=spot \ 31 | python -m torch.distributed.launch --nproc_per_node=2 main.py \ 32 | --model $MODEL --eval true \ 33 | --data_path $root_imagenet \ 34 | --resume $RESUME 35 | 36 | 37 | # illama-base: 81.6 38 | MODEL=illama_base 39 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-base-in1k-81.6.pth' 40 | 41 | srun -p gvembodied \ 42 | --job-name=evaluation_224 \ 43 | --gres=gpu:2 \ 44 | --cpus-per-task=32 \ 45 | --preempt \ 46 | --quotatype=spot \ 47 | python -m torch.distributed.launch --nproc_per_node=2 main.py \ 48 | --model $MODEL --eval true \ 49 | --data_path $root_imagenet \ 50 | --resume $RESUME 51 | 52 | 53 | # illama-base: 83.6 54 | MODEL=illama_base 55 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-base-in21kin1k-224-83.6.pth' 56 | 57 | srun -p gvembodied \ 58 | --job-name=evaluation_224 \ 59 | --gres=gpu:2 \ 60 | --cpus-per-task=32 \ 61 | --preempt \ 62 | --quotatype=spot \ 63 | python -m torch.distributed.launch --nproc_per_node=2 main.py \ 64 | --model $MODEL --eval true \ 65 | --data_path $root_imagenet \ 66 | --resume $RESUME 67 | 68 | 69 | # illama-large: 84.8 70 | MODEL=illama_large 71 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-large-in21kin1k-224-84.8.pth' 72 | 73 | srun -p gvembodied \ 74 | --job-name=evaluation_224 \ 75 | --gres=gpu:2 \ 76 | --cpus-per-task=32 \ 77 | --preempt \ 78 | --quotatype=spot \ 79 | python -m torch.distributed.launch --nproc_per_node=2 main.py \ 80 | --model $MODEL --eval true \ 81 | --data_path $root_imagenet \ 82 | --resume $RESUME -------------------------------------------------------------------------------- /scripts/eval_illama_in1k_384.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | root_imagenet='/mnt/petrelfs/wangjiahao/datasets/classificaton/imagenet/' 4 | 5 | # illama-base: 83.0 6 | MODEL=illama_base 7 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-base-in1k-384-83.0.pth' 8 | 9 | srun -p gvembodied \ 10 | --job-name=evaluation_384 \ 11 | --gres=gpu:2 \ 12 | --cpus-per-task=32 \ 13 | --preempt \ 14 | --quotatype=spot \ 15 | python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \ 16 | --model $MODEL --input_size 384 --eval true \ 17 | --data_path $root_imagenet \ 18 | --resume $RESUME 19 | 20 | 21 | # illama-base: 85.0 22 | MODEL=illama_base 23 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-base-in21kin1k-384-85.0.pth' 24 | 25 | srun -p gvembodied \ 26 | --job-name=evaluation_384 \ 27 | --gres=gpu:2 \ 28 | --cpus-per-task=32 \ 29 | --preempt \ 30 | --quotatype=spot \ 31 | python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \ 32 | --model $MODEL --input_size 384 --eval true \ 33 | --data_path $root_imagenet \ 34 | --resume $RESUME 35 | 36 | 37 | # illama-large: 86.0 38 | MODEL=illama_large 39 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-large-in21kin1k-384-86.0.pth' 40 | 41 | srun -p gvembodied \ 42 | --job-name=evaluation_384 \ 43 | --gres=gpu:2 \ 44 | --cpus-per-task=32 \ 45 | --preempt \ 46 | --quotatype=spot \ 47 | python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \ 48 | --model $MODEL --input_size 384 --eval true \ 49 | --data_path $root_imagenet \ 50 | --resume $RESUME -------------------------------------------------------------------------------- /scripts/train_illama_base_from_llama2.sh: -------------------------------------------------------------------------------- 1 | root_imagenet='/your/path/to/imagenet/' 2 | MODEL=illama_base 3 | OUTPUT='output/path' 4 | FINETUNE='/your/path/to/llama2/pretrained/illama_ws_base.pth' 5 | 6 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \ 7 | --model $MODEL --epochs 300 --mixup 0.95 --cutmix 1.0 \ 8 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 9 | --drop_path 0.4 --drop_mode standard \ 10 | --mask_mode soft --mask_schedule linear --cutoff_soft 25 \ 11 | --finetune $FINETUNE \ 12 | --data_path $root_imagenet \ 13 | --output_dir $OUTPUT -------------------------------------------------------------------------------- /scripts/train_illama_base_in1k.sh: -------------------------------------------------------------------------------- 1 | root_imagenet='/your/path/to/imagenet/' 2 | MODEL=illama_base 3 | OUTPUT='output/path' 4 | 5 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \ 6 | --model $MODEL --epochs 300 --mixup 0.95 --cutmix 1.0 \ 7 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 8 | --drop_path 0.4 --drop_mode standard \ 9 | --mask_mode soft --mask_schedule linear --cutoff_soft 25 \ 10 | --data_path $root_imagenet \ 11 | --output_dir $OUTPUT -------------------------------------------------------------------------------- /scripts/train_illama_small_from_llama2.sh: -------------------------------------------------------------------------------- 1 | root_imagenet='/your/path/to/imagenet/' 2 | MODEL=illama_small 3 | OUTPUT='output/path' 4 | FINETUNE='/your/path/to/llama2/pretrained/illama_ws_small.pth' 5 | 6 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \ 7 | --model $MODEL --epochs 300 --warmup_epochs 5 --mixup 0.5 --cutmix 0.5 \ 8 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 9 | --drop_path 0.1 --drop_mode standard \ 10 | --mask_mode soft --mask_schedule linear --cutoff_soft 50 \ 11 | --finetune $FINETUNE \ 12 | --data_path $root_imagenet \ 13 | --output_dir $OUTPUT -------------------------------------------------------------------------------- /scripts/train_illama_small_in1k.sh: -------------------------------------------------------------------------------- 1 | root_imagenet='/your/path/to/imagenet/' 2 | MODEL=illama_small 3 | OUTPUT='output/path' 4 | 5 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \ 6 | --model $MODEL --epochs 300 --warmup_epochs 5 --mixup 0.5 --cutmix 0.5 \ 7 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 8 | --drop_path 0.1 --drop_mode standard \ 9 | --mask_mode soft --mask_schedule linear --cutoff_soft 50 \ 10 | --data_path $root_imagenet \ 11 | --output_dir $OUTPUT -------------------------------------------------------------------------------- /scripts/train_illama_tiny_from_llama2.sh: -------------------------------------------------------------------------------- 1 | root_imagenet='/your/path/to/imagenet/' 2 | MODEL=illama_tiny 3 | OUTPUT='output/path' 4 | FINETUNE='/your/path/to/llama2/pretrained/illama_ws_tiny.pth' 5 | 6 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \ 7 | --model $MODEL --epochs 300 --warmup_epochs 5 --mixup 0.1 --cutmix 0.1 \ 8 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 9 | --dropout 0 --drop_mode standard \ 10 | --mask_mode soft --mask_schedule constant --cutoff_soft 50 \ 11 | --finetune $FINETUNE \ 12 | --data_path $root_imagenet \ 13 | --output_dir $OUTPUT -------------------------------------------------------------------------------- /scripts/train_illama_tiny_in1k.sh: -------------------------------------------------------------------------------- 1 | root_imagenet='/your/path/to/imagenet/' 2 | MODEL=illama_tiny 3 | OUTPUT='output/path' 4 | 5 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \ 6 | --model $MODEL --epochs 300 --warmup_epochs 5 --mixup 0.1 --cutmix 0.1 \ 7 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 8 | --dropout 0 --drop_mode standard \ 9 | --mask_mode soft --mask_schedule constant --cutoff_soft 50 \ 10 | --data_path $root_imagenet \ 11 | --output_dir $OUTPUT -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import math 5 | import time 6 | from collections import defaultdict, deque 7 | import datetime 8 | import numpy as np 9 | from timm.utils import get_state_dict 10 | 11 | from pathlib import Path 12 | from timm.models import create_model 13 | import torch 14 | import torch.distributed as dist 15 | from torch._six import inf 16 | 17 | from tensorboardX import SummaryWriter 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value) 79 | 80 | 81 | class MetricLogger(object): 82 | def __init__(self, delimiter="\t"): 83 | self.meters = defaultdict(SmoothedValue) 84 | self.delimiter = delimiter 85 | 86 | def update(self, **kwargs): 87 | for k, v in kwargs.items(): 88 | if v is None: 89 | continue 90 | if isinstance(v, torch.Tensor): 91 | v = v.item() 92 | assert isinstance(v, (float, int)) 93 | self.meters[k].update(v) 94 | 95 | def __getattr__(self, attr): 96 | if attr in self.meters: 97 | return self.meters[attr] 98 | if attr in self.__dict__: 99 | return self.__dict__[attr] 100 | raise AttributeError("'{}' object has no attribute '{}'".format( 101 | type(self).__name__, attr)) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append( 107 | "{}: {}".format(name, str(meter)) 108 | ) 109 | return self.delimiter.join(loss_str) 110 | 111 | def synchronize_between_processes(self): 112 | for meter in self.meters.values(): 113 | meter.synchronize_between_processes() 114 | 115 | def add_meter(self, name, meter): 116 | self.meters[name] = meter 117 | 118 | def log_every(self, iterable, print_freq, header=None): 119 | i = 0 120 | if not header: 121 | header = '' 122 | start_time = time.time() 123 | end = time.time() 124 | iter_time = SmoothedValue(fmt='{avg:.4f}') 125 | data_time = SmoothedValue(fmt='{avg:.4f}') 126 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 127 | log_msg = [ 128 | header, 129 | '[{0' + space_fmt + '}/{1}]', 130 | 'eta: {eta}', 131 | '{meters}', 132 | 'time: {time}', 133 | 'data: {data}' 134 | ] 135 | if torch.cuda.is_available(): 136 | log_msg.append('max mem: {memory:.0f}') 137 | log_msg = self.delimiter.join(log_msg) 138 | MB = 1024.0 * 1024.0 139 | for obj in iterable: 140 | data_time.update(time.time() - end) 141 | yield obj 142 | iter_time.update(time.time() - end) 143 | if i % print_freq == 0 or i == len(iterable) - 1: 144 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 145 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 146 | if torch.cuda.is_available(): 147 | print(log_msg.format( 148 | i, len(iterable), eta=eta_string, 149 | meters=str(self), 150 | time=str(iter_time), data=str(data_time), 151 | memory=torch.cuda.max_memory_allocated() / MB)) 152 | else: 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time))) 157 | i += 1 158 | end = time.time() 159 | total_time = time.time() - start_time 160 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 161 | print('{} Total time: {} ({:.4f} s / it)'.format( 162 | header, total_time_str, total_time / len(iterable))) 163 | 164 | 165 | class TensorboardLogger(object): 166 | def __init__(self, log_dir): 167 | self.writer = SummaryWriter(logdir=log_dir) 168 | self.step = 0 169 | 170 | def set_step(self, step=None): 171 | if step is not None: 172 | self.step = step 173 | else: 174 | self.step += 1 175 | 176 | def update(self, head='scalar', step=None, **kwargs): 177 | for k, v in kwargs.items(): 178 | if v is None: 179 | continue 180 | if isinstance(v, torch.Tensor): 181 | v = v.item() 182 | assert isinstance(v, (float, int)) 183 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 184 | 185 | def flush(self): 186 | self.writer.flush() 187 | 188 | 189 | class WandbLogger(object): 190 | def __init__(self, args): 191 | self.args = args 192 | 193 | try: 194 | import wandb 195 | self._wandb = wandb 196 | except ImportError: 197 | raise ImportError( 198 | "To use the Weights and Biases Logger please install wandb." 199 | "Run `pip install wandb` to install it." 200 | ) 201 | 202 | # Initialize a W&B run 203 | if self._wandb.run is None: 204 | self._wandb.init( 205 | project=args.project, 206 | config=args 207 | ) 208 | 209 | def log_epoch_metrics(self, metrics, commit=True): 210 | """ 211 | Log train/test metrics onto W&B. 212 | """ 213 | # Log number of model parameters as W&B summary 214 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 215 | metrics.pop('n_parameters', None) 216 | 217 | # Log current epoch 218 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 219 | metrics.pop('epoch') 220 | 221 | for k, v in metrics.items(): 222 | if 'train' in k: 223 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 224 | elif 'test' in k: 225 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 226 | 227 | self._wandb.log({}) 228 | 229 | def log_checkpoints(self): 230 | output_dir = self.args.output_dir 231 | model_artifact = self._wandb.Artifact( 232 | self._wandb.run.id + "_model", type="model" 233 | ) 234 | 235 | model_artifact.add_dir(output_dir) 236 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 237 | 238 | def set_steps(self): 239 | # Set global training step 240 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 241 | # Set epoch-wise step 242 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 243 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 244 | 245 | 246 | def setup_for_distributed(is_master): 247 | """ 248 | This function disables printing when not in master process 249 | """ 250 | import builtins as __builtin__ 251 | builtin_print = __builtin__.print 252 | 253 | def print(*args, **kwargs): 254 | force = kwargs.pop('force', False) 255 | if is_master or force: 256 | builtin_print(*args, **kwargs) 257 | 258 | __builtin__.print = print 259 | 260 | 261 | def is_dist_avail_and_initialized(): 262 | if not dist.is_available(): 263 | return False 264 | if not dist.is_initialized(): 265 | return False 266 | return True 267 | 268 | 269 | def get_world_size(): 270 | if not is_dist_avail_and_initialized(): 271 | return 1 272 | return dist.get_world_size() 273 | 274 | 275 | def get_rank(): 276 | if not is_dist_avail_and_initialized(): 277 | return 0 278 | return dist.get_rank() 279 | 280 | 281 | def is_main_process(): 282 | return get_rank() == 0 283 | 284 | 285 | def save_on_master(*args, **kwargs): 286 | if is_main_process(): 287 | torch.save(*args, **kwargs) 288 | 289 | 290 | def init_distributed_mode(args): 291 | 292 | if args.dist_on_itp: 293 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 294 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 295 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 296 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 297 | os.environ['LOCAL_RANK'] = str(args.gpu) 298 | os.environ['RANK'] = str(args.rank) 299 | os.environ['WORLD_SIZE'] = str(args.world_size) 300 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 301 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 302 | args.rank = int(os.environ["RANK"]) 303 | args.world_size = int(os.environ['WORLD_SIZE']) 304 | args.gpu = int(os.environ['LOCAL_RANK']) 305 | elif 'SLURM_PROCID' in os.environ: 306 | args.rank = int(os.environ['SLURM_PROCID']) 307 | args.gpu = args.rank % torch.cuda.device_count() 308 | 309 | os.environ['RANK'] = str(args.rank) 310 | os.environ['LOCAL_RANK'] = str(args.gpu) 311 | os.environ['WORLD_SIZE'] = str(args.world_size) 312 | else: 313 | print('Not using distributed mode') 314 | args.distributed = False 315 | return 316 | 317 | args.distributed = True 318 | 319 | torch.cuda.set_device(args.gpu) 320 | args.dist_backend = 'nccl' 321 | print('| distributed init (rank {}): {}, gpu {}'.format( 322 | args.rank, args.dist_url, args.gpu), flush=True) 323 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 324 | world_size=args.world_size, rank=args.rank) 325 | torch.distributed.barrier() 326 | setup_for_distributed(args.rank == 0) 327 | 328 | 329 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 330 | missing_keys = [] 331 | unexpected_keys = [] 332 | error_msgs = [] 333 | # copy state_dict so _load_from_state_dict can modify it 334 | metadata = getattr(state_dict, '_metadata', None) 335 | state_dict = state_dict.copy() 336 | if metadata is not None: 337 | state_dict._metadata = metadata 338 | 339 | def load(module, prefix=''): 340 | local_metadata = {} if metadata is None else metadata.get( 341 | prefix[:-1], {}) 342 | module._load_from_state_dict( 343 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 344 | for name, child in module._modules.items(): 345 | if child is not None: 346 | load(child, prefix + name + '.') 347 | 348 | load(model, prefix=prefix) 349 | 350 | warn_missing_keys = [] 351 | ignore_missing_keys = [] 352 | for key in missing_keys: 353 | keep_flag = True 354 | for ignore_key in ignore_missing.split('|'): 355 | if ignore_key in key: 356 | keep_flag = False 357 | break 358 | if keep_flag: 359 | warn_missing_keys.append(key) 360 | else: 361 | ignore_missing_keys.append(key) 362 | 363 | missing_keys = warn_missing_keys 364 | 365 | if len(missing_keys) > 0: 366 | print("Weights of {} not initialized from pretrained model: {}".format( 367 | model.__class__.__name__, missing_keys)) 368 | if len(unexpected_keys) > 0: 369 | print("Weights from pretrained model not used in {}: {}".format( 370 | model.__class__.__name__, unexpected_keys)) 371 | if len(ignore_missing_keys) > 0: 372 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 373 | model.__class__.__name__, ignore_missing_keys)) 374 | if len(error_msgs) > 0: 375 | print('\n'.join(error_msgs)) 376 | 377 | 378 | class NativeScalerWithGradNormCount: 379 | state_dict_key = "amp_scaler" 380 | 381 | def __init__(self): 382 | self._scaler = torch.cuda.amp.GradScaler() 383 | 384 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 385 | self._scaler.scale(loss).backward(create_graph=create_graph) 386 | if update_grad: 387 | if clip_grad is not None: 388 | assert parameters is not None 389 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 390 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 391 | else: 392 | self._scaler.unscale_(optimizer) 393 | norm = get_grad_norm_(parameters) 394 | self._scaler.step(optimizer) 395 | self._scaler.update() 396 | else: 397 | norm = None 398 | return norm 399 | 400 | def state_dict(self): 401 | return self._scaler.state_dict() 402 | 403 | def load_state_dict(self, state_dict): 404 | self._scaler.load_state_dict(state_dict) 405 | 406 | 407 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 408 | if isinstance(parameters, torch.Tensor): 409 | parameters = [parameters] 410 | parameters = [p for p in parameters if p.grad is not None] 411 | norm_type = float(norm_type) 412 | if len(parameters) == 0: 413 | return torch.tensor(0.) 414 | device = parameters[0].grad.device 415 | if norm_type == inf: 416 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 417 | else: 418 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 419 | return total_norm 420 | 421 | 422 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 423 | start_warmup_value=0, warmup_steps=-1): 424 | warmup_schedule = np.array([]) 425 | warmup_iters = warmup_epochs * niter_per_ep 426 | if warmup_steps > 0: 427 | warmup_iters = warmup_steps 428 | print("Set warmup steps = %d" % warmup_iters) 429 | if warmup_epochs > 0: 430 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 431 | 432 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 433 | schedule = np.array( 434 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 435 | 436 | schedule = np.concatenate((warmup_schedule, schedule)) 437 | 438 | assert len(schedule) == epochs * niter_per_ep 439 | return schedule 440 | 441 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 442 | output_dir = Path(args.output_dir) 443 | epoch_name = str(epoch) 444 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 445 | for checkpoint_path in checkpoint_paths: 446 | to_save = { 447 | 'model': model_without_ddp.state_dict(), 448 | 'optimizer': optimizer.state_dict(), 449 | 'epoch': epoch, 450 | 'scaler': loss_scaler.state_dict(), 451 | 'args': args, 452 | } 453 | 454 | if model_ema is not None: 455 | to_save['model_ema'] = get_state_dict(model_ema) 456 | 457 | save_on_master(to_save, checkpoint_path) 458 | 459 | if is_main_process() and isinstance(epoch, int): 460 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 461 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 462 | if os.path.exists(old_ckpt): 463 | os.remove(old_ckpt) 464 | 465 | 466 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 467 | output_dir = Path(args.output_dir) 468 | if args.auto_resume and len(args.resume) == 0: 469 | import glob 470 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 471 | latest_ckpt = -1 472 | for ckpt in all_checkpoints: 473 | t = ckpt.split('-')[-1].split('.')[0] 474 | if t.isdigit(): 475 | latest_ckpt = max(int(t), latest_ckpt) 476 | if latest_ckpt >= 0: 477 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 478 | print("Auto resume checkpoint: %s" % args.resume) 479 | 480 | if args.resume: 481 | if args.resume.startswith('https'): 482 | checkpoint = torch.hub.load_state_dict_from_url( 483 | args.resume, map_location='cpu', check_hash=True) 484 | else: 485 | checkpoint = torch.load(args.resume, map_location='cpu') 486 | if 'model' in checkpoint: 487 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 488 | else: 489 | model_without_ddp.load_state_dict(checkpoint, strict=False) 490 | # model_without_ddp.load_state_dict(checkpoint['model']) 491 | print("Resume checkpoint %s" % args.resume) 492 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 493 | optimizer.load_state_dict(checkpoint['optimizer']) 494 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 495 | args.start_epoch = checkpoint['epoch'] + 1 496 | else: 497 | assert args.eval, 'Does not support resuming with checkpoint-best' 498 | if hasattr(args, 'model_ema') and args.model_ema: 499 | if 'model_ema' in checkpoint.keys(): 500 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 501 | else: 502 | model_ema.ema.load_state_dict(checkpoint['model']) 503 | if 'scaler' in checkpoint: 504 | loss_scaler.load_state_dict(checkpoint['scaler']) 505 | print("With optim & sched!") 506 | 507 | def reg_scheduler(base_value, final_value, epochs, niter_per_ep, early_epochs=0, early_value=None, 508 | mode='linear', early_mode='regular'): 509 | early_schedule = np.array([]) 510 | early_iters = early_epochs * niter_per_ep 511 | if early_value is None: 512 | early_value = final_value 513 | if early_epochs > 0: 514 | print(f"Set early value to {early_mode} {early_value}") 515 | if early_mode == 'regular': 516 | early_schedule = np.array([early_value] * early_iters) 517 | elif early_mode == 'linear': 518 | early_schedule = np.linspace(early_value, base_value, early_iters) 519 | elif early_mode == 'cosine': 520 | early_schedule = np.array( 521 | [base_value + 0.5 * (early_value - base_value) * (1 + math.cos(math.pi * i / early_iters)) for i in np.arange(early_iters)]) 522 | regular_epochs = epochs - early_epochs 523 | iters = np.arange(regular_epochs * niter_per_ep) 524 | schedule = np.linspace(base_value, final_value, len(iters)) 525 | schedule = np.concatenate((early_schedule, schedule)) 526 | 527 | assert len(schedule) == epochs * niter_per_ep 528 | return schedule 529 | 530 | def calculate_distance(args, model_without_ddp, device): 531 | output_dir = Path(args.output_dir) 532 | start_path = os.path.join(output_dir, 'checkpoint-start.pth') 533 | if not os.path.exists(start_path): 534 | return -1 535 | model_start = build_model(args) 536 | checkpoint_start = torch.load(start_path, map_location='cpu') 537 | model_start.load_state_dict(checkpoint_start['model']) 538 | model_start.to(device) 539 | cur = torch.tensor([]).to(device) 540 | start = torch.tensor([]).to(device) 541 | with torch.no_grad(): 542 | for name, p in model_without_ddp.named_parameters(): 543 | cur = torch.cat((cur, p.flatten().clone().detach())) 544 | for name, p in model_start.named_parameters(): 545 | start = torch.cat((start, p.flatten().clone().detach())) 546 | return torch.nn.MSELoss()(start, cur).item() 547 | 548 | def build_model(args): 549 | if args.model.startswith("convnext"): 550 | model = create_model( 551 | args.model, 552 | pretrained=False, 553 | num_classes=args.nb_classes, 554 | drop_path_rate=args.drop_path, 555 | layer_scale_init_value=args.layer_scale_init_value, 556 | head_init_scale=args.head_init_scale, 557 | drop_rate=args.dropout, 558 | ) 559 | else: 560 | model = create_model( 561 | args.model, 562 | pretrained=False, 563 | num_classes=args.nb_classes, 564 | drop_path_rate=args.drop_path, 565 | drop_rate =args.dropout, 566 | img_size=args.input_size 567 | ) 568 | return model --------------------------------------------------------------------------------