├── configs ├── __init__.py ├── my_config.py ├── base_config.py └── parser.py ├── core ├── __init__.py ├── loss.py ├── sr_trainer.py └── base_trainer.py ├── utils ├── __init__.py ├── optimizer.py ├── metrics.py ├── scheduler.py ├── parallel.py ├── model_ema.py └── utils.py ├── main.py ├── datasets ├── sr_dataset.py ├── val_datasets.py ├── test_dataset.py ├── __init__.py └── sr_base_dataset.py ├── models ├── __init__.py ├── espcn.py ├── srcnn.py ├── vdsr.py ├── fsrcnn.py ├── edsr.py ├── drcn.py ├── drrn.py ├── idn.py ├── lapsrn.py ├── srdensenet.py ├── carn.py └── modules.py ├── tools └── get_model_infos.py └── README.md /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .my_config import MyConfig 2 | from .parser import load_parser -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | from .sr_trainer import SRTrainer 3 | from .loss import get_loss_fn -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .parallel import * 3 | from .optimizer import get_optimizer 4 | from .scheduler import get_scheduler 5 | from .metrics import get_sr_metrics 6 | from .model_ema import get_ema_model -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from core import SRTrainer 2 | from configs import MyConfig, load_parser 3 | 4 | import warnings 5 | warnings.filterwarnings("ignore") 6 | 7 | 8 | if __name__ == '__main__': 9 | config = MyConfig() 10 | 11 | config.init_dependent_config() 12 | 13 | # If you want to use command-line arguments, please uncomment the following line 14 | # config = load_parser(config) 15 | 16 | trainer = SRTrainer(config) 17 | 18 | if config.is_testing: 19 | trainer.predict(config) 20 | elif config.benchmark: 21 | trainer.benchmark(config) 22 | else: 23 | trainer.run(config) 24 | -------------------------------------------------------------------------------- /datasets/sr_dataset.py: -------------------------------------------------------------------------------- 1 | from .sr_base_dataset import SRBaseDataset 2 | 3 | 4 | class SRDataset(SRBaseDataset): 5 | def __init__(self, config, mode): 6 | data_split = { 7 | 'train': ['train/BSDS200', 8 | 'train/General100', 9 | 'train/T91',], 10 | 11 | 'val': [f'val/BSD100/image_SRF_{config.upscale}', 12 | f'val/Set5/image_SRF_{config.upscale}', 13 | f'val/Set14/image_SRF_{config.upscale}',], 14 | } 15 | 16 | super(SRDataset, self).__init__(config, data_split, mode) 17 | -------------------------------------------------------------------------------- /configs/my_config.py: -------------------------------------------------------------------------------- 1 | from .base_config import BaseConfig 2 | 3 | 4 | class MyConfig(BaseConfig): 5 | def __init__(self,): 6 | super(MyConfig, self).__init__() 7 | # Dataset 8 | self.dataset = 'sr' 9 | self.data_root = '/path/to/your/dataset' 10 | self.upscale = 2 11 | self.train_y = True 12 | 13 | # Model 14 | self.model = 'srcnn' 15 | 16 | # Training 17 | self.total_epoch = 6400 18 | self.lr_policy = 'constant' # or step 19 | self.logger_name = 'sr_trainer' 20 | 21 | # Augmentation 22 | self.patch_size = 48 23 | self.rotate = 0.5 24 | self.multi_scale = True 25 | -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | def __init__(self, eps=0.01): 7 | super(CharbonnierLoss, self).__init__() 8 | self.eps = eps 9 | 10 | def forward(self, pred, label): 11 | loss = torch.sqrt((pred - label)**2 + self.eps).mean() 12 | return loss.mean() 13 | 14 | 15 | def get_loss_fn(config, device): 16 | if config.loss_type == 'mae': 17 | criterion = nn.L1Loss() 18 | 19 | elif config.loss_type == 'mse': 20 | criterion = nn.MSELoss() 21 | 22 | elif config.loss_type == 'charbonnier': 23 | criterion = CharbonnierLoss() 24 | 25 | else: 26 | raise NotImplementedError(f"Unsupport loss type: {config.loss_type}") 27 | 28 | return criterion 29 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD, Adam, AdamW 2 | 3 | 4 | def get_optimizer(config, model): 5 | optimizer_hub = {'sgd':SGD, 'adam':Adam, 'adamw':AdamW} 6 | config.lr = config.base_lr * config.gpu_num 7 | params = model.parameters() 8 | 9 | if config.optimizer_type == 'sgd': 10 | optimizer = optimizer_hub[config.optimizer_type](params=params, lr=config.lr, 11 | momentum=config.momentum, 12 | weight_decay=config.weight_decay) 13 | 14 | elif config.optimizer_type in ['adam', 'adamw']: 15 | optimizer = optimizer_hub[config.optimizer_type](params=params, lr=config.lr) 16 | 17 | else: 18 | raise NotImplementedError(f'Unsupported optimizer type: {config.optimizer_type}') 19 | 20 | return optimizer -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 2 | 3 | 4 | def get_sr_metrics(metrics_type): 5 | if metrics_type == 'psnr': 6 | metrics = PeakSignalNoiseRatio(data_range=1.0, base=10.0, 7 | reduction='elementwise_mean', dim=(2,3)) 8 | elif metrics_type == 'ssim': 9 | metrics = StructuralSimilarityIndexMeasure(gaussian_kernel=True, sigma=1.5, 10 | kernel_size=11, 11 | reduction='elementwise_mean', 12 | data_range=1.0, k1=0.01, k2=0.03, 13 | return_full_image=False, 14 | return_contrast_sensitivity=False) 15 | return metrics -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | from .carn import CARN 3 | from .drcn import DRCN 4 | from .drrn import DRRN 5 | from .edsr import EDSR 6 | from .espcn import ESPCN 7 | from .fsrcnn import FSRCNN 8 | from .idn import IDN 9 | from .lapsrn import LapSRN 10 | from .srcnn import SRCNN 11 | from .srdensenet import SRDenseNet 12 | from .vdsr import VDSR 13 | 14 | 15 | model_hub = {'carn':CARN, 'drcn':DRCN, 'drrn':DRRN, 'edsr':EDSR, 'espcn':ESPCN, 'fsrcnn':FSRCNN, 16 | 'idn':IDN, 'lapsrn':LapSRN, 'srcnn':SRCNN, 'srdensenet':SRDenseNet, 'vdsr':VDSR,} 17 | 18 | 19 | def get_model(config): 20 | if config.model in model_hub.keys(): 21 | model = model_hub[config.model](in_channels=config.in_channels, 22 | out_channels=config.out_channels, 23 | upscale=config.upscale) 24 | else: 25 | raise NotImplementedError(f"Unsupport model type: {config.model}") 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /models/espcn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Real-Time Single Image and Video Super-Resolution Using an Efficient 3 | Sub-Pixel Convolutional Neural Network 4 | Url: https://arxiv.org/abs/1609.05158 5 | Create by: zh320 6 | Date: 2023/12/09 7 | """ 8 | 9 | import torch.nn as nn 10 | 11 | from .modules import ConvAct, Activation, Upsample 12 | 13 | 14 | class ESPCN(nn.Module): 15 | def __init__(self, in_channels, out_channels, upscale, n1=64, n2=32, act_type='tanh', 16 | upsample_type='pixelshuffle'): 17 | super(ESPCN, self).__init__() 18 | self.layer1 = ConvAct(in_channels, n1, 5, act_type=act_type) 19 | self.layer2 = ConvAct(n1, n2, 3, act_type=act_type) 20 | self.upsample = Upsample(n2, out_channels, upscale, upsample_type, 3) 21 | self.act3 = Activation(act_type) 22 | 23 | def forward(self, x): 24 | x = self.layer1(x) 25 | x = self.layer2(x) 26 | x = self.act3(self.upsample(x)) 27 | return x 28 | -------------------------------------------------------------------------------- /datasets/val_datasets.py: -------------------------------------------------------------------------------- 1 | from .sr_base_dataset import SRBaseDataset 2 | 3 | 4 | class Set5(SRBaseDataset): 5 | def __init__(self, config, mode): 6 | data_split = { 7 | 'val': [ 8 | f'val/Set5/image_SRF_{config.upscale}', 9 | ], 10 | } 11 | 12 | super(Set5, self).__init__(config, data_split, mode) 13 | 14 | 15 | class Set14(SRBaseDataset): 16 | def __init__(self, config, mode): 17 | data_split = { 18 | 'val': [ 19 | f'val/Set14/image_SRF_{config.upscale}', 20 | ], 21 | } 22 | 23 | super(Set14, self).__init__(config, data_split, mode) 24 | 25 | 26 | class BSD100(SRBaseDataset): 27 | def __init__(self, config, mode): 28 | data_split = { 29 | 'val': [ 30 | f'val/BSD100/image_SRF_{config.upscale}', 31 | ], 32 | } 33 | 34 | super(BSD100, self).__init__(config, data_split, mode) -------------------------------------------------------------------------------- /models/srcnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Image Super-Resolution Using Deep Convolutional Networks 3 | Url: https://arxiv.org/abs/1501.00092 4 | Create by: zh320 5 | Date: 2023/12/09 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .modules import conv5x5, ConvAct 12 | 13 | 14 | class SRCNN(nn.Module): 15 | def __init__(self, in_channels, out_channels, upscale, kernel_setting='935', 16 | act_type='relu'): 17 | super(SRCNN, self).__init__() 18 | if kernel_setting not in ['915', '935', '955']: 19 | raise ValueError(f'Unknown kernel setting: {kernel_setting}. You can choose \ 20 | from ["915", "935", "955"].\n') 21 | kernel_map = {'915':1, '935':3, '955':5} 22 | 23 | self.upscale = upscale 24 | self.layer1 = ConvAct(in_channels, 64, 9, act_type=act_type) 25 | self.layer2 = ConvAct(64, 32, kernel_map[kernel_setting], act_type=act_type) 26 | self.layer3 = conv5x5(32, out_channels) 27 | 28 | def forward(self, x): 29 | x = F.interpolate(x, scale_factor=self.upscale, mode='bicubic') 30 | x = self.layer1(x) 31 | x = self.layer2(x) 32 | x = self.layer3(x) 33 | 34 | return x 35 | -------------------------------------------------------------------------------- /models/vdsr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Accurate Image Super-Resolution Using Very Deep Convolutional Networks 3 | Url: https://arxiv.org/abs/1511.04587 4 | Create by: zh320 5 | Date: 2023/12/16 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .modules import conv3x3, ConvAct 12 | 13 | 14 | class VDSR(nn.Module): 15 | def __init__(self, in_channels, out_channels, upscale, layer_num=20, hid_channels=64, 16 | act_type='relu'): 17 | super(VDSR, self).__init__() 18 | self.upscale = upscale 19 | self.first_layer = conv3x3(in_channels, hid_channels) 20 | layers = [ConvAct(hid_channels, hid_channels, 3, inplace=True) for i in range(layer_num-2)] 21 | self.mid_layer = nn.Sequential(*layers) 22 | self.last_layer = conv3x3(hid_channels, out_channels) 23 | 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 27 | 28 | def forward(self, x): 29 | x = F.interpolate(x, scale_factor=self.upscale, mode='bicubic') 30 | res = self.first_layer(x) 31 | res = self.mid_layer(res) 32 | res = self.last_layer(res) 33 | res += x 34 | 35 | return res 36 | -------------------------------------------------------------------------------- /tools/get_model_infos.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os import path 3 | sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) ) 4 | 5 | from configs import MyConfig, load_parser 6 | from models import get_model 7 | 8 | 9 | def cal_model_params(config, imgw=200, imgh=200): 10 | model = get_model(config) 11 | print(f'\nModel: {config.model}') 12 | 13 | try: 14 | from ptflops import get_model_complexity_info 15 | model.eval() 16 | ''' 17 | Notice that ptflops doesn't take into account torch.nn.functional.* operations. 18 | If you want to get correct macs result, you need to modify the modules like 19 | torch.nn.functional.interpolate to torch.nn.Upsample. 20 | ''' 21 | _, params = get_model_complexity_info(model, (config.in_channels, imgh, imgw), as_strings=True, 22 | print_per_layer_stat=False, verbose=False) 23 | print(f'Number of parameters: {params}\n') 24 | except: 25 | import numpy as np 26 | params = np.sum([p.numel() for p in model.parameters()]) 27 | print(f'Number of parameters: {params / 1e3:.2f}K\n') 28 | 29 | 30 | if __name__ == '__main__': 31 | config = MyConfig() 32 | config = load_parser(config) 33 | 34 | cal_model_params(config) -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import OneCycleLR, StepLR, LambdaLR 2 | from math import ceil 3 | 4 | 5 | def get_scheduler(config, optimizer): 6 | if config.DDP: 7 | config.iters_per_epoch = ceil(config.train_num/config.train_bs/config.gpu_num) 8 | else: 9 | config.iters_per_epoch = ceil(config.train_num/config.train_bs) 10 | config.total_itrs = int(config.total_epoch*config.iters_per_epoch) 11 | 12 | if config.lr_policy == 'cos_warmup': 13 | warmup_ratio = config.warmup_epochs / config.total_epoch 14 | scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.total_itrs, 15 | pct_start=warmup_ratio) 16 | 17 | elif config.lr_policy == 'linear': 18 | scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.total_itrs, 19 | pct_start=0., anneal_strategy='linear') 20 | 21 | elif config.lr_policy == 'step': 22 | scheduler = StepLR(optimizer, step_size=config.step_size, gamma=config.step_gamma) 23 | 24 | elif config.lr_policy == 'constant': 25 | scheduler = LambdaLR(optimizer, lr_lambda=[lambda epoch: (epoch+1)/(epoch+1)]) 26 | 27 | else: 28 | raise NotImplementedError(f'Unsupported scheduler type: {config.lr_policy}') 29 | return scheduler -------------------------------------------------------------------------------- /models/fsrcnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Accelerating the Super-Resolution Convolutional Neural Network 3 | Url: https://arxiv.org/abs/1608.00367 4 | Create by: zh320 5 | Date: 2023/12/09 6 | """ 7 | 8 | import torch.nn as nn 9 | 10 | from .modules import ConvAct, Upsample 11 | 12 | 13 | class FSRCNN(nn.Module): 14 | def __init__(self, in_channels, out_channels, upscale, d=56, s=12, act_type='prelu', 15 | upsample_type='deconvolution'): 16 | super(FSRCNN, self).__init__() 17 | self.first_part = ConvAct(in_channels, d, 5, act_type=act_type, num_parameters=d) 18 | self.mid_part = nn.Sequential( 19 | ConvAct(d, s, 1, act_type=act_type, num_parameters=s), 20 | ConvAct(s, s, 3, act_type=act_type, num_parameters=s), 21 | ConvAct(s, s, 3, act_type=act_type, num_parameters=s), 22 | ConvAct(s, s, 3, act_type=act_type, num_parameters=s), 23 | ConvAct(s, s, 3, act_type=act_type, num_parameters=s), 24 | ConvAct(s, d, 1, act_type=act_type, num_parameters=d) 25 | ) 26 | self.last_part = Upsample(d, out_channels, upscale, upsample_type, 9) 27 | 28 | def forward(self, x): 29 | x = self.first_part(x) 30 | x = self.mid_part(x) 31 | x = self.last_part(x) 32 | return x 33 | -------------------------------------------------------------------------------- /utils/parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | 6 | 7 | def is_parallel(model): 8 | # Returns True if model is of type DP or DDP 9 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) 10 | 11 | 12 | def de_parallel(model): 13 | # De-parallelize a model: returns single-GPU model if model is of type DP or DDP 14 | return model.module if is_parallel(model) else model 15 | 16 | 17 | def set_device(config, rank): 18 | if config.DDP: 19 | torch.cuda.set_device(rank) 20 | dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://') 21 | device = torch.device('cuda', rank) 22 | config.gpu_num = dist.get_world_size() 23 | else: # DP 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | config.gpu_num = torch.cuda.device_count() 26 | config.train_bs *= config.gpu_num 27 | 28 | # Setup num_workers 29 | config.num_workers = config.gpu_num * config.base_workers 30 | 31 | return device 32 | 33 | 34 | def parallel_model(config, model, rank, device): 35 | if config.DDP: 36 | if config.synBN: 37 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 38 | model = DDP(model.to(rank), device_ids=[rank], output_device=rank) 39 | else: 40 | model = nn.DataParallel(model) 41 | model.to(device) 42 | 43 | return model 44 | 45 | 46 | def destroy_ddp_process(config): 47 | if config.DDP: 48 | dist.destroy_process_group() 49 | 50 | 51 | def sampler_set_epoch(config, loader, cur_epochs): 52 | if config.DDP: 53 | loader.sampler.set_epoch(cur_epochs) 54 | -------------------------------------------------------------------------------- /models/edsr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution 3 | Url: https://arxiv.org/abs/1707.02921 4 | Create by: zh320 5 | Date: 2023/12/16 6 | """ 7 | 8 | import torch.nn as nn 9 | 10 | from .modules import conv3x3, ConvAct, Upsample 11 | 12 | 13 | class EDSR(nn.Module): 14 | def __init__(self, in_channels, out_channels, upscale, B=16, F=64, scale_factor=None, 15 | act_type='relu', upsample_type='pixelshuffle'): 16 | super(EDSR, self).__init__() 17 | if scale_factor is None: 18 | scale_factor = 0.1 if B > 16 else 1.0 19 | 20 | self.first_layer = conv3x3(in_channels, F) 21 | 22 | layers = [] 23 | for _ in range(B): 24 | layers.append(ResidualBlock(F, scale_factor, act_type)) 25 | self.res_layers = nn.Sequential(*layers) 26 | 27 | self.mid_layer = conv3x3(F, F) 28 | self.last_layers = nn.Sequential( 29 | Upsample(F, F, upscale, upsample_type, 3), 30 | conv3x3(F, out_channels) 31 | ) 32 | 33 | def forward(self, x): 34 | x = self.first_layer(x) 35 | residual = x 36 | x = self.res_layers(x) 37 | x = self.mid_layer(x) 38 | x += residual 39 | x = self.last_layers(x) 40 | 41 | return x 42 | 43 | 44 | class ResidualBlock(nn.Module): 45 | def __init__(self, channels, scale_factor, act_type): 46 | super(ResidualBlock, self).__init__() 47 | self.scale_factor = scale_factor 48 | self.conv = nn.Sequential( 49 | ConvAct(channels, channels, 3, act_type=act_type), 50 | conv3x3(channels, channels) 51 | ) 52 | 53 | def forward(self, x): 54 | residual = x 55 | x = self.conv(x) 56 | if self.scale_factor < 1: 57 | x = x * self.scale_factor 58 | x += residual 59 | 60 | return x 61 | -------------------------------------------------------------------------------- /utils/model_ema.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Codes are based on 3 | https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | from copy import deepcopy 9 | from .parallel import de_parallel 10 | 11 | 12 | def get_ema_model(config, model, device): 13 | return ModelEmaV2(config, model, device=device) 14 | 15 | 16 | class ModelEmaV2(nn.Module): 17 | def __init__(self, config, model, device=None): 18 | super(ModelEmaV2, self).__init__() 19 | # make a copy of the model for accumulating moving average of weights 20 | self.ema = deepcopy(de_parallel(model)) 21 | self.ema.eval() 22 | self.device = device # perform ema on different device from model if set 23 | if self.device is not None: 24 | self.ema.to(device=device) 25 | self.use_ema = config.use_ema 26 | if config.ema_decay is not None: 27 | if config.ema_decay >= 1. or config.ema_decay <= 0.: 28 | raise ValueError('EMA decay rate out of range.\n') 29 | self.decay = config.ema_decay 30 | else: 31 | self.decay = None 32 | self.total_itrs = config.total_itrs 33 | 34 | @torch.no_grad() 35 | def _update(self, model, update_fn): 36 | for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()): 37 | if self.device is not None: 38 | model_v = model_v.to(device=self.device) 39 | ema_v.copy_(update_fn(ema_v, model_v)) 40 | 41 | def update(self, model, cur_itrs): 42 | if self.use_ema: 43 | if self.decay is not None: # Constant decay 44 | decay = self.decay 45 | else: # Linear decay 46 | decay = min(max(cur_itrs / self.total_itrs, 0), 1) 47 | self._update(de_parallel(model), update_fn=lambda e, m: decay * e + (1. - decay) * m) 48 | else: 49 | self._update(de_parallel(model), update_fn=lambda e, m: m) 50 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os, random, torch, json 2 | import numpy as np 3 | 4 | 5 | def mkdir(path): 6 | if not os.path.exists(path): 7 | os.mkdir(path) 8 | 9 | 10 | def set_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | 17 | def get_writer(config, main_rank): 18 | if config.use_tb and main_rank: 19 | from torch.utils.tensorboard import SummaryWriter 20 | writer = SummaryWriter(config.tb_log_dir) 21 | else: 22 | writer = None 23 | return writer 24 | 25 | 26 | def get_logger(config, main_rank): 27 | if main_rank: 28 | import sys 29 | from loguru import logger 30 | logger.remove() 31 | logger.add(sys.stderr, format="[{time:YYYY-MM-DD HH:mm}] {message}", level="INFO") 32 | 33 | log_path = f'{config.save_dir}/{config.logger_name}.log' 34 | logger.add(log_path, format="[{time:YYYY-MM-DD HH:mm}] {message}", level="INFO") 35 | else: 36 | logger = None 37 | return logger 38 | 39 | 40 | def save_config(config): 41 | config_dict = vars(config) 42 | with open(f'{config.save_dir}/config.json', 'w') as f: 43 | json.dump(config_dict, f, indent=4) 44 | 45 | 46 | def log_config(config, logger): 47 | if config.benchmark: 48 | keys = ['benchmark_datasets', 'upscale', 'train_y', 'model', 'in_channels', 'out_channels',] 49 | else: 50 | keys = ['dataset', 'upscale', 'train_y', 'model', 'in_channels', 'out_channels', 'loss_type', 51 | 'optimizer_type', 'lr_policy', 'total_epoch', 'lr', 'train_bs', 'val_bs', 52 | 'train_num', 'val_num', 'gpu_num', 'num_workers', 'amp_training', 53 | 'DDP', 'use_ema', 'patch_size', 'multi_scale'] 54 | 55 | config_dict = vars(config) 56 | infos = f"\n\n\n{'#'*25} Config Informations {'#'*25}\n" 57 | infos += '\n'.join('%s: %s' % (k, config_dict[k]) for k in keys) 58 | infos += f"\n{'#'*71}\n\n" 59 | logger.info(infos) 60 | -------------------------------------------------------------------------------- /models/drcn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Deeply-Recursive Convolutional Network for Image Super-Resolution 3 | Url: https://arxiv.org/abs/1511.04491 4 | Create by: zh320 5 | Date: 2023/12/23 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .modules import ConvAct 12 | 13 | 14 | class DRCN(nn.Module): 15 | def __init__(self, in_channels, out_channels, upscale, recursions=16, 16 | hid_channels=256, act_type='relu', arch_type='advanced'): 17 | super(DRCN, self).__init__() 18 | if arch_type not in ['basic', 'advanced']: 19 | raise ValueError(f'Unsupported model type: {arch_type}\n') 20 | self.upscale = upscale 21 | self.recursions = recursions 22 | self.arch_type = arch_type 23 | 24 | self.embedding_net = nn.Sequential( 25 | ConvAct(in_channels, hid_channels, 3, act_type=act_type), 26 | ConvAct(hid_channels, hid_channels, 3, act_type=act_type) 27 | ) 28 | self.inference_net = ConvAct(hid_channels, hid_channels, 3, act_type=act_type) 29 | self.reconstruction_net = nn.Sequential( 30 | ConvAct(hid_channels, hid_channels, 3, act_type=act_type), 31 | ConvAct(hid_channels, out_channels, 3, act_type=act_type) 32 | ) 33 | 34 | def forward(self, x): 35 | x = F.interpolate(x, scale_factor=self.upscale, mode='bicubic') 36 | if self.arch_type == 'advanced': 37 | skip = x 38 | 39 | x = self.embedding_net(x) 40 | 41 | for i in range(self.recursions): 42 | x = self.inference_net(x) 43 | 44 | if self.arch_type == 'advanced': 45 | if i == 0: 46 | res = self.reconstruction_net(x + skip) 47 | else: 48 | res += self.reconstruction_net(x + skip) 49 | 50 | if self.arch_type == 'basic': 51 | res = self.reconstruction_net(x) 52 | 53 | return res 54 | -------------------------------------------------------------------------------- /models/drrn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Image Super-Resolution via Deep Recursive Residual Network 3 | Url: https://openaccess.thecvf.com/content_cvpr_2017/html/Tai_Image_Super-Resolution_via_CVPR_2017_paper.html 4 | Create by: zh320 5 | Date: 2023/12/23 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .modules import conv3x3, ConvAct, ConvBNAct 12 | 13 | 14 | class DRRN(nn.Module): 15 | def __init__(self, in_channels, out_channels, upscale, B=1, U=9, hid_channels=128, 16 | act_type='relu', use_bn=False): 17 | super(DRRN, self).__init__() 18 | self.upscale = upscale 19 | self.B = B 20 | self.U = U 21 | 22 | ConvBlock = ConvBNAct if use_bn else ConvAct 23 | self.init_blocks = nn.ModuleList() 24 | self.recursive_blocks = nn.ModuleList() 25 | for i in range(B): 26 | in_ch = in_channels if i == B-1 else hid_channels 27 | self.init_blocks.append(ConvBlock(in_ch, hid_channels, 3, act_type=act_type)) 28 | self.recursive_blocks.append(ResidualUnit(hid_channels, act_type, use_bn)) 29 | self.last_layer = conv3x3(hid_channels, out_channels) 30 | 31 | def forward(self, x): 32 | x = F.interpolate(x, scale_factor=self.upscale, mode='bicubic') 33 | 34 | for i in range(self.B): 35 | res = self.init_blocks[i](x) 36 | residual = res 37 | for _ in range(self.U): 38 | res = self.recursive_blocks[i](res) 39 | res += residual 40 | 41 | res = self.last_layer(res) 42 | x = x + res 43 | 44 | return x 45 | 46 | 47 | class ResidualUnit(nn.Module): 48 | def __init__(self, channels, act_type, use_bn): 49 | super(ResidualUnit, self).__init__() 50 | ConvBlock = ConvBNAct if use_bn else ConvAct 51 | self.conv1 = ConvBlock(channels, channels, 3, act_type=act_type) 52 | self.conv2 = ConvBlock(channels, channels, 3, act_type=act_type) 53 | 54 | def forward(self, x): 55 | res = self.conv2(self.conv1(x)) 56 | res += x 57 | 58 | return res 59 | -------------------------------------------------------------------------------- /models/idn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Fast and Accurate Single Image Super-Resolution via Information Distillation Network 3 | Url: https://arxiv.org/abs/1803.09454 4 | Create by: zh320 5 | Date: 2023/12/30 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .modules import conv1x1, ConvAct, Upsample 13 | 14 | 15 | class IDN(nn.Module): 16 | def __init__(self, in_channels, out_channels, upscale, num_blocks=4, D3=64, s=4, 17 | act_type='leakyrelu', upsample_type='deconvolution'): 18 | super(IDN, self).__init__() 19 | assert s > 1, 's should be larger than 1, otherwise split_ratio will be out of range.\n' 20 | split_ratio = 1 / s 21 | d = int(split_ratio * D3) 22 | self.upscale = upscale 23 | 24 | self.fblock = nn.Sequential( 25 | ConvAct(in_channels, D3, 3, act_type=act_type), 26 | ConvAct(D3, D3, 3, act_type=act_type) 27 | ) 28 | 29 | layers = [] 30 | for i in range(num_blocks): 31 | layers.append(DBlock(D3, d, act_type)) 32 | self.dblocks = nn.Sequential(*layers) 33 | 34 | self.rblock = Upsample(D3, out_channels, upscale, upsample_type, 17) 35 | 36 | def forward(self, x): 37 | x_up = F.interpolate(x, scale_factor=self.upscale, mode='bicubic') 38 | 39 | x = self.fblock(x) 40 | x = self.dblocks(x) 41 | x = self.rblock(x) 42 | 43 | x += x_up 44 | 45 | return x 46 | 47 | 48 | class DBlock(nn.Sequential): 49 | def __init__(self, D3, d, act_type): 50 | super(DBlock, self).__init__( 51 | EnhancementUnit(D3, d, act_type), 52 | conv1x1(D3 + d, D3) 53 | ) 54 | 55 | 56 | class EnhancementUnit(nn.Module): 57 | def __init__(self, D3, d, act_type, groups=[1,4,1,4,1,1]): 58 | super(EnhancementUnit, self).__init__() 59 | assert len(groups) == 6, 'Length of groups should be 6.\n' 60 | self.d = d 61 | 62 | self.conv1 = nn.Sequential( 63 | ConvAct(D3, D3 - d, 3, groups=groups[0], act_type=act_type), 64 | ConvAct(D3 - d, D3 - 2*d, 3, groups=groups[1], act_type=act_type), 65 | ConvAct(D3 - 2*d, D3, 3, groups=groups[2], act_type=act_type), 66 | ) 67 | 68 | self.conv2 = nn.Sequential( 69 | ConvAct(D3 - d, D3, 3, groups=groups[3], act_type=act_type), 70 | ConvAct(D3, D3 - d, 3, groups=groups[4], act_type=act_type), 71 | ConvAct(D3 - d, D3 + d, 3, groups=groups[5], act_type=act_type), 72 | ) 73 | 74 | def forward(self, x): 75 | residual = x 76 | x = self.conv1(x) 77 | x_c = x[:, :self.d, :, :] 78 | x_c = torch.cat([x_c, residual], dim=1) 79 | x_s = x[:, self.d:, :, :] 80 | x_s = self.conv2(x_s) 81 | 82 | return x_s + x_c 83 | -------------------------------------------------------------------------------- /configs/base_config.py: -------------------------------------------------------------------------------- 1 | class BaseConfig: 2 | def __init__(self,): 3 | # Dataset 4 | self.dataset = None 5 | self.dataroot = None 6 | self.upscale = 2 7 | self.train_y = True 8 | 9 | # Model 10 | self.model = None 11 | self.in_channels = None 12 | self.out_channels = None 13 | 14 | # Training 15 | self.total_epoch = 3200 16 | self.base_lr = 0.001 17 | self.train_bs = 16 # For each GPU 18 | self.early_stop_epoch = 1000 19 | self.max_itrs_per_epoch = 200 20 | 21 | # Validating 22 | self.val_bs = 1 # For each GPU 23 | self.begin_val_epoch = 0 # Epoch to start validation 24 | self.val_interval = 1 # Epoch interval between validation 25 | self.metrics = 'psnr' 26 | 27 | # Testing 28 | self.is_testing = False 29 | self.test_bs = 1 30 | self.test_data_folder = None 31 | self.test_lr = False # Test downscaled image 32 | 33 | # Benchmark 34 | self.benchmark = False 35 | self.benchmark_datasets = ['set5', 'set14', 'bsd100'] 36 | 37 | # Loss 38 | self.loss_type = 'mse' 39 | 40 | # Scheduler 41 | self.lr_policy = 'constant' 42 | self.warmup_epochs = 3 43 | self.step_size = None 44 | self.step_gamma = 0.1 45 | 46 | # Optimizer 47 | self.optimizer_type = 'adam' 48 | self.momentum = 0.9 # For SGD 49 | self.weight_decay = 1e-4 # For SGD 50 | 51 | # Monitoring 52 | self.save_ckpt = True 53 | self.save_dir = 'save' 54 | self.use_tb = True # tensorboard 55 | self.tb_log_dir = None 56 | self.ckpt_name = None 57 | 58 | # Training setting 59 | self.amp_training = False 60 | self.resume_training = True 61 | self.load_ckpt = True 62 | self.load_ckpt_path = None 63 | self.base_workers = 8 64 | self.random_seed = 1 65 | self.use_ema = True 66 | self.ema_decay = 0.999 67 | self.ema_start_epoch = 0 68 | 69 | # Augmentation 70 | self.patch_size = None 71 | self.rotate = 0.0 72 | self.multi_scale = False 73 | self.hflip = 0.0 74 | self.vflip = 0.0 75 | 76 | # DDP 77 | self.synBN = False 78 | 79 | def init_dependent_config(self): 80 | if self.load_ckpt_path is None and not (self.is_testing or self.benchmark): 81 | self.load_ckpt_path = f'{self.save_dir}/last.pth' 82 | 83 | if (self.is_testing or self.benchmark) and (self.load_ckpt_path is None): 84 | self.load_ckpt_path = 'best.pth' 85 | 86 | if self.tb_log_dir is None: 87 | self.tb_log_dir = f'{self.save_dir}/tb_logs/' 88 | 89 | if isinstance(self.patch_size, int): 90 | self.patch_size = [self.patch_size, self.patch_size] 91 | 92 | if self.in_channels is None: 93 | self.in_channels = 1 if self.train_y else 3 94 | 95 | if self.out_channels is None: 96 | self.out_channels = 1 if self.train_y else 3 97 | -------------------------------------------------------------------------------- /models/lapsrn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution 3 | Url: https://arxiv.org/abs/1704.03915 4 | Create by: zh320 5 | Date: 2023/12/16 6 | """ 7 | 8 | import torch.nn as nn 9 | from math import log2 10 | 11 | from .modules import ConvAct, Upsample 12 | 13 | 14 | class LapSRN(nn.Module): 15 | def __init__(self, in_channels, out_channels, upscale, hid_channels=64, fe_layer_num=8, 16 | act_type='leakyrelu', upsample_type='deconvolution'): 17 | super(LapSRN, self).__init__() 18 | assert fe_layer_num > 3, 'Layer number should be larger than 3.\n' 19 | if upscale in [2, 4, 8]: 20 | self.num_stage = int(log2(upscale)) 21 | scale_factor = 2 22 | elif upscale == 3: 23 | self.num_stage = 1 24 | scale_factor = 3 25 | else: 26 | raise ValueError(f'Unsupported scale factor: {upscale}\n') 27 | 28 | self.fe_branch = FeatureExtraction(in_channels, out_channels, hid_channels, self.num_stage, 29 | fe_layer_num, scale_factor, upsample_type, act_type) 30 | self.ir_branch = ImageReconstruction(in_channels, out_channels, self.num_stage, 31 | scale_factor, upsample_type) 32 | 33 | def forward(self, x): 34 | feats = self.fe_branch(x) 35 | x = self.ir_branch(x, feats) 36 | 37 | return x 38 | 39 | 40 | class FeatureExtraction(nn.Module): 41 | def __init__(self, in_ch, out_ch, hid_ch, num_stage, layer_num, scale_factor, 42 | upsample_type, act_type): 43 | super(FeatureExtraction, self).__init__() 44 | self.num_stage = num_stage 45 | self.conv = nn.ModuleList() 46 | self.out = nn.ModuleList() 47 | for i in range(num_stage): 48 | init_ch = in_ch if i==0 else hid_ch 49 | layers = [ConvAct(init_ch, hid_ch, 3, act_type=act_type)] 50 | for _ in range(layer_num - 3): 51 | layers.append(ConvAct(hid_ch, hid_ch, 3, act_type=act_type)) 52 | layers.append(Upsample(hid_ch, hid_ch, scale_factor, upsample_type)) 53 | 54 | self.conv.append(nn.Sequential(*layers)) 55 | 56 | self.out.append(ConvAct(hid_ch, out_ch, 3, act_type=act_type)) 57 | 58 | def forward(self, x): 59 | feats = [] 60 | for i in range(self.num_stage): 61 | x = self.conv[i](x) 62 | feat = self.out[i](x) 63 | feats.append(feat) 64 | 65 | return feats 66 | 67 | 68 | class ImageReconstruction(nn.Module): 69 | def __init__(self, in_ch, out_ch, num_stage, scale_factor, upsample_type): 70 | super(ImageReconstruction, self).__init__() 71 | self.num_stage = num_stage 72 | self.up = nn.ModuleList() 73 | for i in range(num_stage): 74 | init_ch = in_ch if i == 0 else out_ch 75 | self.up.append(Upsample(init_ch, out_ch, scale_factor, upsample_type)) 76 | 77 | def forward(self, img, feats): 78 | for i in range(self.num_stage): 79 | img = self.up[i](img) 80 | img += feats[i] 81 | 82 | return img 83 | -------------------------------------------------------------------------------- /datasets/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | from .sr_base_dataset import SRBaseDataset 7 | 8 | 9 | class TestDataset(Dataset): 10 | def __init__(self, config): 11 | data_folder = os.path.expanduser(config.test_data_folder) 12 | self.train_y = config.train_y 13 | self.scale = config.upscale 14 | self.test_lr = config.test_lr 15 | 16 | if not os.path.isdir(data_folder): 17 | raise RuntimeError(f'Test image directory: {data_folder} does not exist.') 18 | 19 | self.hr_images = [] 20 | self.img_names = [] 21 | 22 | for file_name in os.listdir(data_folder): 23 | self.hr_images.append(os.path.join(data_folder, file_name)) 24 | self.img_names.append(file_name) 25 | 26 | def __len__(self): 27 | return len(self.hr_images) 28 | 29 | def __getitem__(self, index): 30 | hr = Image.open(self.hr_images[index]).convert('RGB') 31 | img_name = self.img_names[index] 32 | 33 | if self.test_lr: 34 | # Resize image to make it compatible for upscale factor if needed 35 | hr_width = (hr.width // self.scale) * self.scale 36 | hr_height = (hr.height // self.scale) * self.scale 37 | if hr_width != hr.width or hr_height != hr.height: 38 | hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC) 39 | 40 | # Generate low resolution image using bicubic interpolation of HR image 41 | lr_width = hr_width // self.scale 42 | lr_height = hr_height // self.scale 43 | lr = hr.resize((lr_width, lr_height), resample=Image.BICUBIC) 44 | 45 | # Need interpolated CbCr channels to recover hr images if train with y channel 46 | bicubic = lr.resize((lr.width * self.scale, lr.height * self.scale), resample=Image.BICUBIC) 47 | else: # test hr 48 | bicubic = hr.resize((hr.width * self.scale, hr.height * self.scale), resample=Image.BICUBIC) 49 | 50 | hr = np.array(hr).astype(np.float32) 51 | bicubic = np.array(bicubic).astype(np.float32) 52 | 53 | if self.test_lr: 54 | lr = np.array(lr).astype(np.float32) 55 | if self.train_y: 56 | # RGB to YCbCr (get interpolated CbCr channels here) 57 | ycbcr = SRBaseDataset.rgb_to_ycbcr(bicubic) 58 | 59 | if self.train_y: 60 | # RGB to YCbCr (only need Y channel here) 61 | hr = SRBaseDataset.rgb_to_ycbcr(hr, y_only=True) 62 | if self.test_lr: 63 | lr = SRBaseDataset.rgb_to_ycbcr(lr, y_only=True) 64 | 65 | # HW to CHW --> normalize 66 | hr = np.expand_dims(hr / 255., 0) 67 | if self.test_lr: 68 | lr = np.expand_dims(lr / 255., 0) 69 | else: 70 | # HWC to CHW --> normalize 71 | hr = hr.transpose((2, 0, 1)) / 255. 72 | if self.test_lr: 73 | lr = lr.transpose((2, 0, 1)) / 255. 74 | 75 | images = [np.ascontiguousarray(hr), bicubic] 76 | if self.test_lr: 77 | images.append(np.ascontiguousarray(lr)) 78 | 79 | if self.train_y: 80 | images.append(np.ascontiguousarray(ycbcr)) 81 | 82 | return images, img_name 83 | -------------------------------------------------------------------------------- /models/srdensenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Image Super-Resolution Using Dense Skip Connections 3 | Url: https://openaccess.thecvf.com/content_ICCV_2017/papers/Tong_Image_Super-Resolution_Using_ICCV_2017_paper.pdf 4 | Create by: zh320 5 | Date: 2024/01/27 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .modules import conv3x3, ConvAct, Activation, Upsample 12 | 13 | 14 | class SRDenseNet(nn.Module): 15 | def __init__(self, in_channels, out_channels, upscale, hid_channels=128, num_block=8, num_layer=8, 16 | act_type='relu', upsample_type='deconvolution'): 17 | super(SRDenseNet, self).__init__() 18 | assert upscale in [2,3,4], f'Unsupported upscale factor: {upscale}.\n' 19 | self.num_block = num_block 20 | 21 | # Initial Convolution 22 | self.conv = ConvAct(in_channels, hid_channels, 3, act_type=act_type) 23 | 24 | # Dense Blocks 25 | self.dense_blocks = nn.ModuleList([]) 26 | for _ in range(num_block): 27 | self.dense_blocks.append(DenseBlock(hid_channels, hid_channels, num_layer, act_type)) 28 | 29 | # Bottleneck Layer 30 | self.bottleneck = ConvAct(hid_channels*(num_block+1), hid_channels*2, 1, act_type=act_type) 31 | 32 | # Deconvolution Layers 33 | if upscale in [2, 3]: 34 | self.deconvolution = Upsample(hid_channels*2, hid_channels*2, upscale, upsample_type) 35 | elif upscale in [4]: 36 | self.deconvolution = nn.Sequential( 37 | Upsample(hid_channels*2, hid_channels*2, 2, upsample_type), 38 | Upsample(hid_channels*2, hid_channels*2, 2, upsample_type) 39 | ) 40 | 41 | # Reconstruction Layer 42 | self.reconstruction = conv3x3(hid_channels*2, out_channels) 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | 47 | feats = [x] 48 | for i in range(self.num_block): 49 | x = self.dense_blocks[i](x) 50 | feats.append(x) 51 | 52 | x = self.bottleneck(torch.cat(feats, dim=1)) 53 | x = self.deconvolution(x) 54 | x = self.reconstruction(x) 55 | 56 | return x 57 | 58 | 59 | class DenseBlock(nn.Module): 60 | def __init__(self, in_channels, out_channels, num_layer, act_type): 61 | super(DenseBlock, self).__init__() 62 | assert out_channels % num_layer == 0, 'out_channels should be evenly divided by num_layer.\n' 63 | self.num_layer = num_layer 64 | growth_rate = out_channels // num_layer 65 | 66 | self.conv0 = conv3x3(in_channels, growth_rate) 67 | self.act = nn.ModuleList([Activation(act_type) for _ in range(num_layer)]) 68 | self.conv = nn.ModuleList([]) 69 | for i in range(1, num_layer-1): 70 | self.conv.append(conv3x3(i*growth_rate, growth_rate)) 71 | self.conv.append(conv3x3((num_layer-1)*growth_rate, out_channels)) 72 | 73 | def forward(self, x): 74 | x = self.conv0(x) 75 | feats = [x] 76 | 77 | for i in range(self.num_layer - 1): 78 | x = torch.cat(feats, dim=1) 79 | x = self.act[i](x) 80 | feat = self.conv[i](x) 81 | if i != self.num_layer - 1: 82 | feats.append(feat) 83 | 84 | feat = self.act[-1](feat) 85 | 86 | return feat 87 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from .sr_base_dataset import SRBaseDataset 3 | from .sr_dataset import SRDataset 4 | from .val_datasets import Set5, Set14, BSD100 5 | 6 | dataset_hub = {'sr':SRDataset, 'set5':Set5, 'set14':Set14, 'bsd100':BSD100,} 7 | 8 | 9 | def get_dataset(config): 10 | if config.dataset in dataset_hub.keys(): 11 | train_dataset = dataset_hub[config.dataset](config=config, mode='train') 12 | val_dataset = dataset_hub[config.dataset](config=config, mode='val') 13 | else: 14 | raise NotImplementedError('Unsupported dataset!') 15 | 16 | return train_dataset, val_dataset 17 | 18 | 19 | def get_loader(config, rank, pin_memory=True): 20 | train_dataset, val_dataset = get_dataset(config) 21 | 22 | # Make sure train number is divisible by train batch size 23 | config.train_num = int(len(train_dataset) // config.train_bs * config.train_bs) 24 | config.val_num = len(val_dataset) 25 | 26 | if config.DDP: 27 | from torch.utils.data.distributed import DistributedSampler 28 | train_sampler = DistributedSampler(train_dataset, num_replicas=config.gpu_num, 29 | rank=rank, shuffle=True) 30 | val_sampler = DistributedSampler(val_dataset, num_replicas=config.gpu_num, 31 | rank=rank, shuffle=False) 32 | 33 | train_loader = DataLoader(train_dataset, batch_size=config.train_bs, shuffle=False, 34 | num_workers=config.num_workers, pin_memory=pin_memory, 35 | sampler=train_sampler, drop_last=True) 36 | 37 | val_loader = DataLoader(val_dataset, batch_size=config.val_bs, shuffle=False, 38 | num_workers=config.num_workers, pin_memory=pin_memory, 39 | sampler=val_sampler) 40 | else: 41 | train_loader = DataLoader(train_dataset, batch_size=config.train_bs, 42 | shuffle=True, num_workers=config.num_workers, drop_last=True) 43 | 44 | val_loader = DataLoader(val_dataset, batch_size=config.val_bs, 45 | shuffle=False, num_workers=config.num_workers) 46 | 47 | return train_loader, val_loader 48 | 49 | 50 | def get_val_dataset(config): 51 | val_datasets = [] 52 | for dataset_name in config.benchmark_datasets: 53 | if dataset_name in dataset_hub.keys(): 54 | val_dataset = dataset_hub[dataset_name](config=config, mode='val') 55 | val_datasets.append(val_dataset) 56 | else: 57 | raise NotImplementedError('Unsupported dataset!') 58 | 59 | return val_datasets 60 | 61 | 62 | def get_val_loader(config): 63 | val_datasets = get_val_dataset(config) 64 | 65 | val_loaders = [] 66 | for val_dataset in val_datasets: 67 | val_loader = DataLoader(val_dataset, batch_size=config.val_bs, 68 | shuffle=False, num_workers=config.num_workers) 69 | val_loaders.append(val_loader) 70 | 71 | return val_loaders 72 | 73 | 74 | def get_test_loader(config): 75 | from .test_dataset import TestDataset 76 | dataset = TestDataset(config) 77 | 78 | config.test_num = len(dataset) 79 | 80 | if config.DDP: 81 | raise NotImplementedError() 82 | 83 | else: 84 | test_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=config.num_workers) 85 | 86 | return test_loader 87 | -------------------------------------------------------------------------------- /models/carn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Paper: Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network 3 | Url: https://arxiv.org/abs/1803.08664 4 | Create by: zh320 5 | Date: 2023/12/30 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .modules import conv1x1, conv3x3, ConvAct, Activation, Upsample 12 | 13 | 14 | class CARN(nn.Module): 15 | def __init__(self, in_channels, out_channels, upscale, arch_type='carn', 16 | hid_channels=64, act_type='relu', upsample_type='pixelshuffle'): 17 | super(CARN, self).__init__() 18 | if arch_type not in ['carn', 'carn-m']: 19 | raise ValueError(f'Unsupported arch_type: {arch_type}\n') 20 | block = ResidualBlock if arch_type == 'carn' else ResidualEBlock 21 | 22 | self.conv1 = conv3x3(in_channels, hid_channels) 23 | self.cascading_block1 = CascadingBlock(block, hid_channels, act_type) 24 | self.conv2 = conv1x1(2*hid_channels, hid_channels) 25 | self.cascading_block2 = CascadingBlock(block, hid_channels, act_type) 26 | self.conv3 = conv1x1(3*hid_channels, hid_channels) 27 | self.cascading_block3 = CascadingBlock(block, hid_channels, act_type) 28 | self.conv4 = conv1x1(4*hid_channels, hid_channels) 29 | if upscale in [2, 3]: 30 | self.upsample = nn.Sequential( 31 | conv3x3(hid_channels, hid_channels), 32 | Upsample(hid_channels, hid_channels, upscale, upsample_type, 3) 33 | ) 34 | elif upscale == 4: 35 | self.upsample = nn.Sequential( 36 | conv3x3(hid_channels, hid_channels), 37 | Upsample(hid_channels, hid_channels, 2, upsample_type, 3), 38 | conv3x3(hid_channels, hid_channels), 39 | Upsample(hid_channels, hid_channels, 2, upsample_type, 3) 40 | ) 41 | else: 42 | raise NotImplementedError(f'Unsupported upscale factor: {upscale}\n') 43 | self.conv_last = conv3x3(hid_channels, out_channels) 44 | 45 | def forward(self, x): 46 | x1 = self.conv1(x) 47 | x_cb1 = self.cascading_block1(x1) 48 | x = torch.cat([x1, x_cb1], dim=1) 49 | 50 | x = self.conv2(x) 51 | x_cb2 = self.cascading_block2(x) 52 | x = torch.cat([x1, x_cb1, x_cb2], dim=1) 53 | 54 | x = self.conv3(x) 55 | x_cb3 = self.cascading_block3(x) 56 | x = torch.cat([x1, x_cb1, x_cb2, x_cb3], dim=1) 57 | 58 | x = self.conv4(x) 59 | x = self.upsample(x) 60 | x = self.conv_last(x) 61 | 62 | return x 63 | 64 | 65 | class CascadingBlock(nn.Module): 66 | def __init__(self, block, channels, act_type): 67 | super(CascadingBlock, self).__init__() 68 | self.res1 = block(channels, act_type) 69 | self.conv1 = conv1x1(2*channels, channels) 70 | self.res2 = block(channels, act_type) 71 | self.conv2 = conv1x1(3*channels, channels) 72 | self.res3 = block(channels, act_type) 73 | self.conv3 = conv1x1(4*channels, channels) 74 | 75 | def forward(self, x): 76 | x0 = x 77 | 78 | x1 = self.res1(x) 79 | x = torch.cat([x0, x1], dim=1) 80 | x = self.conv1(x) 81 | 82 | x2 = self.res2(x) 83 | x = torch.cat([x0, x1, x2], dim=1) 84 | x = self.conv2(x) 85 | 86 | x = self.res3(x) 87 | x = torch.cat([x, x0, x1, x2], dim=1) 88 | x = self.conv3(x) 89 | 90 | return x 91 | 92 | 93 | class ResidualBlock(nn.Module): 94 | def __init__(self, channels, act_type): 95 | super(ResidualBlock, self).__init__() 96 | self.conv = nn.Sequential( 97 | ConvAct(channels, channels, 3, act_type=act_type), 98 | conv3x3(channels, channels) 99 | ) 100 | self.act = Activation(act_type) 101 | 102 | def forward(self, x): 103 | residual = x 104 | x = self.conv(x) 105 | x += residual 106 | 107 | return self.act(x) 108 | 109 | 110 | class ResidualEBlock(nn.Module): 111 | def __init__(self, channels, act_type, groups=4): 112 | super(ResidualEBlock, self).__init__() 113 | self.conv = nn.Sequential( 114 | ConvAct(channels, channels, 3, groups=groups, act_type=act_type), 115 | ConvAct(channels, channels, 3, groups=groups, act_type=act_type), 116 | conv1x1(channels, channels) 117 | ) 118 | self.act = Activation(act_type) 119 | 120 | def forward(self, x): 121 | residual = x 122 | x = self.conv(x) 123 | x += residual 124 | 125 | return self.act(x) 126 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # Regular convolution with kernel size 3x3 5 | def conv5x5(in_channels, out_channels, stride=1): 6 | return nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, 7 | padding=2, bias=True) 8 | 9 | 10 | # Regular convolution with kernel size 3x3 11 | def conv3x3(in_channels, out_channels, stride=1): 12 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, 13 | padding=1, bias=True) 14 | 15 | 16 | # Regular convolution with kernel size 1x1, a.k.a. point-wise convolution 17 | def conv1x1(in_channels, out_channels, stride=1): 18 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, 19 | padding=0, bias=True) 20 | 21 | 22 | class Upsample(nn.Module): 23 | def __init__(self, in_channels, out_channels, scale_factor=2, upsample_type=None, 24 | kernel_size=None,): 25 | super(Upsample, self).__init__() 26 | if upsample_type == 'deconvolution': 27 | if kernel_size is None: 28 | kernel_size = 2*scale_factor + 1 29 | padding = (kernel_size - 1) // 2 30 | output_padding = scale_factor - 1 31 | self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, 32 | stride=scale_factor, padding=padding, 33 | output_padding=output_padding, bias=True) 34 | elif upsample_type == 'pixelshuffle': 35 | ks = kernel_size if kernel_size is not None else 3 36 | padding = (ks - 1) // 2 37 | self.up_conv = nn.Sequential( 38 | nn.Conv2d(in_channels, out_channels * (scale_factor**2), ks, 1, padding), 39 | nn.PixelShuffle(scale_factor) 40 | ) 41 | else: 42 | ks = kernel_size if kernel_size is not None else 3 43 | padding = (ks - 1) // 2 44 | self.up_conv = nn.Sequential( 45 | nn.Conv2d(in_channels, out_channels, ks, 1, padding), 46 | nn.Upsample(scale_factor=scale_factor, mode='bicubic') 47 | ) 48 | 49 | def forward(self, x): 50 | return self.up_conv(x) 51 | 52 | 53 | # Regular convolution -> activation 54 | class ConvAct(nn.Sequential): 55 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, 56 | groups=1, bias=True, act_type='relu', **kwargs): 57 | if isinstance(kernel_size, list) or isinstance(kernel_size, tuple): 58 | padding = ((kernel_size[0] - 1) // 2 * dilation, (kernel_size[1] - 1) // 2 * dilation) 59 | elif isinstance(kernel_size, int): 60 | padding = (kernel_size - 1) // 2 * dilation 61 | 62 | super(ConvAct, self).__init__( 63 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias), 64 | Activation(act_type, **kwargs) 65 | ) 66 | 67 | 68 | # Regular convolution -> batchnorm -> activation 69 | class ConvBNAct(nn.Sequential): 70 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, 71 | groups=1, bias=True, act_type='relu', **kwargs): 72 | if isinstance(kernel_size, list) or isinstance(kernel_size, tuple): 73 | padding = ((kernel_size[0] - 1) // 2 * dilation, (kernel_size[1] - 1) // 2 * dilation) 74 | elif isinstance(kernel_size, int): 75 | padding = (kernel_size - 1) // 2 * dilation 76 | 77 | super(ConvBNAct, self).__init__( 78 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias), 79 | nn.BatchNorm2d(out_channels), 80 | Activation(act_type, **kwargs) 81 | ) 82 | 83 | 84 | class Activation(nn.Module): 85 | def __init__(self, act_type, **kwargs): 86 | super(Activation, self).__init__() 87 | activation_hub = {'relu': nn.ReLU, 'relu6': nn.ReLU6, 88 | 'leakyrelu': nn.LeakyReLU, 'prelu': nn.PReLU, 89 | 'celu': nn.CELU, 'elu': nn.ELU, 90 | 'hardswish': nn.Hardswish, 'hardtanh': nn.Hardtanh, 91 | 'gelu': nn.GELU, 'glu': nn.GLU, 92 | 'selu': nn.SELU, 'silu': nn.SiLU, 93 | 'sigmoid': nn.Sigmoid, 'softmax': nn.Softmax, 94 | 'tanh': nn.Tanh, 'none': nn.Identity, 95 | } 96 | 97 | act_type = act_type.lower() 98 | if act_type not in activation_hub.keys(): 99 | raise NotImplementedError(f'Unsupport activation type: {act_type}') 100 | 101 | self.activation = activation_hub[act_type](**kwargs) 102 | 103 | def forward(self, x): 104 | return self.activation(x) 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | PyTorch implementation of efficient image super-resolution models. 4 | 5 | 6 | 7 | # 8 | 9 | # Requirements 10 | 11 | torch == 1.8.1 12 | torchmetrics 13 | loguru 14 | tqdm 15 | 16 | # Supported models 17 | 18 | - [CARN](models/carn.py) [^carn] 19 | - [DRCN](models/drcn.py) [^drcn] 20 | - [DRRN](models/drrn.py) [^drrn] 21 | - [EDSR](models/edsr.py) [^edsr] 22 | - [ESPCN](models/espcn.py) [^espcn] 23 | - [FSRCNN](models/fsrcnn.py) [^fsrcnn] 24 | - [IDN](models/idn.py) [^idn] 25 | - [LapSRN](models/lapsrn.py) [^lapsrn] 26 | - [SRCNN](models/srcnn.py) [^srcnn] 27 | - [SRDenseNet](models/srdensenet.py) [^srdensenet] 28 | - [VDSR](models/vdsr.py) [^vdsr] 29 | 30 | [^carn]: [Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network](https://arxiv.org/abs/1803.08664) 31 | [^drcn]: [Deeply-Recursive Convolutional Network for Image Super-Resolution](https://arxiv.org/abs/1511.04491) 32 | [^drrn]: [ Image Super-Resolution via Deep Recursive Residual Network](https://openaccess.thecvf.com/content_cvpr_2017/html/Tai_Image_Super-Resolution_via_CVPR_2017_paper.html) 33 | [^edsr]: [Enhanced Deep Residual Networks for Single Image Super-Resolution](https://arxiv.org/abs/1707.02921) 34 | [^espcn]: [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158) 35 | [^fsrcnn]: [Accelerating the Super-Resolution Convolutional Neural Network](https://arxiv.org/abs/1608.00367) 36 | [^idn]: [Fast and Accurate Single Image Super-Resolution via Information Distillation Network](https://arxiv.org/abs/1803.09454) 37 | [^lapsrn]: [Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution](https://arxiv.org/abs/1704.03915) 38 | [^srcnn]: [Image Super-Resolution Using Deep Convolutional Networks](https://arxiv.org/abs/1501.00092) 39 | [^srdensenet]: [Image Super-Resolution Using Dense Skip Connections](https://openaccess.thecvf.com/content_ICCV_2017/papers/Tong_Image_Super-Resolution_Using_ICCV_2017_paper.pdf) 40 | [^vdsr]: [Accurate Image Super-Resolution Using Very Deep Convolutional Networks](https://arxiv.org/abs/1511.04587) 41 | 42 | # How to use 43 | 44 | ## DDP training (recommend) 45 | 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py 48 | ``` 49 | 50 | ## DP training 51 | 52 | ``` 53 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py 54 | ``` 55 | 56 | # Performances and checkpoints 57 | 58 | | Model | Year | Train on1 | Set5 | | Set14 | | BSD100 | | 59 | |:------------------------------------------------------------------------------------------------------------------------:|:----:|:--------------------:|:---------------:|:-------------:|:-----------:|:-------------:|:-----------:|:-------------:| 60 | | 2x | | | PSNR (paper/my) | SSIM | PSNR | SSIM | PSNR | SSIM | 61 | | [CARN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/carn_2x.pth) | 2018 | T+B+D | 37.76/37.90 | 0.9590/0.9605 | 33.52/33.14 | 0.9166/0.9152 | 32.09/32.06 | 0.8978/0.8985 | 62 | | [DRCN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/drcn_2x.pth) | 2015 | T | 37.63/37.85 | 0.9588/0.9604 | 33.04/33.22 | 0.9118/0.916 | 31.85/32.05 | 0.8942/0.8982 | 63 | | [DRRN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/drrn_2x.pth) | 2017 | T+B | 37.74/37.76 | 0.9591/0.9599 | 33.23/33.14 | 0.9136/0.9149 | 32.05/31.99 | 0.8973/0.8974 | 64 | | [EDSR](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/edsr_2x.pth) | 2017 | D | 37.99/37.90 | 0.9604/0.9606 | 33.57/33.22 | 0.9175/0.9163 | 32.16/32.10 | 0.8994/0.899 | 65 | | [ESPCN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/espcn_2x.pth) | 2016 | I+T | n.a./36.85 | n.a./0.9559 | n.a./32.31 | n.a./0.9087 | n.a./31.40 | n.a./0.8897 | 66 | | [FSRCNN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/fsrcnn_2x.pth) | 2016 | T+G | 37.00/37.27 | 0.9558/0.958 | 32.63/32.65 | 0.9088/0.9115 | 31.53/31.67 | 0.8920/0.8934 | 67 | | [IDN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/idn_2x.pth) | 2018 | T+B | 37.83/37.84 | 0.96/0.9604 | 33.30/33.12 | 0.9148/0.9155 | 32.08/32.06 | 0.8985/0.8985 | 68 | | [LapSRN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/lapsrn_2x.pth) | 2017 | T+B | 37.52/37.59 | 0.9591/0.9592 | 32.99/32.96 | 0.9124/0.9138 | 31.80/31.89 | 0.8952/0.8961 | 69 | | [SRCNN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/srcnn_2x.pth) | 2014 | I+T | 36.66/36.88 | 0.9542/0.9561 | 32.45/32.42 | 0.9067/0.9092 | 31.36/31.50 | 0.8879/0.8907 | 70 | | [SRDenseNet](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/srdensenet_2x.pth) | 2017 | I | n.a./37.67 | n.a./0.9596 | n.a./33.05 | n.a./0.9142 | n.a./31.93 | n.a./0.8967 | 71 | | [VDSR](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/vdsr_2x.pth) | 2015 | T+B | 37.53/37.74 | 0.9587/0.9598 | 33.03/33.06 | 0.9124/0.9145 | 31.90/31.97 | 0.8960/0.8973 | 72 | 73 | [1 Original training dataset, which are short for B (BSD200), D (DIV2K), G (General100), I (ImageNet), T (T91). In my experiments, the training dataset is T + G + B.] 74 | 75 | # Prepare the dataset 76 | 77 | ``` 78 | /train 79 | /T91 80 | /General100 81 | /BSD200 82 | /val 83 | /Set5 84 | /Set14 85 | /BSD100 86 | ``` 87 | 88 | # References -------------------------------------------------------------------------------- /datasets/sr_base_dataset.py: -------------------------------------------------------------------------------- 1 | import os, random 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class SRBaseDataset(Dataset): 8 | IMG_SUFFIX = ['jpg', 'jpeg', 'png', 'bmp'] 9 | 10 | def __init__(self, config, data_split, mode='train'): 11 | assert mode in ['train', 'val'], f'Unsupported dataset mode: {mode}.\n' 12 | 13 | data_root = os.path.expanduser(config.data_root) 14 | img_dirs = [] 15 | for img_dir in data_split[mode]: 16 | img_dirs.append(os.path.join(data_root, img_dir)) 17 | 18 | if len(img_dirs) == 0: 19 | raise RuntimeError('No image directory found.') 20 | 21 | for img_dir in img_dirs: 22 | if not os.path.isdir(img_dir): 23 | raise RuntimeError(f'Image directory: {img_dir} does not exist.') 24 | 25 | self.mode = mode 26 | self.train_y = config.train_y 27 | self.scale = config.upscale 28 | self.patch_size = config.patch_size 29 | self.random_rotate = config.rotate 30 | self.multi_scale = config.multi_scale 31 | self.hflip = config.hflip 32 | self.vflip = config.vflip 33 | 34 | hr_images = [] 35 | for img_dir in img_dirs: 36 | for file_name in os.listdir(img_dir): 37 | if file_name.split('.')[-1].lower() in SRBaseDataset.IMG_SUFFIX: 38 | img_path = os.path.join(img_dir, file_name) 39 | hr_images.append(img_path) 40 | 41 | if self.mode == 'train': 42 | dataset_repeat_times = config.max_itrs_per_epoch*config.train_bs*config.gpu_num // len(hr_images) 43 | hr_images *= dataset_repeat_times 44 | 45 | self.hr_images = hr_images 46 | 47 | self.num = len(self.hr_images) 48 | 49 | @staticmethod 50 | def rgb_to_ycbcr(img, y_only=False): 51 | if not isinstance(img, np.ndarray): 52 | raise ValueError(f'\nInput must be np.ndarray, but got type: {type(img)} instead.') 53 | if img.shape[2] != 3: 54 | raise ValueError(f'\nInput should be RGB channel array, but got {img.shape[2]} channel instead.') 55 | 56 | y = (np.dot(img, [65.481, 128.553, 24.966]) / 255. + 16.).astype(np.float32) 57 | if y_only: 58 | return y 59 | else: 60 | cb = (np.dot(img, [-37.945, -74.494, 112.439]) / 255. + 128.).astype(np.float32) 61 | cr = (np.dot(img, [112.439, -94.154, -18.285]) / 255. + 128.).astype(np.float32) 62 | return np.array([y, cb, cr]).transpose([1, 2, 0]) 63 | 64 | @staticmethod 65 | def ycbcr_to_rgb(img): 66 | if not isinstance(img, np.ndarray): 67 | raise ValueError(f'\nInput must be np.ndarray, but got type: {type(img)} instead.') 68 | if img.shape[2] != 3: 69 | raise ValueError(f'\nInput should be 3-channel array, but got {img.shape[2]} channel instead.') 70 | 71 | r = (np.dot(img, [298.082, 0, 408.583]) / 255. - 222.921).astype(np.float32) 72 | g = (np.dot(img, [298.082, -100.291, - 208.12]) / 255. + 135.576).astype(np.float32) 73 | b = (np.dot(img, [298.082, 516.412, 0]) / 255. - 276.836).astype(np.float32) 74 | return np.array([r, g, b]).transpose([1, 2, 0]) 75 | 76 | def __len__(self): 77 | return len(self.hr_images) 78 | 79 | def __getitem__(self, index): 80 | hr = Image.open(self.hr_images[index]).convert('RGB') 81 | 82 | if self.mode == 'train': 83 | # Perform multiscale augmentation 84 | if self.multi_scale: 85 | # In case the image is too small to perform random crop 86 | min_scale = max(self.patch_size[0]/hr.width, self.patch_size[1]/hr.height) 87 | if min_scale > 1: 88 | scale = min_scale 89 | else: 90 | scale = random.uniform(0.75, 1.0) 91 | 92 | new_w = int(hr.width * scale) 93 | new_h = int(hr.height * scale) 94 | hr = hr.resize((new_w, new_h), resample=Image.BICUBIC) 95 | 96 | # Calculate the random crop location x0 97 | if hr.width > self.patch_size[0]: 98 | x0 = random.randint(0, hr.width-self.patch_size[0]) 99 | else: 100 | x0 = 0 101 | 102 | # Calculate the random crop location y0 103 | if hr.height > self.patch_size[1]: 104 | y0 = random.randint(0, hr.height-self.patch_size[1]) 105 | else: 106 | y0 = 0 107 | 108 | # Random crop a path using patch_size 109 | hr = hr.crop([x0, y0, x0+self.patch_size[0], y0+self.patch_size[1]]) 110 | 111 | # Random rotation 112 | if random.random() < self.random_rotate: 113 | angle = random.randint(0, 3) * 90 114 | hr = hr.rotate(angle, expand=True) 115 | 116 | # Random horizontal flip 117 | if random.random() < self.hflip: 118 | hr = hr.transpose(Image.FLIP_LEFT_RIGHT) 119 | 120 | # Random vertical flip 121 | if random.random() < self.vflip: 122 | hr = hr.transpose(Image.FLIP_TOP_BOTTOM) 123 | 124 | # Resize image to make it compatible for training if needed 125 | hr_width = (hr.width // self.scale) * self.scale 126 | hr_height = (hr.height // self.scale) * self.scale 127 | if hr_width != hr.width or hr_height != hr.height: 128 | hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC) 129 | 130 | # Generate low resolution image using bicubic interpolation of HR image 131 | lr_width = hr_width // self.scale 132 | lr_height = hr_height // self.scale 133 | lr = hr.resize((lr_width, lr_height), resample=Image.BICUBIC) 134 | 135 | hr = np.array(hr).astype(np.float32) 136 | lr = np.array(lr).astype(np.float32) 137 | 138 | if self.train_y: 139 | # RGB to YCbCr (only need Y channel here) 140 | hr = self.rgb_to_ycbcr(hr, y_only=True) 141 | lr = self.rgb_to_ycbcr(lr, y_only=True) 142 | 143 | # HW to CHW --> normalize 144 | hr = np.expand_dims(hr / 255., 0) 145 | lr = np.expand_dims(lr / 255., 0) 146 | else: 147 | # HWC to CHW --> normalize 148 | hr = hr.transpose((2, 0, 1)) / 255. 149 | lr = lr.transpose((2, 0, 1)) / 255. 150 | 151 | return np.ascontiguousarray(lr), np.ascontiguousarray(hr) 152 | -------------------------------------------------------------------------------- /core/sr_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from torch.cuda import amp 7 | from copy import deepcopy 8 | 9 | from .base_trainer import BaseTrainer 10 | from utils import (get_sr_metrics, sampler_set_epoch, log_config, de_parallel) 11 | 12 | 13 | class SRTrainer(BaseTrainer): 14 | def __init__(self, config): 15 | super().__init__(config) 16 | self.psnr = get_sr_metrics('psnr').to(self.device) 17 | self.ssim = get_sr_metrics('ssim').to(self.device) 18 | if config.metrics not in ['psnr', 'ssim']: 19 | raise ValueError(f'Unsupport metrics type: {config.metrics}') 20 | 21 | def train_one_epoch(self, config): 22 | self.model.train() 23 | 24 | sampler_set_epoch(config, self.train_loader, self.cur_epoch) 25 | 26 | pbar = tqdm(self.train_loader) if self.main_rank else self.train_loader 27 | 28 | for cur_itrs, (images, labels) in enumerate(pbar): 29 | self.cur_itrs = cur_itrs 30 | self.train_itrs += 1 31 | 32 | images = images.to(self.device) 33 | labels = labels.to(self.device) 34 | 35 | self.optimizer.zero_grad() 36 | 37 | # Forward path 38 | with amp.autocast(enabled=config.amp_training): 39 | preds = self.model(images) 40 | loss = self.loss_fn(preds, labels) 41 | 42 | if config.use_tb and self.main_rank: 43 | self.writer.add_scalar('train/loss', loss.detach(), self.train_itrs) 44 | 45 | # Backward path 46 | self.scaler.scale(loss).backward() 47 | self.scaler.step(self.optimizer) 48 | self.scaler.update() 49 | self.scheduler.step() 50 | 51 | if self.cur_epoch >= config.ema_start_epoch: 52 | self.ema_model.update(self.model, self.train_itrs) 53 | else: 54 | self.ema_model.ema = deepcopy(de_parallel(self.model)) 55 | 56 | if self.main_rank: 57 | pbar.set_description(('%s'*2) % 58 | (f'Epoch:{self.cur_epoch}/{config.total_epoch}{" "*4}|', 59 | f'Loss:{loss.detach():4.4g}{" "*4}|',) 60 | ) 61 | 62 | return 63 | 64 | @torch.no_grad() 65 | def validate(self, config, val_best=False): 66 | pbar = tqdm(self.val_loader) if self.main_rank else self.val_loader 67 | for (images, labels) in pbar: 68 | images = images.to(self.device) 69 | labels = labels.to(self.device) 70 | 71 | preds = self.ema_model.ema(images).clamp(0.0, 1.0) 72 | self.psnr.update(preds.detach(), labels) 73 | self.ssim.update(preds.detach(), labels) 74 | 75 | if self.main_rank: 76 | pbar.set_description(('%s'*1) % (f'Validating:{" "*4}|',)) 77 | 78 | psnr = self.psnr.compute() 79 | ssim = self.ssim.compute() 80 | if config.metrics == 'psnr': 81 | score = psnr 82 | elif config.metrics == 'ssim': 83 | score = ssim 84 | 85 | if self.main_rank: 86 | if val_best: 87 | self.logger.info(f'\n\nTrain {config.total_epoch} epochs finished.' + 88 | f'\n\nBest {config.metrics.upper()} is: {psnr:.4f}\n') 89 | else: 90 | self.logger.info(f' Epoch{self.cur_epoch} PSNR: {psnr:.4f} SSIM: {ssim:.4f} | ' + 91 | f'best {config.metrics.upper()} so far: {self.best_score:.4f}\n') 92 | 93 | if config.use_tb and self.cur_epoch < config.total_epoch: 94 | self.writer.add_scalar('val/PSNR', psnr.cpu(), self.cur_epoch+1) 95 | self.writer.add_scalar('val/SSIM', ssim.cpu(), self.cur_epoch+1) 96 | self.psnr.reset() 97 | self.ssim.reset() 98 | return score 99 | 100 | @torch.no_grad() 101 | def benchmark(self, config): 102 | if self.main_rank: 103 | log_config(config, self.logger) 104 | 105 | print(f'{"-"*25} Start benchmarking {"-"*25}\n') 106 | for i, val_loader in enumerate(self.val_loaders): 107 | print(f"\nStart validating dataset: {config.benchmark_datasets[i]}...") 108 | 109 | pbar = tqdm(val_loader) if self.main_rank else val_loader 110 | for (images, labels) in pbar: 111 | images = images.to(self.device) 112 | labels = labels.to(self.device) 113 | 114 | preds = self.model(images).clamp(0.0, 1.0) 115 | self.psnr.update(preds.detach(), labels) 116 | self.ssim.update(preds.detach(), labels) 117 | 118 | if self.main_rank: 119 | pbar.set_description(('%s'*1) % (f'Validating:{" "*4}|',)) 120 | 121 | psnr = self.psnr.compute() 122 | ssim = self.ssim.compute() 123 | 124 | if self.main_rank: 125 | self.logger.info(f' PSNR: {psnr:.4f} SSIM: {ssim:.4f}\n') 126 | 127 | self.psnr.reset() 128 | self.ssim.reset() 129 | print(f'{"-"*25} Finish benchmarking {"-"*25}\n') 130 | 131 | @torch.no_grad() 132 | def predict(self, config): 133 | from datasets import SRBaseDataset 134 | if config.DDP: 135 | raise ValueError('Predict mode currently does not support DDP.') 136 | 137 | if config.test_bs != 1: 138 | self.logger.info('Warning: Predict mode only support batch size 1\n') 139 | 140 | self.logger.info('\nStart predicting...\n') 141 | 142 | if config.test_lr: 143 | test_score_path = f'{config.save_dir}/test_score_{config.model}_x{config.upscale}.txt' 144 | if os.path.isfile(test_score_path): 145 | os.remove(test_score_path) 146 | 147 | for (images, img_names) in tqdm(self.test_loader): 148 | hr = images[0] 149 | bicubic = images[1] 150 | img_name = img_names[0] 151 | 152 | hr = hr.to(self.device, dtype=torch.float32) 153 | if config.test_lr: 154 | lr = images[2].to(self.device, dtype=torch.float32) 155 | pred = self.model(lr).clamp(0.0, 1.0) 156 | 157 | # Compute PSNR and SSIM if test lr image 158 | self.psnr.update(pred.detach(), hr) 159 | self.ssim.update(pred.detach(), hr) 160 | 161 | psnr = self.psnr.compute() 162 | ssim = self.ssim.compute() 163 | 164 | self.psnr.reset() 165 | self.ssim.reset() 166 | 167 | with open(test_score_path, 'a+') as f: 168 | f.write(f'{img_name}\tPSNR: {psnr:.2f}\tSSIM: {ssim:.4f}\n') 169 | else: 170 | pred = self.model(hr).clamp(0.0, 1.0) 171 | 172 | # BCHW --> HWC 173 | if config.train_y: 174 | # Need to concatenate [pred_y, bicubic_cb, bicubic_cr] to obtain predicted image 175 | ycbcr = images[-1].numpy().squeeze(0) 176 | pred = pred.mul(255.0).cpu().unsqueeze(-1).numpy().squeeze(0).squeeze(0) 177 | pred = np.concatenate([pred, ycbcr[...,1:]], axis=-1) 178 | pred = SRBaseDataset.ycbcr_to_rgb(pred) 179 | else: 180 | pred = pred.mul(255.0).cpu().numpy().squeeze(0).transpose([1, 2, 0]) 181 | 182 | pred = np.clip(pred, 0., 255.).astype(np.uint8) 183 | 184 | # Saving results 185 | img_suffix = img_name.split('.')[-1] 186 | img_prefix = img_name[:-len(img_suffix)-1] 187 | pred_path = f'{config.save_dir}/{img_prefix}_{config.model}_x{config.upscale}.{img_suffix}' 188 | bicubic_path = f'{config.save_dir}/{img_prefix}_bicubic_x{config.upscale}.{img_suffix}' 189 | 190 | pred = Image.fromarray(pred.astype(np.uint8)) 191 | pred.save(pred_path) 192 | 193 | bicubic = Image.fromarray(bicubic.numpy().squeeze().astype(np.uint8)) 194 | bicubic.save(bicubic_path) 195 | -------------------------------------------------------------------------------- /configs/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def load_parser(config): 5 | args = get_parser() 6 | 7 | for k,v in vars(args).items(): 8 | if v is not None: 9 | try: 10 | exec(f"config.{k} = v") 11 | except: 12 | raise RuntimeError(f'Unable to assign value to config.{k}') 13 | return config 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser() 18 | # Dataset 19 | parser.add_argument('--dataset', type=str, default=None, choices=['cityscapes'], 20 | help='choose which dataset you want to use') 21 | parser.add_argument('--dataroot', type=str, default=None, 22 | help='path to your dataset') 23 | parser.add_argument('--upscale', type=int, default=None, 24 | help='scale factor for super resolution') 25 | parser.add_argument('--train_y', action='store_false', default=None, 26 | help='whether to train Y channel in YCbCr space or not (default: True)') 27 | 28 | # Model 29 | parser.add_argument('--model', type=str, default=None, 30 | choices=['carn', 'drcn', 'drrn', 'edsr', 'espcn', 'fsrcnn', 31 | 'idn', 'lapsrn', 'srcnn', 'srdensenet', 'vdsr'], 32 | help='choose which model you want to use') 33 | parser.add_argument('--in_channels', type=int, default=None, 34 | help='number of input channel for given model (default: 1 for Y input else 3 for RGB input)') 35 | parser.add_argument('--out_channels', type=int, default=None, 36 | help='number of output channel for given model (default: 1 for Y input else 3 for RGB input)') 37 | 38 | # Training 39 | parser.add_argument('--total_epoch', type=int, default=None, 40 | help='number of total training epochs') 41 | parser.add_argument('--base_lr', type=float, default=None, 42 | help='base learning rate for single GPU, total learning rate *= gpu number') 43 | parser.add_argument('--train_bs', type=int, default=None, 44 | help='training batch size for single GPU, total batch size *= gpu number') 45 | parser.add_argument('--early_stop_epoch', type=int, default=None, 46 | help='epoch number to stop training if validation score does not increase') 47 | parser.add_argument('--max_itrs_per_epoch', type=int, default=None, 48 | help='increase the number of training iterations (suppose there are few training samples)') 49 | 50 | # Validating 51 | parser.add_argument('--val_bs', type=int, default=None, 52 | help='validating batch size for single GPU, total batch size *= gpu number') 53 | parser.add_argument('--begin_val_epoch', type=int, default=None, 54 | help='which epoch to start validating') 55 | parser.add_argument('--val_interval', type=int, default=None, 56 | help='epoch interval between two validations') 57 | parser.add_argument('--metrics', type=str, default=None, choices = ['psnr', 'ssim'], 58 | help='choose which validation metric you want to use (default: psnr)') 59 | 60 | # Testing 61 | parser.add_argument('--is_testing', action='store_true', default=None, 62 | help='whether to perform testing/predicting or not (default: False)') 63 | parser.add_argument('--test_bs', type=int, default=None, 64 | help='testing batch size (currently only support single GPU)') 65 | parser.add_argument('--test_data_folder', type=str, default=None, 66 | help='path to your testing image folder') 67 | parser.add_argument('--test_lr', action='store_false', default=None, 68 | help='whether to test the downscaled/low-resolution image or not (default: True)') 69 | 70 | # Benchmark 71 | parser.add_argument('--benchmark', action='store_true', default=None, 72 | help='whether to perform benchmarking for given datasets or not (default: False)') 73 | parser.add_argument('--benchmark_datasets', type=list, default=None, 74 | help='select which datasets to benchmark (default: [set5, set14, bsd100])') 75 | 76 | # Loss 77 | parser.add_argument('--loss_type', type=str, default=None, choices = ['mse', 'mae', 'charbonnier'], 78 | help='choose which loss you want to use') 79 | 80 | # Scheduler 81 | parser.add_argument('--lr_policy', type=str, default=None, 82 | choices = ['constant', 'step', 'linear', 'cos_warmup'], 83 | help='choose which learning rate policy you want to use (default: constant)') 84 | parser.add_argument('--warmup_epochs', type=int, default=None, 85 | help='warmup epoch number for `cos_warmup` learning rate policy') 86 | parser.add_argument('--step_size', type=int, default=None, 87 | help='step size for `step` learning rate policy') 88 | parser.add_argument('--step_gamma', type=float, default=None, 89 | help='lr reduction factor for `step` learning rate policy (default: 0.1)') 90 | 91 | # Optimizer 92 | parser.add_argument('--optimizer_type', type=str, default=None, 93 | choices = ['sgd', 'adam', 'adamw'], 94 | help='choose which optimizer you want to use (default: adam)') 95 | parser.add_argument('--momentum', type=float, default=None, 96 | help='momentum of SGD optimizer') 97 | parser.add_argument('--weight_decay', type=float, default=None, 98 | help='weight decay rate of SGD optimizer') 99 | 100 | # Monitoring 101 | parser.add_argument('--save_ckpt', action='store_false', default=None, 102 | help='whether to save checkpoint or not (default: True)') 103 | parser.add_argument('--save_dir', type=str, default=None, 104 | help='path to save checkpoints and training configurations etc.') 105 | parser.add_argument('--use_tb', action='store_false', default=None, 106 | help='whether to use tensorboard or not (default: True)') 107 | parser.add_argument('--tb_log_dir', type=str, default=None, 108 | help='path to save tensorboard logs') 109 | parser.add_argument('--ckpt_name', type=str, default=None, 110 | help='given name of the saved checkpoint, otherwise use `last` and `best`') 111 | 112 | # Training setting 113 | parser.add_argument('--amp_training', action='store_true', default=None, 114 | help='whether to use automatic mixed precision training or not (default: False)') 115 | parser.add_argument('--resume_training', action='store_false', default=None, 116 | help='whether to load training state from specific checkpoint or not if present (default: True)') 117 | parser.add_argument('--load_ckpt', action='store_false', default=None, 118 | help='whether to load given checkpoint or not if exist (default: True)') 119 | parser.add_argument('--load_ckpt_path', type=str, default=None, 120 | help='path to load specific checkpoint, otherwise try to load `last.pth`') 121 | parser.add_argument('--base_workers', type=int, default=None, 122 | help='number of workers for single GPU, total workers *= number of GPU') 123 | parser.add_argument('--random_seed', type=int, default=None, 124 | help='random seed') 125 | parser.add_argument('--use_ema', action='store_false', default=None, 126 | help='whether to use exponetial moving average to update weights or not (default: True)') 127 | parser.add_argument('--ema_decay', type=float, default=None, 128 | help='constant decay factor for EMA update, if not given use linear decay instead') 129 | parser.add_argument('--ema_start_epoch', type=int, default=None, 130 | help='epoch number to start EMA update') 131 | 132 | # Augmentation 133 | parser.add_argument('--patch_size', type=int, default=None, 134 | help='crop size for single training patch') 135 | parser.add_argument('--rotate', type=float, default=None, 136 | help='probability to perform rotation') 137 | parser.add_argument('--multi_scale', action='store_false', default=None, 138 | help='whether to perform multi-scale training or not (default: False)') 139 | parser.add_argument('--hflip', type=float, default=None, 140 | help='probability to perform horizontal flip') 141 | parser.add_argument('--vflip', type=float, default=None, 142 | help='probability to perform vertical flip') 143 | 144 | # DDP 145 | parser.add_argument('--synBN', action='store_true', default=None, 146 | help='whether to use SyncBatchNorm or not if trained with DDP (default: False)') 147 | parser.add_argument('--local_rank', type=int, default=None, 148 | help='used for DDP, DO NOT CHANGE') 149 | 150 | args = parser.parse_args() 151 | return args 152 | -------------------------------------------------------------------------------- /core/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.cuda import amp 4 | from copy import deepcopy 5 | from .loss import get_loss_fn 6 | from models import get_model 7 | from datasets import get_loader, get_test_loader, get_val_loader 8 | from utils import (get_optimizer, get_scheduler, parallel_model, de_parallel, 9 | get_ema_model, set_seed, set_device, get_writer, get_logger, 10 | destroy_ddp_process, mkdir, save_config, log_config,) 11 | 12 | 13 | class BaseTrainer: 14 | def __init__(self, config): 15 | super(BaseTrainer, self).__init__() 16 | # DDP parameters, DO NOT CHANGE 17 | self.rank = int(os.getenv('RANK', -1)) 18 | self.local_rank = int(os.getenv('LOCAL_RANK', -1)) 19 | self.world_size = int(os.getenv('WORLD_SIZE', 1)) 20 | config.DDP = self.local_rank != -1 21 | self.main_rank = self.local_rank in [-1, 0] 22 | 23 | # Logger compatible with ddp training 24 | self.logger = get_logger(config, self.main_rank) 25 | 26 | # Select device to train the model 27 | self.device = set_device(config, self.local_rank) 28 | 29 | # Automatic mixed precision training scaler 30 | self.scaler = amp.GradScaler(enabled=config.amp_training) 31 | 32 | # Create directory to save checkpoints and logs 33 | mkdir(config.save_dir) 34 | 35 | # Set random seed to obtain reproducible results 36 | set_seed(config.random_seed) 37 | 38 | # Define model and put it to the selected device 39 | self.model = get_model(config).to(self.device) 40 | 41 | if config.is_testing: 42 | self.test_loader = get_test_loader(config) 43 | elif config.benchmark: 44 | self.val_loaders = get_val_loader(config) 45 | else: 46 | # Tensorboard monitor 47 | self.writer = get_writer(config, self.main_rank) 48 | 49 | # Define loss function 50 | self.loss_fn = get_loss_fn(config, self.device) 51 | 52 | # Get train and validate loader 53 | self.train_loader, self.val_loader = get_loader(config, self.local_rank) 54 | 55 | # Define optimizer 56 | self.optimizer = get_optimizer(config, self.model) 57 | 58 | # Define scheduler to control how learning rate changes 59 | self.scheduler = get_scheduler(config, self.optimizer) 60 | 61 | # Define variables to monitor training process 62 | self.best_score = 0. 63 | self.cur_epoch = 0 64 | self.train_itrs = 0 65 | self.best_epoch = 0 66 | 67 | # Load specific checkpoints if needed 68 | self.load_ckpt(config) 69 | 70 | # Use exponential moving average of checkpoint update if needed 71 | if not config.is_testing and not config.benchmark: 72 | self.ema_model = get_ema_model(config, self.model, self.device) 73 | 74 | def run(self, config): 75 | # Parallel the model using DP or DDP 76 | self.parallel_model(config) 77 | 78 | # Output the training/validating configs (only in rank 0 if DDP) 79 | if self.main_rank: 80 | save_config(config) 81 | log_config(config, self.logger) 82 | 83 | # Start training from the latest epoch or from scratch 84 | start_epoch = self.cur_epoch 85 | for cur_epoch in range(start_epoch, config.total_epoch): 86 | self.cur_epoch = cur_epoch 87 | 88 | self.train_one_epoch(config) 89 | 90 | if cur_epoch >= config.begin_val_epoch and cur_epoch % config.val_interval == 0: 91 | val_score = self.validate(config) 92 | 93 | if self.main_rank and val_score > self.best_score: 94 | # Save best model 95 | self.best_score = val_score 96 | if config.save_ckpt: 97 | self.save_ckpt(config, save_best=True) 98 | self.best_epoch = cur_epoch 99 | 100 | if self.main_rank and config.save_ckpt: 101 | # Save last model 102 | self.save_ckpt(config) 103 | 104 | # Use early stopping if needed 105 | if cur_epoch - self.best_epoch > config.early_stop_epoch: 106 | break 107 | 108 | # Close tensorboard after training 109 | if config.use_tb and self.main_rank: 110 | self.writer.flush() 111 | self.writer.close() 112 | 113 | # Validate for the best model 114 | if config.save_ckpt: 115 | self.val_best(config) 116 | 117 | destroy_ddp_process(config) 118 | 119 | def parallel_model(self, config): 120 | self.model = parallel_model(config, self.model, self.local_rank, self.device) 121 | 122 | def train_one_epoch(self, config): 123 | '''You may implement whatever training process you like here. 124 | ''' 125 | raise NotImplementedError() 126 | 127 | def validate(self, config): 128 | raise NotImplementedError() 129 | 130 | def benchmark(self, config): 131 | raise NotImplementedError() 132 | 133 | def predict(self, config): 134 | raise NotImplementedError() 135 | 136 | def load_ckpt(self, config): 137 | if config.load_ckpt and os.path.isfile(config.load_ckpt_path): 138 | checkpoint = torch.load(config.load_ckpt_path, map_location=torch.device(self.device)) 139 | self.model.load_state_dict(checkpoint['state_dict']) 140 | if self.main_rank: 141 | self.logger.info(f"Load model state dict from {config.load_ckpt_path}") 142 | 143 | if not config.is_testing and not config.benchmark and config.resume_training: 144 | self.cur_epoch = checkpoint['cur_epoch'] + 1 145 | self.best_score = checkpoint['best_score'] 146 | self.best_epoch = checkpoint['best_epoch'] 147 | self.optimizer.load_state_dict(checkpoint['optimizer']) 148 | self.scheduler.load_state_dict(checkpoint['scheduler']) 149 | self.train_itrs = self.cur_epoch * config.iters_per_epoch 150 | 151 | if self.main_rank: 152 | self.logger.info(f"Resume training from {config.load_ckpt_path}") 153 | 154 | del checkpoint 155 | else: 156 | if config.is_testing: 157 | raise ValueError(f'Could not find any pretrained checkpoint at path: {config.load_ckpt_path}.') 158 | else: 159 | if self.main_rank: 160 | self.logger.info('[!] Train from scratch') 161 | 162 | def save_ckpt(self, config, save_best=False): 163 | if config.ckpt_name is None: 164 | save_name = 'best.pth' if save_best else 'last.pth' 165 | save_path = f'{config.save_dir}/{save_name}' 166 | state_dict = self.ema_model.ema.state_dict() if save_best else de_parallel(self.model).state_dict() 167 | 168 | torch.save({ 169 | 'cur_epoch': self.cur_epoch, 170 | 'best_score': self.best_score, 171 | 'best_epoch': self.best_epoch, 172 | 'state_dict': state_dict, 173 | 'optimizer': self.optimizer.state_dict() if not save_best else None, 174 | 'scheduler': self.scheduler.state_dict() if not save_best else None, 175 | }, save_path) 176 | 177 | def val_best(self, config, ckpt_path=None): 178 | ckpt_path = f"{config.save_dir}/best.pth" if ckpt_path is None else ckpt_path 179 | if not os.path.isfile(ckpt_path): 180 | raise ValueError(f'Best checkpoint does not exist at {ckpt_path}') 181 | 182 | if self.main_rank: 183 | self.logger.info(f"\nTrain {self.cur_epoch+1}/{config.total_epoch} epochs finished!\n") 184 | self.logger.info(f'{"#"*50}\nValidation for the best checkpoint...') 185 | 186 | self.model = de_parallel(self.model) 187 | checkpoint = torch.load(ckpt_path, map_location=torch.device(self.device)) 188 | self.model.load_state_dict(checkpoint['state_dict']) 189 | 190 | self.model.to(self.device) 191 | del checkpoint 192 | 193 | self.ema_model.ema = deepcopy(de_parallel(self.model)).eval() 194 | 195 | val_score = self.validate(config, val_best=True) 196 | 197 | if self.main_rank: 198 | self.logger.info(f'Best validation score is {val_score}.\n') 199 | --------------------------------------------------------------------------------