├── custom ├── __init__.py ├── caltime.py ├── serverLog.py ├── lr_scheduler.py └── optimizers.py ├── test.sh ├── train.sh ├── README.md ├── configs └── config.yml ├── datas ├── benchmark.py ├── utils.py └── div2k.py ├── models ├── m_network.py └── m_block.py ├── test_custom_image.py ├── utils.py ├── test.py └── train.py /custom/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py --config ./configs/config.yml -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | export NCCL_IB_DISABLE=1; export NCCL_P2P_DISABLE=1; python train.py --config ./configs/config.yml 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### This is the code of our recent work entitled SRConvNet: A Transformer-Style ConvNet for Lightweight Image Super-Resolution (IJCV 2024) 2 | 3 | If you find our work useful in your research or publications, please consider citing: 4 | ``` 5 | @article{li2024srconvnet, 6 | title={SRConvNet: A Transformer-Style ConvNet for Lightweight Image Super-Resolution}, 7 | author={Li, Feng and Cong, Runmin and Wu, Jingjing and Bai, Huihui and Wang, Meng and Zhao, Yao}, 8 | journal={International Journal of Computer Vision}, 9 | pages={1--17}, 10 | year={2024}, 11 | publisher={Springer} 12 | } 13 | ``` 14 | -------------------------------------------------------------------------------- /configs/config.yml: -------------------------------------------------------------------------------- 1 | model: 'm' 2 | scale: 4 3 | rgb_range: 255 4 | colors: 3 5 | num_blocks: 8 6 | num_heads: 8 7 | num_kernels: 16 8 | dim: 64 9 | optim: 'adam' 10 | fp: 32 11 | # loss: 'SmoothL1Loss' 12 | loss: 'L1Loss' 13 | pretrain: "/folder/model.pt" 14 | 15 | 16 | ## parameters for model training 17 | patch_size: 256 18 | batch_size: 128 19 | data_repeat: 40 20 | data_augment: 1 21 | 22 | epochs: 1000 23 | lr: 0.0002 24 | decays: [500,800,900,950] 25 | gamma: 0.5 26 | log_every: 100 27 | test_every: 1 28 | log_path: "./experiments" 29 | log_name: 30 | 31 | ## hardware specification 32 | gpu_ids: [0,1] 33 | 34 | ## data specification 35 | data_path: '/your/test/datasets/' 36 | eval_sets: 'Set5' 37 | #eval_sets: ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109'] 38 | -------------------------------------------------------------------------------- /custom/caltime.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pytz 3 | import time 4 | import datetime 5 | 6 | 7 | class RemainTime: 8 | def __init__(self, epoch): 9 | self.start_time = time.time() 10 | self.epoch = epoch 11 | 12 | def update(self, now_epoch): 13 | epoch_time = time.time() - self.start_time 14 | epoch_remaining = self.epoch - now_epoch 15 | time_remaining = epoch_time * epoch_remaining 16 | pytz.timezone('Asia/Shanghai') # 东八区 17 | t = datetime.datetime.fromtimestamp(int(time.time()) + time_remaining, 18 | pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d %H:%M:%S') 19 | print('epochs remaining:', epoch_remaining, '\tfinishing time:', t) 20 | 21 | self.start_time = time.time() 22 | -------------------------------------------------------------------------------- /custom/serverLog.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from urllib.parse import quote 3 | import _thread 4 | import time 5 | import threading 6 | 7 | threadLock = threading.Lock() 8 | 9 | 10 | def threadSendLog(content, name): 11 | threadLock.acquire() 12 | content = quote(content, 'utf-8') 13 | name = quote(name, 'utf-8') 14 | url = '' 15 | # print('sendLog:' + url) 16 | try: 17 | # print("----------------sendLog...----------------") 18 | r = requests.get(url, timeout=5) 19 | # print('\nsendLog finish', r.status_code, r.content) 20 | # print('sendLog finish') 21 | except Exception as e: 22 | print('\nsendLog network error!') 23 | finally: 24 | # print("----------------sendLog...----------------") 25 | threadLock.release() 26 | 27 | 28 | class LogClass: 29 | def __init__(self, on=False): 30 | self.on = on 31 | 32 | def sendLog(self, content, name): 33 | if self.on: 34 | try: 35 | _thread.start_new_thread(threadSendLog, (content, name)) 36 | except: 37 | print("cloud log error") 38 | 39 | 40 | if __name__ == '__main__': 41 | log_class = LogClass() 42 | log_class.sendLog('35.8', 'PSNR') 43 | log_class.sendLog('35.8', 'PSNR') 44 | while True: 45 | pass 46 | -------------------------------------------------------------------------------- /datas/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | import numpy as np 7 | import imageio 8 | import torch 9 | import torch.utils.data as data 10 | import skimage.color as sc 11 | from torch.utils.data import DataLoader 12 | import time 13 | from utils import ndarray2tensor 14 | 15 | 16 | class Benchmark(data.Dataset): 17 | def __init__(self, HR_folder, LR_folder, scale=2, colors=1): 18 | super(Benchmark, self).__init__() 19 | self.HR_folder = HR_folder 20 | self.LR_folder = LR_folder 21 | 22 | self.img_postfix = '.png' 23 | self.scale = scale 24 | self.colors = colors 25 | 26 | self.nums_dataset = 0 27 | 28 | self.hr_filenames = [] 29 | self.lr_filenames = [] 30 | ## generate dataset 31 | tags = os.listdir(self.HR_folder) 32 | for tag in tags: 33 | hr_filename = os.path.join(self.HR_folder, tag) 34 | lr_filename = os.path.join(self.LR_folder, 'X{}'.format(scale), 35 | tag.replace('.png', 'x{}.png'.format(self.scale))) 36 | self.hr_filenames.append(hr_filename) 37 | self.lr_filenames.append(lr_filename) 38 | self.nums_trainset = len(self.hr_filenames) 39 | ## if store in ram 40 | self.hr_images = [] 41 | self.lr_images = [] 42 | 43 | LEN = len(self.hr_filenames) 44 | for i in range(LEN): 45 | lr_image, hr_image = imageio.imread(self.lr_filenames[i], pilmode="RGB"), imageio.imread( 46 | self.hr_filenames[i], pilmode="RGB") 47 | if self.colors == 1: 48 | lr_image, hr_image = sc.rgb2ycbcr(lr_image)[:, :, 0:1], sc.rgb2ycbcr(hr_image)[:, :, 0:1] 49 | self.hr_images.append(hr_image) 50 | 51 | self.lr_images.append(lr_image) 52 | 53 | def __len__(self): 54 | return len(self.hr_filenames) 55 | 56 | def __getitem__(self, idx): 57 | # get whole image, store in ram by default 58 | lr, hr = self.lr_images[idx], self.hr_images[idx] 59 | lr_h, lr_w, _ = lr.shape 60 | hr = hr[0:lr_h * self.scale, 0:lr_w * self.scale, :] 61 | lr, hr = ndarray2tensor(lr), ndarray2tensor(hr) 62 | return lr, hr, self.hr_filenames[idx] 63 | -------------------------------------------------------------------------------- /models/m_network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.m_block import * 6 | 7 | def create_model(args): 8 | return SRNet(args) 9 | 10 | 11 | class SRNet(nn.Module): 12 | def __init__(self, args): 13 | super().__init__() 14 | self.sub_mean = MeanShift(args.rgb_range) 15 | self.add_mean = MeanShift(args.rgb_range, sign=1) 16 | self.scale = args.scale 17 | self.num_heads = args.num_heads 18 | self.num_kernels = args.num_kernels 19 | self.colors = args.colors 20 | self.dim = args.dim 21 | self.num_blocks = args.num_blocks 22 | 23 | self.to_feat = nn.Conv2d(self.colors, self.dim, kernel_size=3, stride=1, padding=1) 24 | self.blocks = nn.Sequential( 25 | *[BasicBlock(self.dim, self.num_heads, self.num_kernels) for _ in range(self.num_blocks)] 26 | ) 27 | 28 | if self.scale == 4: 29 | self.upsampling = nn.Sequential( 30 | nn.Conv2d(self.dim, self.dim * 4, 1, 1, 0), 31 | nn.PixelShuffle(2), 32 | nn.GELU(), 33 | nn.Conv2d(self.dim, self.dim * 4, 1, 1, 0), 34 | nn.PixelShuffle(2), 35 | nn.GELU() 36 | ) 37 | else: 38 | self.upsampling = nn.Sequential( 39 | nn.Conv2d(self.dim, self.dim * self.scale * self.scale, 1, 1, 0), 40 | nn.PixelShuffle(self.scale), 41 | nn.GELU() 42 | ) 43 | 44 | self.tail = nn.Conv2d(self.dim, self.colors, 3, 1, 1) 45 | 46 | def forward(self, x): 47 | base = x 48 | x = self.to_feat(x) 49 | x_init = x 50 | x = self.blocks(x) + x_init 51 | x = self.upsampling(x) 52 | x = self.tail(x) 53 | base = F.interpolate(base, scale_factor=self.scale, mode='bilinear', align_corners=False) 54 | return x + base 55 | 56 | def load(self, state_dict, strict=False): 57 | own_state = self.state_dict() 58 | for name, param in state_dict.items(): 59 | name = name[name.index('.') + 1:] 60 | if name in own_state: 61 | if isinstance(param, nn.Parameter): 62 | param = param.data 63 | try: 64 | own_state[name].copy_(param) 65 | except Exception: 66 | if name.find('upsampling') == -1: 67 | raise RuntimeError('While copying the parameter named {}, ' 68 | 'whose dimensions in the model are {} and ' 69 | 'whose dimensions in the checkpoint are {}.' 70 | .format(name, own_state[name].size(), param.size())) 71 | elif strict: 72 | if name.find('upsampling') == -1: 73 | raise KeyError('unexpected key "{}" in state_dict' 74 | .format(name)) 75 | 76 | -------------------------------------------------------------------------------- /datas/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datas.benchmark import Benchmark 3 | from datas.div2k import DIV2K 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | def create_datasets(args): 8 | div2k = DIV2K( 9 | # os.path.join(args.data_path, 'DF2K/DF2K_train_HR'), 10 | # os.path.join(args.data_path, 'DF2K/DF2K_train_LR_bicubic'), 11 | # os.path.join(args.data_path, 'df2k_cache'), 12 | os.path.join(args.data_path, 'DIV2K/DIV2K_train_HR'), 13 | os.path.join(args.data_path, 'DIV2K/DIV2K_train_LR_bicubic'), 14 | os.path.join(args.data_path, 'div2k_cache'), 15 | train=True, 16 | augment=args.data_augment, 17 | scale=args.scale, 18 | colors=args.colors, 19 | patch_size=args.patch_size, 20 | repeat=args.data_repeat, 21 | ) 22 | train_dataloader = DataLoader(dataset=div2k, batch_size=args.batch_size, shuffle=True, 23 | pin_memory=True, drop_last=True) 24 | 25 | valid_dataloaders = [] 26 | if 'Set5' in args.eval_sets: 27 | set5_hr_path = os.path.join(args.data_path, 'benchmark/Set5/HR') 28 | set5_lr_path = os.path.join(args.data_path, 'benchmark/Set5/LR_bicubic') 29 | set5 = Benchmark(set5_hr_path, set5_lr_path, scale=args.scale, colors=args.colors) 30 | valid_dataloaders += [{'name': 'set5', 'dataloader': DataLoader(dataset=set5, batch_size=1, shuffle=False)}] 31 | if 'Set14' in args.eval_sets: 32 | set14_hr_path = os.path.join(args.data_path, 'benchmark/Set14/HR') 33 | set14_lr_path = os.path.join(args.data_path, 'benchmark/Set14/LR_bicubic') 34 | set14 = Benchmark(set14_hr_path, set14_lr_path, scale=args.scale, colors=args.colors) 35 | valid_dataloaders += [{'name': 'set14', 'dataloader': DataLoader(dataset=set14, batch_size=1, shuffle=False)}] 36 | if 'B100' in args.eval_sets: 37 | b100_hr_path = os.path.join(args.data_path, 'benchmark/B100/HR') 38 | b100_lr_path = os.path.join(args.data_path, 'benchmark/B100/LR_bicubic') 39 | b100 = Benchmark(b100_hr_path, b100_lr_path, scale=args.scale, colors=args.colors) 40 | valid_dataloaders += [{'name': 'b100', 'dataloader': DataLoader(dataset=b100, batch_size=1, shuffle=False)}] 41 | if 'Urban100' in args.eval_sets: 42 | u100_hr_path = os.path.join(args.data_path, 'benchmark/Urban100/HR') 43 | u100_lr_path = os.path.join(args.data_path, 'benchmark/Urban100/LR_bicubic') 44 | u100 = Benchmark(u100_hr_path, u100_lr_path, scale=args.scale, colors=args.colors) 45 | valid_dataloaders += [{'name': 'u100', 'dataloader': DataLoader(dataset=u100, batch_size=1, shuffle=False)}] 46 | if 'Manga109' in args.eval_sets: 47 | manga_hr_path = os.path.join(args.data_path, 'benchmark/Manga109/HR') 48 | manga_lr_path = os.path.join(args.data_path, 'benchmark/Manga109/LR_bicubic') 49 | manga = Benchmark(manga_hr_path, manga_lr_path, scale=args.scale, colors=args.colors) 50 | valid_dataloaders += [ 51 | {'name': 'manga109', 'dataloader': DataLoader(dataset=manga, batch_size=1, shuffle=False)}] 52 | 53 | if len(valid_dataloaders) == 0: 54 | print('select no dataset for evaluation!') 55 | else: 56 | selected = '' 57 | for i in range(1, len(valid_dataloaders)): 58 | selected += ", " + valid_dataloaders[i]['name'] 59 | print('select {} for evaluation! '.format(selected)) 60 | return train_dataloader, valid_dataloaders 61 | -------------------------------------------------------------------------------- /test_custom_image.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse, yaml 3 | import utils 4 | import os 5 | from tqdm import tqdm 6 | from torchvision import utils as vutils 7 | import imageio 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.optim.lr_scheduler import MultiStepLR, StepLR 12 | from datas.utils import create_datasets 13 | from multiprocessing import Process 14 | from multiprocessing import Queue 15 | import time 16 | import os 17 | from utils import ndarray2tensor 18 | import matplotlib.pyplot as plt 19 | 20 | class save_img(): 21 | def __init__(self): 22 | self.n_processes = 32 23 | 24 | def begin_background(self): 25 | self.queue = Queue() 26 | 27 | def bg_target(queue): 28 | while True: 29 | if not queue.empty(): 30 | filename, tensor = queue.get() 31 | if filename is None: break 32 | imageio.imwrite(filename, tensor.numpy()) 33 | 34 | self.process = [ 35 | Process(target=bg_target, args=(self.queue,)) \ 36 | for _ in range(self.n_processes) 37 | ] 38 | for p in self.process: p.start() 39 | 40 | def end_background(self): 41 | for _ in range(self.n_processes): self.queue.put((None, None)) 42 | while not self.queue.empty(): time.sleep(1) 43 | for p in self.process: p.join() 44 | 45 | def save_results(self, filename, img): 46 | tensor_cpu = img[0].byte().permute(1, 2, 0).cpu() 47 | self.queue.put((filename, tensor_cpu)) 48 | 49 | 50 | parser = argparse.ArgumentParser(description='config') 51 | 52 | parser.add_argument('--config', type=str, default=None, help='pre-config file for training') 53 | parser.add_argument('--resume', type=str, default=None, help='resume training or not') 54 | parser.add_argument('--custom', type=str, default=None, help='use custom block') 55 | parser.add_argument('--cloudlog', type=str, default=None, help='use cloudlog') 56 | parser.add_argument('--custom_image_path', type=str, default=None, help='path of the custom image') 57 | 58 | device = None 59 | 60 | args = parser.parse_args() 61 | 62 | if args.config: 63 | opt = vars(args) 64 | yaml_args = yaml.load(open(args.config), Loader=yaml.FullLoader) 65 | opt.update(yaml_args) 66 | 67 | ## set visibel gpu 68 | gpu_ids_str = str(args.gpu_ids).replace('[', '').replace(']', '') 69 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 70 | os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(gpu_ids_str) 71 | 72 | ## select active gpu devices 73 | device = None 74 | if len(args.gpu_ids) > 0 and torch.cuda.is_available(): 75 | print('use cuda & cudnn for acceleration!') 76 | print('the gpu id is: {}'.format(args.gpu_ids)) 77 | device = torch.device('cuda') 78 | torch.backends.cudnn.benchmark = True 79 | else: 80 | print('use cpu for training!') 81 | device = torch.device('cpu') 82 | # torch.set_num_threads(args.threads) 83 | 84 | ## definitions of model 85 | try: 86 | model = utils.import_module('models.{}_network'.format(args.model)).create_model(args) 87 | except Exception: 88 | raise ValueError('not supported model type! or something') 89 | if args.fp == 16: 90 | model.half() 91 | 92 | ## load pretrain 93 | if args.pretrain is not None: 94 | print('load pretrained model: {}!'.format(args.pretrain)) 95 | ckpt = torch.load(args.pretrain, map_location=device) 96 | model.load(ckpt['model_state_dict']) 97 | 98 | 99 | model = nn.DataParallel(model).to(device) 100 | print(model) 101 | model = model.eval() 102 | torch.set_grad_enabled(False) 103 | save_path = args.log_path 104 | si = save_img() 105 | si.begin_background() 106 | 107 | filePath = args.custom_image_path 108 | for filename in tqdm(os.listdir(filePath), ncols=80): 109 | lr = imageio.imread(filePath + os.sep + filename) 110 | lr = ndarray2tensor(lr) 111 | lr = torch.unsqueeze(lr, 0) 112 | if args.fp == 16: 113 | lr = lr.type(torch.HalfTensor) 114 | lr = lr.to(device) 115 | sr = model(lr) 116 | 117 | # quantize output to [0, 255] 118 | sr = sr.clamp(0, 255).round() 119 | path = save_path + os.sep + 'custom' + os.sep 120 | if not os.path.exists(path): 121 | os.makedirs(path) 122 | fileUname, ext = '.'.join(filename.split('.')[:-1]), filename.split('.')[-1] 123 | path += (fileUname + '_x' + str(args.scale) + '_SR' + '.' + ext) 124 | si.save_results(path, sr) 125 | 126 | si.end_background() -------------------------------------------------------------------------------- /custom/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Code adapted from Detectron2(https://github.com/facebookresearch/detectron2) 2 | import math 3 | from bisect import bisect_right 4 | from typing import List 5 | import torch 6 | 7 | 8 | # NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes 9 | # only on epoch boundaries. We typically use iteration based schedules instead. 10 | # As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean 11 | # "iteration" instead. 12 | 13 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 14 | def __init__( 15 | self, 16 | optimizer: torch.optim.Optimizer, 17 | milestones: List[int], 18 | gamma: float = 0.1, 19 | warmup_factor: float = 0.001, 20 | warmup_iters: int = 1000, 21 | warmup_method: str = "linear", 22 | last_epoch: int = -1, 23 | ): 24 | if not list(milestones) == sorted(milestones): 25 | raise ValueError( 26 | "Milestones should be a list of" " increasing integers. Got {}", milestones 27 | ) 28 | self.milestones = milestones 29 | self.gamma = gamma 30 | self.warmup_factor = warmup_factor 31 | self.warmup_iters = warmup_iters 32 | self.warmup_method = warmup_method 33 | super().__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self) -> List[float]: 36 | warmup_factor = _get_warmup_factor_at_iter( 37 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 38 | ) 39 | return [ 40 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 41 | for base_lr in self.base_lrs 42 | ] 43 | 44 | def _compute_values(self) -> List[float]: 45 | # The new interface 46 | return self.get_lr() 47 | 48 | 49 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 50 | def __init__( 51 | self, 52 | optimizer: torch.optim.Optimizer, 53 | max_iters: int, 54 | warmup_factor: float = 0.001, 55 | warmup_iters: int = 1000, 56 | warmup_method: str = "linear", 57 | last_epoch: int = -1, 58 | min_lr: float = 0.0, # minimal learning rate 59 | ): 60 | self.max_iters = max_iters 61 | self.warmup_factor = warmup_factor 62 | self.warmup_iters = warmup_iters 63 | self.warmup_method = warmup_method 64 | self.min_lr = min_lr 65 | super().__init__(optimizer, last_epoch) 66 | 67 | def get_lr(self) -> List[float]: 68 | warmup_factor = _get_warmup_factor_at_iter( 69 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 70 | ) 71 | # Different definitions of half-cosine with warmup are possible. For 72 | # simplicity we multiply the standard half-cosine schedule by the warmup 73 | # factor. An alternative is to start the period of the cosine at warmup_iters 74 | # instead of at 0. In the case that warmup_iters << max_iters the two are 75 | # very close to each other. 76 | return [max(self.min_lr, 77 | base_lr 78 | * warmup_factor 79 | * 0.5 80 | * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters))) 81 | for base_lr in self.base_lrs 82 | ] 83 | 84 | def _compute_values(self) -> List[float]: 85 | # The new interface 86 | return self.get_lr() 87 | 88 | 89 | def _get_warmup_factor_at_iter( 90 | method: str, iter: int, warmup_iters: int, warmup_factor: float 91 | ) -> float: 92 | """ 93 | Return the learning rate warmup factor at a specific iteration. 94 | See https://arxiv.org/abs/1706.02677 for more details. 95 | 96 | Args: 97 | method (str): warmup method; either "constant" or "linear". 98 | iter (int): iteration at which to calculate the warmup factor. 99 | warmup_iters (int): the number of warmup iterations. 100 | warmup_factor (float): the base warmup factor (the meaning changes according 101 | to the method used). 102 | 103 | Returns: 104 | float: the effective warmup factor at the given iteration. 105 | """ 106 | if iter >= warmup_iters: 107 | return 1.0 108 | 109 | if method == "constant": 110 | return warmup_factor 111 | elif method == "linear": 112 | alpha = iter / warmup_iters 113 | return warmup_factor * (1 - alpha) + alpha 114 | else: 115 | raise ValueError("Unknown warmup method: {}".format(method)) 116 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | import numpy as np 7 | import datetime 8 | import os 9 | import sys 10 | import cv2 11 | from math import exp 12 | from pytorch_msssim import ssim 13 | import importlib 14 | 15 | 16 | def rgb_to_ycbcr(image: torch.Tensor) -> torch.Tensor: 17 | r"""Convert an RGB image to YCbCr. 18 | 19 | Args: 20 | image (torch.Tensor): RGB Image to be converted to YCbCr. 21 | 22 | Returns: 23 | torch.Tensor: YCbCr version of the image. 24 | """ 25 | 26 | if not torch.is_tensor(image): 27 | raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(image))) 28 | 29 | if len(image.shape) < 3 or image.shape[-3] != 3: 30 | raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}".format(image.shape)) 31 | 32 | image = image / 255. ## image in range (0, 1) 33 | r: torch.Tensor = image[..., 0, :, :] 34 | g: torch.Tensor = image[..., 1, :, :] 35 | b: torch.Tensor = image[..., 2, :, :] 36 | 37 | y: torch.Tensor = 65.481 * r + 128.553 * g + 24.966 * b + 16.0 38 | cb: torch.Tensor = -37.797 * r + -74.203 * g + 112.0 * b + 128.0 39 | cr: torch.Tensor = 112.0 * r + -93.786 * g + -18.214 * b + 128.0 40 | 41 | return torch.stack((y, cb, cr), -3) 42 | 43 | def prepare_qat(model): 44 | ## fuse model 45 | model.module.fuse_model() 46 | ## qconfig and qat-preparation & per-channel quantization 47 | model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') 48 | model = torch.quantization.prepare_qat(model, inplace=True) 49 | return model 50 | 51 | def import_module(name): 52 | return importlib.import_module(name) 53 | 54 | def calc_psnr(sr, hr): 55 | sr, hr = sr.double(), hr.double() 56 | diff = (sr - hr) / 255.00 57 | mse = diff.pow(2).mean() 58 | psnr = -10 * math.log10(mse) 59 | return float(psnr) 60 | 61 | def calc_ssim(sr, hr): 62 | ssim_val = ssim(sr, hr, size_average=True) 63 | return float(ssim_val) 64 | 65 | def ndarray2tensor(ndarray_hwc): 66 | ndarray_chw = np.ascontiguousarray(ndarray_hwc.transpose((2, 0, 1))) 67 | tensor = torch.from_numpy(ndarray_chw).float() 68 | return tensor 69 | 70 | def cur_timestamp_str(): 71 | now = datetime.datetime.now() 72 | year = str(now.year) 73 | month = str(now.month).zfill(2) 74 | day = str(now.day).zfill(2) 75 | hour = str(now.hour).zfill(2) 76 | minute = str(now.minute).zfill(2) 77 | 78 | content = "{}-{}{}-{}{}".format(year, month, day, hour, minute) 79 | return content 80 | 81 | 82 | class ExperimentLogger(object): 83 | def __init__(self, filename='default.log', stream=sys.stdout): 84 | self.terminal = stream 85 | self.log = open(filename, 'a') 86 | 87 | def write(self, message): 88 | self.terminal.write(message) 89 | self.log.write(message) 90 | 91 | def flush(self): 92 | self.terminal.flush() 93 | self.log.flush() 94 | 95 | def get_stat_dict(): 96 | stat_dict = { 97 | 'epochs': 0, 98 | 'losses': [], 99 | 'ema_loss': 0.0, 100 | 'set5': { 101 | 'psnrs': [], 102 | 'ssims': [], 103 | 'best_psnr': { 104 | 'value': 0.0, 105 | 'epoch': 0 106 | }, 107 | 'best_ssim': { 108 | 'value': 0.0, 109 | 'epoch': 0 110 | } 111 | }, 112 | 'set14': { 113 | 'psnrs': [], 114 | 'ssims': [], 115 | 'best_psnr': { 116 | 'value': 0.0, 117 | 'epoch': 0 118 | }, 119 | 'best_ssim': { 120 | 'value': 0.0, 121 | 'epoch': 0 122 | } 123 | }, 124 | 'b100': { 125 | 'psnrs': [], 126 | 'ssims': [], 127 | 'best_psnr': { 128 | 'value': 0.0, 129 | 'epoch': 0 130 | }, 131 | 'best_ssim': { 132 | 'value': 0.0, 133 | 'epoch': 0 134 | } 135 | }, 136 | 'u100': { 137 | 'psnrs': [], 138 | 'ssims': [], 139 | 'best_psnr': { 140 | 'value': 0.0, 141 | 'epoch': 0 142 | }, 143 | 'best_ssim': { 144 | 'value': 0.0, 145 | 'epoch': 0 146 | } 147 | }, 148 | 'manga109': { 149 | 'psnrs': [], 150 | 'ssims': [], 151 | 'best_psnr': { 152 | 'value': 0.0, 153 | 'epoch': 0 154 | }, 155 | 'best_ssim': { 156 | 'value': 0.0, 157 | 'epoch': 0 158 | } 159 | } 160 | } 161 | return stat_dict 162 | 163 | if __name__ == '__main__': 164 | timestamp = cur_timestamp_str() 165 | print(timestamp) 166 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse, yaml 3 | import utils 4 | import os 5 | from tqdm import tqdm 6 | from torchvision import utils as vutils 7 | import imageio 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.optim.lr_scheduler import MultiStepLR, StepLR 12 | from datas.utils import create_datasets 13 | from multiprocessing import Process 14 | from multiprocessing import Queue 15 | import time 16 | 17 | 18 | class save_img(): 19 | def __init__(self): 20 | self.n_processes = 32 21 | 22 | def begin_background(self): 23 | self.queue = Queue() 24 | 25 | def bg_target(queue): 26 | while True: 27 | if not queue.empty(): 28 | filename, tensor = queue.get() 29 | if filename is None: break 30 | imageio.imwrite(filename, tensor.numpy()) 31 | 32 | self.process = [ 33 | Process(target=bg_target, args=(self.queue,)) \ 34 | for _ in range(self.n_processes) 35 | ] 36 | for p in self.process: p.start() 37 | 38 | def end_background(self): 39 | for _ in range(self.n_processes): self.queue.put((None, None)) 40 | while not self.queue.empty(): time.sleep(1) 41 | for p in self.process: p.join() 42 | 43 | def save_results(self, filename, img): 44 | tensor_cpu = img[0].byte().permute(1, 2, 0).cpu() 45 | self.queue.put((filename, tensor_cpu)) 46 | 47 | 48 | parser = argparse.ArgumentParser(description='config') 49 | 50 | parser.add_argument('--config', type=str, default=None, help='pre-config file for training') 51 | parser.add_argument('--resume', type=str, default=None, help='resume training or not') 52 | parser.add_argument('--custom', type=str, default=None, help='use custom block') 53 | parser.add_argument('--cloudlog', type=str, default=None, help='use cloudlog') 54 | 55 | device = None 56 | 57 | args = parser.parse_args() 58 | 59 | if args.config: 60 | opt = vars(args) 61 | yaml_args = yaml.load(open(args.config), Loader=yaml.FullLoader) 62 | opt.update(yaml_args) 63 | 64 | ## set visibel gpu 65 | gpu_ids_str = str(args.gpu_ids).replace('[', '').replace(']', '') 66 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 67 | os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(gpu_ids_str) 68 | 69 | ## select active gpu devices 70 | device = None 71 | if args.gpu_ids is not None and torch.cuda.is_available(): 72 | print('use cuda & cudnn for acceleration!') 73 | print('the gpu id is: {}'.format(args.gpu_ids)) 74 | device = torch.device('cuda') 75 | torch.backends.cudnn.benchmark = True 76 | else: 77 | print('use cpu for training!') 78 | device = torch.device('cpu') 79 | # torch.set_num_threads(args.threads) 80 | args.eval_sets = ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109'] 81 | # args.eval_sets = ['Set5'] 82 | ## create dataset for training and validating 83 | train_dataloader, valid_dataloaders = create_datasets(args) 84 | 85 | ## definitions of model 86 | try: 87 | model = utils.import_module('models.{}_network'.format(args.model)).create_model(args) 88 | except Exception: 89 | raise ValueError('not supported model type! or something') 90 | if args.fp == 16: 91 | model.half() 92 | 93 | ## load pretrain 94 | if args.pretrain is not None: 95 | print('load pretrained model: {}!'.format(args.pretrain)) 96 | ckpt = torch.load(args.pretrain) 97 | model.load(ckpt['model_state_dict']) 98 | 99 | model = nn.DataParallel(model).to(device) 100 | 101 | model = model.eval() 102 | torch.set_grad_enabled(False) 103 | save_path = args.log_path 104 | si = save_img() 105 | si.begin_background() 106 | for valid_dataloader in valid_dataloaders: 107 | 108 | avg_psnr, avg_ssim = 0.0, 0.0 109 | name = valid_dataloader['name'] 110 | loader = valid_dataloader['dataloader'] 111 | for lr, hr, filename in tqdm(loader, ncols=80): 112 | if args.fp == 16: 113 | lr, hr = lr.type(torch.HalfTensor), hr.type(torch.HalfTensor) 114 | lr, hr = lr.to(device), hr.to(device) 115 | sr = model(lr) 116 | 117 | # quantize output to [0, 255] 118 | hr = hr.clamp(0, 255).round() 119 | sr = sr.clamp(0, 255).round() 120 | 121 | path = save_path + os.sep + name + os.sep 122 | if not os.path.exists(path): 123 | os.makedirs(path) 124 | path += filename[0].replace('.png', '_x' + str(args.scale) + '_SR' + '.png') 125 | si.save_results(path, sr) 126 | # tensor_cpu = sr[0].byte().permute(1, 2, 0).cpu() 127 | # imageio.imwrite(path, tensor_cpu.numpy()) 128 | 129 | # conver to ycbcr 130 | if args.colors == 3: 131 | hr_ycbcr = utils.rgb_to_ycbcr(hr) 132 | sr_ycbcr = utils.rgb_to_ycbcr(sr) 133 | hr = hr_ycbcr[:, 0:1, :, :] 134 | sr = sr_ycbcr[:, 0:1, :, :] 135 | # crop image for evaluation 136 | hr = hr[:, :, args.scale:-args.scale, args.scale:-args.scale] 137 | sr = sr[:, :, args.scale:-args.scale, args.scale:-args.scale] 138 | # calculate psnr and ssim 139 | psnr = utils.calc_psnr(sr, hr) 140 | ssim = utils.calc_ssim(sr, hr) 141 | avg_psnr += psnr 142 | avg_ssim += ssim 143 | avg_psnr = round(avg_psnr / len(loader), 2) 144 | avg_ssim = round(avg_ssim / len(loader), 4) 145 | test_log = '[{}x{} PSNR/SSIM: {:.2f}/{:.4f}]'.format(name, args.scale, float(avg_psnr), float(avg_ssim)) 146 | print(test_log) 147 | si.end_background() -------------------------------------------------------------------------------- /datas/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | import numpy as np 7 | import imageio 8 | import torch 9 | import torch.utils.data as data 10 | import skimage.color as sc 11 | import time 12 | from utils import ndarray2tensor 13 | 14 | def crop_patch(lr, hr, patch_size, scale, augment=True): 15 | # crop patch randomly 16 | lr_h, lr_w, _ = lr.shape 17 | hp = patch_size 18 | lp = patch_size // scale 19 | lx, ly = random.randrange(0, lr_w - lp + 1), random.randrange(0, lr_h - lp + 1) 20 | hx, hy = lx * scale, ly * scale 21 | lr_patch, hr_patch = lr[ly:ly+lp, lx:lx+lp, :], hr[hy:hy+hp, hx:hx+hp, :] 22 | # augment data 23 | if augment: 24 | hflip = random.random() > 0.5 25 | vflip = random.random() > 0.5 26 | rot90 = random.random() > 0.5 27 | if hflip: lr_patch, hr_patch = lr_patch[:, ::-1, :], hr_patch[:, ::-1, :] 28 | if vflip: lr_patch, hr_patch = lr_patch[::-1, :, :], hr_patch[::-1, :, :] 29 | if rot90: lr_patch, hr_patch = lr_patch.transpose(1,0,2), hr_patch.transpose(1,0,2) 30 | # numpy to tensor 31 | lr_patch, hr_patch = ndarray2tensor(lr_patch), ndarray2tensor(hr_patch) 32 | return lr_patch, hr_patch 33 | 34 | class DIV2K(data.Dataset): 35 | def __init__( 36 | self, HR_folder, LR_folder, CACHE_folder, 37 | train=True, augment=True, scale=2, colors=1, 38 | patch_size=96, repeat=168 39 | ): 40 | super(DIV2K, self).__init__() 41 | self.HR_folder = HR_folder 42 | self.LR_folder = LR_folder 43 | self.augment = augment 44 | self.img_postfix = '.png' 45 | self.scale = scale 46 | self.colors = colors 47 | self.patch_size = patch_size 48 | self.repeat = repeat 49 | self.nums_trainset = 0 50 | self.train = train 51 | self.cache_dir = CACHE_folder 52 | 53 | ## for raw png images 54 | self.hr_filenames = [] 55 | self.lr_filenames = [] 56 | ## for numpy array data 57 | self.hr_npy_names = [] 58 | self.lr_npy_names = [] 59 | ## store in ram 60 | self.hr_images = [] 61 | self.lr_images = [] 62 | 63 | ## generate dataset 64 | if self.train: 65 | self.start_idx = 1 66 | self.end_idx = 801 67 | else: 68 | self.start_idx =801 69 | self.end_idx = 9000 70 | 71 | for i in range(self.start_idx, self.end_idx): 72 | idx = str(i).zfill(4) 73 | hr_filename = os.path.join(self.HR_folder, idx + self.img_postfix) 74 | lr_filename = os.path.join(self.LR_folder, 'X{}'.format(self.scale), idx + 'x{}'.format(self.scale) + self.img_postfix) 75 | self.hr_filenames.append(hr_filename) 76 | self.lr_filenames.append(lr_filename) 77 | self.nums_trainset = len(self.hr_filenames) 78 | 79 | LEN = self.end_idx - self.start_idx 80 | # hr_dir = os.path.join(self.cache_dir, 'df2k_hr', 'ycbcr' if self.colors==1 else 'rgb') 81 | # lr_dir = os.path.join(self.cache_dir, 'df2k_lr_x{}'.format(self.scale), 'ycbcr' if self.colors==1 else 'rgb') 82 | hr_dir = os.path.join(self.cache_dir, 'div2k_hr', 'ycbcr' if self.colors==1 else 'rgb') 83 | lr_dir = os.path.join(self.cache_dir, 'div2k_lr_x{}'.format(self.scale), 'ycbcr' if self.colors==1 else 'rgb') 84 | if not os.path.exists(hr_dir): 85 | os.makedirs(hr_dir) 86 | else: 87 | for i in range(LEN): 88 | hr_npy_name = self.hr_filenames[i].split('/')[-1].replace('.png', '.npy') 89 | hr_npy_name = os.path.join(hr_dir, hr_npy_name) 90 | self.hr_npy_names.append(hr_npy_name) 91 | 92 | if not os.path.exists(lr_dir): 93 | os.makedirs(lr_dir) 94 | else: 95 | for i in range(LEN): 96 | lr_npy_name = self.lr_filenames[i].split('/')[-1].replace('.png', '.npy') 97 | lr_npy_name = os.path.join(lr_dir, lr_npy_name) 98 | self.lr_npy_names.append(lr_npy_name) 99 | 100 | ## prepare hr images 101 | if len(glob.glob(os.path.join(hr_dir, "*.npy"))) != len(self.hr_filenames): 102 | for i in range(LEN): 103 | if (i+1) % 50 == 0: 104 | print("convert {} hr images to npy data!".format(i+1)) 105 | hr_image = imageio.imread(self.hr_filenames[i], pilmode="RGB") 106 | if self.colors == 1: 107 | hr_image = sc.rgb2ycbcr(hr_image)[:, :, 0:1] 108 | hr_npy_name = self.hr_filenames[i].split('/')[-1].replace('.png', '.npy') 109 | hr_npy_name = os.path.join(hr_dir, hr_npy_name) 110 | self.hr_npy_names.append(hr_npy_name) 111 | np.save(hr_npy_name, hr_image) 112 | else: 113 | print("hr npy datas have already been prepared!, hr: {}".format(len(self.hr_npy_names))) 114 | ## prepare lr images 115 | if len(glob.glob(os.path.join(lr_dir, "*.npy"))) != len(self.lr_filenames): 116 | for i in range(LEN): 117 | if (i+1) % 50 == 0: 118 | print("convert {} lr images to npy data!".format(i+1)) 119 | lr_image = imageio.imread(self.lr_filenames[i], pilmode="RGB") 120 | if self.colors == 1: 121 | lr_image = sc.rgb2ycbcr(lr_image)[:, :, 0:1] 122 | lr_npy_name = self.lr_filenames[i].split('/')[-1].replace('.png', '.npy') 123 | lr_npy_name = os.path.join(lr_dir, lr_npy_name) 124 | self.lr_npy_names.append(lr_npy_name) 125 | np.save(lr_npy_name, lr_image) 126 | else: 127 | print("lr npy datas have already been prepared!, lr: {}".format(len(self.lr_npy_names))) 128 | 129 | def __len__(self): 130 | if self.train: 131 | return self.nums_trainset * self.repeat 132 | else: 133 | return self.nums_trainset 134 | 135 | def __getitem__(self, idx): 136 | # get periodic index 137 | idx = idx % self.nums_trainset 138 | # get whole image 139 | hr, lr = np.load(self.hr_npy_names[idx]), np.load(self.lr_npy_names[idx]) 140 | if self.train: 141 | train_lr_patch, train_hr_patch = crop_patch(lr, hr, self.patch_size, self.scale, True) 142 | return train_lr_patch, train_hr_patch 143 | return lr, hr 144 | 145 | -------------------------------------------------------------------------------- /models/m_block.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | import os 7 | 8 | class MeanShift(nn.Conv2d): 9 | def __init__( 10 | self, rgb_range, 11 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 12 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 13 | std = torch.Tensor(rgb_std) 14 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 15 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 16 | for p in self.parameters(): 17 | p.requires_grad = False 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | r""" From ConvNeXt (https://arxiv.org/pdf/2201.03545.pdf) 22 | """ 23 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 24 | super().__init__() 25 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 26 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 27 | self.eps = eps 28 | self.data_format = data_format 29 | if self.data_format not in ["channels_last", "channels_first"]: 30 | raise NotImplementedError 31 | self.normalized_shape = (normalized_shape,) 32 | 33 | def forward(self, x): 34 | if self.data_format == "channels_last": 35 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 36 | elif self.data_format == "channels_first": 37 | u = x.mean(1, keepdim=True) 38 | s = (x - u).pow(2).mean(1, keepdim=True) 39 | x = (x - u) / torch.sqrt(s + self.eps) 40 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 41 | return x 42 | 43 | 44 | class FourierUnit(nn.Module): 45 | def __init__(self, dim, groups=1, fft_norm='ortho'): 46 | super().__init__() 47 | self.groups = groups 48 | self.fft_norm = fft_norm 49 | 50 | self.conv_layer = nn.Conv2d(in_channels=dim * 2, out_channels=dim * 2, kernel_size=1, stride=1, 51 | padding=0, groups=self.groups, bias=False) 52 | self.act = nn.GELU() 53 | 54 | def forward(self, x): 55 | batch, c, h, w = x.size() 56 | r_size = x.size() 57 | # (batch, c, h, w/2+1, 2) 58 | ffted = torch.rfft(x, signal_ndim=2, normalized=True) 59 | 60 | # (batch, c, 2, h, w/2+1) 61 | ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() 62 | ffted = ffted.view((batch, -1,) + ffted.size()[3:]) 63 | ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) 64 | ffted = self.act(ffted) 65 | 66 | # (batch,c, t, h, w/2+1, 2) 67 | ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous() 68 | output = torch.irfft(ffted, signal_ndim=2, signal_sizes=r_size[2:], normalized=True) 69 | return output 70 | 71 | 72 | class FConvMod(nn.Module): 73 | def __init__(self, dim, num_heads): 74 | super().__init__() 75 | layer_scale_init_value = 1e-6 76 | self.num_heads = num_heads 77 | self.norm = LayerNorm(dim, eps=1e-6, data_format="channels_first") 78 | self.a = FourierUnit(dim) 79 | self.v = nn.Conv2d(dim, dim, 1) 80 | self.act = nn.GELU() 81 | self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones(num_heads), requires_grad=True) 82 | self.CPE = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 83 | self.proj = nn.Conv2d(dim, dim, 1) 84 | 85 | def forward(self, x): 86 | B, C, H, W = x.shape 87 | N = H * W 88 | shortcut = x 89 | pos_embed = self.CPE(x) 90 | x = self.norm(x) 91 | a = self.a(x) 92 | v = self.v(x) 93 | a = rearrange(a, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 94 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 95 | a_all = torch.split(a, math.ceil(N // 4), dim=-1) 96 | v_all = torch.split(v, math.ceil(N // 4), dim=-1) 97 | attns = [] 98 | for a, v in zip(a_all, v_all): 99 | attn = a * v 100 | attn = self.layer_scale.unsqueeze(-1).unsqueeze(-1) * attn 101 | attns.append(attn) 102 | x = torch.cat(attns, dim=-1) 103 | x = F.softmax(x, dim=-1) 104 | x = rearrange(x, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=H, w=W) 105 | x = x + pos_embed 106 | x = self.proj(x) 107 | out = x + shortcut 108 | 109 | return out 110 | 111 | 112 | class KernelAggregation(nn.Module): 113 | def __init__(self, dim, kernel_size, groups, num_kernels, bias=True, init_weight=True): 114 | super().__init__() 115 | self.groups = groups 116 | self.bias = bias 117 | self.num_kernels = num_kernels 118 | self.kernel_size = kernel_size 119 | self.dim = dim 120 | self.weight = nn.Parameter(torch.randn(num_kernels, dim, dim // groups, kernel_size, kernel_size), 121 | requires_grad=True) 122 | if bias: 123 | self.bias = nn.Parameter(torch.zeros(num_kernels, dim)) 124 | else: 125 | self.bias = None 126 | 127 | if init_weight: 128 | self._initialize_weights() 129 | 130 | def _initialize_weights(self): 131 | for i in range(self.num_kernels): 132 | nn.init.kaiming_uniform_(self.weight[i]) 133 | 134 | def forward(self, x, attention): 135 | B, C, H, W = x.shape 136 | x = x.contiguous().view(1, B * self.dim, H, W) 137 | 138 | weight = self.weight.contiguous().view(self.num_kernels, -1) 139 | weight = torch.mm(attention, weight).contiguous().view(B * self.dim, self.dim // self.groups, 140 | self.kernel_size, self.kernel_size) 141 | if self.bias is not None: 142 | bias = torch.mm(attention, self.bias).contiguous().view(-1) 143 | x = F.conv2d(x, weight=weight, bias=bias, stride=1, padding=self.kernel_size // 2, 144 | groups=self.groups * B) 145 | else: 146 | x = F.conv2d(x, weight=weight, bias=None, stride=1, padding=self.kernel_size // 2, 147 | groups=self.groups * B) 148 | x = x.contiguous().view(B, self.dim, x.shape[-2], x.shape[-1]) 149 | 150 | return x 151 | 152 | 153 | class KernelAttention(nn.Module): 154 | def __init__(self, dim, reduction=8, num_kernels=8): 155 | super().__init__() 156 | if dim != 3: 157 | mid_channels = dim // reduction 158 | else: 159 | mid_channels = num_kernels 160 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 161 | self.conv1 = nn.Conv2d(dim, mid_channels, 1) 162 | self.act = nn.GELU() 163 | self.conv2 = nn.Conv2d(mid_channels, num_kernels, 1) 164 | self.sigmoid = nn.Sigmoid() 165 | 166 | def forward(self, x): 167 | x = self.avg_pool(x) 168 | x = self.conv1(x) 169 | x = self.act(x) 170 | x = self.conv2(x) 171 | x = x.view(x.shape[0], -1) 172 | x = self.sigmoid(x) 173 | return x 174 | 175 | 176 | class DynamicKernelAggregation(nn.Module): 177 | def __init__(self, dim, kernel_size, groups=1, num_kernels=4): 178 | super().__init__() 179 | assert dim % groups == 0 180 | self.attention = KernelAttention(dim, num_kernels=num_kernels) 181 | self.aggregation = KernelAggregation(dim, kernel_size=kernel_size, groups=groups, num_kernels=num_kernels) 182 | 183 | def forward(self, x): 184 | attention = x 185 | attention = self.attention(attention) 186 | x = self.aggregation(x, attention) 187 | return x 188 | 189 | 190 | class DyConv(nn.Module): 191 | def __init__(self, dim, kernel_size, groups, num_kernels=1): 192 | super().__init__() 193 | if num_kernels > 1: 194 | self.conv = DynamicKernelAggregation(dim, kernel_size=kernel_size, groups=groups, 195 | num_kernels=num_kernels) 196 | else: 197 | self.conv = nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=groups) 198 | 199 | def forward(self, x): 200 | x = self.conv(x) 201 | return x 202 | 203 | 204 | class MixFFN(nn.Module): 205 | def __init__(self, dim, num_kernels): 206 | super().__init__() 207 | self.proj_in = nn.Conv2d(dim, dim * 2, 1) 208 | self.conv1 = DyConv(dim, kernel_size=5, groups=dim, num_kernels=num_kernels) 209 | self.conv2 = DyConv(dim, kernel_size=7, groups=dim, num_kernels=num_kernels) 210 | self.proj_out = nn.Conv2d(dim * 2, dim, 1) 211 | self.norm = LayerNorm(dim, eps=1e-6, data_format="channels_first") 212 | self.act = nn.GELU() 213 | 214 | def forward(self, x): 215 | shortcut = x 216 | x = self.norm(x) 217 | x = self.act(self.proj_in(x)) 218 | x1, x2 = torch.chunk(x, 2, dim=1) 219 | x1 = self.act(self.conv1(x1)).unsqueeze(dim=2) 220 | x2 = self.act(self.conv2(x2)).unsqueeze(dim=2) 221 | x = torch.cat([x1, x2], dim=2) 222 | x = rearrange(x, 'b c g h w -> b (c g) h w') 223 | x = self.proj_out(x) 224 | x = x + shortcut 225 | return x 226 | 227 | 228 | class BasicBlock(nn.Module): 229 | def __init__(self, dim, num_heads, num_kernels): 230 | super().__init__() 231 | self.attention = FConvMod(dim, num_heads) 232 | self.ffn = MixFFN(dim, num_kernels) 233 | 234 | def forward(self, x): 235 | x = self.attention(x) 236 | x = self.ffn(x) 237 | 238 | return x 239 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse, yaml 3 | import utils 4 | import os 5 | from tqdm import tqdm 6 | import logging 7 | import sys 8 | import time 9 | import importlib 10 | import glob 11 | from custom import optimizers as optim 12 | from custom.caltime import RemainTime 13 | from custom.serverLog import LogClass 14 | 15 | parser = argparse.ArgumentParser(description='config') 16 | ## yaml configuration files 17 | parser.add_argument('--config', type=str, default=None, help='pre-config file for training') 18 | parser.add_argument('--resume', type=str, default=None, help='resume training or not') 19 | parser.add_argument('--custom', type=str, default=None, help='use custom block') 20 | parser.add_argument('--cloudlog', type=str, default=None, help='use cloud log') 21 | 22 | 23 | def save_model(_path, _epoch, _model, _optimizer, _scheduler, _stat_dict): 24 | # torch.save(model.state_dict(), saved_model_path) 25 | torch.save({ 26 | 'epoch': _epoch, 27 | 'model_state_dict': _model.state_dict(), 28 | 'optimizer_state_dict': _optimizer.state_dict(), 29 | 'scheduler_state_dict': _scheduler.state_dict(), 30 | 'stat_dict': _stat_dict 31 | }, _path) 32 | 33 | 34 | if __name__ == '__main__': 35 | args = parser.parse_args() 36 | if args.config: 37 | opt = vars(args) 38 | yaml_args = yaml.load(open(args.config), Loader=yaml.FullLoader) 39 | opt.update(yaml_args) 40 | ## set visibel gpu 41 | gpu_ids_str = str(args.gpu_ids).replace('[', '').replace(']', '') 42 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 43 | os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(gpu_ids_str) 44 | import torch 45 | import torch.nn as nn 46 | import torch.nn.functional as F 47 | from torch.optim.lr_scheduler import MultiStepLR, StepLR 48 | from datas.utils import create_datasets 49 | 50 | ## select active gpu devices 51 | device = None 52 | if args.gpu_ids is not None and torch.cuda.is_available(): 53 | print('use cuda & cudnn for acceleration!') 54 | print('the gpu id is: {}'.format(args.gpu_ids)) 55 | device = torch.device('cuda') 56 | torch.backends.cudnn.benchmark = True 57 | else: 58 | print('use cpu for training!') 59 | device = torch.device('cpu') 60 | 61 | ## create dataset for training and validating 62 | train_dataloader, valid_dataloaders = create_datasets(args) 63 | 64 | ## definitions of model 65 | try: 66 | model = utils.import_module('models.{}_network'.format(args.model)).create_model(args) 67 | except Exception: 68 | raise ValueError('not supported model type! or something') 69 | if args.fp == 16: 70 | model.half() 71 | 72 | ## load pretrain 73 | if args.pretrain is not None: 74 | print('load pretrained model: {}!'.format(args.pretrain)) 75 | ckpt = torch.load(args.pretrain) 76 | model.load(ckpt['model_state_dict']) 77 | 78 | model = nn.DataParallel(model).to(device) 79 | 80 | ## definition of loss and optimizer 81 | loss_func = eval('nn.' + args.loss + '()') 82 | if args.fp == 16: 83 | eps = 1e-3 84 | else: 85 | eps = 1e-8 86 | if args.optim == 'adam': 87 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=eps) 88 | elif args.optim == 'lamb': 89 | optimizer = optim.Lamb(model.parameters(), lr=args.lr, eps=eps) 90 | scheduler = MultiStepLR(optimizer, milestones=args.decays, gamma=args.gamma) 91 | 92 | ## resume training 93 | start_epoch = 1 94 | if args.resume is not None: 95 | ckpt_files = os.path.join(args.resume, 'models', "model_x{}_latest.pt".format(args.scale)) 96 | if len(ckpt_files) != 0: 97 | ckpt = torch.load(ckpt_files) 98 | prev_epoch = ckpt['epoch'] 99 | 100 | start_epoch = prev_epoch + 1 101 | model.load_state_dict(ckpt['model_state_dict']) 102 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 103 | scheduler.load_state_dict(ckpt['scheduler_state_dict']) 104 | stat_dict = ckpt['stat_dict'] 105 | ## reset folder and param 106 | experiment_path = args.resume 107 | log_name = os.path.join(experiment_path, 'log.txt') 108 | experiment_model_path = os.path.join(experiment_path, 'models') 109 | print('select {}, resume training from epoch {}.'.format(ckpt_files, start_epoch)) 110 | else: 111 | ## auto-generate the output logname 112 | experiment_name = None 113 | timestamp = utils.cur_timestamp_str() 114 | if args.log_name is None: 115 | experiment_name = '{}-x{}-{}'.format(args.model, args.scale, timestamp) 116 | else: 117 | experiment_name = '{}-{}'.format(args.log_name, timestamp) 118 | experiment_path = os.path.join(args.log_path, experiment_name) 119 | log_name = os.path.join(experiment_path, 'log.txt') 120 | stat_dict = utils.get_stat_dict() 121 | ## create folder for ckpt and stat 122 | if not os.path.exists(experiment_path): 123 | os.makedirs(experiment_path) 124 | experiment_model_path = os.path.join(experiment_path, 'models') 125 | if not os.path.exists(experiment_model_path): 126 | os.makedirs(experiment_model_path) 127 | ## save training paramters 128 | exp_params = vars(args) 129 | exp_params_name = os.path.join(experiment_path, 'config.yml') 130 | with open(exp_params_name, 'w') as exp_params_file: 131 | yaml.dump(exp_params, exp_params_file, default_flow_style=False) 132 | 133 | ## print architecture of model 134 | time.sleep(3) # sleep 3 seconds 135 | sys.stdout = utils.ExperimentLogger(log_name, sys.stdout) 136 | # print(model) 137 | # num_params = 0 138 | # for param in model.parameters(): 139 | # num_params += param.numel() 140 | # print('Total number of parameters:' + str(num_params // 1024) + 'k') 141 | # sys.stdout.flush() 142 | 143 | ## start training 144 | timer_start = time.time() 145 | rt = RemainTime(args.epochs) 146 | cloudLogName = experiment_path.split(os.sep)[-1] 147 | log = LogClass(args.cloudlog == 'on') 148 | log.sendLog('start trainning', cloudLogName) 149 | for epoch in range(start_epoch, args.epochs + 1): 150 | epoch_loss = 0.0 151 | stat_dict['epochs'] = epoch 152 | model = model.train() 153 | opt_lr = scheduler.get_last_lr() 154 | print('##===========-fp{}-training, Epoch: {}, lr: {} =============##'.format(args.fp, epoch, opt_lr)) 155 | for iter, batch in enumerate(train_dataloader): 156 | optimizer.zero_grad() 157 | lr, hr = batch 158 | if args.fp == 16: 159 | lr, hr = lr.type(torch.HalfTensor), hr.type(torch.HalfTensor) 160 | lr, hr = lr.to(device), hr.to(device) 161 | sr = model(lr) 162 | loss = loss_func(sr, hr) 163 | loss.backward() 164 | optimizer.step() 165 | epoch_loss += float(loss) 166 | 167 | if (iter + 1) % args.log_every == 0: 168 | cur_steps = (iter + 1) * args.batch_size 169 | total_steps = len(train_dataloader.dataset) 170 | fill_width = math.ceil(math.log10(total_steps)) 171 | cur_steps = str(cur_steps).zfill(fill_width) 172 | 173 | epoch_width = math.ceil(math.log10(args.epochs)) 174 | cur_epoch = str(epoch).zfill(epoch_width) 175 | 176 | avg_loss = epoch_loss / (iter + 1) 177 | stat_dict['losses'].append(avg_loss) 178 | 179 | timer_end = time.time() 180 | duration = timer_end - timer_start 181 | timer_start = timer_end 182 | print('Epoch:{}, {}/{}, loss: {:.4f}, time: {:.3f}'.format(cur_epoch, cur_steps, total_steps, avg_loss, 183 | duration)) 184 | 185 | if epoch % args.test_every == 0: 186 | torch.set_grad_enabled(False) 187 | test_log = '' 188 | model = model.eval() 189 | for valid_dataloader in valid_dataloaders: 190 | avg_psnr, avg_ssim = 0.0, 0.0 191 | name = valid_dataloader['name'] 192 | loader = valid_dataloader['dataloader'] 193 | for lr, hr in tqdm(loader, ncols=80): 194 | if args.fp == 16: 195 | lr, hr = lr.type(torch.HalfTensor), hr.type(torch.HalfTensor) 196 | lr, hr = lr.to(device), hr.to(device) 197 | sr = model(lr) 198 | # quantize output to [0, 255] 199 | hr = hr.clamp(0, 255) 200 | sr = sr.clamp(0, 255) 201 | # conver to ycbcr 202 | if args.colors == 3: 203 | hr_ycbcr = utils.rgb_to_ycbcr(hr) 204 | sr_ycbcr = utils.rgb_to_ycbcr(sr) 205 | hr = hr_ycbcr[:, 0:1, :, :] 206 | sr = sr_ycbcr[:, 0:1, :, :] 207 | # crop image for evaluation 208 | hr = hr[:, :, args.scale:-args.scale, args.scale:-args.scale] 209 | sr = sr[:, :, args.scale:-args.scale, args.scale:-args.scale] 210 | # calculate psnr and ssim 211 | psnr = utils.calc_psnr(sr, hr) 212 | ssim = utils.calc_ssim(sr, hr) 213 | avg_psnr += psnr 214 | avg_ssim += ssim 215 | avg_psnr = round(avg_psnr / len(loader) + 5e-3, 2) 216 | avg_ssim = round(avg_ssim / len(loader) + 5e-5, 4) 217 | stat_dict[name]['psnrs'].append(avg_psnr) 218 | stat_dict[name]['ssims'].append(avg_ssim) 219 | save_model_flag = False 220 | if stat_dict[name]['best_psnr']['value'] < avg_psnr: 221 | stat_dict[name]['best_psnr']['value'] = avg_psnr 222 | stat_dict[name]['best_psnr']['epoch'] = epoch 223 | save_model_flag = True 224 | if name == 'set5': 225 | log.sendLog('PSNR:{} epoch:{}/{}'.format(float(avg_psnr), epoch, args.epochs), cloudLogName) 226 | if stat_dict[name]['best_ssim']['value'] < avg_ssim: 227 | stat_dict[name]['best_ssim']['value'] = avg_ssim 228 | stat_dict[name]['best_ssim']['epoch'] = epoch 229 | save_model_flag = True 230 | if save_model_flag: 231 | # sava best model 232 | save_model(os.path.join(experiment_model_path, 'model_x{}_{}.pt'.format(args.scale, epoch)), epoch, 233 | model, optimizer, scheduler, stat_dict) 234 | test_log += '[{}-X{}], PSNR/SSIM: {:.2f}/{:.4f} (Best: {:.2f}/{:.4f}, Epoch: {}/{})\n'.format( 235 | name, args.scale, float(avg_psnr), float(avg_ssim), 236 | stat_dict[name]['best_psnr']['value'], stat_dict[name]['best_ssim']['value'], 237 | stat_dict[name]['best_psnr']['epoch'], stat_dict[name]['best_ssim']['epoch']) 238 | # print log & flush out 239 | print(test_log[:-1]) 240 | sys.stdout.flush() 241 | save_model(os.path.join(experiment_model_path, 'model_x{}_latest.pt'.format(args.scale)), epoch, model, 242 | optimizer, scheduler, stat_dict) 243 | torch.set_grad_enabled(True) 244 | # save stat dict 245 | # save training paramters 246 | stat_dict_name = os.path.join(experiment_path, 'stat_dict.yml') 247 | with open(stat_dict_name, 'w') as stat_dict_file: 248 | yaml.dump(stat_dict, stat_dict_file, default_flow_style=False) 249 | ## update scheduler 250 | scheduler.step() 251 | rt.update(epoch) 252 | print() 253 | log.sendLog('finish trainning', cloudLogName) 254 | -------------------------------------------------------------------------------- /custom/optimizers.py: -------------------------------------------------------------------------------- 1 | # Additional optimizers that have not been incorporated into 2 | # official PyTorch relrease (Oct 11, 2021). Adapted from: 3 | # https://github.com/jettify/pytorch-optimizer 4 | from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union 5 | 6 | import math 7 | import torch 8 | from torch import Tensor 9 | from torch.optim.optimizer import Optimizer 10 | 11 | Params = Union[Iterable[Tensor], Iterable[Dict[str, Any]]] 12 | LossClosure = Callable[[], float] 13 | OptLossClosure = Optional[LossClosure] 14 | Betas2 = Tuple[float, float] 15 | State = Dict[str, Any] 16 | OptFloat = Optional[float] 17 | Nus2 = Tuple[float, float] 18 | 19 | __all__ = ('LARS', 'Lamb') 20 | 21 | 22 | class LARS(Optimizer): 23 | r"""Extends SGD in PyTorch with LARS scaling from the paper 24 | `Large batch training of Convolutional Networks`__. 25 | 26 | Arguments: 27 | params (iterable): iterable of parameters to optimize or dicts defining 28 | parameter groups 29 | lr: learning rate (default: 1e-3) 30 | momentum: momentum factor (default: 0) 31 | dampening: dampening for momentum (default: 0) 32 | eps: term added to the denominator to improve 33 | numerical stability (default: 1e-8) 34 | weight_decay: weight decay (L2 penalty) (default: 0) 35 | nesterov: enables Nesterov momentum (default: False) 36 | trust_coefficient: trust coefficient for computing LR (default: 0.001) 37 | eps: eps for division denominator (default: 1e-8) 38 | 39 | Example: 40 | # >>> import torch_optimizer as optim 41 | # >>> optimizer = optim.LARS(model.parameters(), lr=0.001) 42 | # >>> optimizer.zero_grad() 43 | # >>> loss_fn(model(input), target).backward() 44 | # >>> optimizer.step() 45 | 46 | .. note:: 47 | The application of momentum in the SGD part is modified according to 48 | the PyTorch standards. LARS scaling fits into the equation in the 49 | following fashion. 50 | 51 | .. math:: 52 | \begin{aligned} 53 | g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\ 54 | v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\ 55 | p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 56 | \\end{aligned} 57 | 58 | where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` 59 | denote the parameters, gradient, velocity, momentum, and weight decay 60 | respectively. The :math:`lars_lr` is defined by Eq. 6 in the paper. 61 | The Nesterov version is analogously modified. 62 | 63 | .. warning:: 64 | Parameters with weight decay set to 0 will automatically be excluded 65 | from layer-wise LR scaling. This is to ensure consistency with papers 66 | like SimCLR and BYOL. 67 | 68 | 69 | __ https://arxiv.org/pdf/1708.03888.pdf 70 | 71 | Note: 72 | Reference code: https://github.com/PyTorchLightning/lightning-bolts/ 73 | """ 74 | 75 | def __init__( 76 | self, 77 | params: Params, 78 | lr: float = 1e-2, 79 | momentum: float = 0.0, 80 | dampening: float = 0.0, 81 | weight_decay: float = 0.0, 82 | nesterov: bool = False, 83 | trust_coefficient: float = 0.01, 84 | eps: float = 1e-8, 85 | ): 86 | if lr <= 0.0: 87 | raise ValueError('Invalid learning rate: {}'.format(lr)) 88 | if eps < 0.0: 89 | raise ValueError('Invalid epsilon value: {}'.format(eps)) 90 | if momentum < 0.0: 91 | raise ValueError('Invalid momentum value: {}'.format(momentum)) 92 | if dampening < 0.0: 93 | raise ValueError('Invalid dampening value: {}'.format(dampening)) 94 | if weight_decay < 0.0: 95 | raise ValueError( 96 | 'Invalid weight_decay value: {}'.format(weight_decay) 97 | ) 98 | if trust_coefficient < 0.0: 99 | raise ValueError( 100 | 'Invalid trust_coefficient value: {}'.format(trust_coefficient) 101 | ) 102 | 103 | defaults = dict( 104 | lr=lr, 105 | momentum=momentum, 106 | dampening=dampening, 107 | weight_decay=weight_decay, 108 | nesterov=nesterov, 109 | trust_coefficient=trust_coefficient, 110 | eps=eps, 111 | ) 112 | if nesterov and (momentum <= 0 or dampening != 0): 113 | raise ValueError( 114 | 'Nesterov momentum requires a momentum and zero dampening' 115 | ) 116 | 117 | super().__init__(params, defaults) 118 | 119 | def __setstate__(self, state: State) -> None: 120 | super().__setstate__(state) 121 | 122 | for group in self.param_groups: 123 | group.setdefault('nesterov', False) 124 | 125 | @torch.no_grad() 126 | def step(self, closure: OptLossClosure = None) -> OptFloat: 127 | r"""Performs a single optimization step. 128 | 129 | Arguments: 130 | closure: A closure that reevaluates the model and returns the loss. 131 | """ 132 | loss = None 133 | if closure is not None: 134 | with torch.enable_grad(): 135 | loss = closure() 136 | 137 | # exclude scaling for params with 0 weight decay 138 | for group in self.param_groups: 139 | weight_decay = group['weight_decay'] 140 | momentum = group['momentum'] 141 | dampening = group['dampening'] 142 | nesterov = group['nesterov'] 143 | 144 | for p in group['params']: 145 | if p.grad is None: 146 | continue 147 | 148 | d_p = p.grad 149 | p_norm = torch.norm(p.data) 150 | g_norm = torch.norm(p.grad.data) 151 | 152 | # lars scaling + weight decay part 153 | if weight_decay != 0: 154 | if p_norm != 0 and g_norm != 0: 155 | lars_lr = p_norm / ( 156 | g_norm + p_norm * weight_decay + group['eps'] 157 | ) 158 | lars_lr *= group['trust_coefficient'] 159 | 160 | d_p = d_p.add(p, alpha=weight_decay) 161 | d_p *= lars_lr 162 | 163 | if momentum != 0: 164 | param_state = self.state[p] 165 | if 'momentum_buffer' not in param_state: 166 | buf = param_state['momentum_buffer'] = torch.clone( 167 | d_p 168 | ).detach() 169 | else: 170 | buf = param_state['momentum_buffer'] 171 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 172 | if nesterov: 173 | d_p = d_p.add(buf, alpha=momentum) 174 | else: 175 | d_p = buf 176 | 177 | p.add_(d_p, alpha=-group['lr']) 178 | 179 | return loss 180 | 181 | 182 | class Lamb(Optimizer): 183 | r"""Implements Lamb algorithm. 184 | 185 | It has been proposed in `Large Batch Optimization for Deep Learning: 186 | Training BERT in 76 minutes`__. 187 | 188 | Arguments: 189 | params: iterable of parameters to optimize or dicts defining 190 | parameter groups 191 | lr: learning rate (default: 1e-3) 192 | betas: coefficients used for computing 193 | running averages of gradient and its square (default: (0.9, 0.999)) 194 | eps: term added to the denominator to improve 195 | numerical stability (default: 1e-8) 196 | weight_decay: weight decay (L2 penalty) (default: 0) 197 | clamp_value: clamp weight_norm in (0,clamp_value) (default: 10) 198 | set to a high value to avoid it (e.g 10e3) 199 | adam: always use trust ratio = 1, which turns this 200 | into Adam. Useful for comparison purposes. (default: False) 201 | debias: debias adam by (1 - beta**step) (default: False) 202 | 203 | Example: 204 | # >>> import torch_optimizer as optim 205 | # >>> optimizer = optim.Lamb(model.parameters(), lr=0.1) 206 | # >>> optimizer.zero_grad() 207 | # >>> loss_fn(model(input), target).backward() 208 | # >>> optimizer.step() 209 | 210 | __ https://arxiv.org/abs/1904.00962 211 | 212 | Note: 213 | Reference code: https://github.com/cybertronai/pytorch-lamb 214 | """ 215 | 216 | def __init__( 217 | self, 218 | params: Params, 219 | lr: float = 1e-6, 220 | betas: Betas2 = (0.9, 0.999), 221 | eps: float = 1e-6, 222 | weight_decay: float = 0, 223 | clamp_value: float = 10, 224 | adam: bool = False, 225 | debias: bool = False, 226 | ) -> None: 227 | if lr <= 0.0: 228 | raise ValueError('Invalid learning rate: {}'.format(lr)) 229 | if eps < 0.0: 230 | raise ValueError('Invalid epsilon value: {}'.format(eps)) 231 | if not 0.0 <= betas[0] < 1.0: 232 | raise ValueError( 233 | 'Invalid beta parameter at index 0: {}'.format(betas[0]) 234 | ) 235 | if not 0.0 <= betas[1] < 1.0: 236 | raise ValueError( 237 | 'Invalid beta parameter at index 1: {}'.format(betas[1]) 238 | ) 239 | if weight_decay < 0: 240 | raise ValueError( 241 | 'Invalid weight_decay value: {}'.format(weight_decay) 242 | ) 243 | if clamp_value < 0.0: 244 | raise ValueError('Invalid clamp value: {}'.format(clamp_value)) 245 | 246 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 247 | self.clamp_value = clamp_value 248 | self.adam = adam 249 | self.debias = debias 250 | 251 | super(Lamb, self).__init__(params, defaults) 252 | 253 | @torch.no_grad() 254 | def step(self, closure: OptLossClosure = None) -> OptFloat: 255 | r"""Performs a single optimization step. 256 | 257 | Arguments: 258 | closure: A closure that reevaluates the model and returns the loss. 259 | """ 260 | loss = None 261 | if closure is not None: 262 | with torch.enable_grad(): 263 | loss = closure() 264 | 265 | for group in self.param_groups: 266 | for p in group['params']: 267 | if p.grad is None: 268 | continue 269 | grad = p.grad.data 270 | if grad.is_sparse: 271 | msg = ( 272 | 'Lamb does not support sparse gradients, ' 273 | 'please consider SparseAdam instead' 274 | ) 275 | raise RuntimeError(msg) 276 | 277 | state = self.state[p] 278 | 279 | # State initialization 280 | if len(state) == 0: 281 | state['step'] = 0 282 | # Exponential moving average of gradient values 283 | state['exp_avg'] = torch.zeros_like( 284 | p, memory_format=torch.preserve_format 285 | ) 286 | # Exponential moving average of squared gradient values 287 | state['exp_avg_sq'] = torch.zeros_like( 288 | p, memory_format=torch.preserve_format 289 | ) 290 | 291 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 292 | beta1, beta2 = group['betas'] 293 | 294 | state['step'] += 1 295 | 296 | # Decay the first and second moment running average coefficient 297 | # m_t 298 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 299 | # v_t 300 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 301 | 302 | # Paper v3 does not use debiasing. 303 | if self.debias: 304 | bias_correction = math.sqrt(1 - beta2 ** state['step']) 305 | bias_correction /= 1 - beta1 ** state['step'] 306 | else: 307 | bias_correction = 1 308 | 309 | # Apply bias to lr to avoid broadcast. 310 | step_size = group['lr'] * bias_correction 311 | 312 | weight_norm = torch.norm(p.data).clamp(0, self.clamp_value) 313 | 314 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 315 | if group['weight_decay'] != 0: 316 | adam_step.add_(p.data, alpha=group['weight_decay']) 317 | 318 | adam_norm = torch.norm(adam_step) 319 | if weight_norm == 0 or adam_norm == 0: 320 | trust_ratio = 1 321 | else: 322 | trust_ratio = weight_norm / adam_norm 323 | state['weight_norm'] = weight_norm 324 | state['adam_norm'] = adam_norm 325 | state['trust_ratio'] = trust_ratio 326 | if self.adam: 327 | trust_ratio = 1 328 | 329 | p.data.add_(adam_step, alpha=-step_size * trust_ratio) 330 | 331 | return loss 332 | 333 | 334 | class Lamb16(Optimizer): 335 | r"""Implements Lamb algorithm for FP16 training. 336 | 337 | It has been proposed in `Large Batch Optimization for Deep Learning: 338 | Training BERT in 76 minutes`__. 339 | 340 | Arguments: 341 | params: iterable of parameters to optimize or dicts defining 342 | parameter groups 343 | lr: learning rate (default: 1e-3) 344 | betas: coefficients used for computing 345 | running averages of gradient and its square (default: (0.9, 0.999)) 346 | eps: term added to the denominator to improve 347 | numerical stability (default: 1e-8) 348 | weight_decay: weight decay (L2 penalty) (default: 0) 349 | clamp_value: clamp weight_norm in (0,clamp_value) (default: 10) 350 | set to a high value to avoid it (e.g 10e3) 351 | adam: always use trust ratio = 1, which turns this 352 | into Adam. Useful for comparison purposes. (default: False) 353 | debias: debias adam by (1 - beta**step) (default: False) 354 | 355 | Example: 356 | >>> import torch_optimizer as optim 357 | >>> optimizer = optim.Lamb(model.parameters(), lr=0.1) 358 | >>> optimizer.zero_grad() 359 | >>> loss_fn(model(input), target).backward() 360 | >>> optimizer.step() 361 | 362 | __ https://arxiv.org/abs/1904.00962 363 | 364 | Note: 365 | Reference code: https://github.com/cybertronai/pytorch-lamb 366 | """ 367 | 368 | def __init__( 369 | self, 370 | params: Params, 371 | lr: float = 1e-3, 372 | betas: Betas2 = (0.9, 0.999), 373 | eps: float = 1e-4, # for stability 374 | weight_decay: float = 0, 375 | clamp_value: float = 10, 376 | clamp_trust_ratio: float = 1e6, 377 | adam: bool = False, 378 | debias: bool = False, 379 | ) -> None: 380 | if lr <= 0.0: 381 | raise ValueError('Invalid learning rate: {}'.format(lr)) 382 | if eps < 0.0: 383 | raise ValueError('Invalid epsilon value: {}'.format(eps)) 384 | if not 0.0 <= betas[0] < 1.0: 385 | raise ValueError( 386 | 'Invalid beta parameter at index 0: {}'.format(betas[0]) 387 | ) 388 | if not 0.0 <= betas[1] < 1.0: 389 | raise ValueError( 390 | 'Invalid beta parameter at index 1: {}'.format(betas[1]) 391 | ) 392 | if weight_decay < 0: 393 | raise ValueError( 394 | 'Invalid weight_decay value: {}'.format(weight_decay) 395 | ) 396 | if clamp_value < 0.0: 397 | raise ValueError('Invalid clamp value: {}'.format(clamp_value)) 398 | 399 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 400 | self.clamp_value = clamp_value 401 | self.clamp_trust_ratio = clamp_trust_ratio 402 | self.adam, self.debias = adam, debias 403 | super(Lamb16, self).__init__(params, defaults) 404 | 405 | # This version of Lamb keeps an fp32 copy of the parameters and 406 | # does all of the parameter updates in fp32, while still doing the 407 | # forwards and backwards passes using fp16 (i.e. fp16 copies of the 408 | # parameters and fp16 activations). 409 | # 410 | # Note that this calls .float().cuda() on the params such that it 411 | # moves them to gpu 0--if you're using a different GPU or want to 412 | # do multi-GPU you may need to deal with this. 413 | self.fp32_param_groups = [] 414 | for group in params: 415 | self.fp32_param_groups.append( 416 | {'params': [group['params'][0].data.float().cuda()]}) 417 | 418 | @torch.no_grad() 419 | def step(self, closure: OptLossClosure = None) -> OptFloat: 420 | r"""Performs a single optimization step. 421 | 422 | Arguments: 423 | closure: A closure that reevaluates the model and returns the loss. 424 | """ 425 | loss = None 426 | if closure is not None: 427 | with torch.enable_grad(): 428 | loss = closure() 429 | 430 | for group, fp32_group in zip(self.param_groups, self.fp32_param_groups): 431 | for p, fp32_p in zip(group['params'], fp32_group['params']): 432 | if p.grad is None: 433 | continue 434 | 435 | grad = p.grad.data.float() # gradient in FP32 436 | if grad.is_sparse: 437 | msg = ( 438 | 'Lamb16 does not support sparse gradients, ' 439 | 'please consider SparseAdam instead' 440 | ) 441 | raise RuntimeError(msg) 442 | state = self.state[p] 443 | 444 | # State initialization 445 | if len(state) == 0: 446 | state['step'] = 0 447 | # Exponential moving average of gradient values 448 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 449 | # Exponential moving average of squared gradient values 450 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 451 | 452 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 453 | beta1, beta2 = group['betas'] 454 | 455 | state['step'] += 1 456 | 457 | # Decay the first and second moment running average coefficient 458 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 459 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 460 | 461 | # Paper v3 does not use debiasing. 462 | if self.debias: 463 | bias_correction = math.sqrt(1 - beta2 ** state['step']) 464 | bias_correction /= 1 - beta1 ** state['step'] 465 | else: 466 | bias_correction = 1 467 | 468 | # Apply bias to lr to avoid broadcast. 469 | step_size = group['lr'] * bias_correction 470 | 471 | weight_norm = torch.norm(fp32_p.data).clamp(0, self.clamp_value) 472 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 473 | if group['weight_decay'] != 0: 474 | adam_step.add_(fp32_p.data, alpha=group['weight_decay']) 475 | 476 | adam_norm = torch.norm(adam_step) 477 | if weight_norm == 0 or adam_norm == 0: 478 | trust_ratio = 1 479 | else: 480 | trust_ratio = (weight_norm / adam_norm).clamp(0, self.clamp_trust_ratio) 481 | state['weight_norm'] = weight_norm 482 | state['adam_norm'] = adam_norm 483 | state['trust_ratio'] = trust_ratio 484 | if self.adam: 485 | trust_ratio = 1 486 | 487 | fp32_p.add_(adam_step, alpha=-step_size * trust_ratio) 488 | p.data = fp32_p.half() 489 | 490 | return loss 491 | --------------------------------------------------------------------------------