├── .gitignore
├── README.md
├── datasets.py
├── drop_scheduler.py
├── engine.py
├── engine_soft.py
├── llama2
└── weight_selection.py
├── logs
├── illama-base-in1k-384-83.0.txt
├── illama-base-in1k-81.6.txt
├── illama-base-in21k-checkpoint-89.txt
├── illama-base-in21kin1k-224-83.6.txt
├── illama-base-in21kin1k-384-85.0.txt
├── illama-large-in21k-checkpoint-89.txt
├── illama-large-in21kin1k-224-84.8.txt
├── illama-large-in21kin1k-384-86.0.txt
├── illama-small-in1k-79.9.txt
└── illama-tiny-in1k-75.0.txt
├── main.py
├── main_soft.py
├── main_soft_fthr.py
├── mask_scheduler.py
├── models
├── __init__.py
└── illama.py
├── optim_factory.py
├── scripts
├── eval_illama_in1k_224.sh
├── eval_illama_in1k_384.sh
├── train_illama_base_from_llama2.sh
├── train_illama_base_in1k.sh
├── train_illama_small_from_llama2.sh
├── train_illama_small_in1k.sh
├── train_illama_tiny_from_llama2.sh
└── train_illama_tiny_in1k.sh
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # pyc
2 | __pycache__/
3 |
4 | # iLLaMA
5 | output/
6 | scripts/debug.sh
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Adapting LLaMA Decoder to Vision Transformer](https://arxiv.org/pdf/2404.06773)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Image credit: DALL·E
12 |
13 |
14 |
15 | This is a PyTorch implementation of iLLaMA proposed by our paper "[Adapting LLaMA Decoder to Vision Transformer](https://arxiv.org/abs/2404.06773)".
16 |
17 |
18 | 
19 | Figure 1: Left: iLLaMA architecture. Right: our design roadmap. Colored and gray bars
20 | represent the results of the tiny and base regimes, with the red line depicting the training loss of the
21 | tiny regime. iLLaMA strives to process visual tokens using standard LLaMa components, e.g., causal
22 | self-attention. The proposed PS [cls] and soft mask strategy help overcome training challenges.
23 |
24 |
25 |
26 | 
27 | Figure 2: (a) mask in causal self-attention. (b) mask in causal self-attention with our post-sequence
28 | class token (PS [cls]) method. (c) modified causal mask.
29 |
30 |
31 |
32 | 
33 | Figure 3: (a) Soft mask gradually transitions from a bi-directional mask into a causal mask during
34 | training through a constant or linear schedule. (b) Ablation results of training loss and test accuracy.
35 |
36 |
37 |
38 | ## Requirements
39 | PyTorch and timm 0.5.4 (`pip install timm==0.5.4`).
40 |
41 | Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4).
42 |
43 | ```
44 | │imagenet/
45 | ├──train/
46 | │ ├── n01440764
47 | │ │ ├── n01440764_10026.JPEG
48 | │ │ ├── n01440764_10027.JPEG
49 | │ │ ├── ......
50 | │ ├── ......
51 | ├──val/
52 | │ ├── n01440764
53 | │ │ ├── ILSVRC2012_val_00000293.JPEG
54 | │ │ ├── ILSVRC2012_val_00002138.JPEG
55 | │ │ ├── ......
56 | │ ├── ......
57 | ```
58 |
59 |
60 | ## Models
61 | ### iLLaMA on ImageNet-1K
62 | | Model | Pre-trained dataset | Resolution | Params | MACs | Top1 Acc |
63 | | :--- | :--- | :---: | :---: | :---: | :---: |
64 | | [illama_tiny](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-tiny-in1k-75.0.pth?download=true) | - | 224 | 5.7M | 1.3G | 75.0 |
65 | | [illama_small](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-small-in1k-79.9.pth?download=true) | - | 224 | 21.9M | 4.6G | 79.9 |
66 | | [illama_base](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-base-in1k-81.6.pth?download=true) | - | 224 | 86.3M | 17.6G | 81.6 |
67 | | [illama_base](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-base-in1k-384-83.0.pth?download=true) | - | 384 | 86.3M | 55.5G | 83.0 |
68 | | [illama_base](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-base-in21kin1k-224-83.6.pth?download=true) | ImageNet-21K | 224 | 86.3M | 17.6G | 83.6 |
69 | | [illama_base](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-base-in21kin1k-384-85.0.pth?download=true) | ImageNet-21K | 384 | 86.3M | 55.5G | 85.0 |
70 | | [illama_large](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-large-in21kin1k-224-84.8.pth?download=true) | ImageNet-21K | 224 | 310.2M | 62.8G | 84.8 |
71 | | [illama_large](https://huggingface.co/techmonsterwang/iLLaMA/resolve/main/illama-large-in21kin1k-384-86.0.pth?download=true) | ImageNet-21K | 384 | 310.2M | 194.7G | 86.0 |
72 |
73 |
74 |
75 |
76 | ## Evaluate
77 |
78 | To evaluate models on 224 resolution, run:
79 |
80 | ```bash
81 | MODEL=illama_tiny
82 | RESUME='/your/path/to/model.pth'
83 |
84 | python -m torch.distributed.launch --nproc_per_node=2 main.py \
85 | --model $MODEL --eval true \
86 | --data_path $root_imagenet \
87 | --resume $RESUME
88 | ```
89 |
90 | To evaluate models on 384 resolution, run:
91 |
92 | ```bash
93 | MODEL=illama_base
94 | RESUME='/your/path/to/model.pth'
95 |
96 | python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \
97 | --model $MODEL --input_size 384 --eval true \
98 | --data_path $root_imagenet \
99 | --resume $RESUME
100 | ```
101 |
102 | ## Train
103 | We use batch size of 4096 by default with 8 GPUs.
104 |
105 |
106 | ```bash
107 | bash scripts/train_illama_tiny_in1k.sh
108 | ```
109 | Training scripts of other models are shown in [scripts](/scripts/).
110 |
111 |
112 | ## Initialization Using LLaMA2-7B (Optional)
113 | We use weight selection method to select weights from LLaMA2-7B.
114 |
115 | ```bash
116 | python llama2/weight_selection.py
117 | ```
118 |
119 | Then we use the selected weights to initialize our iLLaMA-T/S/B.
120 |
121 | ```bash
122 | bash scripts/train_illama_tiny_from_llama2.sh
123 | ```
124 | Training scripts of other models are shown in [scripts](/scripts/).
125 |
126 |
127 | ## Bibtex
128 | ```
129 | @article{wang2024adapting,
130 | title={Adapting LLaMA Decoder to Vision Transformer},
131 | author={Wang, Jiahao and Shao, Wenqi and Chen, Mengzhao and Wu, Chengyue and Liu, Yong and Zhang, Kaipeng and Zhang, Songyang and Chen, Kai and Luo, Ping},
132 | journal={arXiv preprint arXiv:2404.06773},
133 | year={2024}
134 | }
135 | ```
136 |
137 | ## Acknowledgment
138 |
139 | Our implementation is based on [pytorch-image-models](https://github.com/huggingface/pytorch-image-models), [llama](https://github.com/meta-llama/llama), [dropout](https://github.com/facebookresearch/dropout), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt), [weight-selection](https://github.com/OscarXZQ/weight-selection), and [MambaOut](https://github.com/yuweihao/MambaOut).
140 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import datasets, transforms
3 |
4 | from timm.data.constants import \
5 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
6 | from timm.data import create_transform
7 |
8 | def build_dataset(is_train, args):
9 | transform = build_transform(is_train, args)
10 |
11 | print("Transform = ")
12 | if isinstance(transform, tuple):
13 | for trans in transform:
14 | print(" - - - - - - - - - - ")
15 | for t in trans.transforms:
16 | print(t)
17 | else:
18 | for t in transform.transforms:
19 | print(t)
20 | print("---------------------------")
21 |
22 | if args.data_set == 'CIFAR10':
23 | dataset = datasets.CIFAR10(args.data_path, train=is_train, download=True, transform=transform)
24 | nb_classes = 10
25 | elif args.data_set == 'CIFAR100':
26 | dataset = datasets.CIFAR100(args.data_path, train=is_train, download=True, transform=transform)
27 | nb_classes = 100
28 | elif args.data_set == 'IMNET':
29 | print("reading from datapath", args.data_path)
30 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
31 | dataset = datasets.ImageFolder(root, transform=transform)
32 | nb_classes = 1000
33 | elif args.data_set == 'IMNET21K':
34 | print("reading from datapath", args.data_path)
35 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
36 | dataset = datasets.ImageFolder(root, transform=transform)
37 | nb_classes = args.nb_classes
38 | assert len(dataset.class_to_idx) == nb_classes
39 | elif args.data_set == "image_folder":
40 | root = args.data_path if is_train else args.eval_data_path
41 | dataset = datasets.ImageFolder(root, transform=transform)
42 | nb_classes = args.nb_classes
43 | assert len(dataset.class_to_idx) == nb_classes
44 | else:
45 | raise NotImplementedError()
46 | print("Number of the class = %d" % nb_classes)
47 |
48 | return dataset, nb_classes
49 |
50 |
51 | def build_transform(is_train, args):
52 | resize_im = args.input_size > 32
53 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
54 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
55 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
56 |
57 | if is_train:
58 | # this should always dispatch to transforms_imagenet_train
59 | transform = create_transform(
60 | input_size=args.input_size,
61 | is_training=True,
62 | color_jitter=args.color_jitter,
63 | auto_augment=args.aa,
64 | interpolation=args.train_interpolation,
65 | re_prob=args.reprob,
66 | re_mode=args.remode,
67 | re_count=args.recount,
68 | mean=mean,
69 | std=std,
70 | )
71 | if not resize_im:
72 | transform.transforms[0] = transforms.RandomCrop(
73 | args.input_size, padding=4)
74 | return transform
75 |
76 | t = []
77 | if resize_im:
78 | # warping (no cropping) when evaluated at 384 or larger
79 | if args.input_size >= 384:
80 | t.append(
81 | transforms.Resize((args.input_size, args.input_size),
82 | interpolation=transforms.InterpolationMode.BICUBIC),
83 | )
84 | print(f"Warping {args.input_size} size input images...")
85 | else:
86 | if args.crop_pct is None:
87 | args.crop_pct = 224 / 256
88 | size = int(args.input_size / args.crop_pct)
89 | t.append(
90 | # to maintain same ratio w.r.t. 224 images
91 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
92 | )
93 | t.append(transforms.CenterCrop(args.input_size))
94 |
95 | t.append(transforms.ToTensor())
96 | t.append(transforms.Normalize(mean, std))
97 | return transforms.Compose(t)
98 |
--------------------------------------------------------------------------------
/drop_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode="standard", schedule="constant"):
4 | assert mode in ["standard", "early", "late"]
5 | if mode == "standard":
6 | return np.full(epochs * niter_per_ep, drop_rate)
7 |
8 | early_iters = cutoff_epoch * niter_per_ep
9 | late_iters = (epochs - cutoff_epoch) * niter_per_ep
10 |
11 | if mode == "early":
12 | assert schedule in ["constant", "linear"]
13 | if schedule == 'constant':
14 | early_schedule = np.full(early_iters, drop_rate)
15 | elif schedule == 'linear':
16 | early_schedule = np.linspace(drop_rate, 0, early_iters)
17 | final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0)))
18 |
19 | elif mode == "late":
20 | assert schedule in ["constant"]
21 | early_schedule = np.full(early_iters, 0)
22 | final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate)))
23 |
24 | assert len(final_schedule) == epochs * niter_per_ep
25 | return final_schedule
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Iterable, Optional
3 | import torch
4 | from timm.data import Mixup
5 | from timm.utils import accuracy, ModelEma
6 |
7 | import utils
8 |
9 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
10 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
12 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
13 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, schedules={},
14 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False):
15 | model.train(True)
16 | metric_logger = utils.MetricLogger(delimiter=" ")
17 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
18 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
19 | header = 'Epoch: [{}]'.format(epoch)
20 | print_freq = 10
21 |
22 | optimizer.zero_grad()
23 |
24 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
25 | step = data_iter_step // update_freq
26 | if step >= num_training_steps_per_epoch:
27 | continue
28 | it = start_steps + step # global training iteration
29 | # Update LR & WD for the first acc
30 | if data_iter_step % update_freq == 0:
31 | if lr_schedule_values is not None or wd_schedule_values is not None:
32 | for i, param_group in enumerate(optimizer.param_groups):
33 | if lr_schedule_values is not None:
34 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
35 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
36 | param_group["weight_decay"] = wd_schedule_values[it]
37 | if 'dp' in schedules:
38 | model.module.update_drop_path(schedules['dp'][it])
39 | if 'do' in schedules:
40 | model.module.update_dropout(schedules['do'][it])
41 |
42 | samples = samples.to(device, non_blocking=True)
43 | targets = targets.to(device, non_blocking=True)
44 |
45 | if mixup_fn is not None:
46 | samples, targets = mixup_fn(samples, targets)
47 |
48 | if use_amp:
49 | with torch.cuda.amp.autocast():
50 | output = model(samples)
51 | loss = criterion(output, targets)
52 | else: # full precision
53 | output = model(samples)
54 | loss = criterion(output, targets)
55 |
56 | loss_value = loss.item()
57 |
58 | if not math.isfinite(loss_value): # this could trigger if using AMP
59 | print("Loss is {}, stopping training".format(loss_value))
60 | assert math.isfinite(loss_value)
61 |
62 | if use_amp:
63 | # this attribute is added by timm on one optimizer (adahessian)
64 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
65 | loss /= update_freq
66 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
67 | parameters=model.parameters(), create_graph=is_second_order,
68 | update_grad=(data_iter_step + 1) % update_freq == 0)
69 | if (data_iter_step + 1) % update_freq == 0:
70 | optimizer.zero_grad()
71 | if model_ema is not None:
72 | model_ema.update(model)
73 | else: # full precision
74 | loss /= update_freq
75 | loss.backward()
76 | if (data_iter_step + 1) % update_freq == 0:
77 | optimizer.step()
78 | optimizer.zero_grad()
79 | if model_ema is not None:
80 | model_ema.update(model)
81 |
82 | torch.cuda.synchronize()
83 |
84 | if mixup_fn is None:
85 | class_acc = (output.max(-1)[-1] == targets).float().mean()
86 | else:
87 | class_acc = None
88 | metric_logger.update(loss=loss_value)
89 | metric_logger.update(class_acc=class_acc)
90 | min_lr = 10.
91 | max_lr = 0.
92 | for group in optimizer.param_groups:
93 | min_lr = min(min_lr, group["lr"])
94 | max_lr = max(max_lr, group["lr"])
95 |
96 | metric_logger.update(lr=max_lr)
97 | metric_logger.update(min_lr=min_lr)
98 | weight_decay_value = None
99 | for group in optimizer.param_groups:
100 | if group["weight_decay"] > 0:
101 | weight_decay_value = group["weight_decay"]
102 | metric_logger.update(weight_decay=weight_decay_value)
103 |
104 | if 'dp' in schedules:
105 | metric_logger.update(drop_path=model.module.drop_path)
106 |
107 | if 'do' in schedules:
108 | metric_logger.update(dropout=model.module.drop_rate)
109 |
110 | if use_amp:
111 | metric_logger.update(grad_norm=grad_norm)
112 |
113 | if log_writer is not None:
114 | log_writer.update(loss=loss_value, head="loss")
115 | log_writer.update(class_acc=class_acc, head="loss")
116 | log_writer.update(lr=max_lr, head="opt")
117 | log_writer.update(min_lr=min_lr, head="opt")
118 | log_writer.update(weight_decay=weight_decay_value, head="opt")
119 | if use_amp:
120 | log_writer.update(grad_norm=grad_norm, head="opt")
121 | log_writer.set_step()
122 |
123 | if wandb_logger:
124 | wandb_logger._wandb.log({
125 | 'Rank-0 Batch Wise/train_loss': loss_value,
126 | 'Rank-0 Batch Wise/train_max_lr': max_lr,
127 | 'Rank-0 Batch Wise/train_min_lr': min_lr
128 | }, commit=False)
129 | if class_acc:
130 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False)
131 | if use_amp:
132 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
133 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
134 |
135 |
136 | # gather the stats from all processes
137 | metric_logger.synchronize_between_processes()
138 | print("Averaged stats:", metric_logger)
139 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
140 |
141 | @torch.no_grad()
142 | def evaluate(data_loader, model, device, use_amp=False):
143 | criterion = torch.nn.CrossEntropyLoss()
144 |
145 | metric_logger = utils.MetricLogger(delimiter=" ")
146 | header = 'Test:'
147 |
148 | # switch to evaluation mode
149 | model.eval()
150 | for batch in metric_logger.log_every(data_loader, 10, header):
151 | images = batch[0]
152 | target = batch[-1]
153 |
154 | images = images.to(device, non_blocking=True)
155 | target = target.to(device, non_blocking=True)
156 |
157 | # compute output
158 | if use_amp:
159 | with torch.cuda.amp.autocast():
160 | output = model(images)
161 | loss = criterion(output, target)
162 | else:
163 | output = model(images)
164 | loss = criterion(output, target)
165 |
166 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
167 |
168 | batch_size = images.shape[0]
169 | metric_logger.update(loss=loss.item())
170 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
171 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
172 | # gather the stats from all processes
173 | metric_logger.synchronize_between_processes()
174 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
175 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
176 |
177 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
178 |
--------------------------------------------------------------------------------
/engine_soft.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Iterable, Optional
3 | import torch
4 | from timm.data import Mixup
5 | from timm.utils import accuracy, ModelEma
6 |
7 | import utils
8 |
9 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
10 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
12 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
13 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, schedules={},
14 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False):
15 | model.train(True)
16 | metric_logger = utils.MetricLogger(delimiter=" ")
17 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
18 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
19 | header = 'Epoch: [{}]'.format(epoch)
20 | print_freq = 10
21 |
22 | optimizer.zero_grad()
23 |
24 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
25 | step = data_iter_step // update_freq
26 | if step >= num_training_steps_per_epoch:
27 | continue
28 | it = start_steps + step # global training iteration
29 | # Update LR & WD for the first acc
30 | if data_iter_step % update_freq == 0:
31 | if lr_schedule_values is not None or wd_schedule_values is not None:
32 | for i, param_group in enumerate(optimizer.param_groups):
33 | if lr_schedule_values is not None:
34 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
35 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
36 | param_group["weight_decay"] = wd_schedule_values[it]
37 | if 'dp' in schedules:
38 | model.module.update_drop_path(schedules['dp'][it])
39 | if 'do' in schedules:
40 | model.module.update_dropout(schedules['do'][it])
41 | if 'sm' in schedules:
42 | model.module.update_soft_mask(schedules['sm'][it])
43 |
44 | samples = samples.to(device, non_blocking=True)
45 | targets = targets.to(device, non_blocking=True)
46 |
47 | if mixup_fn is not None:
48 | samples, targets = mixup_fn(samples, targets)
49 |
50 | if use_amp:
51 | with torch.cuda.amp.autocast():
52 | output = model(samples)
53 | loss = criterion(output, targets)
54 | else: # full precision
55 | output = model(samples)
56 | loss = criterion(output, targets)
57 |
58 | loss_value = loss.item()
59 |
60 | if not math.isfinite(loss_value): # this could trigger if using AMP
61 | print("Loss is {}, stopping training".format(loss_value))
62 | assert math.isfinite(loss_value)
63 |
64 | if use_amp:
65 | # this attribute is added by timm on one optimizer (adahessian)
66 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
67 | loss /= update_freq
68 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
69 | parameters=model.parameters(), create_graph=is_second_order,
70 | update_grad=(data_iter_step + 1) % update_freq == 0)
71 | if (data_iter_step + 1) % update_freq == 0:
72 | optimizer.zero_grad()
73 | if model_ema is not None:
74 | model_ema.update(model)
75 | else: # full precision
76 | loss /= update_freq
77 | loss.backward()
78 | if (data_iter_step + 1) % update_freq == 0:
79 | optimizer.step()
80 | optimizer.zero_grad()
81 | if model_ema is not None:
82 | model_ema.update(model)
83 |
84 | torch.cuda.synchronize()
85 |
86 | if mixup_fn is None:
87 | class_acc = (output.max(-1)[-1] == targets).float().mean()
88 | else:
89 | class_acc = None
90 | metric_logger.update(loss=loss_value)
91 | metric_logger.update(class_acc=class_acc)
92 | min_lr = 10.
93 | max_lr = 0.
94 | for group in optimizer.param_groups:
95 | min_lr = min(min_lr, group["lr"])
96 | max_lr = max(max_lr, group["lr"])
97 |
98 | metric_logger.update(lr=max_lr)
99 | metric_logger.update(min_lr=min_lr)
100 | weight_decay_value = None
101 | for group in optimizer.param_groups:
102 | if group["weight_decay"] > 0:
103 | weight_decay_value = group["weight_decay"]
104 | metric_logger.update(weight_decay=weight_decay_value)
105 |
106 | if 'dp' in schedules:
107 | metric_logger.update(drop_path=model.module.drop_path)
108 |
109 | if 'do' in schedules:
110 | metric_logger.update(dropout=model.module.drop_rate)
111 |
112 | if 'sm' in schedules:
113 | metric_logger.update(soft_mask_rate=model.module.soft_mask_rate)
114 |
115 | if use_amp:
116 | metric_logger.update(grad_norm=grad_norm)
117 |
118 | if log_writer is not None:
119 | log_writer.update(loss=loss_value, head="loss")
120 | log_writer.update(class_acc=class_acc, head="loss")
121 | log_writer.update(lr=max_lr, head="opt")
122 | log_writer.update(min_lr=min_lr, head="opt")
123 | log_writer.update(weight_decay=weight_decay_value, head="opt")
124 | if use_amp:
125 | log_writer.update(grad_norm=grad_norm, head="opt")
126 | log_writer.set_step()
127 |
128 | if wandb_logger:
129 | wandb_logger._wandb.log({
130 | 'Rank-0 Batch Wise/train_loss': loss_value,
131 | 'Rank-0 Batch Wise/train_max_lr': max_lr,
132 | 'Rank-0 Batch Wise/train_min_lr': min_lr
133 | }, commit=False)
134 | if class_acc:
135 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False)
136 | if use_amp:
137 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
138 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
139 |
140 |
141 | # gather the stats from all processes
142 | metric_logger.synchronize_between_processes()
143 | print("Averaged stats:", metric_logger)
144 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
145 |
146 | @torch.no_grad()
147 | def evaluate(data_loader, model, device, use_amp=False):
148 | criterion = torch.nn.CrossEntropyLoss()
149 |
150 | metric_logger = utils.MetricLogger(delimiter=" ")
151 | header = 'Test:'
152 |
153 | # switch to evaluation mode
154 | model.eval()
155 | for batch in metric_logger.log_every(data_loader, 10, header):
156 | images = batch[0]
157 | target = batch[-1]
158 |
159 | images = images.to(device, non_blocking=True)
160 | target = target.to(device, non_blocking=True)
161 |
162 | # compute output
163 | if use_amp:
164 | with torch.cuda.amp.autocast():
165 | output = model(images)
166 | loss = criterion(output, target)
167 | else:
168 | output = model(images)
169 | loss = criterion(output, target)
170 |
171 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
172 |
173 | batch_size = images.shape[0]
174 | metric_logger.update(loss=loss.item())
175 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
176 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
177 | # gather the stats from all processes
178 | metric_logger.synchronize_between_processes()
179 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
180 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
181 |
182 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
183 |
--------------------------------------------------------------------------------
/llama2/weight_selection.py:
--------------------------------------------------------------------------------
1 | from safetensors import safe_open
2 | import os
3 | import torch
4 | import collections
5 | import sys
6 | from functools import partial
7 | sys.path.insert(0, '/mnt/petrelfs/wangjiahao/DoiT')
8 | from models.illama import VisionTransformer as iLLaMa
9 | from models.illama import RMSNorm
10 |
11 |
12 | def uniform_element_selection(wt, s_shape):
13 | assert wt.dim() == len(s_shape), "Tensors have different number of dimensions"
14 | ws = wt.clone()
15 | for dim in range(wt.dim()):
16 | assert wt.shape[dim] >= s_shape[dim], "Teacher's dimension should not be smaller than student's dimension" # determine whether teacher is larger than student on this dimension
17 | if wt.shape[dim] % s_shape[dim] == 0:
18 | step = wt.shape[dim] // s_shape[dim]
19 | indices = torch.arange(s_shape[dim]) * step
20 | else:
21 | indices = torch.round(torch.linspace(0, wt.shape[dim]-1, s_shape[dim])).to(torch.int64)
22 | ws = torch.index_select(ws, dim, indices)
23 | assert ws.shape == s_shape
24 | return ws
25 |
26 |
27 | # show iLLaMA keys
28 | # illama_t = iLLaMa(
29 | # patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
30 | # norm_layer=partial(RMSNorm, eps=1e-6), num_classes=1000, drop_path_rate=0.0, drop_rate=0.0)
31 | tensors_illama = torch.load('/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-tiny-in1k-75.0.pth', map_location='cpu')
32 | if 'model' in tensors_illama.keys():
33 | tensors_illama = tensors_illama['model']
34 | print("illama keys:")
35 | print("\n".join(tensors_illama.keys()))
36 |
37 |
38 | # load llama2-7b-hf
39 | tensors_llama2 = {}
40 | path="llama2-7b-hf"
41 | for n in ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"]:
42 | file_name = os.path.join(path, n)
43 | with safe_open(file_name, framework="pt", device="cpu") as f:
44 | for k in f.keys():
45 | tensors_llama2[k] = f.get_tensor(k)
46 | print("llama2 keys:")
47 | print("\n".join(tensors_llama2.keys()))
48 | print("load done")
49 |
50 | # (576, 192)
51 | print(tensors_illama["blocks.8.attn.qkv.weight"].shape)
52 | # (4096, 4096)
53 | print(tensors_llama2["model.layers.12.self_attn.q_proj.weight"].shape)
54 | # (11008, 4096)
55 | print(tensors_llama2["model.layers.12.mlp.up_proj.weight"].shape)
56 | # (11008, 4096)
57 | print(tensors_llama2["model.layers.12.mlp.gate_proj.weight"].shape)
58 | # (4096)
59 | print(tensors_llama2["model.layers.12.input_layernorm.weight"].shape)
60 |
61 |
62 | # illama:
63 | # blocks.8.norm1.weight
64 | # blocks.8.attn.qkv.weight
65 | # blocks.8.attn.qkv.bias
66 | # blocks.8.attn.proj.weight
67 | # blocks.8.attn.proj.bias
68 | # blocks.8.norm2.weight
69 | # blocks.8.mlp.fc1.weight
70 | # blocks.8.mlp.fc2.weight
71 | # blocks.8.mlp.fc3.weight
72 |
73 | # llama2:
74 | # model.layers.12.input_layernorm.weight
75 | # model.layers.12.mlp.down_proj.weight
76 | # model.layers.12.mlp.gate_proj.weight
77 | # model.layers.12.mlp.up_proj.weight
78 | # model.layers.12.post_attention_layernorm.weight
79 | # model.layers.12.self_attn.k_proj.weight
80 | # model.layers.12.self_attn.o_proj.weight
81 | # model.layers.12.self_attn.q_proj.weight
82 | # model.layers.12.self_attn.rotary_emb.inv_freq
83 | # model.layers.12.self_attn.v_proj.weight
84 |
85 | tensors_llama2_to_illama = collections.OrderedDict()
86 | # for k,v in tensors_llama2.items():
87 | for i in range(12):
88 | # norm
89 | tensors_llama2_to_illama["blocks." + str(i) + ".norm1.weight"] = \
90 | tensors_llama2["model.layers." + str(i) + ".input_layernorm.weight"]
91 | tensors_llama2_to_illama["blocks." + str(i) + ".norm2.weight"] = \
92 | tensors_llama2["model.layers." + str(i) + ".post_attention_layernorm.weight"]
93 | # attn
94 | tensors_llama2_to_illama["blocks." + str(i) + ".attn.qkv.weight"] = \
95 | torch.cat((tensors_llama2["model.layers." + str(i) + ".self_attn.q_proj.weight"],
96 | tensors_llama2["model.layers." + str(i) + ".self_attn.k_proj.weight"],
97 | tensors_llama2["model.layers." + str(i) + ".self_attn.v_proj.weight"]), dim=0)
98 | tensors_llama2_to_illama["blocks." + str(i) + ".attn.proj.weight"] = \
99 | tensors_llama2["model.layers." + str(i) + ".self_attn.o_proj.weight"]
100 | # ffn
101 | tensors_llama2_to_illama["blocks." + str(i) + ".mlp.fc1.weight"] = \
102 | tensors_llama2["model.layers." + str(i) + ".mlp.gate_proj.weight"]
103 | tensors_llama2_to_illama["blocks." + str(i) + ".mlp.fc2.weight"] = \
104 | tensors_llama2["model.layers." + str(i) + ".mlp.down_proj.weight"]
105 | tensors_llama2_to_illama["blocks." + str(i) + ".mlp.fc3.weight"] = \
106 | tensors_llama2["model.layers." + str(i) + ".mlp.up_proj.weight"]
107 | print("llama2_to_illama keys:")
108 | print("\n".join(tensors_llama2_to_illama.keys()))
109 |
110 | # save new
111 | # torch.save(tensors_llama2_to_illama, 'llama2/1.pth')
112 |
113 |
114 |
115 | # weight selection from llama2 to illama
116 | student_illama_tiny = iLLaMa(
117 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
118 | norm_layer=partial(RMSNorm, eps=1e-6), num_classes=1000, drop_path_rate=0.0, drop_rate=0.0)
119 | student_illama_small = iLLaMa(
120 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
121 | norm_layer=partial(RMSNorm, eps=1e-6), num_classes=1000, drop_path_rate=0.0, drop_rate=0.0)
122 | student_illama_base = iLLaMa(
123 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
124 | norm_layer=partial(RMSNorm, eps=1e-6), num_classes=1000, drop_path_rate=0.0, drop_rate=0.0)
125 |
126 | teacher_weights = tensors_llama2_to_illama
127 | student_weights_tiny = student_illama_tiny.state_dict()
128 | student_weights_small = student_illama_small.state_dict()
129 | student_weights_base = student_illama_base.state_dict()
130 |
131 | # illama tiny weight selection
132 | student_weight_selection_tiny = collections.OrderedDict()
133 | for key in student_weights_tiny.keys():
134 | if "norm1.weight" in key or "norm2.weight" in key or "attn.qkv.weight" in key or "attn.proj.weight" in key or "mlp.fc1.weight" in key or "mlp.fc2.weight" in key or "mlp.fc3.weight" in key:
135 | print(f"initializing {key}")
136 | print("teacher_weights shape:", teacher_weights[key].shape)
137 | print("student_weights_tiny shape:", student_weights_tiny[key].shape)
138 | student_weight_selection_tiny[key] = uniform_element_selection(teacher_weights[key], student_weights_tiny[key].shape)
139 | print("student_weights_tiny initialization done")
140 |
141 | # illama small weight selection
142 | student_weight_selection_small = collections.OrderedDict()
143 | for key in student_weights_small.keys():
144 | if "norm1.weight" in key or "norm2.weight" in key or "attn.qkv.weight" in key or "attn.proj.weight" in key or "mlp.fc1.weight" in key or "mlp.fc2.weight" in key or "mlp.fc3.weight" in key:
145 | print(f"initializing {key}")
146 | print("teacher_weights shape:", teacher_weights[key].shape)
147 | print("student_weights_small shape:", student_weights_small[key].shape)
148 | student_weight_selection_small[key] = uniform_element_selection(teacher_weights[key], student_weights_small[key].shape)
149 | print("student_weights_small initialization done")
150 |
151 | # illama base weight selection
152 | student_weight_selection_base = collections.OrderedDict()
153 | for key in student_weights_base.keys():
154 | if "norm1.weight" in key or "norm2.weight" in key or "attn.qkv.weight" in key or "attn.proj.weight" in key or "mlp.fc1.weight" in key or "mlp.fc2.weight" in key or "mlp.fc3.weight" in key:
155 | print(f"initializing {key}")
156 | print("teacher_weights shape:", teacher_weights[key].shape)
157 | print("student_weights_base shape:", student_weights_base[key].shape)
158 | student_weight_selection_base[key] = uniform_element_selection(teacher_weights[key], student_weights_base[key].shape)
159 | print("student_weights_base initialization done")
160 |
161 |
162 | # check keys for selected tiny small base
163 | for key in student_weight_selection_tiny.keys():
164 | assert key in student_weights_tiny, f"Key {key} not found in Model iLLaMA-T"
165 | assert student_weight_selection_tiny[key].shape == student_weights_tiny[key].shape, f"Shape mismatch for key {key}: {student_weight_selection_tiny[key].shape} != {student_weights_tiny[key].shape}"
166 | for key in student_weight_selection_small.keys():
167 | assert key in student_weights_small, f"Key {key} not found in Model iLLaMA-S"
168 | assert student_weight_selection_small[key].shape == student_weights_small[key].shape, f"Shape mismatch for key {key}: {student_weight_selection_small[key].shape} != {student_weights_small[key].shape}"
169 | for key in student_weight_selection_base.keys():
170 | assert key in student_weights_base, f"Key {key} not found in Model iLLaMA-B"
171 | assert student_weight_selection_base[key].shape == student_weights_base[key].shape, f"Shape mismatch for key {key}: {student_weight_selection_base[key].shape} != {student_weights_base[key].shape}"
172 |
173 | print("All keys checked, you can use student_weight_selection_tiny, student_weight_selection_small, student_weight_selection_base.")
174 |
175 | # save weight selection for illama
176 | torch.save(student_weight_selection_tiny, 'llama2/pretrained/illama_ws_tiny.pth')
177 | torch.save(student_weight_selection_small, 'llama2/pretrained/illama_ws_small.pth')
178 | torch.save(student_weight_selection_base, 'llama2/pretrained/illama_ws_base.pth')
179 |
--------------------------------------------------------------------------------
/logs/illama-base-in1k-384-83.0.txt:
--------------------------------------------------------------------------------
1 | {"train_lr": 7.992788848282348e-05, "train_min_lr": 7.992788848282348e-05, "train_loss": 2.473320546398441, "train_class_acc": 0.648087779776179, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.733891578917285, "test_acc1": 81.76600249023437, "test_acc5": 95.97600270751953, "epoch": 0, "n_parameters": 86794216}
2 | {"train_lr": 7.949599477065516e-05, "train_min_lr": 7.949599477065516e-05, "train_loss": 2.2826324206866997, "train_class_acc": 0.6871768210431655, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7196540747989566, "test_acc1": 82.07200255126953, "test_acc5": 96.05600268554687, "epoch": 1, "n_parameters": 86794216}
3 | {"train_lr": 7.863685277934232e-05, "train_min_lr": 7.863685277934232e-05, "train_loss": 2.2307381229601697, "train_class_acc": 0.6979994129696243, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7166183656650303, "test_acc1": 82.18800258789062, "test_acc5": 96.01400269775391, "epoch": 2, "n_parameters": 86794216}
4 | {"train_lr": 7.735987544832995e-05, "train_min_lr": 7.735987544832995e-05, "train_loss": 2.199187022056892, "train_class_acc": 0.7046206784572342, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7158498064906315, "test_acc1": 82.26000260009765, "test_acc5": 96.04400275634765, "epoch": 3, "n_parameters": 86794216}
5 | {"train_lr": 7.567905360848068e-05, "train_min_lr": 7.567905360848068e-05, "train_loss": 2.176901005824907, "train_class_acc": 0.7092958445743405, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7096047366308119, "test_acc1": 82.44200252441406, "test_acc5": 96.15400270996093, "epoch": 4, "n_parameters": 86794216}
6 | {"train_lr": 7.361280269560707e-05, "train_min_lr": 7.361280269560707e-05, "train_loss": 2.155886616605482, "train_class_acc": 0.714189585831335, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7127004508711113, "test_acc1": 82.37200253417969, "test_acc5": 96.11600269287109, "epoch": 5, "n_parameters": 86794216}
7 | {"train_lr": 7.118376098710012e-05, "train_min_lr": 7.118376098710012e-05, "train_loss": 2.143780262636052, "train_class_acc": 0.7167828237410072, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.712347365247498, "test_acc1": 82.32000246337891, "test_acc5": 96.13600271484376, "epoch": 6, "n_parameters": 86794216}
8 | {"train_lr": 6.841854157222974e-05, "train_min_lr": 6.841854157222974e-05, "train_loss": 2.1295641221516997, "train_class_acc": 0.7199880720423661, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7088617964413102, "test_acc1": 82.33000249267577, "test_acc5": 96.1380027319336, "epoch": 7, "n_parameters": 86794216}
9 | {"train_lr": 6.534744077356283e-05, "train_min_lr": 6.534744077356283e-05, "train_loss": 2.117899560379229, "train_class_acc": 0.7227358738009593, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7082845338243451, "test_acc1": 82.44600254394531, "test_acc5": 96.18600267333984, "epoch": 8, "n_parameters": 86794216}
10 | {"train_lr": 6.200410621411962e-05, "train_min_lr": 6.200410621411962e-05, "train_loss": 2.105851644964265, "train_class_acc": 0.7254204448940847, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7060970708599869, "test_acc1": 82.48400255371094, "test_acc5": 96.21400270996094, "epoch": 9, "n_parameters": 86794216}
11 | {"train_lr": 5.8425168166971116e-05, "train_min_lr": 5.8425168166971116e-05, "train_loss": 2.0965708284486206, "train_class_acc": 0.7277209482414069, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7031197266738493, "test_acc1": 82.65600259033204, "test_acc5": 96.1560027368164, "epoch": 10, "n_parameters": 86794216}
12 | {"train_lr": 5.464983822630283e-05, "train_min_lr": 5.464983822630283e-05, "train_loss": 2.087975359384081, "train_class_acc": 0.7299106027677857, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7062890046126856, "test_acc1": 82.49000248291016, "test_acc5": 96.24600264404297, "epoch": 11, "n_parameters": 86794216}
13 | {"train_lr": 5.0719479696983124e-05, "train_min_lr": 5.0719479696983124e-05, "train_loss": 2.0755441978746516, "train_class_acc": 0.7325998576139089, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7021110344550195, "test_acc1": 82.65200256591797, "test_acc5": 96.22800266845704, "epoch": 12, "n_parameters": 86794216}
14 | {"train_lr": 4.667715440953944e-05, "train_min_lr": 4.667715440953944e-05, "train_loss": 2.069228977078109, "train_class_acc": 0.7342329261590728, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6990430768266434, "test_acc1": 82.67000250732421, "test_acc5": 96.24800268310547, "epoch": 13, "n_parameters": 86794216}
15 | {"train_lr": 4.256715092573226e-05, "train_min_lr": 4.256715092573226e-05, "train_loss": 2.060961584470493, "train_class_acc": 0.7357473396282974, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6991592321104108, "test_acc1": 82.67000253417969, "test_acc5": 96.28400275634766, "epoch": 14, "n_parameters": 86794216}
16 | {"train_lr": 3.843449930380363e-05, "train_min_lr": 3.843449930380363e-05, "train_loss": 2.0526455630340594, "train_class_acc": 0.7377301283972821, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.7010124566189888, "test_acc1": 82.69000256835938, "test_acc5": 96.27200267578125, "epoch": 15, "n_parameters": 86794216}
17 | {"train_lr": 3.432447773973649e-05, "train_min_lr": 3.432447773973649e-05, "train_loss": 2.0455116031558918, "train_class_acc": 0.7400072129796164, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6991047438248309, "test_acc1": 82.73400261962891, "test_acc5": 96.31600272216797, "epoch": 16, "n_parameters": 86794216}
18 | {"train_lr": 3.0282116489863717e-05, "train_min_lr": 3.0282116489863717e-05, "train_loss": 2.036234163123069, "train_class_acc": 0.7421242693345323, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6974382813067268, "test_acc1": 82.86800261962891, "test_acc5": 96.26000268310547, "epoch": 17, "n_parameters": 86794216}
19 | {"train_lr": 2.635170450995756e-05, "train_min_lr": 2.635170450995756e-05, "train_loss": 2.031415472404181, "train_class_acc": 0.7433147232214229, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6950523978355163, "test_acc1": 82.86200258544922, "test_acc5": 96.27200266845703, "epoch": 18, "n_parameters": 86794216}
20 | {"train_lr": 2.2576304216161565e-05, "train_min_lr": 2.2576304216161565e-05, "train_loss": 2.026128534757667, "train_class_acc": 0.7445496727617905, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6956147605212022, "test_acc1": 82.88400256835938, "test_acc5": 96.3520026928711, "epoch": 19, "n_parameters": 86794216}
21 | {"train_lr": 1.8997279684147784e-05, "train_min_lr": 1.8997279684147784e-05, "train_loss": 2.020787144066404, "train_class_acc": 0.7456807991107114, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6950576499118718, "test_acc1": 82.85400256835938, "test_acc5": 96.28000270019531, "epoch": 20, "n_parameters": 86794216}
22 | {"train_lr": 1.5653843455648113e-05, "train_min_lr": 1.5653843455648113e-05, "train_loss": 2.0178974027381384, "train_class_acc": 0.7466074015787371, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6941298640596162, "test_acc1": 82.87400258544922, "test_acc5": 96.32000274169921, "epoch": 21, "n_parameters": 86794216}
23 | {"train_lr": 1.2582626917640712e-05, "train_min_lr": 1.2582626917640712e-05, "train_loss": 2.0134329836723284, "train_class_acc": 0.7475152690347722, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6930451631574458, "test_acc1": 82.90600259521484, "test_acc5": 96.32400272460937, "epoch": 22, "n_parameters": 86794216}
24 | {"train_lr": 9.817278961209553e-06, "train_min_lr": 9.817278961209553e-06, "train_loss": 2.010525077039437, "train_class_acc": 0.7482654501398881, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6927108393967835, "test_acc1": 82.94400260009766, "test_acc5": 96.34200279785156, "epoch": 23, "n_parameters": 86794216}
25 | {"train_lr": 7.388097317251548e-06, "train_min_lr": 7.388097317251548e-06, "train_loss": 2.007418536215568, "train_class_acc": 0.7493333458233413, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6922111400680346, "test_acc1": 82.94600257324218, "test_acc5": 96.33200274169921, "epoch": 24, "n_parameters": 86794216}
26 | {"train_lr": 5.321696608196864e-06, "train_min_lr": 5.321696608196864e-06, "train_loss": 2.002901304453659, "train_class_acc": 0.7501826663669064, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.692061053752501, "test_acc1": 82.95000255615234, "test_acc5": 96.33200280029297, "epoch": 25, "n_parameters": 86794216}
27 | {"train_lr": 3.6407167526360116e-06, "train_min_lr": 3.6407167526360116e-06, "train_loss": 2.001760026268798, "train_class_acc": 0.7506760216826539, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6917145001268341, "test_acc1": 82.95800252929688, "test_acc5": 96.34800278320313, "epoch": 26, "n_parameters": 86794216}
28 | {"train_lr": 2.3635749176341855e-06, "train_min_lr": 2.3635749176341855e-06, "train_loss": 2.002015811621214, "train_class_acc": 0.7507548648581135, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6918871564264516, "test_acc1": 82.95200257324218, "test_acc5": 96.33200277832032, "epoch": 27, "n_parameters": 86794216}
29 | {"train_lr": 1.5042637363947687e-06, "train_min_lr": 1.5042637363947687e-06, "train_loss": 1.9995884048453958, "train_class_acc": 0.7511397132294164, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6912183112443517, "test_acc1": 82.9560025805664, "test_acc5": 96.33800274902343, "epoch": 28, "n_parameters": 86794216}
30 | {"train_lr": 1.0721980020418422e-06, "train_min_lr": 1.0721980020418422e-06, "train_loss": 2.000189851820588, "train_class_acc": 0.7509312862210232, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.8000000000000723, "train_soft_mask_rate": 0.0, "test_loss": 0.6910126191272422, "test_acc1": 82.99800246337891, "test_acc5": 96.33600276855469, "epoch": 29, "n_parameters": 86794216}
31 |
--------------------------------------------------------------------------------
/logs/illama-base-in21kin1k-224-83.6.txt:
--------------------------------------------------------------------------------
1 | {"train_lr": 7.992788848282348e-05, "train_min_lr": 7.992788848282348e-05, "train_loss": 2.8800309649784026, "train_class_acc": 0.6066779388988809, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.8331461997404126, "test_acc1": 79.12400257568359, "test_acc5": 95.1300026611328, "epoch": 0, "n_parameters": 86502376}
2 | {"train_lr": 7.949599477065516e-05, "train_min_lr": 7.949599477065516e-05, "train_loss": 2.2193321749556074, "train_class_acc": 0.6949448253896883, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.7763625056916521, "test_acc1": 80.35400242919921, "test_acc5": 95.79000268554688, "epoch": 1, "n_parameters": 86502376}
3 | {"train_lr": 7.863685277934232e-05, "train_min_lr": 7.863685277934232e-05, "train_loss": 2.145499920530571, "train_class_acc": 0.7092076338928857, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.7521247481145249, "test_acc1": 81.05800251953124, "test_acc5": 96.04000268554688, "epoch": 2, "n_parameters": 86502376}
4 | {"train_lr": 7.735987544832995e-05, "train_min_lr": 7.735987544832995e-05, "train_loss": 2.0983406849008954, "train_class_acc": 0.718855384442446, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.7348379289077325, "test_acc1": 81.4920025769043, "test_acc5": 96.17200262939453, "epoch": 3, "n_parameters": 86502376}
5 | {"train_lr": 7.567905360848068e-05, "train_min_lr": 7.567905360848068e-05, "train_loss": 2.0673020389720165, "train_class_acc": 0.7251214653277378, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.71815004223915, "test_acc1": 81.904002578125, "test_acc5": 96.32000264404297, "epoch": 4, "n_parameters": 86502376}
6 | {"train_lr": 7.361280269560707e-05, "train_min_lr": 7.361280269560707e-05, "train_loss": 2.039417670573667, "train_class_acc": 0.7317458533173461, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.710230432720926, "test_acc1": 82.15400252319336, "test_acc5": 96.39000264648438, "epoch": 5, "n_parameters": 86502376}
7 | {"train_lr": 7.118376098710012e-05, "train_min_lr": 7.118376098710012e-05, "train_loss": 2.0196545109164705, "train_class_acc": 0.735794957783773, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.7012611285975302, "test_acc1": 82.18800258666992, "test_acc5": 96.47600269042968, "epoch": 6, "n_parameters": 86502376}
8 | {"train_lr": 6.841854157222974e-05, "train_min_lr": 6.841854157222974e-05, "train_loss": 2.0017266717883087, "train_class_acc": 0.7400204836131095, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.697616238310535, "test_acc1": 82.52600250366211, "test_acc5": 96.49200275146484, "epoch": 7, "n_parameters": 86502376}
9 | {"train_lr": 6.534744077356283e-05, "train_min_lr": 6.534744077356283e-05, "train_loss": 1.9866541259699493, "train_class_acc": 0.7436441471822542, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6930679365324269, "test_acc1": 82.61600248535156, "test_acc5": 96.51000275634766, "epoch": 8, "n_parameters": 86502376}
10 | {"train_lr": 6.200410621411962e-05, "train_min_lr": 6.200410621411962e-05, "train_loss": 1.9714939381977876, "train_class_acc": 0.7470945118904876, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6908177238825055, "test_acc1": 82.52600251464844, "test_acc5": 96.59200271240235, "epoch": 9, "n_parameters": 86502376}
11 | {"train_lr": 5.8425168166971116e-05, "train_min_lr": 5.8425168166971116e-05, "train_loss": 1.960133994753055, "train_class_acc": 0.7494847871702638, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6893090736263126, "test_acc1": 82.65200252685547, "test_acc5": 96.58000280029297, "epoch": 10, "n_parameters": 86502376}
12 | {"train_lr": 5.464983822630283e-05, "train_min_lr": 5.464983822630283e-05, "train_loss": 1.946019628276523, "train_class_acc": 0.7532965814848122, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6812137730903071, "test_acc1": 82.84400247802735, "test_acc5": 96.63000275146484, "epoch": 11, "n_parameters": 86502376}
13 | {"train_lr": 5.0719479696983124e-05, "train_min_lr": 5.0719479696983124e-05, "train_loss": 1.9346819410054423, "train_class_acc": 0.7557016886490807, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6739478692668313, "test_acc1": 83.09200239746093, "test_acc5": 96.69000272460937, "epoch": 12, "n_parameters": 86502376}
14 | {"train_lr": 4.667715440953944e-05, "train_min_lr": 4.667715440953944e-05, "train_loss": 1.9267961068881405, "train_class_acc": 0.75771492181255, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6747824225560514, "test_acc1": 83.08600253417968, "test_acc5": 96.75600277099609, "epoch": 13, "n_parameters": 86502376}
15 | {"train_lr": 4.256715092573226e-05, "train_min_lr": 4.256715092573226e-05, "train_loss": 1.916994124752202, "train_class_acc": 0.7603752935151878, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6742744893688737, "test_acc1": 83.1600024609375, "test_acc5": 96.69000269287109, "epoch": 14, "n_parameters": 86502376}
16 | {"train_lr": 3.843449930380363e-05, "train_min_lr": 3.843449930380363e-05, "train_loss": 1.9078451729953574, "train_class_acc": 0.7626750162370104, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6682798480686113, "test_acc1": 83.12800247558594, "test_acc5": 96.66800276855469, "epoch": 15, "n_parameters": 86502376}
17 | {"train_lr": 3.432447773973649e-05, "train_min_lr": 3.432447773973649e-05, "train_loss": 1.9017058968645253, "train_class_acc": 0.7641628884392486, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6697903494741625, "test_acc1": 83.13600244628907, "test_acc5": 96.710002734375, "epoch": 16, "n_parameters": 86502376}
18 | {"train_lr": 3.0282116489863717e-05, "train_min_lr": 3.0282116489863717e-05, "train_loss": 1.892893716779282, "train_class_acc": 0.7661409934552358, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6674242881304435, "test_acc1": 83.25400260986328, "test_acc5": 96.72800265625, "epoch": 17, "n_parameters": 86502376}
19 | {"train_lr": 2.635170450995756e-05, "train_min_lr": 2.635170450995756e-05, "train_loss": 1.8847338850770017, "train_class_acc": 0.7681589103717026, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.664745727070529, "test_acc1": 83.29600263183593, "test_acc5": 96.77000273681641, "epoch": 18, "n_parameters": 86502376}
20 | {"train_lr": 2.2576304216161565e-05, "train_min_lr": 2.2576304216161565e-05, "train_loss": 1.8793799648854992, "train_class_acc": 0.7697154776179057, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6642096964713499, "test_acc1": 83.43800253173828, "test_acc5": 96.77800274169923, "epoch": 19, "n_parameters": 86502376}
21 | {"train_lr": 1.8997279684147784e-05, "train_min_lr": 1.8997279684147784e-05, "train_loss": 1.8746926908661945, "train_class_acc": 0.7709285696442846, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6633108184452503, "test_acc1": 83.33400262695312, "test_acc5": 96.75200265869141, "epoch": 20, "n_parameters": 86502376}
22 | {"train_lr": 1.5653843455648113e-05, "train_min_lr": 1.5653843455648113e-05, "train_loss": 1.8700135239421558, "train_class_acc": 0.7721518098021583, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6620311849956749, "test_acc1": 83.49600257568359, "test_acc5": 96.81000268554688, "epoch": 21, "n_parameters": 86502376}
23 | {"train_lr": 1.2582626917640712e-05, "train_min_lr": 1.2582626917640712e-05, "train_loss": 1.8679893643503709, "train_class_acc": 0.7726155013489209, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6601207956515309, "test_acc1": 83.5120025805664, "test_acc5": 96.7720026977539, "epoch": 22, "n_parameters": 86502376}
24 | {"train_lr": 9.817278961209553e-06, "train_min_lr": 9.817278961209553e-06, "train_loss": 1.8622591408203355, "train_class_acc": 0.7740495103916867, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.660049624840608, "test_acc1": 83.48000255615234, "test_acc5": 96.81000270751953, "epoch": 23, "n_parameters": 86502376}
25 | {"train_lr": 7.388097317251548e-06, "train_min_lr": 7.388097317251548e-06, "train_loss": 1.8585193322585831, "train_class_acc": 0.7752017136290967, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6595075904526784, "test_acc1": 83.55000259033203, "test_acc5": 96.78800273681641, "epoch": 24, "n_parameters": 86502376}
26 | {"train_lr": 5.321696608196864e-06, "train_min_lr": 5.321696608196864e-06, "train_loss": 1.8578995852280769, "train_class_acc": 0.7751564373501199, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.657806609423106, "test_acc1": 83.60400248779297, "test_acc5": 96.79800274414063, "epoch": 25, "n_parameters": 86502376}
27 | {"train_lr": 3.6407167526360116e-06, "train_min_lr": 3.6407167526360116e-06, "train_loss": 1.8572985059429081, "train_class_acc": 0.7757184877098321, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6586373214545937, "test_acc1": 83.59400256347656, "test_acc5": 96.78200276123047, "epoch": 26, "n_parameters": 86502376}
28 | {"train_lr": 2.3635749176341855e-06, "train_min_lr": 2.3635749176341855e-06, "train_loss": 1.8557780712819119, "train_class_acc": 0.7756099807653877, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6584229432330787, "test_acc1": 83.53800257568359, "test_acc5": 96.79200271484375, "epoch": 27, "n_parameters": 86502376}
29 | {"train_lr": 1.5042637363947687e-06, "train_min_lr": 1.5042637363947687e-06, "train_loss": 1.8542869655306724, "train_class_acc": 0.7759206697142286, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6582214436884827, "test_acc1": 83.57600252685548, "test_acc5": 96.79600274169921, "epoch": 28, "n_parameters": 86502376}
30 | {"train_lr": 1.0721980020418422e-06, "train_min_lr": 1.0721980020418422e-06, "train_loss": 1.8537212703302561, "train_class_acc": 0.7762789768185452, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.20000000000001808, "train_soft_mask_rate": 0.0, "test_loss": 0.6580269141762539, "test_acc1": 83.56400259277343, "test_acc5": 96.79200272949218, "epoch": 29, "n_parameters": 86502376}
31 |
--------------------------------------------------------------------------------
/logs/illama-base-in21kin1k-384-85.0.txt:
--------------------------------------------------------------------------------
1 | {"train_lr": 0.00010990050436237771, "train_min_lr": 0.00010990050436237771, "train_loss": 2.873497698356291, "train_class_acc": 0.5845987272681854, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.7941034439580079, "test_acc1": 80.10400251403809, "test_acc5": 95.64000197753906, "epoch": 0, "n_parameters": 86794216}
2 | {"train_lr": 0.0001093046003797633, "train_min_lr": 0.0001093046003797633, "train_loss": 2.1891729478314104, "train_class_acc": 0.6994349832633893, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.7331855757112019, "test_acc1": 81.6440025, "test_acc5": 96.29000178466796, "epoch": 1, "n_parameters": 86794216}
3 | {"train_lr": 0.00010811920193605486, "train_min_lr": 0.00010811920193605486, "train_loss": 2.0989335492375396, "train_class_acc": 0.7191855890287769, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6977391123586116, "test_acc1": 82.28000253540038, "test_acc5": 96.6640016796875, "epoch": 2, "n_parameters": 86794216}
4 | {"train_lr": 0.00010635729650465759, "train_min_lr": 0.00010635729650465759, "train_loss": 2.044830535443948, "train_class_acc": 0.7307005957733813, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6783538484995849, "test_acc1": 82.9180024987793, "test_acc5": 96.83400157958984, "epoch": 3, "n_parameters": 86794216}
5 | {"train_lr": 0.00010403818789018106, "train_min_lr": 0.00010403818789018106, "train_loss": 2.0058495105370033, "train_class_acc": 0.7397300909272582, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.661780134467872, "test_acc1": 83.26400256591796, "test_acc5": 96.93800155395508, "epoch": 4, "n_parameters": 86794216}
6 | {"train_lr": 0.00010118728473191276, "train_min_lr": 0.00010118728473191276, "train_loss": 1.974607850857478, "train_class_acc": 0.7472795201338929, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6588617378553211, "test_acc1": 83.33800248962402, "test_acc5": 96.9840015612793, "epoch": 5, "n_parameters": 86794216}
7 | {"train_lr": 9.783582212144216e-05, "train_min_lr": 9.783582212144216e-05, "train_loss": 1.9508491760591784, "train_class_acc": 0.7529296875, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6439973158098158, "test_acc1": 83.67800259216308, "test_acc5": 97.14000151123047, "epoch": 6, "n_parameters": 86794216}
8 | {"train_lr": 9.40205193844685e-05, "train_min_lr": 9.40205193844685e-05, "train_loss": 1.927743834557293, "train_class_acc": 0.7581224083233413, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6432044282341483, "test_acc1": 83.71600251037597, "test_acc5": 97.12800151611329, "epoch": 7, "n_parameters": 86794216}
9 | {"train_lr": 8.978317777618145e-05, "train_min_lr": 8.978317777618145e-05, "train_loss": 1.909357096816746, "train_class_acc": 0.7624626861011191, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6321660951921766, "test_acc1": 83.95000252563477, "test_acc5": 97.20400146850587, "epoch": 8, "n_parameters": 86794216}
10 | {"train_lr": 8.517022249796265e-05, "train_min_lr": 8.517022249796265e-05, "train_loss": 1.8913158439174236, "train_class_acc": 0.7667912544964028, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6311975168742211, "test_acc1": 84.0760025769043, "test_acc5": 97.18400147705078, "epoch": 9, "n_parameters": 86794216}
11 | {"train_lr": 8.023219405316332e-05, "train_min_lr": 8.023219405316332e-05, "train_loss": 1.8768250636360366, "train_class_acc": 0.7702845536071143, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6237856106889716, "test_acc1": 84.27800251464843, "test_acc5": 97.26600143920898, "epoch": 10, "n_parameters": 86794216}
12 | {"train_lr": 7.502319451477216e-05, "train_min_lr": 7.502319451477216e-05, "train_loss": 1.8619147903106505, "train_class_acc": 0.7740620003996802, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6272280179463698, "test_acc1": 84.36800256347657, "test_acc5": 97.22600145751953, "epoch": 11, "n_parameters": 86794216}
13 | {"train_lr": 6.960029477178692e-05, "train_min_lr": 6.960029477178692e-05, "train_loss": 1.849825464002937, "train_class_acc": 0.7773375049960032, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6230855795636887, "test_acc1": 84.43800255615234, "test_acc5": 97.23800145874023, "epoch": 12, "n_parameters": 86794216}
14 | {"train_lr": 6.402290924860439e-05, "train_min_lr": 6.402290924860439e-05, "train_loss": 1.8378819944076448, "train_class_acc": 0.7804053632094324, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.622640575275139, "test_acc1": 84.46200261657715, "test_acc5": 97.28800141845703, "epoch": 13, "n_parameters": 86794216}
15 | {"train_lr": 5.8352144948162407e-05, "train_min_lr": 5.8352144948162407e-05, "train_loss": 1.827634315547659, "train_class_acc": 0.783148481215028, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6218125294836887, "test_acc1": 84.50000257263184, "test_acc5": 97.27200144287109, "epoch": 14, "n_parameters": 86794216}
16 | {"train_lr": 5.265013195081766e-05, "train_min_lr": 5.265013195081766e-05, "train_loss": 1.8147839116398616, "train_class_acc": 0.7861304706235012, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6165422496325211, "test_acc1": 84.60000260375976, "test_acc5": 97.31800141357422, "epoch": 15, "n_parameters": 86794216}
17 | {"train_lr": 4.6979342704193364e-05, "train_min_lr": 4.6979342704193364e-05, "train_loss": 1.8073137391844003, "train_class_acc": 0.7881702450539568, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6163868690409138, "test_acc1": 84.58200257385253, "test_acc5": 97.33000140991211, "epoch": 16, "n_parameters": 86794216}
18 | {"train_lr": 4.1401907561964056e-05, "train_min_lr": 4.1401907561964056e-05, "train_loss": 1.7976006015250794, "train_class_acc": 0.7906003322342127, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6167496893897482, "test_acc1": 84.6180025604248, "test_acc5": 97.35400142089844, "epoch": 17, "n_parameters": 86794216}
19 | {"train_lr": 3.597893407070102e-05, "train_min_lr": 3.597893407070102e-05, "train_loss": 1.7888045026217219, "train_class_acc": 0.7930522769284573, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6136899749388012, "test_acc1": 84.5920025415039, "test_acc5": 97.39400140869141, "epoch": 18, "n_parameters": 86794216}
20 | {"train_lr": 3.076983746280512e-05, "train_min_lr": 3.076983746280512e-05, "train_loss": 1.783275410302168, "train_class_acc": 0.7947243767486011, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6135569893937002, "test_acc1": 84.66600256591796, "test_acc5": 97.36400138305665, "epoch": 19, "n_parameters": 86794216}
21 | {"train_lr": 2.5831689690786165e-05, "train_min_lr": 2.5831689690786165e-05, "train_loss": 1.7752880633955332, "train_class_acc": 0.7968812450039968, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6124859313405622, "test_acc1": 84.76600255981445, "test_acc5": 97.40800135131836, "epoch": 20, "n_parameters": 86794216}
22 | {"train_lr": 2.1218594135007993e-05, "train_min_lr": 2.1218594135007993e-05, "train_loss": 1.770902261998561, "train_class_acc": 0.7981021432853717, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6106539139188684, "test_acc1": 84.81200262145997, "test_acc5": 97.38000138061524, "epoch": 21, "n_parameters": 86794216}
23 | {"train_lr": 1.69810928357321e-05, "train_min_lr": 1.69810928357321e-05, "train_loss": 1.7643373457746183, "train_class_acc": 0.7997430180855316, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6096352562558092, "test_acc1": 84.88000256713867, "test_acc5": 97.38400138061523, "epoch": 22, "n_parameters": 86794216}
24 | {"train_lr": 1.3165612743947417e-05, "train_min_lr": 1.3165612743947417e-05, "train_loss": 1.7615033509440536, "train_class_acc": 0.8006266861510791, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6073408679915342, "test_acc1": 84.89200262695313, "test_acc5": 97.43200135620117, "epoch": 23, "n_parameters": 86794216}
25 | {"train_lr": 9.81395705798004e-06, "train_min_lr": 9.81395705798004e-06, "train_loss": 1.757708400523634, "train_class_acc": 0.8013019272082335, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.608022909829129, "test_acc1": 84.94000262390136, "test_acc5": 97.41000136230468, "epoch": 24, "n_parameters": 86794216}
26 | {"train_lr": 6.962847218904466e-06, "train_min_lr": 6.962847218904466e-06, "train_loss": 1.7527120730532206, "train_class_acc": 0.8027710643984812, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6065240236306099, "test_acc1": 84.92000258117676, "test_acc5": 97.44800136108398, "epoch": 25, "n_parameters": 86794216}
27 | {"train_lr": 4.643520582750963e-06, "train_min_lr": 4.643520582750963e-06, "train_loss": 1.750563392784896, "train_class_acc": 0.8033440435151878, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6067701281401617, "test_acc1": 84.99000257385254, "test_acc5": 97.44600135864258, "epoch": 26, "n_parameters": 86794216}
28 | {"train_lr": 2.8813881774952494e-06, "train_min_lr": 2.8813881774952494e-06, "train_loss": 1.7516952332535045, "train_class_acc": 0.8030598958333334, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6067123004089028, "test_acc1": 84.98800256896973, "test_acc5": 97.43000136108398, "epoch": 27, "n_parameters": 86794216}
29 | {"train_lr": 1.6957562945193597e-06, "train_min_lr": 1.6957562945193597e-06, "train_loss": 1.7505262671343738, "train_class_acc": 0.8033760491606715, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6069099562442452, "test_acc1": 84.98800257080079, "test_acc5": 97.43000136108398, "epoch": 28, "n_parameters": 86794216}
30 | {"train_lr": 1.0996149648425372e-06, "train_min_lr": 1.0996149648425372e-06, "train_loss": 1.7485280318288423, "train_class_acc": 0.8038975069944044, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.2000000000000318, "train_soft_mask_rate": 0.0, "test_loss": 0.6071232186801646, "test_acc1": 84.98400258239747, "test_acc5": 97.4180013684082, "epoch": 29, "n_parameters": 86794216}
31 |
--------------------------------------------------------------------------------
/logs/illama-large-in21kin1k-224-84.8.txt:
--------------------------------------------------------------------------------
1 | {"train_lr": 5.99461445631214e-05, "train_min_lr": 5.99461445631214e-05, "train_loss": 2.7198627313895285, "train_class_acc": 0.6766891174560352, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6804418705816146, "test_acc1": 82.96000253662109, "test_acc5": 96.45800264160157, "epoch": 0, "n_parameters": 310445032}
2 | {"train_lr": 5.9623591031248656e-05, "train_min_lr": 5.9623591031248656e-05, "train_loss": 1.9470356172538823, "train_class_acc": 0.7654407723820943, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6517717496150721, "test_acc1": 83.69000252197266, "test_acc5": 96.84800266845703, "epoch": 1, "n_parameters": 310445032}
3 | {"train_lr": 5.898195334153406e-05, "train_min_lr": 5.898195334153406e-05, "train_loss": 1.8859988836188921, "train_class_acc": 0.7761657861211031, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.644474639695934, "test_acc1": 83.92600262451172, "test_acc5": 96.99000259765624, "epoch": 2, "n_parameters": 310445032}
4 | {"train_lr": 5.802826141077773e-05, "train_min_lr": 5.802826141077773e-05, "train_loss": 1.848835029488178, "train_class_acc": 0.7838791466826539, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6354520462439147, "test_acc1": 84.10400263671875, "test_acc5": 97.09200262695313, "epoch": 3, "n_parameters": 310445032}
5 | {"train_lr": 5.6772964087346344e-05, "train_min_lr": 5.6772964087346344e-05, "train_loss": 1.8186543585239745, "train_class_acc": 0.7905823778477218, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6252702991869158, "test_acc1": 84.28800256591796, "test_acc5": 97.17200262451172, "epoch": 4, "n_parameters": 310445032}
6 | {"train_lr": 5.522981467140261e-05, "train_min_lr": 5.522981467140261e-05, "train_loss": 1.7934198190125343, "train_class_acc": 0.7961583857913669, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6280702026916369, "test_acc1": 84.43800261230469, "test_acc5": 97.20400264160156, "epoch": 5, "n_parameters": 310445032}
7 | {"train_lr": 5.34157202308725e-05, "train_min_lr": 5.34157202308725e-05, "train_loss": 1.7750700143107074, "train_class_acc": 0.8004088916366906, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.626661965431546, "test_acc1": 84.5240024584961, "test_acc5": 97.25000264648438, "epoch": 6, "n_parameters": 310445032}
8 | {"train_lr": 5.135055636407023e-05, "train_min_lr": 5.135055636407023e-05, "train_loss": 1.7532443771009252, "train_class_acc": 0.8054899830135891, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6199031330730397, "test_acc1": 84.57000253662109, "test_acc5": 97.33400265136719, "epoch": 7, "n_parameters": 310445032}
9 | {"train_lr": 4.905694943848367e-05, "train_min_lr": 4.905694943848367e-05, "train_loss": 1.7408968822585402, "train_class_acc": 0.8087194307054356, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.62204097242642, "test_acc1": 84.67200252197266, "test_acc5": 97.26200261230468, "epoch": 8, "n_parameters": 310445032}
10 | {"train_lr": 4.656002869155787e-05, "train_min_lr": 4.656002869155787e-05, "train_loss": 1.7256579877798268, "train_class_acc": 0.8119886902977618, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6242762830620276, "test_acc1": 84.61400246337891, "test_acc5": 97.28200252685546, "epoch": 9, "n_parameters": 310445032}
11 | {"train_lr": 4.38871509095102e-05, "train_min_lr": 4.38871509095102e-05, "train_loss": 1.7129281256064999, "train_class_acc": 0.8155319494404476, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6209116319393497, "test_acc1": 84.7500024584961, "test_acc5": 97.33800256347656, "epoch": 10, "n_parameters": 310445032}
12 | {"train_lr": 4.1067600700656485e-05, "train_min_lr": 4.1067600700656485e-05, "train_loss": 1.6997128057864597, "train_class_acc": 0.8191392198741008, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6250483979620324, "test_acc1": 84.59800246826173, "test_acc5": 97.28200264648437, "epoch": 11, "n_parameters": 310445032}
13 | {"train_lr": 3.813226964711385e-05, "train_min_lr": 3.813226964711385e-05, "train_loss": 1.6904406517053203, "train_class_acc": 0.8210681454836131, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6233285891446676, "test_acc1": 84.73200243408203, "test_acc5": 97.26800258789062, "epoch": 12, "n_parameters": 310445032}
14 | {"train_lr": 3.511331785016224e-05, "train_min_lr": 3.511331785016224e-05, "train_loss": 1.6803397394651227, "train_class_acc": 0.8235848820943246, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6287899412897252, "test_acc1": 84.66200260009765, "test_acc5": 97.27600256835937, "epoch": 13, "n_parameters": 310445032}
15 | {"train_lr": 3.2043821577445613e-05, "train_min_lr": 3.2043821577445613e-05, "train_loss": 1.672017324668207, "train_class_acc": 0.8262663306854516, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6278802271832833, "test_acc1": 84.64600250732421, "test_acc5": 97.3100025805664, "epoch": 14, "n_parameters": 310445032}
16 | {"train_lr": 2.8957410872461055e-05, "train_min_lr": 2.8957410872461055e-05, "train_loss": 1.66248219554349, "train_class_acc": 0.8286121103117506, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6275539799444093, "test_acc1": 84.6720025366211, "test_acc5": 97.27400262207031, "epoch": 15, "n_parameters": 310445032}
17 | {"train_lr": 2.5887901096765108e-05, "train_min_lr": 2.5887901096765108e-05, "train_loss": 1.6523609593552318, "train_class_acc": 0.8313396158073542, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6307366626081234, "test_acc1": 84.72800237548829, "test_acc5": 97.29800254394532, "epoch": 16, "n_parameters": 310445032}
18 | {"train_lr": 2.2868922441796946e-05, "train_min_lr": 2.2868922441796946e-05, "train_loss": 1.6468906652965516, "train_class_acc": 0.8327057104316546, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6287269541951082, "test_acc1": 84.78000256347656, "test_acc5": 97.2840025390625, "epoch": 17, "n_parameters": 310445032}
19 | {"train_lr": 1.9933551469461998e-05, "train_min_lr": 1.9933551469461998e-05, "train_loss": 1.6423381217610207, "train_class_acc": 0.834031212529976, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6298858197619215, "test_acc1": 84.80200260498047, "test_acc5": 97.32000254882813, "epoch": 18, "n_parameters": 310445032}
20 | {"train_lr": 1.7113948718399234e-05, "train_min_lr": 1.7113948718399234e-05, "train_loss": 1.6362359538751778, "train_class_acc": 0.8353895008992805, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6317386629045693, "test_acc1": 84.75400255859375, "test_acc5": 97.27000256835937, "epoch": 19, "n_parameters": 310445032}
21 | {"train_lr": 1.4441006346388868e-05, "train_min_lr": 1.4441006346388868e-05, "train_loss": 1.6317886930557606, "train_class_acc": 0.8367150029976019, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.631641619341121, "test_acc1": 84.77200250244141, "test_acc5": 97.27600258300781, "epoch": 20, "n_parameters": 310445032}
22 | {"train_lr": 1.194400966940797e-05, "train_min_lr": 1.194400966940797e-05, "train_loss": 1.6257266758186975, "train_class_acc": 0.8386322192246203, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6315950325811068, "test_acc1": 84.7240025366211, "test_acc5": 97.31400260498047, "epoch": 21, "n_parameters": 310445032}
23 | {"train_lr": 9.650316305579721e-06, "train_min_lr": 9.650316305579721e-06, "train_loss": 1.6226078195424436, "train_class_acc": 0.8390529763689049, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6312967267896952, "test_acc1": 84.78600256591797, "test_acc5": 97.27000255371094, "epoch": 22, "n_parameters": 310445032}
24 | {"train_lr": 7.585056439384332e-06, "train_min_lr": 7.585056439384332e-06, "train_loss": 1.6201602544495575, "train_class_acc": 0.8399928494704236, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.631438076019799, "test_acc1": 84.79600241455078, "test_acc5": 97.27800255615234, "epoch": 23, "n_parameters": 310445032}
25 | {"train_lr": 5.770857490099259e-06, "train_min_lr": 5.770857490099259e-06, "train_loss": 1.618003329766883, "train_class_acc": 0.8402816809052758, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6306168148649558, "test_acc1": 84.76200250732421, "test_acc5": 97.29400259765625, "epoch": 24, "n_parameters": 310445032}
26 | {"train_lr": 4.227596201058397e-06, "train_min_lr": 4.227596201058397e-06, "train_loss": 1.6154780858676496, "train_class_acc": 0.8410123463729017, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.632067259187803, "test_acc1": 84.79600251708985, "test_acc5": 97.28800254638672, "epoch": 25, "n_parameters": 310445032}
27 | {"train_lr": 2.9721808658927067e-06, "train_min_lr": 2.9721808658927067e-06, "train_loss": 1.6140473674830915, "train_class_acc": 0.8416587042865707, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6312973588449127, "test_acc1": 84.83000244628906, "test_acc5": 97.27400254394531, "epoch": 26, "n_parameters": 310445032}
28 | {"train_lr": 2.0183660777267993e-06, "train_min_lr": 2.0183660777267993e-06, "train_loss": 1.6156177416944104, "train_class_acc": 0.8415580035971223, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6316051020763302, "test_acc1": 84.79000246826172, "test_acc5": 97.27800256347656, "epoch": 27, "n_parameters": 310445032}
29 | {"train_lr": 1.3766020309783612e-06, "train_min_lr": 1.3766020309783612e-06, "train_loss": 1.6134667002307854, "train_class_acc": 0.8415853254896083, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.632169377478195, "test_acc1": 84.79400240234375, "test_acc5": 97.29200255859375, "epoch": 28, "n_parameters": 310445032}
30 | {"train_lr": 1.053920026841383e-06, "train_min_lr": 1.053920026841383e-06, "train_loss": 1.6139339337841594, "train_class_acc": 0.8416836843025579, "train_weight_decay": 1.0000000000001084e-08, "train_drop_path": 0.29999999999997284, "train_soft_mask_rate": 0.0, "test_loss": 0.6321293162964, "test_acc1": 84.78600248046875, "test_acc5": 97.30600255371094, "epoch": 29, "n_parameters": 310445032}
31 |
--------------------------------------------------------------------------------
/logs/illama-large-in21kin1k-384-86.0.txt:
--------------------------------------------------------------------------------
1 | {"train_lr": 3.496896466349359e-05, "train_min_lr": 3.496896466349359e-05, "train_loss": 3.4040759515127452, "train_class_acc": 0.5474323666067147, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.7216167944301476, "test_acc1": 82.15600245239258, "test_acc5": 95.91400194091797, "epoch": 0, "n_parameters": 310834152}
2 | {"train_lr": 3.4783086356990544e-05, "train_min_lr": 3.4783086356990544e-05, "train_loss": 2.052658562197817, "train_class_acc": 0.7441101806055156, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.6449875789212769, "test_acc1": 83.74600245239257, "test_acc5": 96.87000160522462, "epoch": 1, "n_parameters": 310834152}
3 | {"train_lr": 3.441332904427394e-05, "train_min_lr": 3.441332904427394e-05, "train_loss": 1.9515358309367958, "train_class_acc": 0.7633596248001598, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.6230926276809069, "test_acc1": 84.35800254638671, "test_acc5": 97.18800145507812, "epoch": 2, "n_parameters": 310834152}
4 | {"train_lr": 3.386374386383809e-05, "train_min_lr": 3.386374386383809e-05, "train_loss": 1.897901961399759, "train_class_acc": 0.7735717675859313, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.6125124040170362, "test_acc1": 84.66400259338378, "test_acc5": 97.27800142578126, "epoch": 3, "n_parameters": 310834152}
5 | {"train_lr": 3.314035218592863e-05, "train_min_lr": 3.314035218592863e-05, "train_loss": 1.8600160417219194, "train_class_acc": 0.7820524830135891, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5977429341974443, "test_acc1": 85.0540025817871, "test_acc5": 97.46000132568359, "epoch": 4, "n_parameters": 310834152}
6 | {"train_lr": 3.225107964114749e-05, "train_min_lr": 3.225107964114749e-05, "train_loss": 1.830572905389668, "train_class_acc": 0.7881803931854516, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5893618462757132, "test_acc1": 85.29800263305664, "test_acc5": 97.53000135253906, "epoch": 5, "n_parameters": 310834152}
7 | {"train_lr": 3.1205669285587655e-05, "train_min_lr": 3.1205669285587655e-05, "train_loss": 1.8076860456690966, "train_class_acc": 0.7934660084432454, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5866689494850249, "test_acc1": 85.39600262695312, "test_acc5": 97.58600130737305, "epoch": 6, "n_parameters": 310834152}
8 | {"train_lr": 3.0015574853870956e-05, "train_min_lr": 3.0015574853870956e-05, "train_loss": 1.7887820138318313, "train_class_acc": 0.7975385316746603, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5859094357039508, "test_acc1": 85.33000260742187, "test_acc5": 97.65400125488281, "epoch": 7, "n_parameters": 310834152}
9 | {"train_lr": 2.8693835269634953e-05, "train_min_lr": 2.8693835269634953e-05, "train_loss": 1.7704070983923597, "train_class_acc": 0.801730490607514, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5785367794234977, "test_acc1": 85.49200261840821, "test_acc5": 97.74800122314453, "epoch": 8, "n_parameters": 310834152}
10 | {"train_lr": 2.7254931788355584e-05, "train_min_lr": 2.7254931788355584e-05, "train_loss": 1.7566198213152486, "train_class_acc": 0.8048787532474021, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5768480996861769, "test_acc1": 85.70000265258788, "test_acc5": 97.69600127685547, "epoch": 9, "n_parameters": 310834152}
11 | {"train_lr": 2.5714629337683663e-05, "train_min_lr": 2.5714629337683663e-05, "train_loss": 1.7435758125980934, "train_class_acc": 0.8079981327438049, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5750291858268081, "test_acc1": 85.58600266601563, "test_acc5": 97.75000124389648, "epoch": 10, "n_parameters": 310834152}
12 | {"train_lr": 2.4089803793598676e-05, "train_min_lr": 2.4089803793598676e-05, "train_loss": 1.7315405576774996, "train_class_acc": 0.8110628684552358, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5750140335475062, "test_acc1": 85.6760026525879, "test_acc5": 97.78000123291015, "epoch": 11, "n_parameters": 310834152}
13 | {"train_lr": 2.239825708477748e-05, "train_min_lr": 2.239825708477748e-05, "train_loss": 1.7212515797236745, "train_class_acc": 0.8138005220823341, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5737644145800465, "test_acc1": 85.70400269042969, "test_acc5": 97.79000122314453, "epoch": 12, "n_parameters": 310834152}
14 | {"train_lr": 2.0658522150940823e-05, "train_min_lr": 2.0658522150940823e-05, "train_loss": 1.7101750767258979, "train_class_acc": 0.8160253047561951, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5718247301669585, "test_acc1": 85.7860026928711, "test_acc5": 97.8000012097168, "epoch": 13, "n_parameters": 310834152}
15 | {"train_lr": 1.8889659892087196e-05, "train_min_lr": 1.8889659892087196e-05, "train_loss": 1.7035075478094945, "train_class_acc": 0.8176162195243805, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5718602961343195, "test_acc1": 85.9500026171875, "test_acc5": 97.8180012133789, "epoch": 14, "n_parameters": 310834152}
16 | {"train_lr": 1.7111050333282612e-05, "train_min_lr": 1.7111050333282612e-05, "train_loss": 1.6949291081650915, "train_class_acc": 0.8199417028876899, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5712083624390407, "test_acc1": 85.84400260742187, "test_acc5": 97.82400120239258, "epoch": 15, "n_parameters": 310834152}
17 | {"train_lr": 1.534218029305105e-05, "train_min_lr": 1.534218029305105e-05, "train_loss": 1.6870110205409647, "train_class_acc": 0.8220735911270983, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.570508421045439, "test_acc1": 85.84200260864257, "test_acc5": 97.85600119018555, "epoch": 16, "n_parameters": 310834152}
18 | {"train_lr": 1.3602429881713486e-05, "train_min_lr": 1.3602429881713486e-05, "train_loss": 1.6825196964221059, "train_class_acc": 0.82322267186251, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5710876270579138, "test_acc1": 85.90600260253906, "test_acc5": 97.82000119018555, "epoch": 17, "n_parameters": 310834152}
19 | {"train_lr": 1.1910860168842462e-05, "train_min_lr": 1.1910860168842462e-05, "train_loss": 1.6754372675912819, "train_class_acc": 0.8250719736710631, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5708488073499723, "test_acc1": 85.8620026220703, "test_acc5": 97.81800118774414, "epoch": 18, "n_parameters": 310834152}
20 | {"train_lr": 1.0286004346196181e-05, "train_min_lr": 1.0286004346196181e-05, "train_loss": 1.6718671010906914, "train_class_acc": 0.8259642286171063, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5703325133941982, "test_acc1": 85.83000261474609, "test_acc5": 97.82800118896485, "epoch": 19, "n_parameters": 310834152}
21 | {"train_lr": 8.745664674190202e-06, "train_min_lr": 8.745664674190202e-06, "train_loss": 1.667895621664042, "train_class_acc": 0.8266925522082335, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5726590922309293, "test_acc1": 85.74800260009765, "test_acc5": 97.81800119018554, "epoch": 20, "n_parameters": 310834152}
22 | {"train_lr": 7.306717436607989e-06, "train_min_lr": 7.306717436607989e-06, "train_loss": 1.663098725781357, "train_class_acc": 0.8280797237210232, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5705665085086478, "test_acc1": 85.86000262939453, "test_acc5": 97.84200119140625, "epoch": 21, "n_parameters": 310834152}
23 | {"train_lr": 5.984928040503593e-06, "train_min_lr": 5.984928040503593e-06, "train_loss": 1.6612017263682912, "train_class_acc": 0.8286690959732215, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5713938242859311, "test_acc1": 85.8300025793457, "test_acc5": 97.79600120483398, "epoch": 22, "n_parameters": 310834152}
24 | {"train_lr": 4.794778287102872e-06, "train_min_lr": 4.794778287102872e-06, "train_loss": 1.657341641054558, "train_class_acc": 0.8294949977517986, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.570363612395936, "test_acc1": 85.92000260131836, "test_acc5": 97.81800120117188, "epoch": 23, "n_parameters": 310834152}
25 | {"train_lr": 3.7493077061589098e-06, "train_min_lr": 3.7493077061589098e-06, "train_loss": 1.6565849094641962, "train_class_acc": 0.8301234012789768, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5699021836521525, "test_acc1": 85.85400256835938, "test_acc5": 97.83400119506835, "epoch": 24, "n_parameters": 310834152}
26 | {"train_lr": 2.8599706921353547e-06, "train_min_lr": 2.8599706921353547e-06, "train_loss": 1.6534634246502062, "train_class_acc": 0.8304403352318146, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5695552553810265, "test_acc1": 85.8520026208496, "test_acc5": 97.84200118652343, "epoch": 25, "n_parameters": 310834152}
27 | {"train_lr": 2.1365110074636167e-06, "train_min_lr": 2.1365110074636167e-06, "train_loss": 1.652146488183932, "train_class_acc": 0.8312748238908872, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5698677959719894, "test_acc1": 85.84400259277344, "test_acc5": 97.85800118774414, "epoch": 26, "n_parameters": 310834152}
28 | {"train_lr": 1.58685502784256e-06, "train_min_lr": 1.58685502784256e-06, "train_loss": 1.6519211692472013, "train_class_acc": 0.8312717013888888, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5704294863902746, "test_acc1": 85.81200258178711, "test_acc5": 97.83800119262695, "epoch": 27, "n_parameters": 310834152}
29 | {"train_lr": 1.2170248992078712e-06, "train_min_lr": 1.2170248992078712e-06, "train_loss": 1.6494522876918412, "train_class_acc": 0.831925865557554, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5699063584572454, "test_acc1": 85.86200257080078, "test_acc5": 97.84400118896484, "epoch": 28, "n_parameters": 310834152}
30 | {"train_lr": 1.0310725578407872e-06, "train_min_lr": 1.0310725578407872e-06, "train_loss": 1.6503264461305263, "train_class_acc": 0.8314910571542766, "train_weight_decay": 9.999999999998698e-09, "train_drop_path": 0.30000000000003596, "train_soft_mask_rate": 0.0, "test_loss": 0.5699177861563376, "test_acc1": 85.86400259155273, "test_acc5": 97.83600119018554, "epoch": 29, "n_parameters": 310834152}
31 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.nn as nn
7 | import torch.backends.cudnn as cudnn
8 | import json
9 | import os
10 |
11 | from pathlib import Path
12 |
13 | from timm.data.mixup import Mixup
14 | from timm.models import create_model
15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
16 | from timm.utils import ModelEma
17 | from optim_factory import create_optimizer, LayerDecayValueAssigner
18 |
19 | from datasets import build_dataset
20 | from engine import train_one_epoch, evaluate
21 |
22 | from utils import NativeScalerWithGradNormCount as NativeScaler
23 | import utils
24 | from drop_scheduler import drop_scheduler
25 |
26 | import models
27 |
28 | def str2bool(v):
29 | """
30 | Converts string to bool type; enables command line
31 | arguments in the format of '--arg1 true --arg2 false'
32 | """
33 | if isinstance(v, bool):
34 | return v
35 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
36 | return True
37 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
38 | return False
39 | else:
40 | raise argparse.ArgumentTypeError('Boolean value expected.')
41 |
42 | def get_args_parser():
43 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False)
44 | parser.add_argument('--batch_size', default=64, type=int,
45 | help='Per GPU batch size')
46 | parser.add_argument('--epochs', default=300, type=int)
47 | parser.add_argument('--update_freq', default=1, type=int,
48 | help='gradient accumulation steps')
49 |
50 | # Model parameters
51 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL',
52 | help='Name of model to train')
53 | parser.add_argument('--input_size', default=224, type=int,
54 | help='image input size')
55 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float,
56 | help="Layer scale initial values")
57 |
58 | ########################## settings specific to this project ##########################
59 |
60 | # dropout and stochastic depth drop rate; set at most one to non-zero
61 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT',
62 | help='Drop path rate (default: 0.0)')
63 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',
64 | help='Drop path rate (default: 0.0)')
65 |
66 | # early / late dropout and stochastic depth settings
67 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode')
68 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'],
69 | help='drop schedule for early dropout / s.d. only')
70 | parser.add_argument('--cutoff_epoch', type=int, default=0,
71 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
72 |
73 | #######################################################################################
74 |
75 | # EMA related parameters
76 | parser.add_argument('--model_ema', type=str2bool, default=False)
77 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
78 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
79 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.')
80 |
81 | # Optimization parameters
82 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
83 | help='Optimizer (default: "adamw"')
84 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
85 | help='Optimizer Epsilon (default: 1e-8)')
86 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
87 | help='Optimizer Betas (default: None, use opt default)')
88 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
89 | help='Clip gradient norm (default: None, no clipping)')
90 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
91 | help='SGD momentum (default: 0.9)')
92 | parser.add_argument('--weight_decay', type=float, default=0.05,
93 | help='weight decay (default: 0.05)')
94 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
95 | weight decay. We use a cosine schedule for WD and using a larger decay by
96 | the end of training improves performance for ViTs.""")
97 |
98 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR',
99 | help='learning rate (default: 4e-3), with total batch size 4096')
100 | parser.add_argument('--layer_decay', type=float, default=1.0)
101 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
102 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
103 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N',
104 | help='epochs to warmup LR, if scheduler supports')
105 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
106 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
107 |
108 | # Augmentation parameters
109 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
110 | help='Color jitter factor (default: 0.4)')
111 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
112 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
113 | parser.add_argument('--smoothing', type=float, default=0.1,
114 | help='Label smoothing (default: 0.1)')
115 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
116 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
117 |
118 | # Evaluation parameters
119 | parser.add_argument('--crop_pct', type=float, default=None)
120 |
121 | # * Random Erase params
122 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
123 | help='Random erase prob (default: 0.25)')
124 | parser.add_argument('--remode', type=str, default='pixel',
125 | help='Random erase mode (default: "pixel")')
126 | parser.add_argument('--recount', type=int, default=1,
127 | help='Random erase count (default: 1)')
128 | parser.add_argument('--resplit', type=str2bool, default=False,
129 | help='Do not random erase first (clean) augmentation split')
130 |
131 | # * Mixup params
132 | parser.add_argument('--mixup', type=float, default=0.8,
133 | help='mixup alpha, mixup enabled if > 0.')
134 | parser.add_argument('--cutmix', type=float, default=1.0,
135 | help='cutmix alpha, cutmix enabled if > 0.')
136 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
137 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
138 | parser.add_argument('--mixup_prob', type=float, default=1.0,
139 | help='Probability of performing mixup or cutmix when either/both is enabled')
140 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
141 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
142 | parser.add_argument('--mixup_mode', type=str, default='batch',
143 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
144 |
145 | # * Finetuning params
146 | parser.add_argument('--finetune', default='',
147 | help='finetune from checkpoint')
148 | parser.add_argument('--head_init_scale', default=1.0, type=float,
149 | help='classifier head initial scale, typically adjusted in fine-tuning')
150 | parser.add_argument('--model_key', default='model|module', type=str,
151 | help='which key to load from saved state dict, usually model or model_ema')
152 | parser.add_argument('--model_prefix', default='', type=str)
153 |
154 | # Dataset parameters
155 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
156 | help='dataset path')
157 | parser.add_argument('--eval_data_path', default=None, type=str,
158 | help='dataset path for evaluation')
159 | parser.add_argument('--nb_classes', default=1000, type=int,
160 | help='number of the classification types')
161 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
162 | parser.add_argument('--data_set', default='IMNET',
163 | type=str, help='ImageNet dataset path')
164 | parser.add_argument('--output_dir', default='',
165 | help='path where to save, empty for no saving')
166 | parser.add_argument('--log_dir', default=None,
167 | help='path where to tensorboard log')
168 | parser.add_argument('--device', default='cuda',
169 | help='device to use for training / testing')
170 | parser.add_argument('--seed', default=0, type=int)
171 |
172 | parser.add_argument('--resume', default='',
173 | help='resume from checkpoint')
174 | parser.add_argument('--auto_resume', type=str2bool, default=True)
175 | parser.add_argument('--save_ckpt', type=str2bool, default=True)
176 | parser.add_argument('--save_ckpt_freq', default=1, type=int)
177 | parser.add_argument('--save_ckpt_num', default=3, type=int)
178 |
179 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
180 | help='start epoch')
181 | parser.add_argument('--eval', type=str2bool, default=False,
182 | help='Perform evaluation only')
183 | parser.add_argument('--dist_eval', type=str2bool, default=True,
184 | help='Enabling distributed evaluation')
185 | parser.add_argument('--disable_eval', type=str2bool, default=False,
186 | help='Disabling evaluation during training')
187 | parser.add_argument('--num_workers', default=10, type=int)
188 | parser.add_argument('--pin_mem', type=str2bool, default=True,
189 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
190 |
191 | # distributed training parameters
192 | parser.add_argument('--world_size', default=1, type=int,
193 | help='number of distributed processes')
194 | parser.add_argument('--local_rank', default=-1, type=int)
195 | parser.add_argument('--dist_on_itp', type=str2bool, default=False)
196 | parser.add_argument('--dist_url', default='env://',
197 | help='url used to set up distributed training')
198 |
199 | parser.add_argument('--use_amp', type=str2bool, default=False,
200 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not")
201 |
202 | # Weights and Biases arguments
203 | parser.add_argument('--enable_wandb', type=str2bool, default=False,
204 | help="enable logging to Weights and Biases")
205 | parser.add_argument('--project', default='convnext', type=str,
206 | help="The name of the W&B project where you're sending the new run.")
207 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False,
208 | help="Save model checkpoints as W&B Artifacts.")
209 |
210 | return parser
211 |
212 | def main(args):
213 | utils.init_distributed_mode(args)
214 | print(args)
215 | device = torch.device(args.device)
216 |
217 | # fix the seed for reproducibility
218 | seed = args.seed + utils.get_rank()
219 | torch.manual_seed(seed)
220 | np.random.seed(seed)
221 | cudnn.benchmark = True
222 |
223 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
224 | if args.disable_eval:
225 | args.dist_eval = False
226 | dataset_val = None
227 | else:
228 | dataset_val, _ = build_dataset(is_train=False, args=args)
229 |
230 | num_tasks = utils.get_world_size()
231 | global_rank = utils.get_rank()
232 |
233 | sampler_train = torch.utils.data.DistributedSampler(
234 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
235 | )
236 | print("Sampler_train = %s" % str(sampler_train))
237 | if args.dist_eval:
238 | if len(dataset_val) % num_tasks != 0:
239 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
240 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
241 | 'equal num of samples per-process.')
242 | sampler_val = torch.utils.data.DistributedSampler(
243 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
244 | else:
245 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
246 |
247 | if global_rank == 0 and args.log_dir is not None:
248 | os.makedirs(args.log_dir, exist_ok=True)
249 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
250 | else:
251 | log_writer = None
252 |
253 | if global_rank == 0 and args.enable_wandb:
254 | wandb_logger = utils.WandbLogger(args)
255 | else:
256 | wandb_logger = None
257 |
258 | data_loader_train = torch.utils.data.DataLoader(
259 | dataset_train, sampler=sampler_train,
260 | batch_size=args.batch_size,
261 | num_workers=args.num_workers,
262 | pin_memory=args.pin_mem,
263 | drop_last=True,
264 | )
265 |
266 | if dataset_val is not None:
267 | data_loader_val = torch.utils.data.DataLoader(
268 | dataset_val, sampler=sampler_val,
269 | batch_size=int(1.5 * args.batch_size),
270 | num_workers=args.num_workers,
271 | pin_memory=args.pin_mem,
272 | drop_last=False
273 | )
274 | else:
275 | data_loader_val = None
276 |
277 | mixup_fn = None
278 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
279 | if mixup_active:
280 | print("Mixup is activated!")
281 | mixup_fn = Mixup(
282 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
283 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
284 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
285 |
286 | model = utils.build_model(args)
287 | if args.finetune:
288 | if args.finetune.startswith('https'):
289 | checkpoint = torch.hub.load_state_dict_from_url(
290 | args.finetune, map_location='cpu', check_hash=True)
291 | else:
292 | checkpoint = torch.load(args.finetune, map_location='cpu')
293 |
294 | print("Load ckpt from %s" % args.finetune)
295 | checkpoint_model = None
296 | for model_key in args.model_key.split('|'):
297 | if model_key in checkpoint:
298 | checkpoint_model = checkpoint[model_key]
299 | print("Load state_dict by model_key = %s" % model_key)
300 | break
301 | if checkpoint_model is None:
302 | checkpoint_model = checkpoint
303 | state_dict = model.state_dict()
304 | for k in ['head.weight', 'head.bias']:
305 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
306 | print(f"Removing key {k} from pretrained checkpoint")
307 | del checkpoint_model[k]
308 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
309 | model.to(device)
310 |
311 | model_ema = None
312 | if args.model_ema:
313 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
314 | model_ema = ModelEma(
315 | model,
316 | decay=args.model_ema_decay,
317 | device='cpu' if args.model_ema_force_cpu else '',
318 | resume='')
319 | print("Using EMA with decay = %.8f" % args.model_ema_decay)
320 |
321 | model_without_ddp = model
322 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
323 |
324 | print("Model = %s" % str(model_without_ddp))
325 | print('number of params:', n_parameters)
326 |
327 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
328 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
329 | print("LR = %.8f" % args.lr)
330 | print("Batch size = %d" % total_batch_size)
331 | print("Update frequent = %d" % args.update_freq)
332 | print("Number of training examples = %d" % len(dataset_train))
333 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
334 |
335 | if args.layer_decay < 1.0 or args.layer_decay > 1.0:
336 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value.
337 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \
338 | "Layer Decay impl only supports convnext_small/base/large/xlarge"
339 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
340 | else:
341 | assigner = None
342 |
343 | if assigner is not None:
344 | print("Assigned values = %s" % str(assigner.values))
345 |
346 | if args.distributed:
347 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
348 | model_without_ddp = model.module
349 |
350 | optimizer = create_optimizer(
351 | args, model_without_ddp, skip_list=None,
352 | get_num_layer=assigner.get_layer_id if assigner is not None else None,
353 | get_layer_scale=assigner.get_scale if assigner is not None else None)
354 |
355 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used
356 |
357 | print("Use Cosine LR scheduler")
358 | lr_schedule_values = utils.cosine_scheduler(
359 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
360 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
361 | )
362 |
363 | if args.weight_decay_end is None:
364 | args.weight_decay_end = args.weight_decay
365 | wd_schedule_values = utils.cosine_scheduler(
366 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
367 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
368 |
369 | schedules = {}
370 |
371 | # At most one of dropout and stochastic depth should be enabled.
372 | assert(args.dropout == 0 or args.drop_path == 0)
373 | # ConvNeXt does not support dropout.
374 | assert(args.dropout == 0 if args.model.startswith("convnext") else True)
375 |
376 | if args.dropout > 0:
377 | schedules['do'] = drop_scheduler(
378 | args.dropout, args.epochs, num_training_steps_per_epoch,
379 | args.cutoff_epoch, args.drop_mode, args.drop_schedule)
380 | print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do'])))
381 |
382 | if args.drop_path > 0:
383 | schedules['dp'] = drop_scheduler(
384 | args.drop_path, args.epochs, num_training_steps_per_epoch,
385 | args.cutoff_epoch, args.drop_mode, args.drop_schedule)
386 | print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))
387 |
388 | if mixup_fn is not None:
389 | # smoothing is handled with mixup label transform
390 | criterion = SoftTargetCrossEntropy()
391 | elif args.smoothing > 0.:
392 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
393 | else:
394 | criterion = torch.nn.CrossEntropyLoss()
395 |
396 | print("criterion = %s" % str(criterion))
397 |
398 | utils.auto_load_model(
399 | args=args, model=model, model_without_ddp=model_without_ddp,
400 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
401 |
402 | if args.eval:
403 | print(f"Eval only mode")
404 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
405 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%")
406 | return
407 |
408 | max_accuracy = 0.0
409 | if args.model_ema and args.model_ema_eval:
410 | max_accuracy_ema = 0.0
411 |
412 | print("Start training for %d epochs" % args.epochs)
413 | start_time = time.time()
414 | for epoch in range(args.start_epoch, args.epochs):
415 | if args.distributed:
416 | data_loader_train.sampler.set_epoch(epoch)
417 | if log_writer is not None:
418 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
419 | if wandb_logger:
420 | wandb_logger.set_steps()
421 | train_stats = train_one_epoch(
422 | model, criterion, data_loader_train, optimizer,
423 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
424 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch,
425 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, schedules=schedules,
426 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
427 | use_amp=args.use_amp
428 | )
429 | if args.output_dir and args.save_ckpt:
430 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
431 | utils.save_model(
432 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
433 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
434 | if data_loader_val is not None:
435 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
436 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
437 | if max_accuracy < test_stats["acc1"]:
438 | max_accuracy = test_stats["acc1"]
439 | if args.output_dir and args.save_ckpt:
440 | utils.save_model(
441 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
442 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
443 | print(f'Max accuracy: {max_accuracy:.2f}%')
444 |
445 | if log_writer is not None:
446 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
447 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
448 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
449 |
450 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
451 | **{f'test_{k}': v for k, v in test_stats.items()},
452 | 'epoch': epoch,
453 | 'n_parameters': n_parameters}
454 |
455 | # repeat testing routines for EMA, if ema eval is turned on
456 | if args.model_ema and args.model_ema_eval:
457 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp)
458 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%")
459 | if max_accuracy_ema < test_stats_ema["acc1"]:
460 | max_accuracy_ema = test_stats_ema["acc1"]
461 | if args.output_dir and args.save_ckpt:
462 | utils.save_model(
463 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
464 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema)
465 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%')
466 | if log_writer is not None:
467 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch)
468 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}})
469 | else:
470 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
471 | 'epoch': epoch,
472 | 'n_parameters': n_parameters}
473 |
474 | if args.output_dir and utils.is_main_process():
475 | if log_writer is not None:
476 | log_writer.flush()
477 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
478 | f.write(json.dumps(log_stats) + "\n")
479 |
480 | if wandb_logger:
481 | wandb_logger.log_epoch_metrics(log_stats)
482 |
483 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir:
484 | wandb_logger.log_checkpoints()
485 |
486 |
487 | total_time = time.time() - start_time
488 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
489 | print('Training time {}'.format(total_time_str))
490 |
491 | if __name__ == '__main__':
492 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()])
493 | args = parser.parse_args()
494 | if args.output_dir:
495 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
496 | main(args)
497 |
--------------------------------------------------------------------------------
/main_soft.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.nn as nn
7 | import torch.backends.cudnn as cudnn
8 | import json
9 | import os
10 |
11 | from pathlib import Path
12 |
13 | from timm.data.mixup import Mixup
14 | from timm.models import create_model
15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
16 | from timm.utils import ModelEma
17 | from optim_factory import create_optimizer, LayerDecayValueAssigner
18 |
19 | from datasets import build_dataset
20 | from engine_soft import train_one_epoch, evaluate
21 |
22 | from utils import NativeScalerWithGradNormCount as NativeScaler
23 | import utils
24 | from drop_scheduler import drop_scheduler
25 | from mask_scheduler import mask_scheduler
26 |
27 | import models
28 |
29 | def str2bool(v):
30 | """
31 | Converts string to bool type; enables command line
32 | arguments in the format of '--arg1 true --arg2 false'
33 | """
34 | if isinstance(v, bool):
35 | return v
36 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
37 | return True
38 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
39 | return False
40 | else:
41 | raise argparse.ArgumentTypeError('Boolean value expected.')
42 |
43 | def get_args_parser():
44 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False)
45 | parser.add_argument('--batch_size', default=64, type=int,
46 | help='Per GPU batch size')
47 | parser.add_argument('--epochs', default=300, type=int)
48 | parser.add_argument('--update_freq', default=1, type=int,
49 | help='gradient accumulation steps')
50 |
51 | # Model parameters
52 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL',
53 | help='Name of model to train')
54 | parser.add_argument('--input_size', default=224, type=int,
55 | help='image input size')
56 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float,
57 | help="Layer scale initial values")
58 |
59 | ########################## settings specific to this project ##########################
60 |
61 | # dropout and stochastic depth drop rate; set at most one to non-zero
62 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT',
63 | help='Drop path rate (default: 0.0)')
64 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',
65 | help='Drop path rate (default: 0.0)')
66 |
67 | # early / late dropout and stochastic depth settings
68 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode')
69 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'],
70 | help='drop schedule for early dropout / s.d. only')
71 | parser.add_argument('--cutoff_epoch', type=int, default=0,
72 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
73 |
74 | # soft mask settings
75 | parser.add_argument('--soft_mask', type=float, default=1.0, metavar='PCT',
76 | help='Soft mask rate (default: 1.0)')
77 | parser.add_argument('--mask_mode', type=str, default='standard', choices=['standard', 'soft'], help='mask mode')
78 | parser.add_argument('--mask_schedule', type=str, default='constant', choices=['constant', 'linear'],
79 | help='mask schedule for soft mask')
80 | parser.add_argument('--cutoff_soft', type=int, default=0,
81 | help='if mask_mode is early / late, this is the epoch where soft mask ends / starts')
82 |
83 | #######################################################################################
84 |
85 | # EMA related parameters
86 | parser.add_argument('--model_ema', type=str2bool, default=False)
87 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
88 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
89 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.')
90 |
91 | # Optimization parameters
92 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
93 | help='Optimizer (default: "adamw"')
94 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
95 | help='Optimizer Epsilon (default: 1e-8)')
96 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
97 | help='Optimizer Betas (default: None, use opt default)')
98 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
99 | help='Clip gradient norm (default: None, no clipping)')
100 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
101 | help='SGD momentum (default: 0.9)')
102 | parser.add_argument('--weight_decay', type=float, default=0.05,
103 | help='weight decay (default: 0.05)')
104 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
105 | weight decay. We use a cosine schedule for WD and using a larger decay by
106 | the end of training improves performance for ViTs.""")
107 |
108 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR',
109 | help='learning rate (default: 4e-3), with total batch size 4096')
110 | parser.add_argument('--layer_decay', type=float, default=1.0)
111 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
112 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
113 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N',
114 | help='epochs to warmup LR, if scheduler supports')
115 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
116 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
117 |
118 | # Augmentation parameters
119 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
120 | help='Color jitter factor (default: 0.4)')
121 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
122 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
123 | parser.add_argument('--smoothing', type=float, default=0.1,
124 | help='Label smoothing (default: 0.1)')
125 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
126 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
127 |
128 | # Evaluation parameters
129 | parser.add_argument('--crop_pct', type=float, default=None)
130 |
131 | # * Random Erase params
132 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
133 | help='Random erase prob (default: 0.25)')
134 | parser.add_argument('--remode', type=str, default='pixel',
135 | help='Random erase mode (default: "pixel")')
136 | parser.add_argument('--recount', type=int, default=1,
137 | help='Random erase count (default: 1)')
138 | parser.add_argument('--resplit', type=str2bool, default=False,
139 | help='Do not random erase first (clean) augmentation split')
140 |
141 | # * Mixup params
142 | parser.add_argument('--mixup', type=float, default=0.8,
143 | help='mixup alpha, mixup enabled if > 0.')
144 | parser.add_argument('--cutmix', type=float, default=1.0,
145 | help='cutmix alpha, cutmix enabled if > 0.')
146 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
147 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
148 | parser.add_argument('--mixup_prob', type=float, default=1.0,
149 | help='Probability of performing mixup or cutmix when either/both is enabled')
150 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
151 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
152 | parser.add_argument('--mixup_mode', type=str, default='batch',
153 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
154 |
155 | # * Finetuning params
156 | parser.add_argument('--finetune', default='',
157 | help='finetune from checkpoint')
158 | parser.add_argument('--head_init_scale', default=1.0, type=float,
159 | help='classifier head initial scale, typically adjusted in fine-tuning')
160 | parser.add_argument('--model_key', default='model|module', type=str,
161 | help='which key to load from saved state dict, usually model or model_ema')
162 | parser.add_argument('--model_prefix', default='', type=str)
163 |
164 | # Dataset parameters
165 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
166 | help='dataset path')
167 | parser.add_argument('--eval_data_path', default=None, type=str,
168 | help='dataset path for evaluation')
169 | parser.add_argument('--nb_classes', default=1000, type=int,
170 | help='number of the classification types')
171 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
172 | parser.add_argument('--data_set', default='IMNET',
173 | type=str, help='ImageNet dataset path')
174 | parser.add_argument('--output_dir', default='',
175 | help='path where to save, empty for no saving')
176 | parser.add_argument('--log_dir', default=None,
177 | help='path where to tensorboard log')
178 | parser.add_argument('--device', default='cuda',
179 | help='device to use for training / testing')
180 | parser.add_argument('--seed', default=0, type=int)
181 |
182 | parser.add_argument('--resume', default='',
183 | help='resume from checkpoint')
184 | parser.add_argument('--auto_resume', type=str2bool, default=True)
185 | parser.add_argument('--save_ckpt', type=str2bool, default=True)
186 | parser.add_argument('--save_ckpt_freq', default=1, type=int)
187 | parser.add_argument('--save_ckpt_num', default=3, type=int)
188 |
189 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
190 | help='start epoch')
191 | parser.add_argument('--eval', type=str2bool, default=False,
192 | help='Perform evaluation only')
193 | parser.add_argument('--dist_eval', type=str2bool, default=True,
194 | help='Enabling distributed evaluation')
195 | parser.add_argument('--disable_eval', type=str2bool, default=False,
196 | help='Disabling evaluation during training')
197 | parser.add_argument('--num_workers', default=10, type=int)
198 | parser.add_argument('--pin_mem', type=str2bool, default=True,
199 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
200 |
201 | # distributed training parameters
202 | parser.add_argument('--world_size', default=1, type=int,
203 | help='number of distributed processes')
204 | parser.add_argument('--local_rank', default=-1, type=int)
205 | parser.add_argument('--dist_on_itp', type=str2bool, default=False)
206 | parser.add_argument('--dist_url', default='env://',
207 | help='url used to set up distributed training')
208 |
209 | parser.add_argument('--use_amp', type=str2bool, default=False,
210 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not")
211 |
212 | # Weights and Biases arguments
213 | parser.add_argument('--enable_wandb', type=str2bool, default=False,
214 | help="enable logging to Weights and Biases")
215 | parser.add_argument('--project', default='convnext', type=str,
216 | help="The name of the W&B project where you're sending the new run.")
217 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False,
218 | help="Save model checkpoints as W&B Artifacts.")
219 |
220 | return parser
221 |
222 | def main(args):
223 | utils.init_distributed_mode(args)
224 | print(args)
225 | device = torch.device(args.device)
226 |
227 | # fix the seed for reproducibility
228 | seed = args.seed + utils.get_rank()
229 | torch.manual_seed(seed)
230 | np.random.seed(seed)
231 | cudnn.benchmark = True
232 |
233 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
234 | if args.disable_eval:
235 | args.dist_eval = False
236 | dataset_val = None
237 | else:
238 | dataset_val, _ = build_dataset(is_train=False, args=args)
239 |
240 | num_tasks = utils.get_world_size()
241 | global_rank = utils.get_rank()
242 |
243 | sampler_train = torch.utils.data.DistributedSampler(
244 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
245 | )
246 | print("Sampler_train = %s" % str(sampler_train))
247 | if args.dist_eval:
248 | if len(dataset_val) % num_tasks != 0:
249 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
250 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
251 | 'equal num of samples per-process.')
252 | sampler_val = torch.utils.data.DistributedSampler(
253 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
254 | else:
255 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
256 |
257 | if global_rank == 0 and args.log_dir is not None:
258 | os.makedirs(args.log_dir, exist_ok=True)
259 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
260 | else:
261 | log_writer = None
262 |
263 | if global_rank == 0 and args.enable_wandb:
264 | wandb_logger = utils.WandbLogger(args)
265 | else:
266 | wandb_logger = None
267 |
268 | data_loader_train = torch.utils.data.DataLoader(
269 | dataset_train, sampler=sampler_train,
270 | batch_size=args.batch_size,
271 | num_workers=args.num_workers,
272 | pin_memory=args.pin_mem,
273 | drop_last=True,
274 | )
275 |
276 | if dataset_val is not None:
277 | data_loader_val = torch.utils.data.DataLoader(
278 | dataset_val, sampler=sampler_val,
279 | batch_size=int(1.5 * args.batch_size),
280 | num_workers=args.num_workers,
281 | pin_memory=args.pin_mem,
282 | drop_last=False
283 | )
284 | else:
285 | data_loader_val = None
286 |
287 | mixup_fn = None
288 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
289 | if mixup_active:
290 | print("Mixup is activated!")
291 | mixup_fn = Mixup(
292 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
293 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
294 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
295 |
296 | model = utils.build_model(args)
297 | if args.finetune:
298 | if args.finetune.startswith('https'):
299 | checkpoint = torch.hub.load_state_dict_from_url(
300 | args.finetune, map_location='cpu', check_hash=True)
301 | else:
302 | checkpoint = torch.load(args.finetune, map_location='cpu')
303 |
304 | print("Load ckpt from %s" % args.finetune)
305 | checkpoint_model = None
306 | for model_key in args.model_key.split('|'):
307 | if model_key in checkpoint:
308 | checkpoint_model = checkpoint[model_key]
309 | print("Load state_dict by model_key = %s" % model_key)
310 | break
311 | if checkpoint_model is None:
312 | checkpoint_model = checkpoint
313 | state_dict = model.state_dict()
314 | for k in ['head.weight', 'head.bias']:
315 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
316 | print(f"Removing key {k} from pretrained checkpoint")
317 | del checkpoint_model[k]
318 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
319 | model.to(device)
320 |
321 | model_ema = None
322 | if args.model_ema:
323 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
324 | model_ema = ModelEma(
325 | model,
326 | decay=args.model_ema_decay,
327 | device='cpu' if args.model_ema_force_cpu else '',
328 | resume='')
329 | print("Using EMA with decay = %.8f" % args.model_ema_decay)
330 |
331 | model_without_ddp = model
332 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
333 |
334 | print("Model = %s" % str(model_without_ddp))
335 | print('number of params:', n_parameters)
336 |
337 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
338 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
339 | print("LR = %.8f" % args.lr)
340 | print("Batch size = %d" % total_batch_size)
341 | print("Update frequent = %d" % args.update_freq)
342 | print("Number of training examples = %d" % len(dataset_train))
343 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
344 |
345 | if args.layer_decay < 1.0 or args.layer_decay > 1.0:
346 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value.
347 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \
348 | "Layer Decay impl only supports convnext_small/base/large/xlarge"
349 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
350 | else:
351 | assigner = None
352 |
353 | if assigner is not None:
354 | print("Assigned values = %s" % str(assigner.values))
355 |
356 | if args.distributed:
357 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
358 | model_without_ddp = model.module
359 |
360 | optimizer = create_optimizer(
361 | args, model_without_ddp, skip_list=None,
362 | get_num_layer=assigner.get_layer_id if assigner is not None else None,
363 | get_layer_scale=assigner.get_scale if assigner is not None else None)
364 |
365 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used
366 |
367 | print("Use Cosine LR scheduler")
368 | lr_schedule_values = utils.cosine_scheduler(
369 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
370 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
371 | )
372 |
373 | if args.weight_decay_end is None:
374 | args.weight_decay_end = args.weight_decay
375 | wd_schedule_values = utils.cosine_scheduler(
376 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
377 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
378 |
379 | schedules = {}
380 |
381 | # At most one of dropout and stochastic depth should be enabled.
382 | assert(args.dropout == 0 or args.drop_path == 0)
383 | # ConvNeXt does not support dropout.
384 | assert(args.dropout == 0 if args.model.startswith("convnext") else True)
385 |
386 | if args.dropout > 0:
387 | schedules['do'] = drop_scheduler(
388 | args.dropout, args.epochs, num_training_steps_per_epoch,
389 | args.cutoff_epoch, args.drop_mode, args.drop_schedule)
390 | print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do'])))
391 |
392 | if args.drop_path > 0:
393 | schedules['dp'] = drop_scheduler(
394 | args.drop_path, args.epochs, num_training_steps_per_epoch,
395 | args.cutoff_epoch, args.drop_mode, args.drop_schedule)
396 | print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))
397 |
398 | # Mask schedule enabled.
399 | schedules['sm'] = mask_scheduler(
400 | args.soft_mask, args.epochs, num_training_steps_per_epoch,
401 | args.cutoff_soft, args.mask_mode, args.mask_schedule)
402 | print("Min MASK = %.7f, Max MASK = %.7f" % (min(schedules['sm']), max(schedules['sm'])))
403 |
404 | if mixup_fn is not None:
405 | # smoothing is handled with mixup label transform
406 | criterion = SoftTargetCrossEntropy()
407 | elif args.smoothing > 0.:
408 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
409 | else:
410 | criterion = torch.nn.CrossEntropyLoss()
411 |
412 | print("criterion = %s" % str(criterion))
413 |
414 | utils.auto_load_model(
415 | args=args, model=model, model_without_ddp=model_without_ddp,
416 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
417 |
418 | if args.eval:
419 | print(f"Eval only mode")
420 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
421 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%")
422 | return
423 |
424 | max_accuracy = 0.0
425 | if args.model_ema and args.model_ema_eval:
426 | max_accuracy_ema = 0.0
427 |
428 | print("Start training for %d epochs" % args.epochs)
429 | start_time = time.time()
430 | for epoch in range(args.start_epoch, args.epochs):
431 | if args.distributed:
432 | data_loader_train.sampler.set_epoch(epoch)
433 | if log_writer is not None:
434 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
435 | if wandb_logger:
436 | wandb_logger.set_steps()
437 | train_stats = train_one_epoch(
438 | model, criterion, data_loader_train, optimizer,
439 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
440 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch,
441 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, schedules=schedules,
442 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
443 | use_amp=args.use_amp
444 | )
445 | if args.output_dir and args.save_ckpt:
446 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
447 | utils.save_model(
448 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
449 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
450 | if data_loader_val is not None:
451 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
452 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
453 | if max_accuracy < test_stats["acc1"]:
454 | max_accuracy = test_stats["acc1"]
455 | if args.output_dir and args.save_ckpt:
456 | utils.save_model(
457 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
458 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
459 | print(f'Max accuracy: {max_accuracy:.2f}%')
460 |
461 | if log_writer is not None:
462 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
463 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
464 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
465 |
466 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
467 | **{f'test_{k}': v for k, v in test_stats.items()},
468 | 'epoch': epoch,
469 | 'n_parameters': n_parameters}
470 |
471 | # repeat testing routines for EMA, if ema eval is turned on
472 | if args.model_ema and args.model_ema_eval:
473 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp)
474 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%")
475 | if max_accuracy_ema < test_stats_ema["acc1"]:
476 | max_accuracy_ema = test_stats_ema["acc1"]
477 | if args.output_dir and args.save_ckpt:
478 | utils.save_model(
479 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
480 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema)
481 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%')
482 | if log_writer is not None:
483 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch)
484 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}})
485 | else:
486 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
487 | 'epoch': epoch,
488 | 'n_parameters': n_parameters}
489 |
490 | if args.output_dir and utils.is_main_process():
491 | if log_writer is not None:
492 | log_writer.flush()
493 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
494 | f.write(json.dumps(log_stats) + "\n")
495 |
496 | if wandb_logger:
497 | wandb_logger.log_epoch_metrics(log_stats)
498 |
499 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir:
500 | wandb_logger.log_checkpoints()
501 |
502 |
503 | total_time = time.time() - start_time
504 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
505 | print('Training time {}'.format(total_time_str))
506 |
507 | if __name__ == '__main__':
508 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()])
509 | args = parser.parse_args()
510 | if args.output_dir:
511 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
512 | main(args)
513 |
--------------------------------------------------------------------------------
/main_soft_fthr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.nn as nn
7 | import torch.backends.cudnn as cudnn
8 | import json
9 | import os
10 |
11 | from pathlib import Path
12 |
13 | from timm.data.mixup import Mixup
14 | from timm.models import create_model
15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
16 | from timm.utils import ModelEma
17 | from optim_factory import create_optimizer, LayerDecayValueAssigner
18 |
19 | from datasets import build_dataset
20 | from engine_soft import train_one_epoch, evaluate
21 |
22 | from utils import NativeScalerWithGradNormCount as NativeScaler
23 | import utils
24 | from drop_scheduler import drop_scheduler
25 | from mask_scheduler import mask_scheduler
26 |
27 | import models
28 |
29 | def str2bool(v):
30 | """
31 | Converts string to bool type; enables command line
32 | arguments in the format of '--arg1 true --arg2 false'
33 | """
34 | if isinstance(v, bool):
35 | return v
36 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
37 | return True
38 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
39 | return False
40 | else:
41 | raise argparse.ArgumentTypeError('Boolean value expected.')
42 |
43 | def get_args_parser():
44 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False)
45 | parser.add_argument('--batch_size', default=64, type=int,
46 | help='Per GPU batch size')
47 | parser.add_argument('--epochs', default=300, type=int)
48 | parser.add_argument('--update_freq', default=1, type=int,
49 | help='gradient accumulation steps')
50 |
51 | # Model parameters
52 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL',
53 | help='Name of model to train')
54 | parser.add_argument('--input_size', default=224, type=int,
55 | help='image input size')
56 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float,
57 | help="Layer scale initial values")
58 |
59 | ########################## settings specific to this project ##########################
60 |
61 | # dropout and stochastic depth drop rate; set at most one to non-zero
62 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT',
63 | help='Drop path rate (default: 0.0)')
64 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',
65 | help='Drop path rate (default: 0.0)')
66 |
67 | # early / late dropout and stochastic depth settings
68 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode')
69 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'],
70 | help='drop schedule for early dropout / s.d. only')
71 | parser.add_argument('--cutoff_epoch', type=int, default=0,
72 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts')
73 |
74 | # soft mask settings
75 | parser.add_argument('--soft_mask', type=float, default=1.0, metavar='PCT',
76 | help='Soft mask rate (default: 1.0)')
77 | parser.add_argument('--mask_mode', type=str, default='standard', choices=['standard', 'soft'], help='mask mode')
78 | parser.add_argument('--mask_schedule', type=str, default='constant', choices=['constant', 'linear'],
79 | help='mask schedule for soft mask')
80 | parser.add_argument('--cutoff_soft', type=int, default=0,
81 | help='if mask_mode is early / late, this is the epoch where soft mask ends / starts')
82 |
83 | #######################################################################################
84 |
85 | # EMA related parameters
86 | parser.add_argument('--model_ema', type=str2bool, default=False)
87 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
88 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
89 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.')
90 |
91 | # Optimization parameters
92 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
93 | help='Optimizer (default: "adamw"')
94 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
95 | help='Optimizer Epsilon (default: 1e-8)')
96 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
97 | help='Optimizer Betas (default: None, use opt default)')
98 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
99 | help='Clip gradient norm (default: None, no clipping)')
100 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
101 | help='SGD momentum (default: 0.9)')
102 | parser.add_argument('--weight_decay', type=float, default=0.05,
103 | help='weight decay (default: 0.05)')
104 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
105 | weight decay. We use a cosine schedule for WD and using a larger decay by
106 | the end of training improves performance for ViTs.""")
107 |
108 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR',
109 | help='learning rate (default: 4e-3), with total batch size 4096')
110 | parser.add_argument('--layer_decay', type=float, default=1.0)
111 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
112 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
113 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N',
114 | help='epochs to warmup LR, if scheduler supports')
115 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
116 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
117 |
118 | # Augmentation parameters
119 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
120 | help='Color jitter factor (default: 0.4)')
121 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
122 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
123 | parser.add_argument('--smoothing', type=float, default=0.1,
124 | help='Label smoothing (default: 0.1)')
125 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
126 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
127 |
128 | # Evaluation parameters
129 | parser.add_argument('--crop_pct', type=float, default=None)
130 |
131 | # * Random Erase params
132 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
133 | help='Random erase prob (default: 0.25)')
134 | parser.add_argument('--remode', type=str, default='pixel',
135 | help='Random erase mode (default: "pixel")')
136 | parser.add_argument('--recount', type=int, default=1,
137 | help='Random erase count (default: 1)')
138 | parser.add_argument('--resplit', type=str2bool, default=False,
139 | help='Do not random erase first (clean) augmentation split')
140 |
141 | # * Mixup params
142 | parser.add_argument('--mixup', type=float, default=0.8,
143 | help='mixup alpha, mixup enabled if > 0.')
144 | parser.add_argument('--cutmix', type=float, default=1.0,
145 | help='cutmix alpha, cutmix enabled if > 0.')
146 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
147 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
148 | parser.add_argument('--mixup_prob', type=float, default=1.0,
149 | help='Probability of performing mixup or cutmix when either/both is enabled')
150 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
151 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
152 | parser.add_argument('--mixup_mode', type=str, default='batch',
153 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
154 |
155 | # * Finetuning params
156 | parser.add_argument('--finetune', default='',
157 | help='finetune from checkpoint')
158 | parser.add_argument('--head_init_scale', default=1.0, type=float,
159 | help='classifier head initial scale, typically adjusted in fine-tuning')
160 | parser.add_argument('--model_key', default='model|module', type=str,
161 | help='which key to load from saved state dict, usually model or model_ema')
162 | parser.add_argument('--model_prefix', default='', type=str)
163 |
164 | # Dataset parameters
165 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
166 | help='dataset path')
167 | parser.add_argument('--eval_data_path', default=None, type=str,
168 | help='dataset path for evaluation')
169 | parser.add_argument('--nb_classes', default=1000, type=int,
170 | help='number of the classification types')
171 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
172 | parser.add_argument('--data_set', default='IMNET',
173 | type=str, help='ImageNet dataset path')
174 | parser.add_argument('--output_dir', default='',
175 | help='path where to save, empty for no saving')
176 | parser.add_argument('--log_dir', default=None,
177 | help='path where to tensorboard log')
178 | parser.add_argument('--device', default='cuda',
179 | help='device to use for training / testing')
180 | parser.add_argument('--seed', default=0, type=int)
181 |
182 | parser.add_argument('--resume', default='',
183 | help='resume from checkpoint')
184 | parser.add_argument('--auto_resume', type=str2bool, default=True)
185 | parser.add_argument('--save_ckpt', type=str2bool, default=True)
186 | parser.add_argument('--save_ckpt_freq', default=1, type=int)
187 | parser.add_argument('--save_ckpt_num', default=3, type=int)
188 |
189 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
190 | help='start epoch')
191 | parser.add_argument('--eval', type=str2bool, default=False,
192 | help='Perform evaluation only')
193 | parser.add_argument('--dist_eval', type=str2bool, default=True,
194 | help='Enabling distributed evaluation')
195 | parser.add_argument('--disable_eval', type=str2bool, default=False,
196 | help='Disabling evaluation during training')
197 | parser.add_argument('--num_workers', default=10, type=int)
198 | parser.add_argument('--pin_mem', type=str2bool, default=True,
199 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
200 |
201 | # distributed training parameters
202 | parser.add_argument('--world_size', default=1, type=int,
203 | help='number of distributed processes')
204 | parser.add_argument('--local_rank', default=-1, type=int)
205 | parser.add_argument('--dist_on_itp', type=str2bool, default=False)
206 | parser.add_argument('--dist_url', default='env://',
207 | help='url used to set up distributed training')
208 |
209 | parser.add_argument('--use_amp', type=str2bool, default=False,
210 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not")
211 |
212 | # Weights and Biases arguments
213 | parser.add_argument('--enable_wandb', type=str2bool, default=False,
214 | help="enable logging to Weights and Biases")
215 | parser.add_argument('--project', default='convnext', type=str,
216 | help="The name of the W&B project where you're sending the new run.")
217 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False,
218 | help="Save model checkpoints as W&B Artifacts.")
219 |
220 | return parser
221 |
222 | def main(args):
223 | utils.init_distributed_mode(args)
224 | print(args)
225 | device = torch.device(args.device)
226 |
227 | # fix the seed for reproducibility
228 | seed = args.seed + utils.get_rank()
229 | torch.manual_seed(seed)
230 | np.random.seed(seed)
231 | cudnn.benchmark = True
232 |
233 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
234 | if args.disable_eval:
235 | args.dist_eval = False
236 | dataset_val = None
237 | else:
238 | dataset_val, _ = build_dataset(is_train=False, args=args)
239 |
240 | num_tasks = utils.get_world_size()
241 | global_rank = utils.get_rank()
242 |
243 | sampler_train = torch.utils.data.DistributedSampler(
244 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
245 | )
246 | print("Sampler_train = %s" % str(sampler_train))
247 | if args.dist_eval:
248 | if len(dataset_val) % num_tasks != 0:
249 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
250 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
251 | 'equal num of samples per-process.')
252 | sampler_val = torch.utils.data.DistributedSampler(
253 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
254 | else:
255 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
256 |
257 | if global_rank == 0 and args.log_dir is not None:
258 | os.makedirs(args.log_dir, exist_ok=True)
259 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
260 | else:
261 | log_writer = None
262 |
263 | if global_rank == 0 and args.enable_wandb:
264 | wandb_logger = utils.WandbLogger(args)
265 | else:
266 | wandb_logger = None
267 |
268 | data_loader_train = torch.utils.data.DataLoader(
269 | dataset_train, sampler=sampler_train,
270 | batch_size=args.batch_size,
271 | num_workers=args.num_workers,
272 | pin_memory=args.pin_mem,
273 | drop_last=True,
274 | )
275 |
276 | if dataset_val is not None:
277 | data_loader_val = torch.utils.data.DataLoader(
278 | dataset_val, sampler=sampler_val,
279 | batch_size=int(1.5 * args.batch_size),
280 | num_workers=args.num_workers,
281 | pin_memory=args.pin_mem,
282 | drop_last=False
283 | )
284 | else:
285 | data_loader_val = None
286 |
287 | mixup_fn = None
288 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
289 | if mixup_active:
290 | print("Mixup is activated!")
291 | mixup_fn = Mixup(
292 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
293 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
294 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
295 |
296 | model = utils.build_model(args)
297 | if args.finetune:
298 | if args.finetune.startswith('https'):
299 | checkpoint = torch.hub.load_state_dict_from_url(
300 | args.finetune, map_location='cpu', check_hash=True)
301 | else:
302 | checkpoint = torch.load(args.finetune, map_location='cpu')
303 |
304 | print("Load ckpt from %s" % args.finetune)
305 | checkpoint_model = None
306 | for model_key in args.model_key.split('|'):
307 | if model_key in checkpoint:
308 | checkpoint_model = checkpoint[model_key]
309 | print("Load state_dict by model_key = %s" % model_key)
310 | break
311 | if checkpoint_model is None:
312 | checkpoint_model = checkpoint
313 | state_dict = model.state_dict()
314 | for k in ['head.weight', 'head.bias']:
315 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
316 | print(f"Removing key {k} from pretrained checkpoint")
317 | del checkpoint_model[k]
318 |
319 | # I add these code, interpolate position embedding
320 | print('here')
321 | pos_embed_checkpoint = checkpoint_model['pos_embed']
322 | embedding_size = pos_embed_checkpoint.shape[-1]
323 | num_patches = model.patch_embed.num_patches
324 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
325 | # height (== width) for the checkpoint position embedding
326 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
327 | # height (== width) for the new position embedding
328 | new_size = int(num_patches ** 0.5)
329 | # class_token and dist_token are kept unchanged
330 | extra_tokens = pos_embed_checkpoint[:, -num_extra_tokens:]
331 | # only the position tokens are interpolated
332 | pos_tokens = pos_embed_checkpoint[:, :-num_extra_tokens]
333 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
334 | pos_tokens = torch.nn.functional.interpolate(
335 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
336 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
337 | new_pos_embed = torch.cat((pos_tokens, extra_tokens), dim=1)
338 | checkpoint_model['pos_embed'] = new_pos_embed
339 | # interpolate position embedding done
340 |
341 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
342 | model.to(device)
343 |
344 | model_ema = None
345 | if args.model_ema:
346 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
347 | model_ema = ModelEma(
348 | model,
349 | decay=args.model_ema_decay,
350 | device='cpu' if args.model_ema_force_cpu else '',
351 | resume='')
352 | print("Using EMA with decay = %.8f" % args.model_ema_decay)
353 |
354 | model_without_ddp = model
355 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
356 |
357 | print("Model = %s" % str(model_without_ddp))
358 | print('number of params:', n_parameters)
359 |
360 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
361 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
362 | print("LR = %.8f" % args.lr)
363 | print("Batch size = %d" % total_batch_size)
364 | print("Update frequent = %d" % args.update_freq)
365 | print("Number of training examples = %d" % len(dataset_train))
366 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
367 |
368 | if args.layer_decay < 1.0 or args.layer_decay > 1.0:
369 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value.
370 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \
371 | "Layer Decay impl only supports convnext_small/base/large/xlarge"
372 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
373 | else:
374 | assigner = None
375 |
376 | if assigner is not None:
377 | print("Assigned values = %s" % str(assigner.values))
378 |
379 | if args.distributed:
380 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
381 | model_without_ddp = model.module
382 |
383 | optimizer = create_optimizer(
384 | args, model_without_ddp, skip_list=None,
385 | get_num_layer=assigner.get_layer_id if assigner is not None else None,
386 | get_layer_scale=assigner.get_scale if assigner is not None else None)
387 |
388 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used
389 |
390 | print("Use Cosine LR scheduler")
391 | lr_schedule_values = utils.cosine_scheduler(
392 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
393 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
394 | )
395 |
396 | if args.weight_decay_end is None:
397 | args.weight_decay_end = args.weight_decay
398 | wd_schedule_values = utils.cosine_scheduler(
399 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
400 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
401 |
402 | schedules = {}
403 |
404 | # At most one of dropout and stochastic depth should be enabled.
405 | assert(args.dropout == 0 or args.drop_path == 0)
406 | # ConvNeXt does not support dropout.
407 | assert(args.dropout == 0 if args.model.startswith("convnext") else True)
408 |
409 | if args.dropout > 0:
410 | schedules['do'] = drop_scheduler(
411 | args.dropout, args.epochs, num_training_steps_per_epoch,
412 | args.cutoff_epoch, args.drop_mode, args.drop_schedule)
413 | print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do'])))
414 |
415 | if args.drop_path > 0:
416 | schedules['dp'] = drop_scheduler(
417 | args.drop_path, args.epochs, num_training_steps_per_epoch,
418 | args.cutoff_epoch, args.drop_mode, args.drop_schedule)
419 | print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp'])))
420 |
421 | # Mask schedule enabled.
422 | schedules['sm'] = mask_scheduler(
423 | args.soft_mask, args.epochs, num_training_steps_per_epoch,
424 | args.cutoff_soft, args.mask_mode, args.mask_schedule)
425 | print("Min MASK = %.7f, Max MASK = %.7f" % (min(schedules['sm']), max(schedules['sm'])))
426 |
427 | if mixup_fn is not None:
428 | # smoothing is handled with mixup label transform
429 | criterion = SoftTargetCrossEntropy()
430 | elif args.smoothing > 0.:
431 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
432 | else:
433 | criterion = torch.nn.CrossEntropyLoss()
434 |
435 | print("criterion = %s" % str(criterion))
436 |
437 | utils.auto_load_model(
438 | args=args, model=model, model_without_ddp=model_without_ddp,
439 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
440 |
441 | if args.eval:
442 | print(f"Eval only mode")
443 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
444 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%")
445 | return
446 |
447 | max_accuracy = 0.0
448 | if args.model_ema and args.model_ema_eval:
449 | max_accuracy_ema = 0.0
450 |
451 | print("Start training for %d epochs" % args.epochs)
452 | start_time = time.time()
453 | for epoch in range(args.start_epoch, args.epochs):
454 | if args.distributed:
455 | data_loader_train.sampler.set_epoch(epoch)
456 | if log_writer is not None:
457 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
458 | if wandb_logger:
459 | wandb_logger.set_steps()
460 | train_stats = train_one_epoch(
461 | model, criterion, data_loader_train, optimizer,
462 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
463 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch,
464 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, schedules=schedules,
465 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
466 | use_amp=args.use_amp
467 | )
468 | if args.output_dir and args.save_ckpt:
469 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
470 | utils.save_model(
471 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
472 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
473 | if data_loader_val is not None:
474 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
475 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
476 | if max_accuracy < test_stats["acc1"]:
477 | max_accuracy = test_stats["acc1"]
478 | if args.output_dir and args.save_ckpt:
479 | utils.save_model(
480 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
481 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
482 | print(f'Max accuracy: {max_accuracy:.2f}%')
483 |
484 | if log_writer is not None:
485 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
486 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
487 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
488 |
489 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
490 | **{f'test_{k}': v for k, v in test_stats.items()},
491 | 'epoch': epoch,
492 | 'n_parameters': n_parameters}
493 |
494 | # repeat testing routines for EMA, if ema eval is turned on
495 | if args.model_ema and args.model_ema_eval:
496 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp)
497 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%")
498 | if max_accuracy_ema < test_stats_ema["acc1"]:
499 | max_accuracy_ema = test_stats_ema["acc1"]
500 | if args.output_dir and args.save_ckpt:
501 | utils.save_model(
502 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
503 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema)
504 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%')
505 | if log_writer is not None:
506 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch)
507 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}})
508 | else:
509 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
510 | 'epoch': epoch,
511 | 'n_parameters': n_parameters}
512 |
513 | if args.output_dir and utils.is_main_process():
514 | if log_writer is not None:
515 | log_writer.flush()
516 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
517 | f.write(json.dumps(log_stats) + "\n")
518 |
519 | if wandb_logger:
520 | wandb_logger.log_epoch_metrics(log_stats)
521 |
522 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir:
523 | wandb_logger.log_checkpoints()
524 |
525 |
526 | total_time = time.time() - start_time
527 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
528 | print('Training time {}'.format(total_time_str))
529 |
530 | if __name__ == '__main__':
531 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()])
532 | args = parser.parse_args()
533 | if args.output_dir:
534 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
535 | main(args)
536 |
--------------------------------------------------------------------------------
/mask_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def mask_scheduler(soft_mask_rate, epochs, niter_per_ep, cutoff_epoch=0, mode="standard", schedule="constant"):
4 | assert mode in ["standard", "soft"]
5 | if mode == "standard":
6 | return np.full(epochs * niter_per_ep, soft_mask_rate)
7 |
8 | early_iters = cutoff_epoch * niter_per_ep
9 | late_iters = (epochs - cutoff_epoch) * niter_per_ep
10 |
11 | if mode == "soft":
12 | assert schedule in ["constant", "linear"]
13 | if schedule == 'constant':
14 | early_schedule = np.full(early_iters, soft_mask_rate)
15 | elif schedule == 'linear':
16 | early_schedule = np.linspace(soft_mask_rate, 0, early_iters)
17 | final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0)))
18 |
19 | assert len(final_schedule) == epochs * niter_per_ep
20 | return final_schedule
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .illama import *
--------------------------------------------------------------------------------
/models/illama.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 | from typing import Optional, Tuple
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 | from timm.models.registry import register_model
8 |
9 |
10 | class RMSNorm(torch.nn.Module):
11 | def __init__(self, dim: int, eps: float = 1e-6):
12 | super().__init__()
13 | self.eps = eps
14 | self.weight = nn.Parameter(torch.ones(dim))
15 |
16 | def _norm(self, x):
17 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
18 |
19 | def forward(self, x):
20 | output = self._norm(x.float()).type_as(x)
21 | return output * self.weight
22 |
23 |
24 | class Mlp(nn.Module):
25 | def __init__(self, in_features, hidden_features, multiple_of=256, act_layer=nn.GELU, drop=0.):
26 | super().__init__()
27 | hidden_features = int(2 * hidden_features / 3)
28 | hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
29 |
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
31 | self.fc2 = nn.Linear(hidden_features, in_features, bias=False)
32 | self.fc3 = nn.Linear(in_features, hidden_features, bias=False)
33 |
34 | def forward(self, x):
35 | x = F.silu(self.fc1(x)) * self.fc3(x) # [B, N+1, 4D*2/3]
36 | # print(x.shape)
37 | x = self.fc2(x)
38 | return x
39 |
40 |
41 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
42 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [32], float32
43 | t = torch.arange(end, device=freqs.device) # [197], int64
44 | freqs = torch.outer(t, freqs).float() # [197, 32], float32
45 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [197, 32], complex64
46 | return freqs_cis # [197, 32], complex64
47 |
48 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
49 | ndim = x.ndim # 4, since [bsz, 197, 3, 32], complex64
50 | assert 0 <= 1 < ndim
51 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
52 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # [1, 197, 1, 32], list
53 | return freqs_cis.view(*shape) # [1, 197, 1, 32], complex64
54 |
55 | def apply_rotary_emb(
56 | xq: torch.Tensor,
57 | xk: torch.Tensor,
58 | freqs_cis: torch.Tensor,
59 | ) -> Tuple[torch.Tensor, torch.Tensor]:
60 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bsz, 197, 3, 64], float32 → [bsz, 197, 3, 32, 2], float32 → [bsz, 197, 3, 32], complex64
61 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [bsz, 197, 3, 64], float32 → [bsz, 197, 3, 32, 2], float32 → [bsz, 197, 3, 32], complex64
62 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, 197, 1, 32], complex64
63 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bsz, 197, 3, 64], float32
64 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) # [bsz, 197, 3, 64], float32
65 | return xq_out.type_as(xq), xk_out.type_as(xk) # [bsz, 197, 3, 64], float32, [bsz, 197, 3, 64], float32
66 |
67 |
68 | class Attention(nn.Module):
69 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
70 | super().__init__()
71 | self.num_heads = num_heads
72 | head_dim = dim // num_heads
73 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
74 | self.scale = qk_scale or head_dim ** -0.5
75 |
76 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
77 | # self.attn_drop = nn.Dropout(attn_drop)
78 | self.proj = nn.Linear(dim, dim)
79 | # self.proj_drop = nn.Dropout(proj_drop)
80 |
81 | def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
82 | B, N, C = x.shape
83 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) # [3, B, N, self.num_heads, C // self.num_heads]
84 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # [B, N, self.num_heads, C // self.num_heads]
85 |
86 | q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
87 |
88 | q = q.transpose(1, 2) # [B, self.num_heads, N, C // self.num_heads]
89 | k = k.transpose(1, 2) # [B, self.num_heads, N, C // self.num_heads]
90 | v = v.transpose(1, 2) # [B, self.num_heads, N, C // self.num_heads]
91 | attn = (q @ k.transpose(-2, -1)) * self.scale # [B, self.num_heads, N, N]
92 | attn = attn.softmax(dim=-1)
93 | if mask is not None:
94 | attn = attn * mask # (B, H, N, N)
95 | # attn = self.attn_drop(attn)
96 |
97 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
98 | x = self.proj(x)
99 | # x = self.proj_drop(x)
100 | return x
101 |
102 |
103 | class Block(nn.Module):
104 |
105 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
106 | drop_path=0., act_layer=nn.GELU, norm_layer=RMSNorm):
107 | super().__init__()
108 | self.norm1 = norm_layer(dim)
109 | self.attn = Attention(
110 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
111 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
112 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
113 | self.norm2 = norm_layer(dim)
114 | mlp_hidden_dim = int(dim * mlp_ratio)
115 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
116 |
117 | def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
118 | x = x + self.drop_path(self.attn(self.norm1(x), freqs_cis, mask))
119 | x = x + self.drop_path(self.mlp(self.norm2(x)))
120 | return x
121 |
122 |
123 | class PatchEmbed(nn.Module):
124 | """ Image to Patch Embedding
125 | """
126 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
127 | super().__init__()
128 | img_size = to_2tuple(img_size)
129 | patch_size = to_2tuple(patch_size)
130 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
131 | self.img_size = img_size
132 | self.patch_size = patch_size
133 | self.num_patches = num_patches
134 |
135 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
136 |
137 | def forward(self, x):
138 | B, C, H, W = x.shape
139 | # FIXME look at relaxing size constraints
140 | assert H == self.img_size[0] and W == self.img_size[1], \
141 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
142 | x = self.proj(x).flatten(2).transpose(1, 2)
143 | return x
144 |
145 |
146 | class HybridEmbed(nn.Module):
147 | """ CNN Feature Map Embedding
148 | Extract feature map from CNN, flatten, project to embedding dim.
149 | """
150 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
151 | super().__init__()
152 | assert isinstance(backbone, nn.Module)
153 | img_size = to_2tuple(img_size)
154 | self.img_size = img_size
155 | self.backbone = backbone
156 | if feature_size is None:
157 | with torch.no_grad():
158 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
159 | # map for all networks, the feature metadata has reliable channel and stride info, but using
160 | # stride to calc feature dim requires info about padding of each stage that isn't captured.
161 | training = backbone.training
162 | if training:
163 | backbone.eval()
164 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
165 | feature_size = o.shape[-2:]
166 | feature_dim = o.shape[1]
167 | backbone.train(training)
168 | else:
169 | feature_size = to_2tuple(feature_size)
170 | feature_dim = self.backbone.feature_info.channels()[-1]
171 | self.num_patches = feature_size[0] * feature_size[1]
172 | self.proj = nn.Linear(feature_dim, embed_dim)
173 |
174 | def forward(self, x):
175 | x = self.backbone(x)[-1]
176 | x = x.flatten(2).transpose(1, 2)
177 | x = self.proj(x)
178 | return x
179 |
180 |
181 | class VisionTransformer(nn.Module):
182 | """ Vision Transformer with support for patch or hybrid CNN input stage
183 | """
184 |
185 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
186 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
187 | drop_path_rate=0., hybrid_backbone=None, norm_layer=RMSNorm):
188 | super().__init__()
189 | self.num_classes = num_classes
190 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
191 | # I add these two lines
192 | self.drop_rate=drop_rate
193 | attn_drop_rate=drop_rate
194 | # I add these one line
195 | self.soft_mask_rate = 0.0
196 | if hybrid_backbone is not None:
197 | self.patch_embed = HybridEmbed(
198 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
199 | else:
200 | self.patch_embed = PatchEmbed(
201 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
202 | num_patches = self.patch_embed.num_patches
203 |
204 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
205 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
206 | self.pos_drop = nn.Dropout(p=drop_rate)
207 | self.depth = depth
208 |
209 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
210 | self.blocks = nn.ModuleList([
211 | Block(
212 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
213 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
214 | for i in range(depth)])
215 | self.norm = norm_layer(embed_dim)
216 |
217 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
218 | #self.repr = nn.Linear(embed_dim, representation_size)
219 | #self.repr_act = nn.Tanh()
220 |
221 | # Classifier head
222 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
223 |
224 | trunc_normal_(self.pos_embed, std=.02)
225 | trunc_normal_(self.cls_token, std=.02)
226 | self.apply(self._init_weights)
227 |
228 | self.freqs_cis = precompute_freqs_cis(
229 | self.num_features // num_heads, num_patches + 1
230 | )
231 |
232 | def _init_weights(self, m):
233 | if isinstance(m, nn.Linear):
234 | trunc_normal_(m.weight, std=.02)
235 | if isinstance(m, nn.Linear) and m.bias is not None:
236 | nn.init.constant_(m.bias, 0)
237 | elif isinstance(m, nn.LayerNorm):
238 | nn.init.constant_(m.bias, 0)
239 | nn.init.constant_(m.weight, 1.0)
240 |
241 | @torch.jit.ignore
242 | def no_weight_decay(self):
243 | return {'pos_embed', 'cls_token'}
244 |
245 | def get_classifier(self):
246 | return self.head
247 |
248 | def reset_classifier(self, num_classes, global_pool=''):
249 | self.num_classes = num_classes
250 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
251 |
252 | def forward_features(self, x):
253 | B = x.shape[0]
254 | x = self.patch_embed(x)
255 | freqs_cis = self.freqs_cis.to(x.device)
256 |
257 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
258 | x = torch.cat((x, cls_tokens), dim=1)
259 | x = x + self.pos_embed
260 | x = self.pos_drop(x)
261 | N = x.shape[1]
262 |
263 | mask = None
264 | if N > 1:
265 | mask_bidirectional = torch.full((1, 1, N, N), 1, device=x.device).type_as(x)
266 | mask_casual = torch.full((1, 1, N, N), 0, device=x.device)
267 | triu_mask = torch.tril(torch.ones((N, N), dtype=torch.bool))
268 | mask_casual[:, :, triu_mask] = 1
269 | mask_casual = mask_casual.type_as(x)
270 | # soft_mask_rate: from one to zero
271 | mask = self.soft_mask_rate * mask_bidirectional + (1 - self.soft_mask_rate) * mask_casual
272 |
273 | for blk in self.blocks:
274 | x = blk(x, freqs_cis, mask)
275 |
276 | x = self.norm(x)
277 | return x[:, -1]
278 |
279 | def forward(self, x):
280 | x = self.forward_features(x)
281 | x = self.head(x)
282 | return x
283 |
284 | def update_drop_path(self, drop_path_rate):
285 | self.drop_path = drop_path_rate
286 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, self.depth)]
287 | for i in range(self.depth):
288 | self.blocks[i].drop_path.drop_prob = dp_rates[i]
289 |
290 | def update_dropout(self, drop_rate):
291 | self.drop_rate = drop_rate
292 | for module in self.modules():
293 | if isinstance(module, nn.Dropout):
294 | module.p = drop_rate
295 |
296 | def update_soft_mask(self, soft_mask_rate):
297 | self.soft_mask_rate = soft_mask_rate
298 |
299 |
300 | @register_model
301 | def illama_tiny(pretrained=False, **kwargs):
302 | model = VisionTransformer(
303 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
304 | norm_layer=partial(RMSNorm, eps=1e-6), **kwargs)
305 | return model
306 |
307 | @register_model
308 | def illama_small(pretrained=False, **kwargs):
309 | model = VisionTransformer(
310 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
311 | norm_layer=partial(RMSNorm, eps=1e-6), **kwargs)
312 | return model
313 |
314 | @register_model
315 | def illama_base(pretrained=False, **kwargs):
316 | model = VisionTransformer(
317 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
318 | norm_layer=partial(RMSNorm, eps=1e-6), **kwargs)
319 | return model
320 |
321 | @register_model
322 | def illama_large(pretrained=False, **kwargs):
323 | model = VisionTransformer(
324 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
325 | norm_layer=partial(RMSNorm, eps=1e-6), **kwargs)
326 | return model
327 |
328 |
329 |
--------------------------------------------------------------------------------
/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 |
4 | from timm.optim.adafactor import Adafactor
5 | from timm.optim.adahessian import Adahessian
6 | from timm.optim.adamp import AdamP
7 | from timm.optim.lookahead import Lookahead
8 | from timm.optim.nadam import Nadam
9 | # from timm.optim.novograd import NovoGrad
10 | from timm.optim.nvnovograd import NvNovoGrad
11 | from timm.optim.radam import RAdam
12 | from timm.optim.rmsprop_tf import RMSpropTF
13 | from timm.optim.sgdp import SGDP
14 |
15 | import json
16 |
17 | try:
18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | def get_num_layer_for_convnext(var_name):
25 | """
26 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three
27 | consecutive blocks, including possible neighboring downsample layers;
28 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
29 | """
30 | num_max_layer = 12
31 | if var_name.startswith("downsample_layers"):
32 | stage_id = int(var_name.split('.')[1])
33 | if stage_id == 0:
34 | layer_id = 0
35 | elif stage_id == 1 or stage_id == 2:
36 | layer_id = stage_id + 1
37 | elif stage_id == 3:
38 | layer_id = 12
39 | return layer_id
40 |
41 | elif var_name.startswith("stages"):
42 | stage_id = int(var_name.split('.')[1])
43 | block_id = int(var_name.split('.')[2])
44 | if stage_id == 0 or stage_id == 1:
45 | layer_id = stage_id + 1
46 | elif stage_id == 2:
47 | layer_id = 3 + block_id // 3
48 | elif stage_id == 3:
49 | layer_id = 12
50 | return layer_id
51 | else:
52 | return num_max_layer + 1
53 |
54 | class LayerDecayValueAssigner(object):
55 | def __init__(self, values):
56 | self.values = values
57 |
58 | def get_scale(self, layer_id):
59 | return self.values[layer_id]
60 |
61 | def get_layer_id(self, var_name):
62 | return get_num_layer_for_convnext(var_name)
63 |
64 |
65 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
66 | parameter_group_names = {}
67 | parameter_group_vars = {}
68 |
69 | for name, param in model.named_parameters():
70 | if not param.requires_grad:
71 | continue # frozen weights
72 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
73 | group_name = "no_decay"
74 | this_weight_decay = 0.
75 | else:
76 | group_name = "decay"
77 | this_weight_decay = weight_decay
78 | if get_num_layer is not None:
79 | layer_id = get_num_layer(name)
80 | group_name = "layer_%d_%s" % (layer_id, group_name)
81 | else:
82 | layer_id = None
83 |
84 | if group_name not in parameter_group_names:
85 | if get_layer_scale is not None:
86 | scale = get_layer_scale(layer_id)
87 | else:
88 | scale = 1.
89 |
90 | parameter_group_names[group_name] = {
91 | "weight_decay": this_weight_decay,
92 | "params": [],
93 | "lr_scale": scale
94 | }
95 | parameter_group_vars[group_name] = {
96 | "weight_decay": this_weight_decay,
97 | "params": [],
98 | "lr_scale": scale
99 | }
100 |
101 | parameter_group_vars[group_name]["params"].append(param)
102 | parameter_group_names[group_name]["params"].append(name)
103 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
104 | return list(parameter_group_vars.values())
105 |
106 |
107 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
108 | opt_lower = args.opt.lower()
109 | weight_decay = args.weight_decay
110 | # if weight_decay and filter_bias_and_bn:
111 | if filter_bias_and_bn:
112 | skip = {}
113 | if skip_list is not None:
114 | skip = skip_list
115 | elif hasattr(model, 'no_weight_decay'):
116 | skip = model.no_weight_decay()
117 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
118 | weight_decay = 0.
119 | else:
120 | parameters = model.parameters()
121 |
122 | if 'fused' in opt_lower:
123 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
124 |
125 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
126 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
127 | opt_args['eps'] = args.opt_eps
128 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
129 | opt_args['betas'] = args.opt_betas
130 |
131 | opt_split = opt_lower.split('_')
132 | opt_lower = opt_split[-1]
133 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
134 | opt_args.pop('eps', None)
135 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
136 | elif opt_lower == 'momentum':
137 | opt_args.pop('eps', None)
138 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
139 | elif opt_lower == 'adam':
140 | optimizer = optim.Adam(parameters, **opt_args)
141 | elif opt_lower == 'adamw':
142 | optimizer = optim.AdamW(parameters, **opt_args)
143 | elif opt_lower == 'nadam':
144 | optimizer = Nadam(parameters, **opt_args)
145 | elif opt_lower == 'radam':
146 | optimizer = RAdam(parameters, **opt_args)
147 | elif opt_lower == 'adamp':
148 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
149 | elif opt_lower == 'sgdp':
150 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
151 | elif opt_lower == 'adadelta':
152 | optimizer = optim.Adadelta(parameters, **opt_args)
153 | elif opt_lower == 'adafactor':
154 | if not args.lr:
155 | opt_args['lr'] = None
156 | optimizer = Adafactor(parameters, **opt_args)
157 | elif opt_lower == 'adahessian':
158 | optimizer = Adahessian(parameters, **opt_args)
159 | elif opt_lower == 'rmsprop':
160 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
161 | elif opt_lower == 'rmsproptf':
162 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
163 | elif opt_lower == 'novograd':
164 | optimizer = NovoGrad(parameters, **opt_args)
165 | elif opt_lower == 'nvnovograd':
166 | optimizer = NvNovoGrad(parameters, **opt_args)
167 | elif opt_lower == 'fusedsgd':
168 | opt_args.pop('eps', None)
169 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
170 | elif opt_lower == 'fusedmomentum':
171 | opt_args.pop('eps', None)
172 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
173 | elif opt_lower == 'fusedadam':
174 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
175 | elif opt_lower == 'fusedadamw':
176 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
177 | elif opt_lower == 'fusedlamb':
178 | optimizer = FusedLAMB(parameters, **opt_args)
179 | elif opt_lower == 'fusednovograd':
180 | opt_args.setdefault('betas', (0.95, 0.98))
181 | optimizer = FusedNovoGrad(parameters, **opt_args)
182 | else:
183 | assert False and "Invalid optimizer"
184 |
185 | if len(opt_split) > 1:
186 | if opt_split[0] == 'lookahead':
187 | optimizer = Lookahead(optimizer)
188 |
189 | return optimizer
190 |
--------------------------------------------------------------------------------
/scripts/eval_illama_in1k_224.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | root_imagenet='/mnt/petrelfs/wangjiahao/datasets/classificaton/imagenet/'
4 |
5 | # illama-tiny: 75.0
6 | MODEL=illama_tiny
7 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-tiny-in1k-75.0.pth'
8 |
9 | srun -p gvembodied \
10 | --job-name=evaluation_224 \
11 | --gres=gpu:2 \
12 | --cpus-per-task=32 \
13 | --preempt \
14 | --quotatype=spot \
15 | python -m torch.distributed.launch --nproc_per_node=2 main.py \
16 | --model $MODEL --eval true \
17 | --data_path $root_imagenet \
18 | --resume $RESUME
19 |
20 |
21 | # illama-small: 79.9
22 | MODEL=illama_small
23 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-small-in1k-79.9.pth'
24 |
25 | srun -p gvembodied \
26 | --job-name=evaluation_224 \
27 | --gres=gpu:2 \
28 | --cpus-per-task=32 \
29 | --preempt \
30 | --quotatype=spot \
31 | python -m torch.distributed.launch --nproc_per_node=2 main.py \
32 | --model $MODEL --eval true \
33 | --data_path $root_imagenet \
34 | --resume $RESUME
35 |
36 |
37 | # illama-base: 81.6
38 | MODEL=illama_base
39 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-base-in1k-81.6.pth'
40 |
41 | srun -p gvembodied \
42 | --job-name=evaluation_224 \
43 | --gres=gpu:2 \
44 | --cpus-per-task=32 \
45 | --preempt \
46 | --quotatype=spot \
47 | python -m torch.distributed.launch --nproc_per_node=2 main.py \
48 | --model $MODEL --eval true \
49 | --data_path $root_imagenet \
50 | --resume $RESUME
51 |
52 |
53 | # illama-base: 83.6
54 | MODEL=illama_base
55 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-base-in21kin1k-224-83.6.pth'
56 |
57 | srun -p gvembodied \
58 | --job-name=evaluation_224 \
59 | --gres=gpu:2 \
60 | --cpus-per-task=32 \
61 | --preempt \
62 | --quotatype=spot \
63 | python -m torch.distributed.launch --nproc_per_node=2 main.py \
64 | --model $MODEL --eval true \
65 | --data_path $root_imagenet \
66 | --resume $RESUME
67 |
68 |
69 | # illama-large: 84.8
70 | MODEL=illama_large
71 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-large-in21kin1k-224-84.8.pth'
72 |
73 | srun -p gvembodied \
74 | --job-name=evaluation_224 \
75 | --gres=gpu:2 \
76 | --cpus-per-task=32 \
77 | --preempt \
78 | --quotatype=spot \
79 | python -m torch.distributed.launch --nproc_per_node=2 main.py \
80 | --model $MODEL --eval true \
81 | --data_path $root_imagenet \
82 | --resume $RESUME
--------------------------------------------------------------------------------
/scripts/eval_illama_in1k_384.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | root_imagenet='/mnt/petrelfs/wangjiahao/datasets/classificaton/imagenet/'
4 |
5 | # illama-base: 83.0
6 | MODEL=illama_base
7 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-base-in1k-384-83.0.pth'
8 |
9 | srun -p gvembodied \
10 | --job-name=evaluation_384 \
11 | --gres=gpu:2 \
12 | --cpus-per-task=32 \
13 | --preempt \
14 | --quotatype=spot \
15 | python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \
16 | --model $MODEL --input_size 384 --eval true \
17 | --data_path $root_imagenet \
18 | --resume $RESUME
19 |
20 |
21 | # illama-base: 85.0
22 | MODEL=illama_base
23 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-base-in21kin1k-384-85.0.pth'
24 |
25 | srun -p gvembodied \
26 | --job-name=evaluation_384 \
27 | --gres=gpu:2 \
28 | --cpus-per-task=32 \
29 | --preempt \
30 | --quotatype=spot \
31 | python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \
32 | --model $MODEL --input_size 384 --eval true \
33 | --data_path $root_imagenet \
34 | --resume $RESUME
35 |
36 |
37 | # illama-large: 86.0
38 | MODEL=illama_large
39 | RESUME='/mnt/petrelfs/wangjiahao/DoiT/pretrained/illama-large-in21kin1k-384-86.0.pth'
40 |
41 | srun -p gvembodied \
42 | --job-name=evaluation_384 \
43 | --gres=gpu:2 \
44 | --cpus-per-task=32 \
45 | --preempt \
46 | --quotatype=spot \
47 | python -m torch.distributed.launch --nproc_per_node=2 main_soft_fthr.py \
48 | --model $MODEL --input_size 384 --eval true \
49 | --data_path $root_imagenet \
50 | --resume $RESUME
--------------------------------------------------------------------------------
/scripts/train_illama_base_from_llama2.sh:
--------------------------------------------------------------------------------
1 | root_imagenet='/your/path/to/imagenet/'
2 | MODEL=illama_base
3 | OUTPUT='output/path'
4 | FINETUNE='/your/path/to/llama2/pretrained/illama_ws_base.pth'
5 |
6 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \
7 | --model $MODEL --epochs 300 --mixup 0.95 --cutmix 1.0 \
8 | --batch_size 128 --lr 4e-3 --update_freq 4 \
9 | --drop_path 0.4 --drop_mode standard \
10 | --mask_mode soft --mask_schedule linear --cutoff_soft 25 \
11 | --finetune $FINETUNE \
12 | --data_path $root_imagenet \
13 | --output_dir $OUTPUT
--------------------------------------------------------------------------------
/scripts/train_illama_base_in1k.sh:
--------------------------------------------------------------------------------
1 | root_imagenet='/your/path/to/imagenet/'
2 | MODEL=illama_base
3 | OUTPUT='output/path'
4 |
5 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \
6 | --model $MODEL --epochs 300 --mixup 0.95 --cutmix 1.0 \
7 | --batch_size 128 --lr 4e-3 --update_freq 4 \
8 | --drop_path 0.4 --drop_mode standard \
9 | --mask_mode soft --mask_schedule linear --cutoff_soft 25 \
10 | --data_path $root_imagenet \
11 | --output_dir $OUTPUT
--------------------------------------------------------------------------------
/scripts/train_illama_small_from_llama2.sh:
--------------------------------------------------------------------------------
1 | root_imagenet='/your/path/to/imagenet/'
2 | MODEL=illama_small
3 | OUTPUT='output/path'
4 | FINETUNE='/your/path/to/llama2/pretrained/illama_ws_small.pth'
5 |
6 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \
7 | --model $MODEL --epochs 300 --warmup_epochs 5 --mixup 0.5 --cutmix 0.5 \
8 | --batch_size 128 --lr 4e-3 --update_freq 4 \
9 | --drop_path 0.1 --drop_mode standard \
10 | --mask_mode soft --mask_schedule linear --cutoff_soft 50 \
11 | --finetune $FINETUNE \
12 | --data_path $root_imagenet \
13 | --output_dir $OUTPUT
--------------------------------------------------------------------------------
/scripts/train_illama_small_in1k.sh:
--------------------------------------------------------------------------------
1 | root_imagenet='/your/path/to/imagenet/'
2 | MODEL=illama_small
3 | OUTPUT='output/path'
4 |
5 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \
6 | --model $MODEL --epochs 300 --warmup_epochs 5 --mixup 0.5 --cutmix 0.5 \
7 | --batch_size 128 --lr 4e-3 --update_freq 4 \
8 | --drop_path 0.1 --drop_mode standard \
9 | --mask_mode soft --mask_schedule linear --cutoff_soft 50 \
10 | --data_path $root_imagenet \
11 | --output_dir $OUTPUT
--------------------------------------------------------------------------------
/scripts/train_illama_tiny_from_llama2.sh:
--------------------------------------------------------------------------------
1 | root_imagenet='/your/path/to/imagenet/'
2 | MODEL=illama_tiny
3 | OUTPUT='output/path'
4 | FINETUNE='/your/path/to/llama2/pretrained/illama_ws_tiny.pth'
5 |
6 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \
7 | --model $MODEL --epochs 300 --warmup_epochs 5 --mixup 0.1 --cutmix 0.1 \
8 | --batch_size 128 --lr 4e-3 --update_freq 4 \
9 | --dropout 0 --drop_mode standard \
10 | --mask_mode soft --mask_schedule constant --cutoff_soft 50 \
11 | --finetune $FINETUNE \
12 | --data_path $root_imagenet \
13 | --output_dir $OUTPUT
--------------------------------------------------------------------------------
/scripts/train_illama_tiny_in1k.sh:
--------------------------------------------------------------------------------
1 | root_imagenet='/your/path/to/imagenet/'
2 | MODEL=illama_tiny
3 | OUTPUT='output/path'
4 |
5 | python -m torch.distributed.launch --nproc_per_node=8 main_soft.py \
6 | --model $MODEL --epochs 300 --warmup_epochs 5 --mixup 0.1 --cutmix 0.1 \
7 | --batch_size 128 --lr 4e-3 --update_freq 4 \
8 | --dropout 0 --drop_mode standard \
9 | --mask_mode soft --mask_schedule constant --cutoff_soft 50 \
10 | --data_path $root_imagenet \
11 | --output_dir $OUTPUT
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import os
4 | import math
5 | import time
6 | from collections import defaultdict, deque
7 | import datetime
8 | import numpy as np
9 | from timm.utils import get_state_dict
10 |
11 | from pathlib import Path
12 | from timm.models import create_model
13 | import torch
14 | import torch.distributed as dist
15 | from torch._six import inf
16 |
17 | from tensorboardX import SummaryWriter
18 |
19 | class SmoothedValue(object):
20 | """Track a series of values and provide access to smoothed values over a
21 | window or the global series average.
22 | """
23 |
24 | def __init__(self, window_size=20, fmt=None):
25 | if fmt is None:
26 | fmt = "{median:.4f} ({global_avg:.4f})"
27 | self.deque = deque(maxlen=window_size)
28 | self.total = 0.0
29 | self.count = 0
30 | self.fmt = fmt
31 |
32 | def update(self, value, n=1):
33 | self.deque.append(value)
34 | self.count += n
35 | self.total += value * n
36 |
37 | def synchronize_between_processes(self):
38 | """
39 | Warning: does not synchronize the deque!
40 | """
41 | if not is_dist_avail_and_initialized():
42 | return
43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
44 | dist.barrier()
45 | dist.all_reduce(t)
46 | t = t.tolist()
47 | self.count = int(t[0])
48 | self.total = t[1]
49 |
50 | @property
51 | def median(self):
52 | d = torch.tensor(list(self.deque))
53 | return d.median().item()
54 |
55 | @property
56 | def avg(self):
57 | d = torch.tensor(list(self.deque), dtype=torch.float32)
58 | return d.mean().item()
59 |
60 | @property
61 | def global_avg(self):
62 | return self.total / self.count
63 |
64 | @property
65 | def max(self):
66 | return max(self.deque)
67 |
68 | @property
69 | def value(self):
70 | return self.deque[-1]
71 |
72 | def __str__(self):
73 | return self.fmt.format(
74 | median=self.median,
75 | avg=self.avg,
76 | global_avg=self.global_avg,
77 | max=self.max,
78 | value=self.value)
79 |
80 |
81 | class MetricLogger(object):
82 | def __init__(self, delimiter="\t"):
83 | self.meters = defaultdict(SmoothedValue)
84 | self.delimiter = delimiter
85 |
86 | def update(self, **kwargs):
87 | for k, v in kwargs.items():
88 | if v is None:
89 | continue
90 | if isinstance(v, torch.Tensor):
91 | v = v.item()
92 | assert isinstance(v, (float, int))
93 | self.meters[k].update(v)
94 |
95 | def __getattr__(self, attr):
96 | if attr in self.meters:
97 | return self.meters[attr]
98 | if attr in self.__dict__:
99 | return self.__dict__[attr]
100 | raise AttributeError("'{}' object has no attribute '{}'".format(
101 | type(self).__name__, attr))
102 |
103 | def __str__(self):
104 | loss_str = []
105 | for name, meter in self.meters.items():
106 | loss_str.append(
107 | "{}: {}".format(name, str(meter))
108 | )
109 | return self.delimiter.join(loss_str)
110 |
111 | def synchronize_between_processes(self):
112 | for meter in self.meters.values():
113 | meter.synchronize_between_processes()
114 |
115 | def add_meter(self, name, meter):
116 | self.meters[name] = meter
117 |
118 | def log_every(self, iterable, print_freq, header=None):
119 | i = 0
120 | if not header:
121 | header = ''
122 | start_time = time.time()
123 | end = time.time()
124 | iter_time = SmoothedValue(fmt='{avg:.4f}')
125 | data_time = SmoothedValue(fmt='{avg:.4f}')
126 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
127 | log_msg = [
128 | header,
129 | '[{0' + space_fmt + '}/{1}]',
130 | 'eta: {eta}',
131 | '{meters}',
132 | 'time: {time}',
133 | 'data: {data}'
134 | ]
135 | if torch.cuda.is_available():
136 | log_msg.append('max mem: {memory:.0f}')
137 | log_msg = self.delimiter.join(log_msg)
138 | MB = 1024.0 * 1024.0
139 | for obj in iterable:
140 | data_time.update(time.time() - end)
141 | yield obj
142 | iter_time.update(time.time() - end)
143 | if i % print_freq == 0 or i == len(iterable) - 1:
144 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
145 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
146 | if torch.cuda.is_available():
147 | print(log_msg.format(
148 | i, len(iterable), eta=eta_string,
149 | meters=str(self),
150 | time=str(iter_time), data=str(data_time),
151 | memory=torch.cuda.max_memory_allocated() / MB))
152 | else:
153 | print(log_msg.format(
154 | i, len(iterable), eta=eta_string,
155 | meters=str(self),
156 | time=str(iter_time), data=str(data_time)))
157 | i += 1
158 | end = time.time()
159 | total_time = time.time() - start_time
160 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
161 | print('{} Total time: {} ({:.4f} s / it)'.format(
162 | header, total_time_str, total_time / len(iterable)))
163 |
164 |
165 | class TensorboardLogger(object):
166 | def __init__(self, log_dir):
167 | self.writer = SummaryWriter(logdir=log_dir)
168 | self.step = 0
169 |
170 | def set_step(self, step=None):
171 | if step is not None:
172 | self.step = step
173 | else:
174 | self.step += 1
175 |
176 | def update(self, head='scalar', step=None, **kwargs):
177 | for k, v in kwargs.items():
178 | if v is None:
179 | continue
180 | if isinstance(v, torch.Tensor):
181 | v = v.item()
182 | assert isinstance(v, (float, int))
183 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
184 |
185 | def flush(self):
186 | self.writer.flush()
187 |
188 |
189 | class WandbLogger(object):
190 | def __init__(self, args):
191 | self.args = args
192 |
193 | try:
194 | import wandb
195 | self._wandb = wandb
196 | except ImportError:
197 | raise ImportError(
198 | "To use the Weights and Biases Logger please install wandb."
199 | "Run `pip install wandb` to install it."
200 | )
201 |
202 | # Initialize a W&B run
203 | if self._wandb.run is None:
204 | self._wandb.init(
205 | project=args.project,
206 | config=args
207 | )
208 |
209 | def log_epoch_metrics(self, metrics, commit=True):
210 | """
211 | Log train/test metrics onto W&B.
212 | """
213 | # Log number of model parameters as W&B summary
214 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
215 | metrics.pop('n_parameters', None)
216 |
217 | # Log current epoch
218 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
219 | metrics.pop('epoch')
220 |
221 | for k, v in metrics.items():
222 | if 'train' in k:
223 | self._wandb.log({f'Global Train/{k}': v}, commit=False)
224 | elif 'test' in k:
225 | self._wandb.log({f'Global Test/{k}': v}, commit=False)
226 |
227 | self._wandb.log({})
228 |
229 | def log_checkpoints(self):
230 | output_dir = self.args.output_dir
231 | model_artifact = self._wandb.Artifact(
232 | self._wandb.run.id + "_model", type="model"
233 | )
234 |
235 | model_artifact.add_dir(output_dir)
236 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
237 |
238 | def set_steps(self):
239 | # Set global training step
240 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
241 | # Set epoch-wise step
242 | self._wandb.define_metric('Global Train/*', step_metric='epoch')
243 | self._wandb.define_metric('Global Test/*', step_metric='epoch')
244 |
245 |
246 | def setup_for_distributed(is_master):
247 | """
248 | This function disables printing when not in master process
249 | """
250 | import builtins as __builtin__
251 | builtin_print = __builtin__.print
252 |
253 | def print(*args, **kwargs):
254 | force = kwargs.pop('force', False)
255 | if is_master or force:
256 | builtin_print(*args, **kwargs)
257 |
258 | __builtin__.print = print
259 |
260 |
261 | def is_dist_avail_and_initialized():
262 | if not dist.is_available():
263 | return False
264 | if not dist.is_initialized():
265 | return False
266 | return True
267 |
268 |
269 | def get_world_size():
270 | if not is_dist_avail_and_initialized():
271 | return 1
272 | return dist.get_world_size()
273 |
274 |
275 | def get_rank():
276 | if not is_dist_avail_and_initialized():
277 | return 0
278 | return dist.get_rank()
279 |
280 |
281 | def is_main_process():
282 | return get_rank() == 0
283 |
284 |
285 | def save_on_master(*args, **kwargs):
286 | if is_main_process():
287 | torch.save(*args, **kwargs)
288 |
289 |
290 | def init_distributed_mode(args):
291 |
292 | if args.dist_on_itp:
293 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
294 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
295 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
296 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
297 | os.environ['LOCAL_RANK'] = str(args.gpu)
298 | os.environ['RANK'] = str(args.rank)
299 | os.environ['WORLD_SIZE'] = str(args.world_size)
300 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
301 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
302 | args.rank = int(os.environ["RANK"])
303 | args.world_size = int(os.environ['WORLD_SIZE'])
304 | args.gpu = int(os.environ['LOCAL_RANK'])
305 | elif 'SLURM_PROCID' in os.environ:
306 | args.rank = int(os.environ['SLURM_PROCID'])
307 | args.gpu = args.rank % torch.cuda.device_count()
308 |
309 | os.environ['RANK'] = str(args.rank)
310 | os.environ['LOCAL_RANK'] = str(args.gpu)
311 | os.environ['WORLD_SIZE'] = str(args.world_size)
312 | else:
313 | print('Not using distributed mode')
314 | args.distributed = False
315 | return
316 |
317 | args.distributed = True
318 |
319 | torch.cuda.set_device(args.gpu)
320 | args.dist_backend = 'nccl'
321 | print('| distributed init (rank {}): {}, gpu {}'.format(
322 | args.rank, args.dist_url, args.gpu), flush=True)
323 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
324 | world_size=args.world_size, rank=args.rank)
325 | torch.distributed.barrier()
326 | setup_for_distributed(args.rank == 0)
327 |
328 |
329 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
330 | missing_keys = []
331 | unexpected_keys = []
332 | error_msgs = []
333 | # copy state_dict so _load_from_state_dict can modify it
334 | metadata = getattr(state_dict, '_metadata', None)
335 | state_dict = state_dict.copy()
336 | if metadata is not None:
337 | state_dict._metadata = metadata
338 |
339 | def load(module, prefix=''):
340 | local_metadata = {} if metadata is None else metadata.get(
341 | prefix[:-1], {})
342 | module._load_from_state_dict(
343 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
344 | for name, child in module._modules.items():
345 | if child is not None:
346 | load(child, prefix + name + '.')
347 |
348 | load(model, prefix=prefix)
349 |
350 | warn_missing_keys = []
351 | ignore_missing_keys = []
352 | for key in missing_keys:
353 | keep_flag = True
354 | for ignore_key in ignore_missing.split('|'):
355 | if ignore_key in key:
356 | keep_flag = False
357 | break
358 | if keep_flag:
359 | warn_missing_keys.append(key)
360 | else:
361 | ignore_missing_keys.append(key)
362 |
363 | missing_keys = warn_missing_keys
364 |
365 | if len(missing_keys) > 0:
366 | print("Weights of {} not initialized from pretrained model: {}".format(
367 | model.__class__.__name__, missing_keys))
368 | if len(unexpected_keys) > 0:
369 | print("Weights from pretrained model not used in {}: {}".format(
370 | model.__class__.__name__, unexpected_keys))
371 | if len(ignore_missing_keys) > 0:
372 | print("Ignored weights of {} not initialized from pretrained model: {}".format(
373 | model.__class__.__name__, ignore_missing_keys))
374 | if len(error_msgs) > 0:
375 | print('\n'.join(error_msgs))
376 |
377 |
378 | class NativeScalerWithGradNormCount:
379 | state_dict_key = "amp_scaler"
380 |
381 | def __init__(self):
382 | self._scaler = torch.cuda.amp.GradScaler()
383 |
384 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
385 | self._scaler.scale(loss).backward(create_graph=create_graph)
386 | if update_grad:
387 | if clip_grad is not None:
388 | assert parameters is not None
389 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
390 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
391 | else:
392 | self._scaler.unscale_(optimizer)
393 | norm = get_grad_norm_(parameters)
394 | self._scaler.step(optimizer)
395 | self._scaler.update()
396 | else:
397 | norm = None
398 | return norm
399 |
400 | def state_dict(self):
401 | return self._scaler.state_dict()
402 |
403 | def load_state_dict(self, state_dict):
404 | self._scaler.load_state_dict(state_dict)
405 |
406 |
407 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
408 | if isinstance(parameters, torch.Tensor):
409 | parameters = [parameters]
410 | parameters = [p for p in parameters if p.grad is not None]
411 | norm_type = float(norm_type)
412 | if len(parameters) == 0:
413 | return torch.tensor(0.)
414 | device = parameters[0].grad.device
415 | if norm_type == inf:
416 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
417 | else:
418 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
419 | return total_norm
420 |
421 |
422 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
423 | start_warmup_value=0, warmup_steps=-1):
424 | warmup_schedule = np.array([])
425 | warmup_iters = warmup_epochs * niter_per_ep
426 | if warmup_steps > 0:
427 | warmup_iters = warmup_steps
428 | print("Set warmup steps = %d" % warmup_iters)
429 | if warmup_epochs > 0:
430 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
431 |
432 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
433 | schedule = np.array(
434 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
435 |
436 | schedule = np.concatenate((warmup_schedule, schedule))
437 |
438 | assert len(schedule) == epochs * niter_per_ep
439 | return schedule
440 |
441 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
442 | output_dir = Path(args.output_dir)
443 | epoch_name = str(epoch)
444 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
445 | for checkpoint_path in checkpoint_paths:
446 | to_save = {
447 | 'model': model_without_ddp.state_dict(),
448 | 'optimizer': optimizer.state_dict(),
449 | 'epoch': epoch,
450 | 'scaler': loss_scaler.state_dict(),
451 | 'args': args,
452 | }
453 |
454 | if model_ema is not None:
455 | to_save['model_ema'] = get_state_dict(model_ema)
456 |
457 | save_on_master(to_save, checkpoint_path)
458 |
459 | if is_main_process() and isinstance(epoch, int):
460 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
461 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
462 | if os.path.exists(old_ckpt):
463 | os.remove(old_ckpt)
464 |
465 |
466 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
467 | output_dir = Path(args.output_dir)
468 | if args.auto_resume and len(args.resume) == 0:
469 | import glob
470 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
471 | latest_ckpt = -1
472 | for ckpt in all_checkpoints:
473 | t = ckpt.split('-')[-1].split('.')[0]
474 | if t.isdigit():
475 | latest_ckpt = max(int(t), latest_ckpt)
476 | if latest_ckpt >= 0:
477 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
478 | print("Auto resume checkpoint: %s" % args.resume)
479 |
480 | if args.resume:
481 | if args.resume.startswith('https'):
482 | checkpoint = torch.hub.load_state_dict_from_url(
483 | args.resume, map_location='cpu', check_hash=True)
484 | else:
485 | checkpoint = torch.load(args.resume, map_location='cpu')
486 | if 'model' in checkpoint:
487 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
488 | else:
489 | model_without_ddp.load_state_dict(checkpoint, strict=False)
490 | # model_without_ddp.load_state_dict(checkpoint['model'])
491 | print("Resume checkpoint %s" % args.resume)
492 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
493 | optimizer.load_state_dict(checkpoint['optimizer'])
494 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
495 | args.start_epoch = checkpoint['epoch'] + 1
496 | else:
497 | assert args.eval, 'Does not support resuming with checkpoint-best'
498 | if hasattr(args, 'model_ema') and args.model_ema:
499 | if 'model_ema' in checkpoint.keys():
500 | model_ema.ema.load_state_dict(checkpoint['model_ema'])
501 | else:
502 | model_ema.ema.load_state_dict(checkpoint['model'])
503 | if 'scaler' in checkpoint:
504 | loss_scaler.load_state_dict(checkpoint['scaler'])
505 | print("With optim & sched!")
506 |
507 | def reg_scheduler(base_value, final_value, epochs, niter_per_ep, early_epochs=0, early_value=None,
508 | mode='linear', early_mode='regular'):
509 | early_schedule = np.array([])
510 | early_iters = early_epochs * niter_per_ep
511 | if early_value is None:
512 | early_value = final_value
513 | if early_epochs > 0:
514 | print(f"Set early value to {early_mode} {early_value}")
515 | if early_mode == 'regular':
516 | early_schedule = np.array([early_value] * early_iters)
517 | elif early_mode == 'linear':
518 | early_schedule = np.linspace(early_value, base_value, early_iters)
519 | elif early_mode == 'cosine':
520 | early_schedule = np.array(
521 | [base_value + 0.5 * (early_value - base_value) * (1 + math.cos(math.pi * i / early_iters)) for i in np.arange(early_iters)])
522 | regular_epochs = epochs - early_epochs
523 | iters = np.arange(regular_epochs * niter_per_ep)
524 | schedule = np.linspace(base_value, final_value, len(iters))
525 | schedule = np.concatenate((early_schedule, schedule))
526 |
527 | assert len(schedule) == epochs * niter_per_ep
528 | return schedule
529 |
530 | def calculate_distance(args, model_without_ddp, device):
531 | output_dir = Path(args.output_dir)
532 | start_path = os.path.join(output_dir, 'checkpoint-start.pth')
533 | if not os.path.exists(start_path):
534 | return -1
535 | model_start = build_model(args)
536 | checkpoint_start = torch.load(start_path, map_location='cpu')
537 | model_start.load_state_dict(checkpoint_start['model'])
538 | model_start.to(device)
539 | cur = torch.tensor([]).to(device)
540 | start = torch.tensor([]).to(device)
541 | with torch.no_grad():
542 | for name, p in model_without_ddp.named_parameters():
543 | cur = torch.cat((cur, p.flatten().clone().detach()))
544 | for name, p in model_start.named_parameters():
545 | start = torch.cat((start, p.flatten().clone().detach()))
546 | return torch.nn.MSELoss()(start, cur).item()
547 |
548 | def build_model(args):
549 | if args.model.startswith("convnext"):
550 | model = create_model(
551 | args.model,
552 | pretrained=False,
553 | num_classes=args.nb_classes,
554 | drop_path_rate=args.drop_path,
555 | layer_scale_init_value=args.layer_scale_init_value,
556 | head_init_scale=args.head_init_scale,
557 | drop_rate=args.dropout,
558 | )
559 | else:
560 | model = create_model(
561 | args.model,
562 | pretrained=False,
563 | num_classes=args.nb_classes,
564 | drop_path_rate=args.drop_path,
565 | drop_rate =args.dropout,
566 | img_size=args.input_size
567 | )
568 | return model
--------------------------------------------------------------------------------