├── util ├── __pycache__ │ ├── misc.cpython-36.pyc │ ├── misc.cpython-37.pyc │ ├── misc.cpython-38.pyc │ ├── custom.cpython-36.pyc │ ├── custom.cpython-37.pyc │ ├── custom.cpython-38.pyc │ ├── datasets.cpython-36.pyc │ ├── datasets.cpython-38.pyc │ ├── lr_decay.cpython-36.pyc │ ├── lr_decay.cpython-38.pyc │ ├── lr_sched.cpython-36.pyc │ ├── lr_sched.cpython-37.pyc │ ├── lr_sched.cpython-38.pyc │ ├── pos_embed.cpython-36.pyc │ └── pos_embed.cpython-38.pyc ├── loss.py ├── lr_sched.py ├── lars.py ├── datasets.py ├── lr_decay.py ├── pos_embed.py └── misc.py ├── model ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── diffmae.cpython-36.pyc │ ├── diffmae.cpython-38.pyc │ ├── diffusion.cpython-36.pyc │ ├── diffusion.cpython-38.pyc │ ├── modules.cpython-36.pyc │ └── modules.cpython-38.pyc ├── diffusion.py ├── diffmae.py └── modules.py ├── README.md ├── dataset.py ├── main_pretrain.py ├── engine_pretrain.py └── option.py /util/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/custom.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/custom.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/custom.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/custom.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/custom.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/custom.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/diffmae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/model/__pycache__/diffmae.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/diffmae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/model/__pycache__/diffmae.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/diffusion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/model/__pycache__/diffusion.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/model/__pycache__/diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/model/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/model/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_decay.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/lr_decay.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_decay.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/lr_decay.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_sched.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/lr_sched.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_sched.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/lr_sched.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_sched.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/lr_sched.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/pos_embed.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/pos_embed.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/pos_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kny5265/diffmae-pytorch/HEAD/util/__pycache__/pos_embed.cpython-38.pyc -------------------------------------------------------------------------------- /util/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def calc_for_diffmae(args, model, samples, pred, ids_restore, ids_masked): # only apply on masked patches 4 | 5 | if args.multi_gpu: 6 | target = model.module.patchify(samples) 7 | else: 8 | target = model.patchify(samples) 9 | 10 | target = torch.gather(target, dim=1, index=ids_masked[:, :, None].expand(-1, -1, target.shape[2])) 11 | 12 | loss = torch.nn.functional.mse_loss(pred, target) 13 | 14 | return loss 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffMAE - Unofficial PyTorch Implementation 2 | 3 | This repository contains an **unofficial** PyTorch implementation of DiffMAE (Diffusion Masked Autoencoder) 4 | For more details, please refer to the original paper: [DiffMAE: Diffusion Masked Autoencoder](https://arxiv.org/abs/2304.03283). 5 | 6 | This code is based on [https://github.com/facebookresearch/mae](https://github.com/facebookresearch/mae). 7 | 8 | I have currently uploaded only the pre-training stage code and will update the code related to the fine-tuning stage soon. 9 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torchvision.transforms as transforms 3 | from torch.utils.data import DataLoader, random_split 4 | 5 | def create_dataset(args): 6 | if args.dataset == 'cifar10': 7 | train_transform = transforms.Compose([ 8 | transforms.Resize((args.img_size, args.img_size)), 9 | transforms.ToTensor(), 10 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 11 | test_transform = transforms.Compose([ 12 | transforms.Resize((args.img_size, args.img_size)), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 15 | 16 | train_set = torchvision.datasets.CIFAR10(root=args.data_path, train=True, 17 | download=True, transform=train_transform) 18 | test_set = torchvision.datasets.CIFAR10(root=args.data_path, train=False, 19 | download=True, transform=test_transform) 20 | 21 | elif args.dataset == 'mnist': 22 | transform = transforms.Compose([ 23 | transforms.Resize((args.img_size, args.img_size)), 24 | transforms.ToTensor()]) 25 | 26 | train_set = torchvision.datasets.MNIST(root=args.data_path, train=True, 27 | download=True, transform=transform) 28 | test_set = torchvision.datasets.MNIST(root=args.data_path, train=False, 29 | download=True, transform=transform) 30 | 31 | valid_len = int(len(train_set)*0.1) 32 | train_set, valid_set = random_split(train_set, [len(train_set)-valid_len, valid_len]) 33 | 34 | dataloader = {} 35 | dataloader['train'] = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, pin_memory=args.pin_mem) 36 | dataloader['valid'] = DataLoader(valid_set, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers, pin_memory=args.pin_mem) 37 | dataloader['test'] = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers, pin_memory=args.pin_mem) 38 | 39 | return dataloader -------------------------------------------------------------------------------- /model/diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based from 3 | https://github.com/taki0112/diffusion-pytorch 4 | """ 5 | 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | from matplotlib import pyplot as plt 10 | from tqdm import tqdm 11 | from torch import optim 12 | from model.modules import linear_beta_schedule, cosine_beta_schedule 13 | # from modules import linear_beta_schedule, cosine_beta_schedule 14 | import logging 15 | import torch.nn.functional as F 16 | 17 | class Diffusion: 18 | def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, objective='ddpm', schedule='linear', device="cuda"): 19 | self.noise_steps = noise_steps 20 | self.beta_start = beta_start 21 | self.beta_end = beta_end 22 | self.img_size = img_size 23 | self.device = device 24 | 25 | self.objective = objective 26 | 27 | self.beta = self.prepare_noise_schedule(schedule, beta_start, beta_end) 28 | 29 | self.alpha = 1. - self.beta 30 | self.alpha_hat = torch.cumprod(self.alpha, dim=0) 31 | 32 | def prepare_noise_schedule(self, schedule, beta_start, beta_end): 33 | if schedule == 'linear': 34 | return linear_beta_schedule(self.noise_steps, beta_start, beta_end) 35 | else: 36 | return cosine_beta_schedule(self.noise_steps) 37 | 38 | def noise_samples(self, x, t): 39 | sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None].to(x.device) 40 | sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None].to(x.device) 41 | z = torch.randn_like(x) 42 | return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * z, z 43 | 44 | def sample_timesteps(self, n): 45 | t = torch.randint(low=1, high=self.noise_steps, size=(n,)) 46 | return t 47 | 48 | def tensor_to_image(self, x): 49 | x = (x.clamp(-1, 1) + 1) / 2 50 | # x = (x * 255).type(torch.uint8) 51 | return x 52 | 53 | def sample(self, predicted_token): 54 | # reverse process 55 | with torch.no_grad(): 56 | for i in tqdm(reversed(range(1, self.noise_steps))): 57 | t = (torch.ones(predicted_token.size()[0], dtype=torch.long) * i).to(self.device) 58 | 59 | alpha = self.alpha[t][:, None, None, None].to(predicted_token.device) 60 | beta = self.beta[t][:, None, None, None].to(predicted_token.device) 61 | alpha_hat = self.alpha_hat[t][:, None, None, None].to(predicted_token.device) 62 | alpha_hat_prev = self.alpha_hat[t-1][:, None, None, None].to(predicted_token.device) 63 | beta_tilde = beta * (1 - alpha_hat_prev) / (1 - alpha_hat) # similar to beta 64 | 65 | noise = torch.randn_like(predicted_token) 66 | 67 | predict_x0 = 0 68 | direction_point = 1 / torch.sqrt(alpha) * predicted_token 69 | random_noise = torch.sqrt(beta_tilde) * noise 70 | 71 | x = predict_x0 + direction_point + random_noise 72 | 73 | return torch.clamp(x, -1.0, 1.0) 74 | 75 | 76 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # main implementation code for pre-training 2 | 3 | import os 4 | import re 5 | import glob 6 | import time 7 | import json 8 | import datetime 9 | import torch 10 | import torch.nn as nn 11 | 12 | import timm 13 | 14 | assert timm.__version__ == "0.3.2" # version check 15 | import timm.optim.optim_factory as optim_factory 16 | 17 | import util.misc as misc 18 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 19 | 20 | import option 21 | from dataset import create_dataset 22 | from model import diffmae, diffusion 23 | from engine_pretrain import train_one_epoch, evaluate 24 | 25 | def main(args): 26 | 27 | if args.mode != 'pretrain': 28 | print('Pre-training phase: args.mode has to be "pretrain"') 29 | exit(0) 30 | 31 | dataloader = create_dataset(args) 32 | 33 | diff = diffusion.Diffusion(schedule='cosine') 34 | model = diffmae.DiffMAE(args, diff) 35 | 36 | if args.cuda: 37 | args.device = "cuda:{}".format(args.gpu_ids[0]) 38 | if args.multi_gpu: 39 | model = nn.DataParallel(model, output_device=args.gpu_ids[0], device_ids=args.gpu_ids) 40 | else: 41 | args.device = torch.device("cpu") 42 | 43 | model = model.to(args.device) 44 | eff_batch_size = args.batch_size * args.accum_iter 45 | 46 | if args.lr is None: # only base_lr is specified 47 | args.lr = args.blr * eff_batch_size / 256 48 | 49 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 50 | print("actual lr: %.2e" % args.lr) 51 | print("effective batch size: %d" % eff_batch_size) 52 | 53 | # following timm: set wd as 0 for bias and norm layers 54 | param_groups = optim_factory.add_weight_decay(model, args.weight_decay) 55 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 56 | 57 | loss_scaler = NativeScaler() 58 | 59 | if args.resume: 60 | misc.load_model(args=args, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler) 61 | 62 | print(f"Start training for {args.epochs} epochs") 63 | start_time = time.time() 64 | log_writer = None 65 | for epoch in range(args.start_epoch, args.epochs): 66 | train_stats = train_one_epoch( 67 | model, dataloader['train'], 68 | optimizer, epoch, loss_scaler, 69 | log_writer=log_writer, 70 | args=args, iter=epoch 71 | ) 72 | if args.output_dir and (epoch % args.save_freq == 0 or epoch + 1 == args.epochs): 73 | misc.save_model( 74 | args=args, model=model, optimizer=optimizer, 75 | loss_scaler=loss_scaler, epoch=epoch) 76 | 77 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 78 | 'epoch': epoch,} 79 | 80 | if args.output_dir and misc.is_main_process(): 81 | if log_writer is not None: 82 | log_writer.flush() 83 | with open(os.path.join(args.savedir, "log.txt"), mode="a", encoding="utf-8") as f: 84 | f.write(json.dumps(log_stats) + "\n") 85 | 86 | total_time = time.time() - start_time 87 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 88 | print('Training time {}'.format(total_time_str)) 89 | 90 | pattern = re.compile(r'\d+') 91 | file_list = sorted(glob.glob('{}/*.pth'.format(args.savedir)), key=lambda x:int(pattern.findall(x)[-1])) 92 | args.resume = file_list[-1] 93 | misc.load_model(args=args, model_without_ddp=model, optimizer=optimizer, loss_scaler=loss_scaler) 94 | 95 | if args.eval: 96 | test_stats = evaluate(model, dataloader['test'], epoch='', args=args) 97 | 98 | log_stats = {**{f'test_{k}': v for k, v in test_stats.items()}} 99 | 100 | if args.output_dir and misc.is_main_process(): 101 | if log_writer is not None: 102 | log_writer.flush() 103 | with open(os.path.join(args.savedir, "log.txt"), mode="a", encoding="utf-8") as f: 104 | f.write(json.dumps(log_stats) + "\n") 105 | 106 | if __name__ == '__main__': 107 | args = option.Options().gather_options() 108 | main(args) 109 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 15 | """ 16 | Parameter groups for layer-wise lr decay 17 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 18 | """ 19 | param_group_names = {} 20 | param_groups = {} 21 | 22 | num_layers = 4 + 1 # n_blocks + 1 23 | 24 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 25 | 26 | for n, p in model.named_parameters(): 27 | if not p.requires_grad: 28 | continue 29 | 30 | # no decay: all 1D parameters and model specific ones 31 | if p.ndim == 1 or n in no_weight_decay_list: 32 | g_decay = "no_decay" 33 | this_decay = 0. 34 | else: 35 | g_decay = "decay" 36 | this_decay = weight_decay 37 | 38 | layer_id = get_layer_id_for_vit(n, num_layers) 39 | group_name = "layer_%d_%s" % (layer_id, g_decay) 40 | 41 | if group_name not in param_group_names: 42 | this_scale = layer_scales[layer_id] 43 | 44 | param_group_names[group_name] = { 45 | "lr_scale": this_scale, 46 | "weight_decay": this_decay, 47 | "params": [], 48 | } 49 | param_groups[group_name] = { 50 | "lr_scale": this_scale, 51 | "weight_decay": this_decay, 52 | "params": [], 53 | } 54 | 55 | param_group_names[group_name]["params"].append(n) 56 | param_groups[group_name]["params"].append(p) 57 | 58 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 59 | 60 | return list(param_groups.values()) 61 | 62 | 63 | def get_layer_id_for_vit(name, num_layers): 64 | """ 65 | Assign a parameter with its layer id 66 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 67 | """ 68 | if name in ['cls_token', 'pos_embed']: 69 | return 0 70 | elif name.startswith('patch_embed'): 71 | return 0 72 | elif name.startswith('enc'): 73 | return int(name.split('.')[2]) + 1 74 | else: 75 | return num_layers 76 | 77 | 78 | # def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 79 | # """ 80 | # Parameter groups for layer-wise lr decay 81 | # Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 82 | # """ 83 | # param_group_names = {} 84 | # param_groups = {} 85 | 86 | # num_layers = len(model.blocks) + 1 87 | 88 | # layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 89 | 90 | # for n, p in model.named_parameters(): 91 | # if not p.requires_grad: 92 | # continue 93 | 94 | # # no decay: all 1D parameters and model specific ones 95 | # if p.ndim == 1 or n in no_weight_decay_list: 96 | # g_decay = "no_decay" 97 | # this_decay = 0. 98 | # else: 99 | # g_decay = "decay" 100 | # this_decay = weight_decay 101 | 102 | # layer_id = get_layer_id_for_vit(n, num_layers) 103 | # group_name = "layer_%d_%s" % (layer_id, g_decay) 104 | 105 | # if group_name not in param_group_names: 106 | # this_scale = layer_scales[layer_id] 107 | 108 | # param_group_names[group_name] = { 109 | # "lr_scale": this_scale, 110 | # "weight_decay": this_decay, 111 | # "params": [], 112 | # } 113 | # param_groups[group_name] = { 114 | # "lr_scale": this_scale, 115 | # "weight_decay": this_decay, 116 | # "params": [], 117 | # } 118 | 119 | # param_group_names[group_name]["params"].append(n) 120 | # param_groups[group_name]["params"].append(p) 121 | 122 | # # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 123 | 124 | # return list(param_groups.values()) 125 | 126 | 127 | # def get_layer_id_for_vit(name, num_layers): 128 | # """ 129 | # Assign a parameter with its layer id 130 | # Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 131 | # """ 132 | # if name in ['cls_token', 'pos_embed']: 133 | # return 0 134 | # elif name.startswith('patch_embed'): 135 | # return 0 136 | # elif name.startswith('blocks'): 137 | # return int(name.split('.')[1]) + 1 138 | # else: 139 | # return num_layers -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | # def interpolate_pos_embed(model, checkpoint_model): 76 | # if 'pos_embed' in checkpoint_model: 77 | # pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | # embedding_size = pos_embed_checkpoint.shape[-1] 79 | # num_patches = model.patch_embed.num_patches 80 | # num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # # height (== width) for the checkpoint position embedding 82 | # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # # height (== width) for the new position embedding 84 | # new_size = int(num_patches ** 0.5) 85 | # # class_token and dist_token are kept unchanged 86 | # if orig_size != new_size: 87 | # print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # # only the position tokens are interpolated 90 | # pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | # pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | # pos_tokens = torch.nn.functional.interpolate( 93 | # pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | # new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | # checkpoint_model['pos_embed'] = new_pos_embed 97 | 98 | """ my method """ 99 | def interpolate_pos_embed(model, checkpoint_model): 100 | if 'pos_embed' in checkpoint_model: 101 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 102 | embedding_size = pos_embed_checkpoint.shape[-1] 103 | num_patches = model.num_patches 104 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 105 | # height (== width) for the checkpoint position embedding 106 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 107 | # height (== width) for the new position embedding 108 | new_size = int(num_patches ** 0.5) 109 | # class_token and dist_token are kept unchanged 110 | if orig_size != new_size: 111 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 112 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 113 | # only the position tokens are interpolated 114 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 115 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 116 | pos_tokens = torch.nn.functional.interpolate( 117 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 118 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 119 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 120 | checkpoint_model['pos_embed'] = new_pos_embed 121 | -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import os 12 | import math 13 | import sys 14 | import imageio 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | from typing import Iterable 18 | from PIL import Image 19 | 20 | import torch 21 | from torchvision.transforms.functional import to_pil_image 22 | 23 | import util.misc as misc 24 | import util.lr_sched as lr_sched 25 | from util.loss import calc_for_diffmae 26 | 27 | def concat_images_horizontally(image_list): 28 | widths, heights = zip(*(img.size for img in image_list)) 29 | 30 | total_width = sum(widths) 31 | max_height = max(heights) 32 | 33 | new_img = Image.new('RGB', (total_width, max_height)) 34 | 35 | x_offset = 0 36 | for img in image_list: 37 | new_img.paste(img, (x_offset, 0)) 38 | x_offset += img.width 39 | 40 | return new_img 41 | 42 | def denormalize(tensor): 43 | tensor = tensor * 0.5 + 0.5 44 | return tensor 45 | 46 | def train_one_epoch(model: torch.nn.Module, 47 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 48 | epoch: int, loss_scaler, 49 | log_writer=None, 50 | args=None, iter=0): 51 | 52 | model.train(True) 53 | metric_logger = misc.MetricLogger(delimiter=" ") 54 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 55 | header = 'Epoch: [{}]'.format(epoch) 56 | print_freq = 100 57 | 58 | accum_iter = args.accum_iter 59 | 60 | optimizer.zero_grad() 61 | 62 | if log_writer is not None: 63 | print('log_dir: {}'.format(log_writer.log_dir)) 64 | 65 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 66 | # we use a per iteration (instead of per epoch) lr scheduler 67 | if data_iter_step % accum_iter == 0: 68 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 69 | 70 | samples = samples.to(args.device, non_blocking=True) 71 | 72 | with torch.cuda.amp.autocast(): 73 | pred, ids_restore, mask, ids_masked, _ = model(samples) 74 | loss = calc_for_diffmae(args, model, samples, pred, ids_restore, ids_masked) 75 | 76 | loss_value = loss.item() 77 | 78 | if not math.isfinite(loss_value): 79 | print("Loss is {}, stopping training".format(loss_value)) 80 | sys.exit(1) 81 | 82 | loss /= accum_iter 83 | loss_scaler(loss, optimizer, parameters=model.parameters(), 84 | update_grad=(data_iter_step + 1) % accum_iter == 0) 85 | 86 | if (data_iter_step + 1) % accum_iter == 0: 87 | optimizer.zero_grad() 88 | 89 | torch.cuda.synchronize() 90 | 91 | metric_logger.update(loss=loss_value) 92 | 93 | lr = optimizer.param_groups[0]["lr"] 94 | metric_logger.update(lr=lr) 95 | 96 | loss_value_reduce = misc.all_reduce_mean(loss_value) 97 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 98 | """ We use epoch_1000x as the x-axis in tensorboard. 99 | This calibrates different curves when batch size changes. 100 | """ 101 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 102 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 103 | log_writer.add_scalar('lr', lr, epoch_1000x) 104 | 105 | 106 | # gather the stats from all processes 107 | metric_logger.synchronize_between_processes() 108 | print("Averaged stats:", metric_logger) 109 | 110 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 111 | 112 | def evaluate(model: torch.nn.Module, 113 | data_loader: Iterable, 114 | epoch='', 115 | log_writer=None, 116 | args=None): 117 | 118 | model.eval() 119 | 120 | metric_logger = misc.MetricLogger(delimiter=" ") 121 | header = 'Test: ' 122 | print_freq = 20 123 | 124 | accum_iter = args.accum_iter 125 | 126 | if log_writer is not None: 127 | print('log_dir: {}'.format(log_writer.log_dir)) 128 | 129 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 130 | 131 | samples = samples.to(args.device, non_blocking=True) 132 | 133 | with torch.cuda.amp.autocast(): 134 | pred, ids_restore, mask, ids_masked, ids_keep = model(samples) 135 | loss = calc_for_diffmae(args, model, samples, pred, ids_restore, ids_masked) 136 | 137 | loss_value = loss.item() 138 | 139 | loss /= accum_iter 140 | 141 | metric_logger.update(loss=loss_value) 142 | 143 | if args.sampling and data_iter_step % 100 == 0: 144 | if args.multi_gpu: 145 | model_ = model.module 146 | else: 147 | model_ = model 148 | 149 | for n in range(pred.size()[0]): 150 | if n % 100 == 0: 151 | sampled_token = model_.diffusion.sample(pred[n].unsqueeze(0)) 152 | sampled_token = sampled_token.squeeze() 153 | visible_tokens = model_.patchify(samples) 154 | 155 | visible_tokens = torch.gather(visible_tokens, dim=1, index=ids_keep[:, :, None].expand(-1, -1, visible_tokens.shape[2])) 156 | img = torch.cat([visible_tokens[n], sampled_token], dim=0) 157 | img = torch.gather(img, dim=0, index=ids_restore[n].unsqueeze(-1).repeat(1, img.shape[1])) # to unshuffle 158 | img = model_.unpatchify(img.unsqueeze(0)) 159 | 160 | img = denormalize(img) 161 | samples = denormalize(samples) 162 | 163 | img_array = img.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() # (n_channels, height, width) -> (height, width, n_channels) 164 | org_array = samples[n].squeeze(0).permute(1, 2, 0).detach().cpu().numpy() 165 | # org_img = to_pil_image(samples[n] * 255) 166 | org_img = Image.fromarray((org_array * 255).astype(np.uint8)) 167 | img = Image.fromarray((img_array * 255).astype(np.uint8)) 168 | 169 | images = [org_img, img] 170 | concatenated_image_horizontal = concat_images_horizontally(images) 171 | concatenated_image_horizontal.save('{}/sample/output_{}_{}.png'.format(args.savedir, data_iter_step, n)) 172 | 173 | # gather the stats from all processes 174 | metric_logger.synchronize_between_processes() 175 | print("Averaged stats:", metric_logger) 176 | 177 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 178 | -------------------------------------------------------------------------------- /model/diffmae.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import einops 3 | import torch 4 | import torch.nn as nn 5 | 6 | from model.modules import PatchEmbed, EncoderBlock, DecoderBlock, get_2d_sincos_pos_embed 7 | 8 | class DiffMAE(nn.Module): 9 | def __init__(self, args, diffusion, norm_layer=partial(nn.LayerNorm, eps=1e-6)): 10 | super().__init__() 11 | self.patch_size = args.patch_size 12 | self.mask_ratio = args.mask_ratio 13 | self.patch_embed = PatchEmbed(args.img_size, args.patch_size, 14 | args.n_channels, args.emb_dim) 15 | 16 | self.num_patches = int(args.img_size // args.patch_size) ** 2 17 | self.cls_token = nn.Parameter(torch.zeros(1, 1, args.emb_dim)) 18 | self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, args.emb_dim) * .02) 19 | self.norm = norm_layer(args.emb_dim) 20 | 21 | self.blocks = nn.ModuleList([ 22 | EncoderBlock(args.emb_dim, args.num_heads) for i in range(args.depth)]) 23 | 24 | self.decoder_embed = nn.Linear(args.emb_dim, args.dec_emb_dim, bias=True) 25 | # num_masked_patches = self.num_patches - int(self.num_patches * (1 - args.mask_ratio)) 26 | # self.decoder_pos_embed = nn.Parameter(torch.randn(1, num_masked_patches + 1, args.dec_emb_dim) * .02) 27 | self.decoder_pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, args.dec_emb_dim) * .02) 28 | self.decoder_norm = norm_layer(args.dec_emb_dim) 29 | self.decoder_pred = nn.Linear(args.dec_emb_dim, args.patch_size **2 * args.n_channels, bias=True) 30 | 31 | self.dec_blocks = nn.ModuleList([ 32 | DecoderBlock(args.emb_dim, args.dec_emb_dim, args.num_heads) for i in range(args.depth)]) 33 | 34 | self.diffusion = diffusion 35 | 36 | self.initialize_weights() 37 | 38 | def initialize_weights(self): 39 | # initialization 40 | # initialize (and freeze) pos_embed by sin-cos embedding 41 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 42 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 43 | 44 | # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(int(self.patch_embed.num_patches*self.mask_ratio)**.5), cls_token=False) 45 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 46 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 47 | 48 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 49 | w = self.patch_embed.proj.weight.data 50 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 51 | 52 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 53 | torch.nn.init.normal_(self.cls_token, std=.02) 54 | # torch.nn.init.normal_(self.mask_token, std=.02) 55 | 56 | # initialize nn.Linear and nn.LayerNorm 57 | self.apply(self._init_weights) 58 | 59 | def _init_weights(self, m): 60 | if isinstance(m, nn.Linear): 61 | # we use xavier_uniform following official JAX ViT: 62 | torch.nn.init.xavier_uniform_(m.weight) 63 | if isinstance(m, nn.Linear) and m.bias is not None: 64 | nn.init.constant_(m.bias, 0) 65 | elif isinstance(m, nn.LayerNorm): 66 | nn.init.constant_(m.bias, 0) 67 | nn.init.constant_(m.weight, 1.0) 68 | 69 | def patchify(self, imgs): 70 | """ 71 | imgs: (N, 3, H, W) 72 | x: (N, L, patch_size**2 *3) 73 | """ 74 | p = self.patch_embed.patch_size[0] 75 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 76 | 77 | h = w = imgs.shape[2] // p 78 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 79 | x = torch.einsum('nchpwq->nhwpqc', x) 80 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 81 | return x 82 | 83 | def unpatchify(self, x): 84 | """ 85 | x: (N, L, patch_size**2 *3) 86 | imgs: (N, 3, H, W) 87 | """ 88 | p = self.patch_embed.patch_size[0] 89 | h = w = int(x.shape[1]**.5) 90 | assert h * w == x.shape[1] 91 | 92 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 93 | x = torch.einsum('nhwpqc->nchpwq', x) 94 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 95 | return imgs 96 | 97 | def random_masking(self, x, mask_ratio): 98 | """ 99 | Perform per-sample random masking by per-sample shuffling. 100 | Per-sample shuffling is done by argsort random noise. 101 | x: [N, L, D], sequence 102 | """ 103 | N, L, D = x.shape # batch, length, dim 104 | len_keep = int(L * (1 - mask_ratio)) 105 | 106 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 107 | 108 | # sort noise for each sample 109 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 110 | ids_restore = torch.argsort(ids_shuffle, dim=1) 111 | 112 | # keep the first subset 113 | ids_keep = ids_shuffle[:, :len_keep] 114 | ids_masked = ids_shuffle[:, len_keep:] 115 | 116 | visible_tokens = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 117 | masked_tokens = torch.gather(x, dim=1, index=ids_masked.unsqueeze(-1).repeat(1, 1, D)) 118 | 119 | # generate the binary mask: 0 is keep, 1 is remove 120 | mask = torch.ones([N, L], device=x.device) 121 | mask[:, :len_keep] = 0 122 | # unshuffle to get the binary mask 123 | mask = torch.gather(mask, dim=1, index=ids_restore) 124 | 125 | return visible_tokens, masked_tokens, mask, ids_restore, ids_masked, ids_keep 126 | 127 | def forward(self, x): 128 | t = self.diffusion.sample_timesteps(x.shape[0]) 129 | 130 | x = self.patch_embed(x) 131 | x += self.pos_embed[:, 1:, :] 132 | 133 | x, mask_token, mask, ids_restore, ids_masked, ids_keep = self.random_masking(x, self.mask_ratio) 134 | 135 | mask_token, noise = self.diffusion.noise_samples(mask_token, t) 136 | 137 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 138 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 139 | x = torch.cat((cls_tokens, x), dim=1) 140 | 141 | outputs = [] 142 | for block in self.blocks: 143 | x = block(x) 144 | outputs.append(x) 145 | 146 | outputs[-1] = self.norm(outputs[-1]) 147 | 148 | mask_token = self.decoder_embed(mask_token) 149 | # mask_token += self.decoder_pos_embed[:, 1:, :] 150 | decoder_pos_embed = nn.Parameter( 151 | torch.gather(self.decoder_pos_embed[:, 1:, :].repeat(mask_token.shape[0], 1, 1), dim=1, 152 | index=ids_masked.unsqueeze(-1).repeat(1, 1, mask_token.shape[-1]))) 153 | mask_token += decoder_pos_embed 154 | for dec_block, enc_output in zip(self.dec_blocks, reversed(outputs)): 155 | mask_token = dec_block(mask_token, enc_output) 156 | x8 = self.decoder_norm(mask_token) 157 | x8 = self.decoder_pred(x8) 158 | 159 | return x8, ids_restore, mask, ids_masked, ids_keep 160 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based from 3 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 4 | """ 5 | 6 | import os 7 | import json 8 | import random 9 | import datetime 10 | import argparse 11 | import numpy as np 12 | 13 | import torch 14 | 15 | 16 | class Options(): 17 | def __init__(self): 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | parser.add_argument('--batch_size', default=64, type=int, 22 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 23 | parser.add_argument('--epochs', default=100, type=int) 24 | parser.add_argument('--save_freq', default=10, type=int) 25 | 26 | # Model parameters 27 | parser.add_argument('--model', default='diffmae', type=str, metavar='MODEL', 28 | help='Name of model to train') 29 | parser.add_argument('--depth', default=8, type=int) 30 | parser.add_argument('--num_heads', default=8, type=int) 31 | parser.add_argument('--img_size', default=224, type=int, 32 | help='input image size') 33 | parser.add_argument('--mask_ratio', default=0.75, type=float, 34 | help='Masking ratio (percentage of removed patches).') 35 | parser.add_argument('--norm_pix_loss', action='store_true', 36 | help='Use (per-patch) normalized points as targets for computing loss') 37 | parser.set_defaults(norm_pix_loss=False) 38 | 39 | # Optimizer parameters 40 | parser.add_argument('--weight_decay', type=float, default=0.05, 41 | help='weight decay (default: 0.05)') 42 | parser.add_argument('--layer_decay', type=float, default=0.75, 43 | help='layer-wise lr decay from ELECTRA/BEiT') 44 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 45 | help='learning rate (absolute lr)') 46 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 47 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 48 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 49 | help='lower lr bound for cyclic schedulers that hit 0') 50 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 51 | help='epochs to warmup LR') 52 | parser.add_argument('--clip_grad', type=float, default=0., 53 | help='clip grad') 54 | 55 | # Dataset parameters 56 | parser.add_argument('--patch_size', default=8, type=int) 57 | parser.add_argument('--data_path', default='../../data/', type=str, 58 | help='dataset path') 59 | parser.add_argument('--dataset', type=str, required=True, 60 | help='name of dataset') 61 | parser.add_argument('--output_dir', default='./output_dir', 62 | help='path where to save, empty for no saving') 63 | parser.add_argument('--log_dir', default='./output_dir', 64 | help='path where to tensorboard log') 65 | 66 | parser.add_argument('--manual_seed', default=0, type=int) 67 | parser.add_argument('--finetune', default='', 68 | help='model path for finetuning') 69 | parser.add_argument('--resume', default='', 70 | help='resume from checkpoint') 71 | parser.add_argument('--eval', action='store_true', 72 | help='Perform evaluation only') 73 | 74 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 75 | help='start epoch') 76 | parser.add_argument('--num_workers', default=0, type=int) 77 | parser.add_argument('--pin_mem', action='store_true', 78 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 79 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 80 | parser.set_defaults(pin_mem=True) 81 | 82 | parser.add_argument('--n_channels', type=int, default=3, 83 | help='number of features, default=3') 84 | parser.add_argument('--n_classes', type=int, default=10, 85 | help='number of classes, default=10') 86 | 87 | # Training parameters 88 | parser.add_argument('--accum_iter', default=1, type=int, 89 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 90 | parser.add_argument('--mode', type=str, choices=['pretrain', 'finetune']) 91 | parser.add_argument('--emb_dim', type=int, default=1024, 92 | help='feature dimension for embedding') 93 | parser.add_argument('--dec_emb_dim', type=int, default=512, 94 | help='feature dimension for decoder embedding') 95 | 96 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 97 | parser.add_argument('--multi_gpu', action='store_true', help='enables multi gpu') 98 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[], 99 | help='gpu ids: e.g. 0 0,1,2, 0,2. use [] for CPU') 100 | 101 | parser.add_argument('--dist_on_itp', action='store_true') 102 | 103 | parser.add_argument('--sampling', action='store_true') 104 | self.initialized = True 105 | return parser 106 | 107 | def save_options(self, args): 108 | note = input("Anything to note: ") 109 | 110 | os.makedirs(args.savedir, exist_ok=True) 111 | os.makedirs('{}/sample'.format(args.savedir, exist_ok=True)) 112 | config_file = args.savedir + "/config.txt" 113 | with open(config_file, 'w') as f: 114 | json.dump(args.__dict__, f, indent=2) 115 | f.write("\nnote: {}\n".format( 116 | note 117 | )) 118 | 119 | def setup(self, args): 120 | try: 121 | os.makedirs(args.output_dir) 122 | except OSError: 123 | pass 124 | 125 | if args.manual_seed is None: 126 | args.manual_seed = random.randint(1, 10000) 127 | 128 | print("Random Seed: ", args.manual_seed) 129 | random.seed(args.manual_seed) 130 | np.random.seed(args.manual_seed) 131 | torch.manual_seed(args.manual_seed) 132 | 133 | torch.backends.cudnn.benchmark = True 134 | 135 | if args.cuda: 136 | torch.cuda.manual_seed_all(args.manual_seed) 137 | 138 | dt = datetime.datetime.now() 139 | date = dt.strftime("%Y%m%d-%H%M") 140 | 141 | model_opt = args.dataset + "-" + date + "-" + args.model 142 | 143 | args.savedir = os.path.join(args.output_dir, model_opt) 144 | os.makedirs(args.savedir, exist_ok=True) 145 | 146 | args.log_file = os.path.join(args.savedir, 'log.csv') 147 | 148 | def set_device(self, args): 149 | n_gpu = torch.cuda.device_count() 150 | if args.multi_gpu and len(args.gpu_ids) == 0 and torch.cuda.is_available(): 151 | args.gpu_ids = list(range(n_gpu)) 152 | elif args.gpu_ids and torch.cuda.is_available(): 153 | gpu_ids = args.gpu_ids 154 | args.gpu_ids = [] 155 | for id in gpu_ids: 156 | if id >= 0 and id < n_gpu: 157 | args.gpu_ids.append(id) 158 | args.gpu_ids = sorted(args.gpu_ids) 159 | if len(args.gpu_ids) > 1: 160 | args.multi_gpu = True 161 | else: 162 | args.multi_gpu = False 163 | else: 164 | args.gpu_ids = [] 165 | 166 | if args.cuda: 167 | args.device = "cuda:{}".format(args.gpu_ids[0]) 168 | else: 169 | args.device = "cpu" 170 | 171 | def gather_options(self): 172 | if not self.initialized: 173 | parser = argparse.ArgumentParser() 174 | parser = self.initialize(parser) 175 | 176 | self.parser = parser 177 | args = parser.parse_args() 178 | self.setup(args) 179 | self.set_device(args) 180 | print(args) 181 | self.save_options(args) 182 | 183 | return args 184 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | import collections 7 | from itertools import repeat 8 | from collections import OrderedDict 9 | 10 | try: 11 | from torch import _assert 12 | except ImportError: 13 | def _assert(condition: bool, message: str): 14 | assert condition, message 15 | 16 | def _ntuple(n): 17 | def parse(x): 18 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 19 | return tuple(x) 20 | return tuple(repeat(x, n)) 21 | return parse 22 | 23 | to_2tuple = _ntuple(2) 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 2D Image to Patch Embedding 27 | """ 28 | def __init__( 29 | self, 30 | img_size=224, 31 | patch_size=16, 32 | in_chans=3, 33 | embed_dim=768, 34 | norm_layer=None, 35 | flatten=True, 36 | bias=True, 37 | ): 38 | super().__init__() 39 | img_size = to_2tuple(img_size) 40 | patch_size = to_2tuple(patch_size) 41 | self.img_size = img_size 42 | self.patch_size = patch_size 43 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 44 | self.num_patches = self.grid_size[0] * self.grid_size[1] 45 | self.flatten = flatten 46 | 47 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 48 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 49 | 50 | def forward(self, x): 51 | B, C, H, W = x.shape 52 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 53 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 54 | x = self.proj(x) 55 | if self.flatten: 56 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 57 | x = self.norm(x) 58 | return x 59 | 60 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 61 | """ 62 | grid_size: int of the grid height and width 63 | return: 64 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 65 | """ 66 | grid_h = np.arange(grid_size, dtype=np.float32) 67 | grid_w = np.arange(grid_size, dtype=np.float32) 68 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 69 | grid = np.stack(grid, axis=0) 70 | 71 | grid = grid.reshape([2, 1, grid_size, grid_size]) 72 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 73 | if cls_token: 74 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 75 | return pos_embed 76 | 77 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 78 | assert embed_dim % 2 == 0 79 | 80 | # use half of dimensions to encode grid_h 81 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 82 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 83 | 84 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 85 | return emb 86 | 87 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 88 | """ 89 | embed_dim: output dimension for each position 90 | pos: a list of positions to be encoded: size (M,) 91 | out: (M, D) 92 | """ 93 | assert embed_dim % 2 == 0 94 | omega = np.arange(embed_dim // 2, dtype=np.float) 95 | omega /= embed_dim / 2. 96 | omega = 1. / 10000**omega # (D/2,) 97 | 98 | pos = pos.reshape(-1) # (M,) 99 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 100 | 101 | emb_sin = np.sin(out) # (M, D/2) 102 | emb_cos = np.cos(out) # (M, D/2) 103 | 104 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 105 | return emb 106 | 107 | class MultiHeadAttention(nn.Module): 108 | def __init__(self, emb_dim, num_heads, dropout): 109 | super().__init__() 110 | self.emb_dim = emb_dim 111 | self.num_heads = num_heads 112 | self.att_drop = nn.Dropout(dropout) 113 | self.projection = nn.Linear(emb_dim, emb_dim) 114 | self.qkv = nn.Linear(emb_dim, emb_dim * 3) 115 | 116 | def forward(self, x): 117 | qkv = self.qkv(x) 118 | qkv = rearrange(qkv, 'b n (h d qkv) -> (qkv) b h n d', h=self.num_heads, qkv=3) 119 | q, k, v = qkv[0], qkv[1], qkv[2] 120 | 121 | energy = torch.einsum('bhqd, bhkd -> bhqk', q, k) 122 | scaling = self.emb_dim ** 1/2 123 | energy = nn.functional.softmax(energy / scaling, dim=-1) 124 | 125 | energy = self.att_drop(energy) 126 | output = torch.einsum('bhqv, bhvd -> bhqd', energy, v) 127 | output = rearrange(output, 'b h q d -> b q (h d)') 128 | output = self.projection(output) 129 | return output 130 | 131 | class MultiheadCrossAttention(nn.Module): 132 | # operate only cross-attention for decoder 133 | # just change the query, key, value ? 134 | def __init__(self, enc_emb_dim, dec_emb_dim, num_heads, dropout): 135 | super().__init__() 136 | self.emb_dim = dec_emb_dim 137 | self.num_heads = num_heads 138 | self.att_drop = nn.Dropout(dropout) 139 | self.projection = nn.Linear(dec_emb_dim, dec_emb_dim) 140 | self.q = nn.Linear(dec_emb_dim, dec_emb_dim) 141 | self.k = nn.Linear(enc_emb_dim, dec_emb_dim) 142 | self.v = nn.Linear(enc_emb_dim, dec_emb_dim) 143 | 144 | def forward(self, x, enc_output): 145 | q = rearrange(self.q(x), 'b n (h d) -> b h n d', h=self.num_heads) 146 | k = rearrange(self.k(enc_output), 'b n (h d) -> b h n d', h=self.num_heads) 147 | v = rearrange(self.v(enc_output), 'b n (h d) -> b h n d', h=self.num_heads) 148 | 149 | energy = torch.einsum('bhqd, bhkd -> bhqk', q, k) 150 | scaling = self.emb_dim ** 1/2 151 | energy = nn.functional.softmax(energy / scaling, dim=-1) 152 | energy = self.att_drop(energy) 153 | output = torch.einsum('bhqv, bhvd -> bhqd', energy, v) 154 | output = rearrange(output, 'b h q d -> b q (h d)') 155 | output = self.projection(output) 156 | return output 157 | 158 | class FeedForwardBlock(nn.Sequential): 159 | def __init__(self, emb_dim, expansion=4, drop_p=0.): 160 | super().__init__( 161 | nn.Linear(emb_dim, emb_dim * expansion), 162 | nn.GELU(), 163 | nn.Dropout(drop_p), 164 | nn.Linear(emb_dim * expansion, emb_dim) 165 | ) 166 | 167 | class EncoderBlock(nn.Module): 168 | def __init__(self, emb_dim, num_heads, dropout=0.5): 169 | super().__init__() 170 | self.emb_dim = emb_dim 171 | self.block1 = nn.Sequential( 172 | nn.LayerNorm(emb_dim), 173 | MultiHeadAttention(emb_dim, num_heads, dropout), 174 | nn.Dropout(dropout) 175 | ) 176 | self.block2 = nn.Sequential( 177 | nn.LayerNorm(emb_dim), 178 | FeedForwardBlock(emb_dim), 179 | nn.Dropout(dropout) 180 | ) 181 | 182 | def forward(self, x): 183 | res1 = x 184 | x = self.block1(x) 185 | x += res1 186 | res2 = x 187 | x = self.block2(x) 188 | x += res2 189 | return x 190 | 191 | class DecoderBlock(nn.Module): 192 | def __init__(self, enc_emb_dim, dec_emb_dim, num_heads, dropout=0.5): 193 | super().__init__() 194 | self.emb_dim = dec_emb_dim 195 | self.norm1 = nn.LayerNorm(dec_emb_dim) 196 | self.attn1 = MultiheadCrossAttention(enc_emb_dim, dec_emb_dim, num_heads, dropout) 197 | self.drop1 = nn.Dropout(dropout) 198 | 199 | self.block2 = nn.Sequential( 200 | nn.LayerNorm (dec_emb_dim), 201 | FeedForwardBlock(dec_emb_dim), 202 | nn.Dropout(dropout) 203 | ) 204 | 205 | def forward(self, x, enc_output): 206 | res1 = x 207 | x = self.norm1(x) 208 | x = self.attn1(x, enc_output) 209 | x = self.drop1(x) 210 | x += res1 211 | res2 = x 212 | x = self.block2(x) 213 | x += res2 214 | return x 215 | 216 | """ for diffusion """ 217 | 218 | def linear_beta_schedule(timesteps, beta_start, beta_end): 219 | scale = 1000 / timesteps 220 | beta_start = scale * beta_start 221 | beta_end = scale * beta_end 222 | 223 | return torch.linspace(beta_start, beta_end, timesteps) 224 | 225 | def cosine_beta_schedule(timesteps, s = 0.008): 226 | """ 227 | cosine schedule 228 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 229 | """ 230 | steps = timesteps + 1 231 | x = torch.linspace(0, timesteps, steps) 232 | 233 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 234 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 235 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 236 | 237 | return torch.clamp(betas, 0, 0.999) -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import glob 16 | import time 17 | from collections import defaultdict, deque 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.distributed as dist 23 | from torch._six import inf 24 | 25 | 26 | class SmoothedValue(object): 27 | """Track a series of values and provide access to smoothed values over a 28 | window or the global series average. 29 | """ 30 | 31 | def __init__(self, window_size=20, fmt=None): 32 | if fmt is None: 33 | fmt = "{median:.4f} ({global_avg:.4f})" 34 | self.deque = deque(maxlen=window_size) 35 | self.total = 0.0 36 | self.count = 0 37 | self.fmt = fmt 38 | 39 | def update(self, value, n=1): 40 | self.deque.append(value) 41 | self.count += n 42 | self.total += value * n 43 | 44 | def synchronize_between_processes(self): 45 | """ 46 | Warning: does not synchronize the deque! 47 | """ 48 | if not is_dist_avail_and_initialized(): 49 | print('no use distributed training') 50 | return 51 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 52 | dist.barrier() 53 | dist.all_reduce(t) 54 | t = t.tolist() 55 | self.count = int(t[0]) 56 | self.total = t[1] 57 | 58 | @property 59 | def median(self): 60 | d = torch.tensor(list(self.deque)) 61 | return d.median().item() 62 | 63 | @property 64 | def avg(self): 65 | d = torch.tensor(list(self.deque), dtype=torch.float32) 66 | return d.mean().item() 67 | 68 | @property 69 | def global_avg(self): 70 | return self.total / self.count 71 | 72 | @property 73 | def max(self): 74 | return max(self.deque) 75 | 76 | @property 77 | def value(self): 78 | return self.deque[-1] 79 | 80 | def __str__(self): 81 | return self.fmt.format( 82 | median=self.median, 83 | avg=self.avg, 84 | global_avg=self.global_avg, 85 | max=self.max, 86 | value=self.value) 87 | 88 | 89 | class MetricLogger(object): 90 | def __init__(self, delimiter="\t"): 91 | self.meters = defaultdict(SmoothedValue) 92 | self.delimiter = delimiter 93 | 94 | def update(self, **kwargs): 95 | for k, v in kwargs.items(): 96 | if v is None: 97 | continue 98 | if isinstance(v, torch.Tensor): 99 | v = v.item() 100 | assert isinstance(v, (float, int)) 101 | self.meters[k].update(v) 102 | 103 | def __getattr__(self, attr): 104 | if attr in self.meters: 105 | return self.meters[attr] 106 | if attr in self.__dict__: 107 | return self.__dict__[attr] 108 | raise AttributeError("'{}' object has no attribute '{}'".format( 109 | type(self).__name__, attr)) 110 | 111 | def __str__(self): 112 | loss_str = [] 113 | for name, meter in self.meters.items(): 114 | # print('name: {}, meter:{}'.format(name, meter)) lr은 그냥 0.00000으로 출력, loss는 뒤에 괄호 붙는 형태로 출력 115 | loss_str.append( 116 | "{}: {}".format(name, str(meter)) 117 | ) # str(meter) -> 2.0018 (2.0018), 1.9997 (1.9944) 118 | return self.delimiter.join(loss_str) 119 | 120 | def synchronize_between_processes(self): 121 | for meter in self.meters.values(): 122 | meter.synchronize_between_processes() 123 | 124 | def add_meter(self, name, meter): 125 | self.meters[name] = meter 126 | 127 | def log_every(self, iterable, print_freq, header=None): 128 | i = 0 129 | if not header: 130 | header = '' 131 | start_time = time.time() 132 | end = time.time() 133 | iter_time = SmoothedValue(fmt='{avg:.4f}') 134 | data_time = SmoothedValue(fmt='{avg:.4f}') 135 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 136 | log_msg = [ 137 | header, 138 | '[{0' + space_fmt + '}/{1}]', 139 | 'eta: {eta}', 140 | '{meters}', 141 | 'time: {time}', 142 | 'data: {data}' 143 | ] 144 | if torch.cuda.is_available(): 145 | log_msg.append('max mem: {memory:.0f}') 146 | log_msg = self.delimiter.join(log_msg) 147 | MB = 1024.0 * 1024.0 148 | for obj in iterable: 149 | data_time.update(time.time() - end) 150 | yield obj 151 | iter_time.update(time.time() - end) 152 | if i % print_freq == 0 or i == len(iterable) - 1: 153 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 154 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) # 분 단위로 바꿔주는 작업? 155 | if torch.cuda.is_available(): 156 | print(log_msg.format( 157 | i, len(iterable), eta=eta_string, 158 | meters=str(self), # -> meters=str(self) 여기서 __str__ 메서드 실행되는 것 159 | time=str(iter_time), data=str(data_time), 160 | memory=torch.cuda.max_memory_allocated() / MB)) 161 | else: 162 | print(log_msg.format( 163 | i, len(iterable), eta=eta_string, 164 | meters=str(self), 165 | time=str(iter_time), data=str(data_time))) 166 | i += 1 167 | end = time.time() 168 | total_time = time.time() - start_time 169 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 170 | print('{} Total time: {} ({:.4f} s / it)'.format( 171 | header, total_time_str, total_time / len(iterable))) 172 | 173 | 174 | def setup_for_distributed(is_master): 175 | """ 176 | This function disables printing when not in master process 177 | """ 178 | builtin_print = builtins.print 179 | 180 | def print(*args, **kwargs): 181 | force = kwargs.pop('force', False) 182 | force = force or (get_world_size() > 8) 183 | if is_master or force: 184 | now = datetime.datetime.now().time() 185 | builtin_print('[{}] '.format(now), end='') # print with time stamp 186 | builtin_print(*args, **kwargs) 187 | 188 | builtins.print = print 189 | 190 | 191 | def is_dist_avail_and_initialized(): 192 | if not dist.is_available(): 193 | return False 194 | if not dist.is_initialized(): 195 | return False 196 | return True 197 | 198 | 199 | def get_world_size(): 200 | if not is_dist_avail_and_initialized(): 201 | return 1 202 | return dist.get_world_size() 203 | 204 | 205 | def get_rank(): 206 | if not is_dist_avail_and_initialized(): 207 | return 0 208 | return dist.get_rank() 209 | 210 | 211 | def is_main_process(): 212 | return get_rank() == 0 213 | 214 | 215 | def save_on_master(*args, **kwargs): 216 | if is_main_process(): 217 | torch.save(*args, **kwargs) 218 | 219 | 220 | def init_distributed_mode(args): 221 | if args.dist_on_itp: 222 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 223 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 224 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 225 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 226 | os.environ['LOCAL_RANK'] = str(args.gpu) 227 | os.environ['RANK'] = str(args.rank) 228 | os.environ['WORLD_SIZE'] = str(args.world_size) 229 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 230 | args.rank = int(os.environ["RANK"]) 231 | args.world_size = int(os.environ['WORLD_SIZE']) 232 | args.gpu = int(os.environ['LOCAL_RANK']) 233 | elif 'SLURM_PROCID' in os.environ: 234 | args.rank = int(os.environ['SLURM_PROCID']) 235 | args.gpu = args.rank % torch.cuda.device_count() 236 | else: 237 | print('Not using distributed mode') 238 | setup_for_distributed(is_master=True) # hack 239 | args.distributed = False 240 | return 241 | 242 | args.distributed = True 243 | 244 | torch.cuda.set_device(args.gpu) 245 | args.dist_backend = 'nccl' 246 | print('| distributed init (rank {}): {}, gpu {}'.format( 247 | args.rank, args.dist_url, args.gpu), flush=True) 248 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 249 | world_size=args.world_size, rank=args.rank) 250 | torch.distributed.barrier() 251 | setup_for_distributed(args.rank == 0) 252 | 253 | 254 | class NativeScalerWithGradNormCount: 255 | state_dict_key = "amp_scaler" 256 | 257 | def __init__(self): 258 | self._scaler = torch.cuda.amp.GradScaler() 259 | 260 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 261 | self._scaler.scale(loss).backward(create_graph=create_graph) # loss 를 scaling 하고 backward 를 진행함. 262 | if update_grad: 263 | if clip_grad is not None: 264 | assert parameters is not None 265 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 266 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 267 | else: 268 | self._scaler.unscale_(optimizer) 269 | norm = get_grad_norm_(parameters) 270 | self._scaler.step(optimizer) 271 | self._scaler.update() 272 | else: 273 | norm = None 274 | return norm 275 | 276 | def state_dict(self): 277 | return self._scaler.state_dict() 278 | 279 | def load_state_dict(self, state_dict): 280 | self._scaler.load_state_dict(state_dict) 281 | 282 | 283 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 284 | if isinstance(parameters, torch.Tensor): 285 | parameters = [parameters] 286 | parameters = [p for p in parameters if p.grad is not None] 287 | norm_type = float(norm_type) 288 | if len(parameters) == 0: 289 | return torch.tensor(0.) 290 | device = parameters[0].grad.device 291 | if norm_type == inf: 292 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 293 | else: 294 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 295 | return total_norm 296 | 297 | 298 | def save_model(args, epoch, model, optimizer, loss_scaler): 299 | output_dir = Path(args.savedir) 300 | epoch_name = str(epoch) 301 | if os.path.exists('{}/{}'.format(args.savedir, 'checkpoint-%s.pth' % epoch_name)): 302 | os.remove('{}/{}'.format(args.savedir, 'checkpoint-%s.pth' % epoch_name)) 303 | 304 | 305 | if loss_scaler is not None: 306 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 307 | 308 | for checkpoint_path in checkpoint_paths: 309 | if isinstance(model, nn.DataParallel) or isinstance(model, nn.parallel.DistributedDataParallel): 310 | to_save = { 311 | 'model': model.module.state_dict(), 312 | 'optimizer': optimizer.state_dict(), 313 | 'epoch': epoch, 314 | 'scaler': loss_scaler.state_dict(), 315 | 'args': args, 316 | } 317 | else: 318 | to_save = { 319 | 'model': model.state_dict(), 320 | 'optimizer': optimizer.state_dict(), 321 | 'epoch': epoch, 322 | 'scaler': loss_scaler.state_dict(), 323 | 'args': args, 324 | } 325 | 326 | save_on_master(to_save, checkpoint_path) 327 | else: 328 | client_state = {'epoch': epoch} 329 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 330 | 331 | 332 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 333 | if args.resume: 334 | if args.resume.startswith('https'): 335 | checkpoint = torch.hub.load_state_dict_from_url( 336 | args.resume, map_location='cpu', check_hash=True) 337 | else: 338 | checkpoint = torch.load(args.resume, map_location='cpu') 339 | 340 | if args.multi_gpu: 341 | model_without_ddp.module.load_state_dict(checkpoint['model']) 342 | else: 343 | model_without_ddp.load_state_dict(checkpoint['model']) 344 | 345 | print("Resume checkpoint %s" % args.resume) 346 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 347 | optimizer.load_state_dict(checkpoint['optimizer']) 348 | if isinstance(checkpoint['epoch'], int): 349 | args.start_epoch = checkpoint['epoch'] + 1 350 | if 'scaler' in checkpoint: 351 | loss_scaler.load_state_dict(checkpoint['scaler']) 352 | print("With optim & sched!") 353 | 354 | 355 | def all_reduce_mean(x): 356 | world_size = get_world_size() 357 | if world_size > 1: 358 | x_reduce = torch.tensor(x).to('cuda') 359 | dist.all_reduce(x_reduce) 360 | x_reduce /= world_size 361 | return x_reduce.item() 362 | else: 363 | return x 364 | --------------------------------------------------------------------------------