├── .gitignore ├── README.md ├── dataset_fruit_veg ├── split_dataset.py └── statistic_mean_std.py ├── train.py └── util ├── __pycache__ └── misc.cpython-37.pyc ├── crop.py ├── datasets.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── misc.py └── pos_embed.py /.gitignore: -------------------------------------------------------------------------------- 1 | debug 2 | .idea 3 | dataset_fruit_veg/raw 4 | dataset_fruit_veg/test 5 | dataset_fruit_veg/train 6 | output_dir_pretrained -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于 ResNet18 的果蔬分类 2 | 3 | >参考[ B 站教程](https://www.bilibili.com/video/BV1vZ4y1h7X9/?spm_id_from=333.788)实现 4 | 5 | ## Dataset 6 | - 数据集:https://aistudio.baidu.com/aistudio/datasetdetail/119023/0 7 | - 下载完毕之后解压到`dataset_fruit_veg`目录下,并将文件夹命名为`raw`。 8 | - 运行`split_dataset.py`。 9 | 10 | ## Train 11 | - 将`if __name__ == '__main__'`下的 mode 改成 train。 12 | - 运行`train.py` 13 | ```shell 14 | python train.py 15 | ``` 16 | - 如果要修改训练时的参数,参考`train.py`文件中的`get_args_parser`函数修改默认参数,或者是在上面的命令行中带上相关参数。例如: 17 | ```shell 18 | python train.py --batch_size=36 --epochs=30 19 | ``` 20 | - 训练完毕之后会将模型文件保存到`output_dir_pretrained`下。所以在测试时,将`get_args_parser`中的 resume 的 default 值修改为跑出来的模型文件,就可以用训练得到的模型进行测试。 21 | 22 | ## Test 23 | - 将`if __name__ == '__main__'`下的 mode 改成 infer。 24 | - 运行`train.py` 25 | ```shell 26 | python train.py 27 | ``` 28 | - 程序会遍历`dataset_fruit_veg/test`下的图片,每张图片都会输出准确率和预测的结果。 -------------------------------------------------------------------------------- /dataset_fruit_veg/split_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | """ 3 | @Created: 2023/3/22 11:39 4 | @Author: Red9th 5 | @File: split_dataset.py 6 | @Software: PyCharm 7 | """ 8 | import os 9 | import glob 10 | import random 11 | import shutil 12 | from PIL import Image 13 | 14 | if __name__ == '__main__': 15 | test_split_ratio = 0.05 16 | desired_size = 128 17 | raw_path = './raw' 18 | 19 | dirs = glob.glob(os.path.join(raw_path, '*')) 20 | dirs = [d for d in dirs if os.path.isdir(d)] 21 | 22 | # print(f'Total: {len(dirs)} classes: {dirs}') 23 | 24 | for path in dirs: 25 | path = path.split('\\')[-1] 26 | 27 | os.makedirs(f'train/{path}', exist_ok=True) 28 | os.makedirs(f'test/{path}', exist_ok=True) 29 | 30 | files = glob.glob(os.path.join(raw_path, path, '*.jpg')) 31 | files += glob.glob(os.path.join(raw_path, path, '*.png')) 32 | 33 | random.shuffle(files) 34 | 35 | boundary = int(len(files) * test_split_ratio) 36 | 37 | for i, file in enumerate(files): 38 | img = Image.open(file).convert('RGB') 39 | old_size = img.size 40 | ratio = float(desired_size) / max(old_size) 41 | new_size = (int(old_size[0] * ratio), int(old_size[1] * ratio)) 42 | im = img.resize(new_size, Image.ANTIALIAS) 43 | new_im = Image.new('RGB', (desired_size, desired_size)) 44 | new_im.paste(im, ((desired_size - new_size[0]) // 2, (desired_size - new_size[1]) // 2)) 45 | 46 | assert new_im.mode == 'RGB' 47 | 48 | if i <= boundary: 49 | new_im.save(os.path.join(f'test/{path}', file.split('\\')[-1].split('.')[0] + '.jpg')) 50 | else: 51 | new_im.save(os.path.join(f'train/{path}', file.split('\\')[-1].split('.')[0] + '.jpg')) 52 | 53 | # test_files = glob.glob(os.path.join('test', '*', '*.jpg')) 54 | # train_files = glob.glob(os.path.join('train', '*', '*.jpg')) 55 | 56 | # print(f'total {len(test_files)} files for testing') 57 | # print(f'total {len(train_files)} files for training') -------------------------------------------------------------------------------- /dataset_fruit_veg/statistic_mean_std.py: -------------------------------------------------------------------------------- 1 | # -*- coding = utf-8 -*- 2 | """ 3 | @Created: 2023/3/22 17:54 4 | @Author: Red9th 5 | @File: statistic_mean_std.py 6 | @Software: PyCharm 7 | """ 8 | # 统计数据库中所有图片的每个通道的均值和标准差(用于归一化) 9 | import os 10 | import glob 11 | import numpy as np 12 | from PIL import Image 13 | 14 | if __name__ == '__main__': 15 | train_files = glob.glob(os.path.join('train', '*', '*.jpg')) 16 | 17 | print(f'total {len(train_files)} files for training') 18 | 19 | result = [] 20 | for file in train_files: 21 | img = Image.open(file).convert('RGB') 22 | img = np.array(img).astype(np.uint8) 23 | img = img / 255. 24 | result.append(img) 25 | 26 | print(np.shape(result)) # [BS, H, W, C] 27 | mean = np.mean(result, axis=(0, 1, 2)) 28 | std = np.std(result, axis=(0, 1, 2)) 29 | print(mean) 30 | print(std) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import math 4 | import os 5 | import random 6 | import sys 7 | import time 8 | from pathlib import Path 9 | 10 | import torch.utils.data 11 | import torchvision.datasets 12 | 13 | import timm 14 | from timm.utils import accuracy 15 | # 记录 loss 曲线 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | from util import misc 19 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 20 | 21 | from collections.abc import Iterable 22 | 23 | from PIL import Image 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | @torch.no_grad() 28 | def evaluate(data_loader, model, device): 29 | criterion = torch.nn.CrossEntropyLoss() 30 | 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | header = 'Test:' 33 | 34 | model.eval() 35 | 36 | # 模型开始在 cpu 上,而数据在 cuda 上,需要将模型转到 cuda 上 37 | model = model.to(device) 38 | 39 | for batch in metric_logger.log_every(data_loader, 10, header): 40 | images = batch[0] 41 | target = batch[-1] 42 | images = images.to(device, non_blocking=True) 43 | target = target.to(device, non_blocking=True) 44 | 45 | output = model(images) 46 | loss = criterion(output, target) 47 | 48 | output = torch.nn.functional.softmax(output, dim=-1) 49 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 50 | 51 | # 更新 log 52 | batch_size = images.shape[0] 53 | metric_logger.update(loss=loss.item()) 54 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 55 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 56 | 57 | metric_logger.synchronize_between_processes() 58 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 59 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 60 | 61 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 62 | 63 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 64 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 65 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 66 | log_writer=None, 67 | args=None): 68 | model.train(True) 69 | print_freq = 2 70 | accum_iter = args.accum_iter # 每隔几步做一次梯度更新,默认是 1 71 | 72 | if log_writer is not None: 73 | print('log_dir: {}'.format(log_writer.log_dir)) 74 | 75 | for data_iter_step, (samples, targets) in enumerate(data_loader): 76 | # 移入到对应的设备上 77 | samples = samples.to(device, non_blocking=True) 78 | targets = targets.to(device, non_blocking=True) 79 | 80 | outputs = model(samples) 81 | 82 | warmup_lr = args.lr 83 | optimizer.param_groups[0]["lr"] = warmup_lr 84 | 85 | loss = criterion(outputs, targets) 86 | loss /= accum_iter 87 | 88 | # 梯度更新 89 | loss_scaler(loss, optimizer, clip_grad=max_norm, 90 | parameters=model.parameters(), create_graph=False, 91 | update_grad=(data_iter_step + 1) % accum_iter == 0) 92 | 93 | loss_value = loss.item() 94 | 95 | if (data_iter_step + 1) % accum_iter == 0: 96 | optimizer.zero_grad() 97 | 98 | if not math.isfinite(loss_value): 99 | print("Loss is {}, stopping training".format(loss_value)) 100 | sys.exit(1) 101 | 102 | # 写入 TensorBoard 中 103 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 104 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 105 | log_writer.add_scalar('loss', loss_value, epoch_1000x) 106 | log_writer.add_scalar('lr', warmup_lr, epoch_1000x) 107 | print(f"Epoch: {epoch}, Step: {data_iter_step}, Loss: {loss}, Lr: {warmup_lr}") 108 | 109 | def build_transform(is_train, args): 110 | # 训练 111 | if is_train: 112 | print("train transform") 113 | return torchvision.transforms.Compose([ 114 | # 调整大小到网络模型所接受的输入 115 | torchvision.transforms.Resize((args.input_size, args.input_size)), 116 | torchvision.transforms.RandomHorizontalFlip(), 117 | torchvision.transforms.RandomVerticalFlip(), 118 | torchvision.transforms.RandomPerspective(distortion_scale=0.6, p=1.0), 119 | torchvision.transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), 120 | # 图片转成 0~1 的浮点数 121 | torchvision.transforms.ToTensor() 122 | ]) 123 | 124 | # 验证 125 | print("eval transform") 126 | return torchvision.transforms.Compose([ 127 | torchvision.transforms.Resize((args.input_size, args.input_size)), 128 | torchvision.transforms.ToTensor() 129 | ]) 130 | 131 | def build_dataset(is_train, args): 132 | transform = build_transform(is_train, args) 133 | path = os.path.join(args.root_path, 'train' if is_train else 'test') 134 | # 适用每种类别都有一个文件夹的数据集 135 | dataset = torchvision.datasets.ImageFolder(path, transform=transform) 136 | info = dataset.find_classes(path) 137 | print(f"finding classes from {path}:\t{info[0]}") # 类别名 138 | print(f"mapping classes from {path} to indexes:\t{info[1]}") # 类别的索引 139 | 140 | return dataset 141 | 142 | def get_args_parser(): 143 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 144 | parser.add_argument('--batch_size', default=72, type=int, 145 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 146 | parser.add_argument('--epochs', default=400, type=int) 147 | parser.add_argument('--accum_iter', default=1, type=int, 148 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 149 | 150 | # Model parameters 151 | 152 | parser.add_argument('--input_size', default=128, type=int, 153 | help='images input size') 154 | 155 | # Optimizer parameters 156 | parser.add_argument('--weight_decay', type=float, default=0.0001, 157 | help='weight decay (default: 0.0001)') 158 | 159 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', 160 | help='learning rate (absolute lr)') 161 | 162 | # Dataset parameters 163 | parser.add_argument('--root_path', default='dataset_fruit_veg', 164 | help='path where to save, empty for no saving') 165 | parser.add_argument('--output_dir', default='./output_dir_pretrained', 166 | help='path where to save, empty for no saving') 167 | parser.add_argument('--log_dir', default='./output_dir_pretrained', 168 | help='path where to tensorboard log') 169 | 170 | # 是否需要加载已有模型(在 pretrained 中) 171 | parser.add_argument('--resume', default='output_dir_pretrained/checkpoint-20.pth', 172 | # parser.add_argument('--resume', default='', 173 | help='resume from checkpoint') 174 | 175 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 176 | help='start epoch') 177 | parser.add_argument('--num_workers', default=5, type=int) 178 | parser.add_argument('--pin_mem', action='store_true', 179 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 180 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 181 | parser.set_defaults(pin_mem=True) 182 | 183 | return parser 184 | 185 | def main(args, mode='train', test_image_path=''): 186 | print(f"{mode} mode...") 187 | if mode == 'train': 188 | dataset_train = build_dataset(is_train=True, args=args) 189 | dataset_val = build_dataset(is_train=False, args=args) 190 | 191 | # dataloader 以什么样的顺序取 dataset 中的数据,训练集需要打散,而验证集不需要 192 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 193 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 194 | 195 | data_loader_train = torch.utils.data.DataLoader( 196 | dataset_train, sampler=sampler_train, 197 | batch_size=args.batch_size, 198 | num_workers=args.num_workers, 199 | pin_memory=args.pin_mem, 200 | drop_last=True 201 | ) 202 | 203 | data_loader_val = torch.utils.data.DataLoader( 204 | dataset_val, sampler=sampler_val, 205 | batch_size=args.batch_size, 206 | num_workers=args.num_workers, 207 | pin_memory=args.pin_mem, 208 | drop_last=False 209 | ) 210 | 211 | # 构建模型 212 | model = timm.create_model('resnet18', pretrained=True, num_classes=36, drop_rate=0.1, drop_path_rate=0.1) 213 | 214 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 215 | print('number of trainable params (M): %.2f' % (n_parameters / 1.e6)) 216 | 217 | criterion = torch.nn.CrossEntropyLoss() 218 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 219 | 220 | os.makedirs(args.log_dir, exist_ok=True) 221 | 222 | log_writer = SummaryWriter(log_dir=args.log_dir) 223 | 224 | loss_scaler = NativeScaler() 225 | 226 | # 训练中断后可以读入已有模型 227 | misc.load_model(args=args, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler) 228 | 229 | for epoch in range(args.start_epoch, args.epochs): 230 | print(f"Epoch {epoch}") 231 | print(f"length of data_loader_train is {len(data_loader_train)}") 232 | 233 | if epoch % 1 == 0: 234 | print("Evaluating...") 235 | model.eval() 236 | test_stats = evaluate(data_loader_val, model, device) 237 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 238 | 239 | if log_writer is not None: 240 | # 写入到 tensorboard 中 241 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 242 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 243 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 244 | model.train() 245 | 246 | print("Training...") 247 | train_one_epoch( 248 | model, criterion, data_loader_train, 249 | optimizer, device, epoch+1, 250 | loss_scaler, None, 251 | log_writer=log_writer, 252 | args=args 253 | ) 254 | 255 | # 保存模型 256 | if args.output_dir: 257 | print("Saving checkpoint...") 258 | misc.save_model( 259 | args=args, model=model, model_without_ddp=model, optimizer=optimizer, 260 | loss_scaler=loss_scaler, epoch=epoch 261 | ) 262 | else: 263 | model = timm.create_model('resnet18', pretrained=False, num_classes=36, drop_rate=0.1, drop_path_rate=0.1) 264 | 265 | class_dict = { 266 | 'apple': 0, 'banana': 1, 'beetroot': 2, 'bell pepper': 3, 'cabbage': 4, 'capsicum': 5, 'carrot': 6, 267 | 'cauliflower': 7, 'chilli pepper': 8, 'corn': 9, 'cucumber': 10, 'eggplant': 11, 'garlic': 12, 268 | 'ginger': 13, 'grapes': 14, 'jalepeno': 15, 'kiwi': 16, 'lemon': 17, 'lettuce': 18, 'mango': 19, 269 | 'onion': 20, 'orange': 21, 'paprika': 22, 'pear': 23, 'peas': 24, 'pineapple': 25, 'pomegranate': 26, 270 | 'potato': 27, 'raddish': 28, 'soy beans': 29, 'spinach': 30, 'sweetcorn': 31, 'sweetpotato': 32, 271 | 'tomato': 33, 'turnip': 34, 'watermelon': 35 272 | } 273 | 274 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 275 | print('number of trainable params (M): %.2f' % (n_parameters / 1.e6)) 276 | 277 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 278 | os.makedirs(args.log_dir, exist_ok=True) 279 | loss_scaler = NativeScaler() 280 | 281 | # 训练中断后可以从读入已有模型 282 | misc.load_model(args=args, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler) 283 | 284 | model.eval() 285 | 286 | image = Image.open(test_image_path).convert('RGB') 287 | image = image.resize((args.input_size, args.input_size), Image.ANTIALIAS) 288 | image = torchvision.transforms.ToTensor()(image).unsqueeze(0) 289 | 290 | with torch.no_grad(): 291 | output = model(image) 292 | 293 | output = torch.nn.functional.softmax(output, dim=-1) 294 | # 找到最大值对应的索引 295 | class_idx = torch.argmax(output, dim=1)[0] 296 | 297 | score = torch.max(output, dim=1)[0][0] 298 | print(f"image path is {test_image_path}") 299 | print(f"score is {score.item()}, class id is {class_idx.item()}, " 300 | f"class name is {list(class_dict.keys())[list(class_dict.values()).index(class_idx)]}") 301 | 302 | # time.sleep(0.5) 303 | 304 | if __name__ == '__main__': 305 | args = get_args_parser() 306 | # 得到一个对象 307 | args = args.parse_args() 308 | 309 | if args.output_dir: 310 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 311 | 312 | mode = 'infer' # infer or train 313 | 314 | if mode == 'train': 315 | main(args, mode=mode) 316 | else: 317 | images = glob.glob('dataset_fruit_veg/test/*/*.jpg') 318 | 319 | for image in images: 320 | print('\n') 321 | main(args, mode=mode, test_image_path=image) -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Red9th/Maize-Classification-By-ResNet/772a545dae47e865dca2c65e29e515cdab5b6399/util/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | print("Resume checkpoint %s" % args.resume) 324 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if 'scaler' in checkpoint: 328 | loss_scaler.load_state_dict(checkpoint['scaler']) 329 | print("With optim & sched!") 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | --------------------------------------------------------------------------------