├── configs
├── __init__.py
├── my_config.py
├── base_config.py
└── parser.py
├── core
├── __init__.py
├── loss.py
├── sr_trainer.py
└── base_trainer.py
├── utils
├── __init__.py
├── optimizer.py
├── metrics.py
├── scheduler.py
├── parallel.py
├── model_ema.py
└── utils.py
├── main.py
├── datasets
├── sr_dataset.py
├── val_datasets.py
├── test_dataset.py
├── __init__.py
└── sr_base_dataset.py
├── models
├── __init__.py
├── espcn.py
├── srcnn.py
├── vdsr.py
├── fsrcnn.py
├── edsr.py
├── drcn.py
├── drrn.py
├── idn.py
├── lapsrn.py
├── srdensenet.py
├── carn.py
└── modules.py
├── tools
└── get_model_infos.py
└── README.md
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | from .my_config import MyConfig
2 | from .parser import load_parser
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_trainer import BaseTrainer
2 | from .sr_trainer import SRTrainer
3 | from .loss import get_loss_fn
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .parallel import *
3 | from .optimizer import get_optimizer
4 | from .scheduler import get_scheduler
5 | from .metrics import get_sr_metrics
6 | from .model_ema import get_ema_model
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from core import SRTrainer
2 | from configs import MyConfig, load_parser
3 |
4 | import warnings
5 | warnings.filterwarnings("ignore")
6 |
7 |
8 | if __name__ == '__main__':
9 | config = MyConfig()
10 |
11 | config.init_dependent_config()
12 |
13 | # If you want to use command-line arguments, please uncomment the following line
14 | # config = load_parser(config)
15 |
16 | trainer = SRTrainer(config)
17 |
18 | if config.is_testing:
19 | trainer.predict(config)
20 | elif config.benchmark:
21 | trainer.benchmark(config)
22 | else:
23 | trainer.run(config)
24 |
--------------------------------------------------------------------------------
/datasets/sr_dataset.py:
--------------------------------------------------------------------------------
1 | from .sr_base_dataset import SRBaseDataset
2 |
3 |
4 | class SRDataset(SRBaseDataset):
5 | def __init__(self, config, mode):
6 | data_split = {
7 | 'train': ['train/BSDS200',
8 | 'train/General100',
9 | 'train/T91',],
10 |
11 | 'val': [f'val/BSD100/image_SRF_{config.upscale}',
12 | f'val/Set5/image_SRF_{config.upscale}',
13 | f'val/Set14/image_SRF_{config.upscale}',],
14 | }
15 |
16 | super(SRDataset, self).__init__(config, data_split, mode)
17 |
--------------------------------------------------------------------------------
/configs/my_config.py:
--------------------------------------------------------------------------------
1 | from .base_config import BaseConfig
2 |
3 |
4 | class MyConfig(BaseConfig):
5 | def __init__(self,):
6 | super(MyConfig, self).__init__()
7 | # Dataset
8 | self.dataset = 'sr'
9 | self.data_root = '/path/to/your/dataset'
10 | self.upscale = 2
11 | self.train_y = True
12 |
13 | # Model
14 | self.model = 'srcnn'
15 |
16 | # Training
17 | self.total_epoch = 6400
18 | self.lr_policy = 'constant' # or step
19 | self.logger_name = 'sr_trainer'
20 |
21 | # Augmentation
22 | self.patch_size = 48
23 | self.rotate = 0.5
24 | self.multi_scale = True
25 |
--------------------------------------------------------------------------------
/core/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CharbonnierLoss(nn.Module):
6 | def __init__(self, eps=0.01):
7 | super(CharbonnierLoss, self).__init__()
8 | self.eps = eps
9 |
10 | def forward(self, pred, label):
11 | loss = torch.sqrt((pred - label)**2 + self.eps).mean()
12 | return loss.mean()
13 |
14 |
15 | def get_loss_fn(config, device):
16 | if config.loss_type == 'mae':
17 | criterion = nn.L1Loss()
18 |
19 | elif config.loss_type == 'mse':
20 | criterion = nn.MSELoss()
21 |
22 | elif config.loss_type == 'charbonnier':
23 | criterion = CharbonnierLoss()
24 |
25 | else:
26 | raise NotImplementedError(f"Unsupport loss type: {config.loss_type}")
27 |
28 | return criterion
29 |
--------------------------------------------------------------------------------
/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | from torch.optim import SGD, Adam, AdamW
2 |
3 |
4 | def get_optimizer(config, model):
5 | optimizer_hub = {'sgd':SGD, 'adam':Adam, 'adamw':AdamW}
6 | config.lr = config.base_lr * config.gpu_num
7 | params = model.parameters()
8 |
9 | if config.optimizer_type == 'sgd':
10 | optimizer = optimizer_hub[config.optimizer_type](params=params, lr=config.lr,
11 | momentum=config.momentum,
12 | weight_decay=config.weight_decay)
13 |
14 | elif config.optimizer_type in ['adam', 'adamw']:
15 | optimizer = optimizer_hub[config.optimizer_type](params=params, lr=config.lr)
16 |
17 | else:
18 | raise NotImplementedError(f'Unsupported optimizer type: {config.optimizer_type}')
19 |
20 | return optimizer
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
2 |
3 |
4 | def get_sr_metrics(metrics_type):
5 | if metrics_type == 'psnr':
6 | metrics = PeakSignalNoiseRatio(data_range=1.0, base=10.0,
7 | reduction='elementwise_mean', dim=(2,3))
8 | elif metrics_type == 'ssim':
9 | metrics = StructuralSimilarityIndexMeasure(gaussian_kernel=True, sigma=1.5,
10 | kernel_size=11,
11 | reduction='elementwise_mean',
12 | data_range=1.0, k1=0.01, k2=0.03,
13 | return_full_image=False,
14 | return_contrast_sensitivity=False)
15 | return metrics
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import os, torch
2 | from .carn import CARN
3 | from .drcn import DRCN
4 | from .drrn import DRRN
5 | from .edsr import EDSR
6 | from .espcn import ESPCN
7 | from .fsrcnn import FSRCNN
8 | from .idn import IDN
9 | from .lapsrn import LapSRN
10 | from .srcnn import SRCNN
11 | from .srdensenet import SRDenseNet
12 | from .vdsr import VDSR
13 |
14 |
15 | model_hub = {'carn':CARN, 'drcn':DRCN, 'drrn':DRRN, 'edsr':EDSR, 'espcn':ESPCN, 'fsrcnn':FSRCNN,
16 | 'idn':IDN, 'lapsrn':LapSRN, 'srcnn':SRCNN, 'srdensenet':SRDenseNet, 'vdsr':VDSR,}
17 |
18 |
19 | def get_model(config):
20 | if config.model in model_hub.keys():
21 | model = model_hub[config.model](in_channels=config.in_channels,
22 | out_channels=config.out_channels,
23 | upscale=config.upscale)
24 | else:
25 | raise NotImplementedError(f"Unsupport model type: {config.model}")
26 |
27 | return model
28 |
--------------------------------------------------------------------------------
/models/espcn.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Real-Time Single Image and Video Super-Resolution Using an Efficient
3 | Sub-Pixel Convolutional Neural Network
4 | Url: https://arxiv.org/abs/1609.05158
5 | Create by: zh320
6 | Date: 2023/12/09
7 | """
8 |
9 | import torch.nn as nn
10 |
11 | from .modules import ConvAct, Activation, Upsample
12 |
13 |
14 | class ESPCN(nn.Module):
15 | def __init__(self, in_channels, out_channels, upscale, n1=64, n2=32, act_type='tanh',
16 | upsample_type='pixelshuffle'):
17 | super(ESPCN, self).__init__()
18 | self.layer1 = ConvAct(in_channels, n1, 5, act_type=act_type)
19 | self.layer2 = ConvAct(n1, n2, 3, act_type=act_type)
20 | self.upsample = Upsample(n2, out_channels, upscale, upsample_type, 3)
21 | self.act3 = Activation(act_type)
22 |
23 | def forward(self, x):
24 | x = self.layer1(x)
25 | x = self.layer2(x)
26 | x = self.act3(self.upsample(x))
27 | return x
28 |
--------------------------------------------------------------------------------
/datasets/val_datasets.py:
--------------------------------------------------------------------------------
1 | from .sr_base_dataset import SRBaseDataset
2 |
3 |
4 | class Set5(SRBaseDataset):
5 | def __init__(self, config, mode):
6 | data_split = {
7 | 'val': [
8 | f'val/Set5/image_SRF_{config.upscale}',
9 | ],
10 | }
11 |
12 | super(Set5, self).__init__(config, data_split, mode)
13 |
14 |
15 | class Set14(SRBaseDataset):
16 | def __init__(self, config, mode):
17 | data_split = {
18 | 'val': [
19 | f'val/Set14/image_SRF_{config.upscale}',
20 | ],
21 | }
22 |
23 | super(Set14, self).__init__(config, data_split, mode)
24 |
25 |
26 | class BSD100(SRBaseDataset):
27 | def __init__(self, config, mode):
28 | data_split = {
29 | 'val': [
30 | f'val/BSD100/image_SRF_{config.upscale}',
31 | ],
32 | }
33 |
34 | super(BSD100, self).__init__(config, data_split, mode)
--------------------------------------------------------------------------------
/models/srcnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Image Super-Resolution Using Deep Convolutional Networks
3 | Url: https://arxiv.org/abs/1501.00092
4 | Create by: zh320
5 | Date: 2023/12/09
6 | """
7 |
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from .modules import conv5x5, ConvAct
12 |
13 |
14 | class SRCNN(nn.Module):
15 | def __init__(self, in_channels, out_channels, upscale, kernel_setting='935',
16 | act_type='relu'):
17 | super(SRCNN, self).__init__()
18 | if kernel_setting not in ['915', '935', '955']:
19 | raise ValueError(f'Unknown kernel setting: {kernel_setting}. You can choose \
20 | from ["915", "935", "955"].\n')
21 | kernel_map = {'915':1, '935':3, '955':5}
22 |
23 | self.upscale = upscale
24 | self.layer1 = ConvAct(in_channels, 64, 9, act_type=act_type)
25 | self.layer2 = ConvAct(64, 32, kernel_map[kernel_setting], act_type=act_type)
26 | self.layer3 = conv5x5(32, out_channels)
27 |
28 | def forward(self, x):
29 | x = F.interpolate(x, scale_factor=self.upscale, mode='bicubic')
30 | x = self.layer1(x)
31 | x = self.layer2(x)
32 | x = self.layer3(x)
33 |
34 | return x
35 |
--------------------------------------------------------------------------------
/models/vdsr.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Accurate Image Super-Resolution Using Very Deep Convolutional Networks
3 | Url: https://arxiv.org/abs/1511.04587
4 | Create by: zh320
5 | Date: 2023/12/16
6 | """
7 |
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from .modules import conv3x3, ConvAct
12 |
13 |
14 | class VDSR(nn.Module):
15 | def __init__(self, in_channels, out_channels, upscale, layer_num=20, hid_channels=64,
16 | act_type='relu'):
17 | super(VDSR, self).__init__()
18 | self.upscale = upscale
19 | self.first_layer = conv3x3(in_channels, hid_channels)
20 | layers = [ConvAct(hid_channels, hid_channels, 3, inplace=True) for i in range(layer_num-2)]
21 | self.mid_layer = nn.Sequential(*layers)
22 | self.last_layer = conv3x3(hid_channels, out_channels)
23 |
24 | for m in self.modules():
25 | if isinstance(m, nn.Conv2d):
26 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
27 |
28 | def forward(self, x):
29 | x = F.interpolate(x, scale_factor=self.upscale, mode='bicubic')
30 | res = self.first_layer(x)
31 | res = self.mid_layer(res)
32 | res = self.last_layer(res)
33 | res += x
34 |
35 | return res
36 |
--------------------------------------------------------------------------------
/tools/get_model_infos.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from os import path
3 | sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) )
4 |
5 | from configs import MyConfig, load_parser
6 | from models import get_model
7 |
8 |
9 | def cal_model_params(config, imgw=200, imgh=200):
10 | model = get_model(config)
11 | print(f'\nModel: {config.model}')
12 |
13 | try:
14 | from ptflops import get_model_complexity_info
15 | model.eval()
16 | '''
17 | Notice that ptflops doesn't take into account torch.nn.functional.* operations.
18 | If you want to get correct macs result, you need to modify the modules like
19 | torch.nn.functional.interpolate to torch.nn.Upsample.
20 | '''
21 | _, params = get_model_complexity_info(model, (config.in_channels, imgh, imgw), as_strings=True,
22 | print_per_layer_stat=False, verbose=False)
23 | print(f'Number of parameters: {params}\n')
24 | except:
25 | import numpy as np
26 | params = np.sum([p.numel() for p in model.parameters()])
27 | print(f'Number of parameters: {params / 1e3:.2f}K\n')
28 |
29 |
30 | if __name__ == '__main__':
31 | config = MyConfig()
32 | config = load_parser(config)
33 |
34 | cal_model_params(config)
--------------------------------------------------------------------------------
/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import OneCycleLR, StepLR, LambdaLR
2 | from math import ceil
3 |
4 |
5 | def get_scheduler(config, optimizer):
6 | if config.DDP:
7 | config.iters_per_epoch = ceil(config.train_num/config.train_bs/config.gpu_num)
8 | else:
9 | config.iters_per_epoch = ceil(config.train_num/config.train_bs)
10 | config.total_itrs = int(config.total_epoch*config.iters_per_epoch)
11 |
12 | if config.lr_policy == 'cos_warmup':
13 | warmup_ratio = config.warmup_epochs / config.total_epoch
14 | scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.total_itrs,
15 | pct_start=warmup_ratio)
16 |
17 | elif config.lr_policy == 'linear':
18 | scheduler = OneCycleLR(optimizer, max_lr=config.lr, total_steps=config.total_itrs,
19 | pct_start=0., anneal_strategy='linear')
20 |
21 | elif config.lr_policy == 'step':
22 | scheduler = StepLR(optimizer, step_size=config.step_size, gamma=config.step_gamma)
23 |
24 | elif config.lr_policy == 'constant':
25 | scheduler = LambdaLR(optimizer, lr_lambda=[lambda epoch: (epoch+1)/(epoch+1)])
26 |
27 | else:
28 | raise NotImplementedError(f'Unsupported scheduler type: {config.lr_policy}')
29 | return scheduler
--------------------------------------------------------------------------------
/models/fsrcnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Accelerating the Super-Resolution Convolutional Neural Network
3 | Url: https://arxiv.org/abs/1608.00367
4 | Create by: zh320
5 | Date: 2023/12/09
6 | """
7 |
8 | import torch.nn as nn
9 |
10 | from .modules import ConvAct, Upsample
11 |
12 |
13 | class FSRCNN(nn.Module):
14 | def __init__(self, in_channels, out_channels, upscale, d=56, s=12, act_type='prelu',
15 | upsample_type='deconvolution'):
16 | super(FSRCNN, self).__init__()
17 | self.first_part = ConvAct(in_channels, d, 5, act_type=act_type, num_parameters=d)
18 | self.mid_part = nn.Sequential(
19 | ConvAct(d, s, 1, act_type=act_type, num_parameters=s),
20 | ConvAct(s, s, 3, act_type=act_type, num_parameters=s),
21 | ConvAct(s, s, 3, act_type=act_type, num_parameters=s),
22 | ConvAct(s, s, 3, act_type=act_type, num_parameters=s),
23 | ConvAct(s, s, 3, act_type=act_type, num_parameters=s),
24 | ConvAct(s, d, 1, act_type=act_type, num_parameters=d)
25 | )
26 | self.last_part = Upsample(d, out_channels, upscale, upsample_type, 9)
27 |
28 | def forward(self, x):
29 | x = self.first_part(x)
30 | x = self.mid_part(x)
31 | x = self.last_part(x)
32 | return x
33 |
--------------------------------------------------------------------------------
/utils/parallel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.distributed as dist
4 | from torch.nn.parallel import DistributedDataParallel as DDP
5 |
6 |
7 | def is_parallel(model):
8 | # Returns True if model is of type DP or DDP
9 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
10 |
11 |
12 | def de_parallel(model):
13 | # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
14 | return model.module if is_parallel(model) else model
15 |
16 |
17 | def set_device(config, rank):
18 | if config.DDP:
19 | torch.cuda.set_device(rank)
20 | dist.init_process_group(backend=dist.Backend.NCCL, init_method='env://')
21 | device = torch.device('cuda', rank)
22 | config.gpu_num = dist.get_world_size()
23 | else: # DP
24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25 | config.gpu_num = torch.cuda.device_count()
26 | config.train_bs *= config.gpu_num
27 |
28 | # Setup num_workers
29 | config.num_workers = config.gpu_num * config.base_workers
30 |
31 | return device
32 |
33 |
34 | def parallel_model(config, model, rank, device):
35 | if config.DDP:
36 | if config.synBN:
37 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
38 | model = DDP(model.to(rank), device_ids=[rank], output_device=rank)
39 | else:
40 | model = nn.DataParallel(model)
41 | model.to(device)
42 |
43 | return model
44 |
45 |
46 | def destroy_ddp_process(config):
47 | if config.DDP:
48 | dist.destroy_process_group()
49 |
50 |
51 | def sampler_set_epoch(config, loader, cur_epochs):
52 | if config.DDP:
53 | loader.sampler.set_epoch(cur_epochs)
54 |
--------------------------------------------------------------------------------
/models/edsr.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution
3 | Url: https://arxiv.org/abs/1707.02921
4 | Create by: zh320
5 | Date: 2023/12/16
6 | """
7 |
8 | import torch.nn as nn
9 |
10 | from .modules import conv3x3, ConvAct, Upsample
11 |
12 |
13 | class EDSR(nn.Module):
14 | def __init__(self, in_channels, out_channels, upscale, B=16, F=64, scale_factor=None,
15 | act_type='relu', upsample_type='pixelshuffle'):
16 | super(EDSR, self).__init__()
17 | if scale_factor is None:
18 | scale_factor = 0.1 if B > 16 else 1.0
19 |
20 | self.first_layer = conv3x3(in_channels, F)
21 |
22 | layers = []
23 | for _ in range(B):
24 | layers.append(ResidualBlock(F, scale_factor, act_type))
25 | self.res_layers = nn.Sequential(*layers)
26 |
27 | self.mid_layer = conv3x3(F, F)
28 | self.last_layers = nn.Sequential(
29 | Upsample(F, F, upscale, upsample_type, 3),
30 | conv3x3(F, out_channels)
31 | )
32 |
33 | def forward(self, x):
34 | x = self.first_layer(x)
35 | residual = x
36 | x = self.res_layers(x)
37 | x = self.mid_layer(x)
38 | x += residual
39 | x = self.last_layers(x)
40 |
41 | return x
42 |
43 |
44 | class ResidualBlock(nn.Module):
45 | def __init__(self, channels, scale_factor, act_type):
46 | super(ResidualBlock, self).__init__()
47 | self.scale_factor = scale_factor
48 | self.conv = nn.Sequential(
49 | ConvAct(channels, channels, 3, act_type=act_type),
50 | conv3x3(channels, channels)
51 | )
52 |
53 | def forward(self, x):
54 | residual = x
55 | x = self.conv(x)
56 | if self.scale_factor < 1:
57 | x = x * self.scale_factor
58 | x += residual
59 |
60 | return x
61 |
--------------------------------------------------------------------------------
/utils/model_ema.py:
--------------------------------------------------------------------------------
1 | '''
2 | Codes are based on
3 | https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | from copy import deepcopy
9 | from .parallel import de_parallel
10 |
11 |
12 | def get_ema_model(config, model, device):
13 | return ModelEmaV2(config, model, device=device)
14 |
15 |
16 | class ModelEmaV2(nn.Module):
17 | def __init__(self, config, model, device=None):
18 | super(ModelEmaV2, self).__init__()
19 | # make a copy of the model for accumulating moving average of weights
20 | self.ema = deepcopy(de_parallel(model))
21 | self.ema.eval()
22 | self.device = device # perform ema on different device from model if set
23 | if self.device is not None:
24 | self.ema.to(device=device)
25 | self.use_ema = config.use_ema
26 | if config.ema_decay is not None:
27 | if config.ema_decay >= 1. or config.ema_decay <= 0.:
28 | raise ValueError('EMA decay rate out of range.\n')
29 | self.decay = config.ema_decay
30 | else:
31 | self.decay = None
32 | self.total_itrs = config.total_itrs
33 |
34 | @torch.no_grad()
35 | def _update(self, model, update_fn):
36 | for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
37 | if self.device is not None:
38 | model_v = model_v.to(device=self.device)
39 | ema_v.copy_(update_fn(ema_v, model_v))
40 |
41 | def update(self, model, cur_itrs):
42 | if self.use_ema:
43 | if self.decay is not None: # Constant decay
44 | decay = self.decay
45 | else: # Linear decay
46 | decay = min(max(cur_itrs / self.total_itrs, 0), 1)
47 | self._update(de_parallel(model), update_fn=lambda e, m: decay * e + (1. - decay) * m)
48 | else:
49 | self._update(de_parallel(model), update_fn=lambda e, m: m)
50 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os, random, torch, json
2 | import numpy as np
3 |
4 |
5 | def mkdir(path):
6 | if not os.path.exists(path):
7 | os.mkdir(path)
8 |
9 |
10 | def set_seed(seed):
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 |
16 |
17 | def get_writer(config, main_rank):
18 | if config.use_tb and main_rank:
19 | from torch.utils.tensorboard import SummaryWriter
20 | writer = SummaryWriter(config.tb_log_dir)
21 | else:
22 | writer = None
23 | return writer
24 |
25 |
26 | def get_logger(config, main_rank):
27 | if main_rank:
28 | import sys
29 | from loguru import logger
30 | logger.remove()
31 | logger.add(sys.stderr, format="[{time:YYYY-MM-DD HH:mm}] {message}", level="INFO")
32 |
33 | log_path = f'{config.save_dir}/{config.logger_name}.log'
34 | logger.add(log_path, format="[{time:YYYY-MM-DD HH:mm}] {message}", level="INFO")
35 | else:
36 | logger = None
37 | return logger
38 |
39 |
40 | def save_config(config):
41 | config_dict = vars(config)
42 | with open(f'{config.save_dir}/config.json', 'w') as f:
43 | json.dump(config_dict, f, indent=4)
44 |
45 |
46 | def log_config(config, logger):
47 | if config.benchmark:
48 | keys = ['benchmark_datasets', 'upscale', 'train_y', 'model', 'in_channels', 'out_channels',]
49 | else:
50 | keys = ['dataset', 'upscale', 'train_y', 'model', 'in_channels', 'out_channels', 'loss_type',
51 | 'optimizer_type', 'lr_policy', 'total_epoch', 'lr', 'train_bs', 'val_bs',
52 | 'train_num', 'val_num', 'gpu_num', 'num_workers', 'amp_training',
53 | 'DDP', 'use_ema', 'patch_size', 'multi_scale']
54 |
55 | config_dict = vars(config)
56 | infos = f"\n\n\n{'#'*25} Config Informations {'#'*25}\n"
57 | infos += '\n'.join('%s: %s' % (k, config_dict[k]) for k in keys)
58 | infos += f"\n{'#'*71}\n\n"
59 | logger.info(infos)
60 |
--------------------------------------------------------------------------------
/models/drcn.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Deeply-Recursive Convolutional Network for Image Super-Resolution
3 | Url: https://arxiv.org/abs/1511.04491
4 | Create by: zh320
5 | Date: 2023/12/23
6 | """
7 |
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from .modules import ConvAct
12 |
13 |
14 | class DRCN(nn.Module):
15 | def __init__(self, in_channels, out_channels, upscale, recursions=16,
16 | hid_channels=256, act_type='relu', arch_type='advanced'):
17 | super(DRCN, self).__init__()
18 | if arch_type not in ['basic', 'advanced']:
19 | raise ValueError(f'Unsupported model type: {arch_type}\n')
20 | self.upscale = upscale
21 | self.recursions = recursions
22 | self.arch_type = arch_type
23 |
24 | self.embedding_net = nn.Sequential(
25 | ConvAct(in_channels, hid_channels, 3, act_type=act_type),
26 | ConvAct(hid_channels, hid_channels, 3, act_type=act_type)
27 | )
28 | self.inference_net = ConvAct(hid_channels, hid_channels, 3, act_type=act_type)
29 | self.reconstruction_net = nn.Sequential(
30 | ConvAct(hid_channels, hid_channels, 3, act_type=act_type),
31 | ConvAct(hid_channels, out_channels, 3, act_type=act_type)
32 | )
33 |
34 | def forward(self, x):
35 | x = F.interpolate(x, scale_factor=self.upscale, mode='bicubic')
36 | if self.arch_type == 'advanced':
37 | skip = x
38 |
39 | x = self.embedding_net(x)
40 |
41 | for i in range(self.recursions):
42 | x = self.inference_net(x)
43 |
44 | if self.arch_type == 'advanced':
45 | if i == 0:
46 | res = self.reconstruction_net(x + skip)
47 | else:
48 | res += self.reconstruction_net(x + skip)
49 |
50 | if self.arch_type == 'basic':
51 | res = self.reconstruction_net(x)
52 |
53 | return res
54 |
--------------------------------------------------------------------------------
/models/drrn.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Image Super-Resolution via Deep Recursive Residual Network
3 | Url: https://openaccess.thecvf.com/content_cvpr_2017/html/Tai_Image_Super-Resolution_via_CVPR_2017_paper.html
4 | Create by: zh320
5 | Date: 2023/12/23
6 | """
7 |
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from .modules import conv3x3, ConvAct, ConvBNAct
12 |
13 |
14 | class DRRN(nn.Module):
15 | def __init__(self, in_channels, out_channels, upscale, B=1, U=9, hid_channels=128,
16 | act_type='relu', use_bn=False):
17 | super(DRRN, self).__init__()
18 | self.upscale = upscale
19 | self.B = B
20 | self.U = U
21 |
22 | ConvBlock = ConvBNAct if use_bn else ConvAct
23 | self.init_blocks = nn.ModuleList()
24 | self.recursive_blocks = nn.ModuleList()
25 | for i in range(B):
26 | in_ch = in_channels if i == B-1 else hid_channels
27 | self.init_blocks.append(ConvBlock(in_ch, hid_channels, 3, act_type=act_type))
28 | self.recursive_blocks.append(ResidualUnit(hid_channels, act_type, use_bn))
29 | self.last_layer = conv3x3(hid_channels, out_channels)
30 |
31 | def forward(self, x):
32 | x = F.interpolate(x, scale_factor=self.upscale, mode='bicubic')
33 |
34 | for i in range(self.B):
35 | res = self.init_blocks[i](x)
36 | residual = res
37 | for _ in range(self.U):
38 | res = self.recursive_blocks[i](res)
39 | res += residual
40 |
41 | res = self.last_layer(res)
42 | x = x + res
43 |
44 | return x
45 |
46 |
47 | class ResidualUnit(nn.Module):
48 | def __init__(self, channels, act_type, use_bn):
49 | super(ResidualUnit, self).__init__()
50 | ConvBlock = ConvBNAct if use_bn else ConvAct
51 | self.conv1 = ConvBlock(channels, channels, 3, act_type=act_type)
52 | self.conv2 = ConvBlock(channels, channels, 3, act_type=act_type)
53 |
54 | def forward(self, x):
55 | res = self.conv2(self.conv1(x))
56 | res += x
57 |
58 | return res
59 |
--------------------------------------------------------------------------------
/models/idn.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Fast and Accurate Single Image Super-Resolution via Information Distillation Network
3 | Url: https://arxiv.org/abs/1803.09454
4 | Create by: zh320
5 | Date: 2023/12/30
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from .modules import conv1x1, ConvAct, Upsample
13 |
14 |
15 | class IDN(nn.Module):
16 | def __init__(self, in_channels, out_channels, upscale, num_blocks=4, D3=64, s=4,
17 | act_type='leakyrelu', upsample_type='deconvolution'):
18 | super(IDN, self).__init__()
19 | assert s > 1, 's should be larger than 1, otherwise split_ratio will be out of range.\n'
20 | split_ratio = 1 / s
21 | d = int(split_ratio * D3)
22 | self.upscale = upscale
23 |
24 | self.fblock = nn.Sequential(
25 | ConvAct(in_channels, D3, 3, act_type=act_type),
26 | ConvAct(D3, D3, 3, act_type=act_type)
27 | )
28 |
29 | layers = []
30 | for i in range(num_blocks):
31 | layers.append(DBlock(D3, d, act_type))
32 | self.dblocks = nn.Sequential(*layers)
33 |
34 | self.rblock = Upsample(D3, out_channels, upscale, upsample_type, 17)
35 |
36 | def forward(self, x):
37 | x_up = F.interpolate(x, scale_factor=self.upscale, mode='bicubic')
38 |
39 | x = self.fblock(x)
40 | x = self.dblocks(x)
41 | x = self.rblock(x)
42 |
43 | x += x_up
44 |
45 | return x
46 |
47 |
48 | class DBlock(nn.Sequential):
49 | def __init__(self, D3, d, act_type):
50 | super(DBlock, self).__init__(
51 | EnhancementUnit(D3, d, act_type),
52 | conv1x1(D3 + d, D3)
53 | )
54 |
55 |
56 | class EnhancementUnit(nn.Module):
57 | def __init__(self, D3, d, act_type, groups=[1,4,1,4,1,1]):
58 | super(EnhancementUnit, self).__init__()
59 | assert len(groups) == 6, 'Length of groups should be 6.\n'
60 | self.d = d
61 |
62 | self.conv1 = nn.Sequential(
63 | ConvAct(D3, D3 - d, 3, groups=groups[0], act_type=act_type),
64 | ConvAct(D3 - d, D3 - 2*d, 3, groups=groups[1], act_type=act_type),
65 | ConvAct(D3 - 2*d, D3, 3, groups=groups[2], act_type=act_type),
66 | )
67 |
68 | self.conv2 = nn.Sequential(
69 | ConvAct(D3 - d, D3, 3, groups=groups[3], act_type=act_type),
70 | ConvAct(D3, D3 - d, 3, groups=groups[4], act_type=act_type),
71 | ConvAct(D3 - d, D3 + d, 3, groups=groups[5], act_type=act_type),
72 | )
73 |
74 | def forward(self, x):
75 | residual = x
76 | x = self.conv1(x)
77 | x_c = x[:, :self.d, :, :]
78 | x_c = torch.cat([x_c, residual], dim=1)
79 | x_s = x[:, self.d:, :, :]
80 | x_s = self.conv2(x_s)
81 |
82 | return x_s + x_c
83 |
--------------------------------------------------------------------------------
/configs/base_config.py:
--------------------------------------------------------------------------------
1 | class BaseConfig:
2 | def __init__(self,):
3 | # Dataset
4 | self.dataset = None
5 | self.dataroot = None
6 | self.upscale = 2
7 | self.train_y = True
8 |
9 | # Model
10 | self.model = None
11 | self.in_channels = None
12 | self.out_channels = None
13 |
14 | # Training
15 | self.total_epoch = 3200
16 | self.base_lr = 0.001
17 | self.train_bs = 16 # For each GPU
18 | self.early_stop_epoch = 1000
19 | self.max_itrs_per_epoch = 200
20 |
21 | # Validating
22 | self.val_bs = 1 # For each GPU
23 | self.begin_val_epoch = 0 # Epoch to start validation
24 | self.val_interval = 1 # Epoch interval between validation
25 | self.metrics = 'psnr'
26 |
27 | # Testing
28 | self.is_testing = False
29 | self.test_bs = 1
30 | self.test_data_folder = None
31 | self.test_lr = False # Test downscaled image
32 |
33 | # Benchmark
34 | self.benchmark = False
35 | self.benchmark_datasets = ['set5', 'set14', 'bsd100']
36 |
37 | # Loss
38 | self.loss_type = 'mse'
39 |
40 | # Scheduler
41 | self.lr_policy = 'constant'
42 | self.warmup_epochs = 3
43 | self.step_size = None
44 | self.step_gamma = 0.1
45 |
46 | # Optimizer
47 | self.optimizer_type = 'adam'
48 | self.momentum = 0.9 # For SGD
49 | self.weight_decay = 1e-4 # For SGD
50 |
51 | # Monitoring
52 | self.save_ckpt = True
53 | self.save_dir = 'save'
54 | self.use_tb = True # tensorboard
55 | self.tb_log_dir = None
56 | self.ckpt_name = None
57 |
58 | # Training setting
59 | self.amp_training = False
60 | self.resume_training = True
61 | self.load_ckpt = True
62 | self.load_ckpt_path = None
63 | self.base_workers = 8
64 | self.random_seed = 1
65 | self.use_ema = True
66 | self.ema_decay = 0.999
67 | self.ema_start_epoch = 0
68 |
69 | # Augmentation
70 | self.patch_size = None
71 | self.rotate = 0.0
72 | self.multi_scale = False
73 | self.hflip = 0.0
74 | self.vflip = 0.0
75 |
76 | # DDP
77 | self.synBN = False
78 |
79 | def init_dependent_config(self):
80 | if self.load_ckpt_path is None and not (self.is_testing or self.benchmark):
81 | self.load_ckpt_path = f'{self.save_dir}/last.pth'
82 |
83 | if (self.is_testing or self.benchmark) and (self.load_ckpt_path is None):
84 | self.load_ckpt_path = 'best.pth'
85 |
86 | if self.tb_log_dir is None:
87 | self.tb_log_dir = f'{self.save_dir}/tb_logs/'
88 |
89 | if isinstance(self.patch_size, int):
90 | self.patch_size = [self.patch_size, self.patch_size]
91 |
92 | if self.in_channels is None:
93 | self.in_channels = 1 if self.train_y else 3
94 |
95 | if self.out_channels is None:
96 | self.out_channels = 1 if self.train_y else 3
97 |
--------------------------------------------------------------------------------
/models/lapsrn.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution
3 | Url: https://arxiv.org/abs/1704.03915
4 | Create by: zh320
5 | Date: 2023/12/16
6 | """
7 |
8 | import torch.nn as nn
9 | from math import log2
10 |
11 | from .modules import ConvAct, Upsample
12 |
13 |
14 | class LapSRN(nn.Module):
15 | def __init__(self, in_channels, out_channels, upscale, hid_channels=64, fe_layer_num=8,
16 | act_type='leakyrelu', upsample_type='deconvolution'):
17 | super(LapSRN, self).__init__()
18 | assert fe_layer_num > 3, 'Layer number should be larger than 3.\n'
19 | if upscale in [2, 4, 8]:
20 | self.num_stage = int(log2(upscale))
21 | scale_factor = 2
22 | elif upscale == 3:
23 | self.num_stage = 1
24 | scale_factor = 3
25 | else:
26 | raise ValueError(f'Unsupported scale factor: {upscale}\n')
27 |
28 | self.fe_branch = FeatureExtraction(in_channels, out_channels, hid_channels, self.num_stage,
29 | fe_layer_num, scale_factor, upsample_type, act_type)
30 | self.ir_branch = ImageReconstruction(in_channels, out_channels, self.num_stage,
31 | scale_factor, upsample_type)
32 |
33 | def forward(self, x):
34 | feats = self.fe_branch(x)
35 | x = self.ir_branch(x, feats)
36 |
37 | return x
38 |
39 |
40 | class FeatureExtraction(nn.Module):
41 | def __init__(self, in_ch, out_ch, hid_ch, num_stage, layer_num, scale_factor,
42 | upsample_type, act_type):
43 | super(FeatureExtraction, self).__init__()
44 | self.num_stage = num_stage
45 | self.conv = nn.ModuleList()
46 | self.out = nn.ModuleList()
47 | for i in range(num_stage):
48 | init_ch = in_ch if i==0 else hid_ch
49 | layers = [ConvAct(init_ch, hid_ch, 3, act_type=act_type)]
50 | for _ in range(layer_num - 3):
51 | layers.append(ConvAct(hid_ch, hid_ch, 3, act_type=act_type))
52 | layers.append(Upsample(hid_ch, hid_ch, scale_factor, upsample_type))
53 |
54 | self.conv.append(nn.Sequential(*layers))
55 |
56 | self.out.append(ConvAct(hid_ch, out_ch, 3, act_type=act_type))
57 |
58 | def forward(self, x):
59 | feats = []
60 | for i in range(self.num_stage):
61 | x = self.conv[i](x)
62 | feat = self.out[i](x)
63 | feats.append(feat)
64 |
65 | return feats
66 |
67 |
68 | class ImageReconstruction(nn.Module):
69 | def __init__(self, in_ch, out_ch, num_stage, scale_factor, upsample_type):
70 | super(ImageReconstruction, self).__init__()
71 | self.num_stage = num_stage
72 | self.up = nn.ModuleList()
73 | for i in range(num_stage):
74 | init_ch = in_ch if i == 0 else out_ch
75 | self.up.append(Upsample(init_ch, out_ch, scale_factor, upsample_type))
76 |
77 | def forward(self, img, feats):
78 | for i in range(self.num_stage):
79 | img = self.up[i](img)
80 | img += feats[i]
81 |
82 | return img
83 |
--------------------------------------------------------------------------------
/datasets/test_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 |
6 | from .sr_base_dataset import SRBaseDataset
7 |
8 |
9 | class TestDataset(Dataset):
10 | def __init__(self, config):
11 | data_folder = os.path.expanduser(config.test_data_folder)
12 | self.train_y = config.train_y
13 | self.scale = config.upscale
14 | self.test_lr = config.test_lr
15 |
16 | if not os.path.isdir(data_folder):
17 | raise RuntimeError(f'Test image directory: {data_folder} does not exist.')
18 |
19 | self.hr_images = []
20 | self.img_names = []
21 |
22 | for file_name in os.listdir(data_folder):
23 | self.hr_images.append(os.path.join(data_folder, file_name))
24 | self.img_names.append(file_name)
25 |
26 | def __len__(self):
27 | return len(self.hr_images)
28 |
29 | def __getitem__(self, index):
30 | hr = Image.open(self.hr_images[index]).convert('RGB')
31 | img_name = self.img_names[index]
32 |
33 | if self.test_lr:
34 | # Resize image to make it compatible for upscale factor if needed
35 | hr_width = (hr.width // self.scale) * self.scale
36 | hr_height = (hr.height // self.scale) * self.scale
37 | if hr_width != hr.width or hr_height != hr.height:
38 | hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC)
39 |
40 | # Generate low resolution image using bicubic interpolation of HR image
41 | lr_width = hr_width // self.scale
42 | lr_height = hr_height // self.scale
43 | lr = hr.resize((lr_width, lr_height), resample=Image.BICUBIC)
44 |
45 | # Need interpolated CbCr channels to recover hr images if train with y channel
46 | bicubic = lr.resize((lr.width * self.scale, lr.height * self.scale), resample=Image.BICUBIC)
47 | else: # test hr
48 | bicubic = hr.resize((hr.width * self.scale, hr.height * self.scale), resample=Image.BICUBIC)
49 |
50 | hr = np.array(hr).astype(np.float32)
51 | bicubic = np.array(bicubic).astype(np.float32)
52 |
53 | if self.test_lr:
54 | lr = np.array(lr).astype(np.float32)
55 | if self.train_y:
56 | # RGB to YCbCr (get interpolated CbCr channels here)
57 | ycbcr = SRBaseDataset.rgb_to_ycbcr(bicubic)
58 |
59 | if self.train_y:
60 | # RGB to YCbCr (only need Y channel here)
61 | hr = SRBaseDataset.rgb_to_ycbcr(hr, y_only=True)
62 | if self.test_lr:
63 | lr = SRBaseDataset.rgb_to_ycbcr(lr, y_only=True)
64 |
65 | # HW to CHW --> normalize
66 | hr = np.expand_dims(hr / 255., 0)
67 | if self.test_lr:
68 | lr = np.expand_dims(lr / 255., 0)
69 | else:
70 | # HWC to CHW --> normalize
71 | hr = hr.transpose((2, 0, 1)) / 255.
72 | if self.test_lr:
73 | lr = lr.transpose((2, 0, 1)) / 255.
74 |
75 | images = [np.ascontiguousarray(hr), bicubic]
76 | if self.test_lr:
77 | images.append(np.ascontiguousarray(lr))
78 |
79 | if self.train_y:
80 | images.append(np.ascontiguousarray(ycbcr))
81 |
82 | return images, img_name
83 |
--------------------------------------------------------------------------------
/models/srdensenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Image Super-Resolution Using Dense Skip Connections
3 | Url: https://openaccess.thecvf.com/content_ICCV_2017/papers/Tong_Image_Super-Resolution_Using_ICCV_2017_paper.pdf
4 | Create by: zh320
5 | Date: 2024/01/27
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from .modules import conv3x3, ConvAct, Activation, Upsample
12 |
13 |
14 | class SRDenseNet(nn.Module):
15 | def __init__(self, in_channels, out_channels, upscale, hid_channels=128, num_block=8, num_layer=8,
16 | act_type='relu', upsample_type='deconvolution'):
17 | super(SRDenseNet, self).__init__()
18 | assert upscale in [2,3,4], f'Unsupported upscale factor: {upscale}.\n'
19 | self.num_block = num_block
20 |
21 | # Initial Convolution
22 | self.conv = ConvAct(in_channels, hid_channels, 3, act_type=act_type)
23 |
24 | # Dense Blocks
25 | self.dense_blocks = nn.ModuleList([])
26 | for _ in range(num_block):
27 | self.dense_blocks.append(DenseBlock(hid_channels, hid_channels, num_layer, act_type))
28 |
29 | # Bottleneck Layer
30 | self.bottleneck = ConvAct(hid_channels*(num_block+1), hid_channels*2, 1, act_type=act_type)
31 |
32 | # Deconvolution Layers
33 | if upscale in [2, 3]:
34 | self.deconvolution = Upsample(hid_channels*2, hid_channels*2, upscale, upsample_type)
35 | elif upscale in [4]:
36 | self.deconvolution = nn.Sequential(
37 | Upsample(hid_channels*2, hid_channels*2, 2, upsample_type),
38 | Upsample(hid_channels*2, hid_channels*2, 2, upsample_type)
39 | )
40 |
41 | # Reconstruction Layer
42 | self.reconstruction = conv3x3(hid_channels*2, out_channels)
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 |
47 | feats = [x]
48 | for i in range(self.num_block):
49 | x = self.dense_blocks[i](x)
50 | feats.append(x)
51 |
52 | x = self.bottleneck(torch.cat(feats, dim=1))
53 | x = self.deconvolution(x)
54 | x = self.reconstruction(x)
55 |
56 | return x
57 |
58 |
59 | class DenseBlock(nn.Module):
60 | def __init__(self, in_channels, out_channels, num_layer, act_type):
61 | super(DenseBlock, self).__init__()
62 | assert out_channels % num_layer == 0, 'out_channels should be evenly divided by num_layer.\n'
63 | self.num_layer = num_layer
64 | growth_rate = out_channels // num_layer
65 |
66 | self.conv0 = conv3x3(in_channels, growth_rate)
67 | self.act = nn.ModuleList([Activation(act_type) for _ in range(num_layer)])
68 | self.conv = nn.ModuleList([])
69 | for i in range(1, num_layer-1):
70 | self.conv.append(conv3x3(i*growth_rate, growth_rate))
71 | self.conv.append(conv3x3((num_layer-1)*growth_rate, out_channels))
72 |
73 | def forward(self, x):
74 | x = self.conv0(x)
75 | feats = [x]
76 |
77 | for i in range(self.num_layer - 1):
78 | x = torch.cat(feats, dim=1)
79 | x = self.act[i](x)
80 | feat = self.conv[i](x)
81 | if i != self.num_layer - 1:
82 | feats.append(feat)
83 |
84 | feat = self.act[-1](feat)
85 |
86 | return feat
87 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from .sr_base_dataset import SRBaseDataset
3 | from .sr_dataset import SRDataset
4 | from .val_datasets import Set5, Set14, BSD100
5 |
6 | dataset_hub = {'sr':SRDataset, 'set5':Set5, 'set14':Set14, 'bsd100':BSD100,}
7 |
8 |
9 | def get_dataset(config):
10 | if config.dataset in dataset_hub.keys():
11 | train_dataset = dataset_hub[config.dataset](config=config, mode='train')
12 | val_dataset = dataset_hub[config.dataset](config=config, mode='val')
13 | else:
14 | raise NotImplementedError('Unsupported dataset!')
15 |
16 | return train_dataset, val_dataset
17 |
18 |
19 | def get_loader(config, rank, pin_memory=True):
20 | train_dataset, val_dataset = get_dataset(config)
21 |
22 | # Make sure train number is divisible by train batch size
23 | config.train_num = int(len(train_dataset) // config.train_bs * config.train_bs)
24 | config.val_num = len(val_dataset)
25 |
26 | if config.DDP:
27 | from torch.utils.data.distributed import DistributedSampler
28 | train_sampler = DistributedSampler(train_dataset, num_replicas=config.gpu_num,
29 | rank=rank, shuffle=True)
30 | val_sampler = DistributedSampler(val_dataset, num_replicas=config.gpu_num,
31 | rank=rank, shuffle=False)
32 |
33 | train_loader = DataLoader(train_dataset, batch_size=config.train_bs, shuffle=False,
34 | num_workers=config.num_workers, pin_memory=pin_memory,
35 | sampler=train_sampler, drop_last=True)
36 |
37 | val_loader = DataLoader(val_dataset, batch_size=config.val_bs, shuffle=False,
38 | num_workers=config.num_workers, pin_memory=pin_memory,
39 | sampler=val_sampler)
40 | else:
41 | train_loader = DataLoader(train_dataset, batch_size=config.train_bs,
42 | shuffle=True, num_workers=config.num_workers, drop_last=True)
43 |
44 | val_loader = DataLoader(val_dataset, batch_size=config.val_bs,
45 | shuffle=False, num_workers=config.num_workers)
46 |
47 | return train_loader, val_loader
48 |
49 |
50 | def get_val_dataset(config):
51 | val_datasets = []
52 | for dataset_name in config.benchmark_datasets:
53 | if dataset_name in dataset_hub.keys():
54 | val_dataset = dataset_hub[dataset_name](config=config, mode='val')
55 | val_datasets.append(val_dataset)
56 | else:
57 | raise NotImplementedError('Unsupported dataset!')
58 |
59 | return val_datasets
60 |
61 |
62 | def get_val_loader(config):
63 | val_datasets = get_val_dataset(config)
64 |
65 | val_loaders = []
66 | for val_dataset in val_datasets:
67 | val_loader = DataLoader(val_dataset, batch_size=config.val_bs,
68 | shuffle=False, num_workers=config.num_workers)
69 | val_loaders.append(val_loader)
70 |
71 | return val_loaders
72 |
73 |
74 | def get_test_loader(config):
75 | from .test_dataset import TestDataset
76 | dataset = TestDataset(config)
77 |
78 | config.test_num = len(dataset)
79 |
80 | if config.DDP:
81 | raise NotImplementedError()
82 |
83 | else:
84 | test_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=config.num_workers)
85 |
86 | return test_loader
87 |
--------------------------------------------------------------------------------
/models/carn.py:
--------------------------------------------------------------------------------
1 | """
2 | Paper: Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network
3 | Url: https://arxiv.org/abs/1803.08664
4 | Create by: zh320
5 | Date: 2023/12/30
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from .modules import conv1x1, conv3x3, ConvAct, Activation, Upsample
12 |
13 |
14 | class CARN(nn.Module):
15 | def __init__(self, in_channels, out_channels, upscale, arch_type='carn',
16 | hid_channels=64, act_type='relu', upsample_type='pixelshuffle'):
17 | super(CARN, self).__init__()
18 | if arch_type not in ['carn', 'carn-m']:
19 | raise ValueError(f'Unsupported arch_type: {arch_type}\n')
20 | block = ResidualBlock if arch_type == 'carn' else ResidualEBlock
21 |
22 | self.conv1 = conv3x3(in_channels, hid_channels)
23 | self.cascading_block1 = CascadingBlock(block, hid_channels, act_type)
24 | self.conv2 = conv1x1(2*hid_channels, hid_channels)
25 | self.cascading_block2 = CascadingBlock(block, hid_channels, act_type)
26 | self.conv3 = conv1x1(3*hid_channels, hid_channels)
27 | self.cascading_block3 = CascadingBlock(block, hid_channels, act_type)
28 | self.conv4 = conv1x1(4*hid_channels, hid_channels)
29 | if upscale in [2, 3]:
30 | self.upsample = nn.Sequential(
31 | conv3x3(hid_channels, hid_channels),
32 | Upsample(hid_channels, hid_channels, upscale, upsample_type, 3)
33 | )
34 | elif upscale == 4:
35 | self.upsample = nn.Sequential(
36 | conv3x3(hid_channels, hid_channels),
37 | Upsample(hid_channels, hid_channels, 2, upsample_type, 3),
38 | conv3x3(hid_channels, hid_channels),
39 | Upsample(hid_channels, hid_channels, 2, upsample_type, 3)
40 | )
41 | else:
42 | raise NotImplementedError(f'Unsupported upscale factor: {upscale}\n')
43 | self.conv_last = conv3x3(hid_channels, out_channels)
44 |
45 | def forward(self, x):
46 | x1 = self.conv1(x)
47 | x_cb1 = self.cascading_block1(x1)
48 | x = torch.cat([x1, x_cb1], dim=1)
49 |
50 | x = self.conv2(x)
51 | x_cb2 = self.cascading_block2(x)
52 | x = torch.cat([x1, x_cb1, x_cb2], dim=1)
53 |
54 | x = self.conv3(x)
55 | x_cb3 = self.cascading_block3(x)
56 | x = torch.cat([x1, x_cb1, x_cb2, x_cb3], dim=1)
57 |
58 | x = self.conv4(x)
59 | x = self.upsample(x)
60 | x = self.conv_last(x)
61 |
62 | return x
63 |
64 |
65 | class CascadingBlock(nn.Module):
66 | def __init__(self, block, channels, act_type):
67 | super(CascadingBlock, self).__init__()
68 | self.res1 = block(channels, act_type)
69 | self.conv1 = conv1x1(2*channels, channels)
70 | self.res2 = block(channels, act_type)
71 | self.conv2 = conv1x1(3*channels, channels)
72 | self.res3 = block(channels, act_type)
73 | self.conv3 = conv1x1(4*channels, channels)
74 |
75 | def forward(self, x):
76 | x0 = x
77 |
78 | x1 = self.res1(x)
79 | x = torch.cat([x0, x1], dim=1)
80 | x = self.conv1(x)
81 |
82 | x2 = self.res2(x)
83 | x = torch.cat([x0, x1, x2], dim=1)
84 | x = self.conv2(x)
85 |
86 | x = self.res3(x)
87 | x = torch.cat([x, x0, x1, x2], dim=1)
88 | x = self.conv3(x)
89 |
90 | return x
91 |
92 |
93 | class ResidualBlock(nn.Module):
94 | def __init__(self, channels, act_type):
95 | super(ResidualBlock, self).__init__()
96 | self.conv = nn.Sequential(
97 | ConvAct(channels, channels, 3, act_type=act_type),
98 | conv3x3(channels, channels)
99 | )
100 | self.act = Activation(act_type)
101 |
102 | def forward(self, x):
103 | residual = x
104 | x = self.conv(x)
105 | x += residual
106 |
107 | return self.act(x)
108 |
109 |
110 | class ResidualEBlock(nn.Module):
111 | def __init__(self, channels, act_type, groups=4):
112 | super(ResidualEBlock, self).__init__()
113 | self.conv = nn.Sequential(
114 | ConvAct(channels, channels, 3, groups=groups, act_type=act_type),
115 | ConvAct(channels, channels, 3, groups=groups, act_type=act_type),
116 | conv1x1(channels, channels)
117 | )
118 | self.act = Activation(act_type)
119 |
120 | def forward(self, x):
121 | residual = x
122 | x = self.conv(x)
123 | x += residual
124 |
125 | return self.act(x)
126 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | # Regular convolution with kernel size 3x3
5 | def conv5x5(in_channels, out_channels, stride=1):
6 | return nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride,
7 | padding=2, bias=True)
8 |
9 |
10 | # Regular convolution with kernel size 3x3
11 | def conv3x3(in_channels, out_channels, stride=1):
12 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
13 | padding=1, bias=True)
14 |
15 |
16 | # Regular convolution with kernel size 1x1, a.k.a. point-wise convolution
17 | def conv1x1(in_channels, out_channels, stride=1):
18 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
19 | padding=0, bias=True)
20 |
21 |
22 | class Upsample(nn.Module):
23 | def __init__(self, in_channels, out_channels, scale_factor=2, upsample_type=None,
24 | kernel_size=None,):
25 | super(Upsample, self).__init__()
26 | if upsample_type == 'deconvolution':
27 | if kernel_size is None:
28 | kernel_size = 2*scale_factor + 1
29 | padding = (kernel_size - 1) // 2
30 | output_padding = scale_factor - 1
31 | self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
32 | stride=scale_factor, padding=padding,
33 | output_padding=output_padding, bias=True)
34 | elif upsample_type == 'pixelshuffle':
35 | ks = kernel_size if kernel_size is not None else 3
36 | padding = (ks - 1) // 2
37 | self.up_conv = nn.Sequential(
38 | nn.Conv2d(in_channels, out_channels * (scale_factor**2), ks, 1, padding),
39 | nn.PixelShuffle(scale_factor)
40 | )
41 | else:
42 | ks = kernel_size if kernel_size is not None else 3
43 | padding = (ks - 1) // 2
44 | self.up_conv = nn.Sequential(
45 | nn.Conv2d(in_channels, out_channels, ks, 1, padding),
46 | nn.Upsample(scale_factor=scale_factor, mode='bicubic')
47 | )
48 |
49 | def forward(self, x):
50 | return self.up_conv(x)
51 |
52 |
53 | # Regular convolution -> activation
54 | class ConvAct(nn.Sequential):
55 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
56 | groups=1, bias=True, act_type='relu', **kwargs):
57 | if isinstance(kernel_size, list) or isinstance(kernel_size, tuple):
58 | padding = ((kernel_size[0] - 1) // 2 * dilation, (kernel_size[1] - 1) // 2 * dilation)
59 | elif isinstance(kernel_size, int):
60 | padding = (kernel_size - 1) // 2 * dilation
61 |
62 | super(ConvAct, self).__init__(
63 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias),
64 | Activation(act_type, **kwargs)
65 | )
66 |
67 |
68 | # Regular convolution -> batchnorm -> activation
69 | class ConvBNAct(nn.Sequential):
70 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
71 | groups=1, bias=True, act_type='relu', **kwargs):
72 | if isinstance(kernel_size, list) or isinstance(kernel_size, tuple):
73 | padding = ((kernel_size[0] - 1) // 2 * dilation, (kernel_size[1] - 1) // 2 * dilation)
74 | elif isinstance(kernel_size, int):
75 | padding = (kernel_size - 1) // 2 * dilation
76 |
77 | super(ConvBNAct, self).__init__(
78 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias),
79 | nn.BatchNorm2d(out_channels),
80 | Activation(act_type, **kwargs)
81 | )
82 |
83 |
84 | class Activation(nn.Module):
85 | def __init__(self, act_type, **kwargs):
86 | super(Activation, self).__init__()
87 | activation_hub = {'relu': nn.ReLU, 'relu6': nn.ReLU6,
88 | 'leakyrelu': nn.LeakyReLU, 'prelu': nn.PReLU,
89 | 'celu': nn.CELU, 'elu': nn.ELU,
90 | 'hardswish': nn.Hardswish, 'hardtanh': nn.Hardtanh,
91 | 'gelu': nn.GELU, 'glu': nn.GLU,
92 | 'selu': nn.SELU, 'silu': nn.SiLU,
93 | 'sigmoid': nn.Sigmoid, 'softmax': nn.Softmax,
94 | 'tanh': nn.Tanh, 'none': nn.Identity,
95 | }
96 |
97 | act_type = act_type.lower()
98 | if act_type not in activation_hub.keys():
99 | raise NotImplementedError(f'Unsupport activation type: {act_type}')
100 |
101 | self.activation = activation_hub[act_type](**kwargs)
102 |
103 | def forward(self, x):
104 | return self.activation(x)
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 |
3 | PyTorch implementation of efficient image super-resolution models.
4 |
5 |
6 |
7 | #
8 |
9 | # Requirements
10 |
11 | torch == 1.8.1
12 | torchmetrics
13 | loguru
14 | tqdm
15 |
16 | # Supported models
17 |
18 | - [CARN](models/carn.py) [^carn]
19 | - [DRCN](models/drcn.py) [^drcn]
20 | - [DRRN](models/drrn.py) [^drrn]
21 | - [EDSR](models/edsr.py) [^edsr]
22 | - [ESPCN](models/espcn.py) [^espcn]
23 | - [FSRCNN](models/fsrcnn.py) [^fsrcnn]
24 | - [IDN](models/idn.py) [^idn]
25 | - [LapSRN](models/lapsrn.py) [^lapsrn]
26 | - [SRCNN](models/srcnn.py) [^srcnn]
27 | - [SRDenseNet](models/srdensenet.py) [^srdensenet]
28 | - [VDSR](models/vdsr.py) [^vdsr]
29 |
30 | [^carn]: [Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network](https://arxiv.org/abs/1803.08664)
31 | [^drcn]: [Deeply-Recursive Convolutional Network for Image Super-Resolution](https://arxiv.org/abs/1511.04491)
32 | [^drrn]: [ Image Super-Resolution via Deep Recursive Residual Network](https://openaccess.thecvf.com/content_cvpr_2017/html/Tai_Image_Super-Resolution_via_CVPR_2017_paper.html)
33 | [^edsr]: [Enhanced Deep Residual Networks for Single Image Super-Resolution](https://arxiv.org/abs/1707.02921)
34 | [^espcn]: [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158)
35 | [^fsrcnn]: [Accelerating the Super-Resolution Convolutional Neural Network](https://arxiv.org/abs/1608.00367)
36 | [^idn]: [Fast and Accurate Single Image Super-Resolution via Information Distillation Network](https://arxiv.org/abs/1803.09454)
37 | [^lapsrn]: [Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution](https://arxiv.org/abs/1704.03915)
38 | [^srcnn]: [Image Super-Resolution Using Deep Convolutional Networks](https://arxiv.org/abs/1501.00092)
39 | [^srdensenet]: [Image Super-Resolution Using Dense Skip Connections](https://openaccess.thecvf.com/content_ICCV_2017/papers/Tong_Image_Super-Resolution_Using_ICCV_2017_paper.pdf)
40 | [^vdsr]: [Accurate Image Super-Resolution Using Very Deep Convolutional Networks](https://arxiv.org/abs/1511.04587)
41 |
42 | # How to use
43 |
44 | ## DDP training (recommend)
45 |
46 | ```
47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py
48 | ```
49 |
50 | ## DP training
51 |
52 | ```
53 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py
54 | ```
55 |
56 | # Performances and checkpoints
57 |
58 | | Model | Year | Train on1 | Set5 | | Set14 | | BSD100 | |
59 | |:------------------------------------------------------------------------------------------------------------------------:|:----:|:--------------------:|:---------------:|:-------------:|:-----------:|:-------------:|:-----------:|:-------------:|
60 | | 2x | | | PSNR (paper/my) | SSIM | PSNR | SSIM | PSNR | SSIM |
61 | | [CARN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/carn_2x.pth) | 2018 | T+B+D | 37.76/37.90 | 0.9590/0.9605 | 33.52/33.14 | 0.9166/0.9152 | 32.09/32.06 | 0.8978/0.8985 |
62 | | [DRCN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/drcn_2x.pth) | 2015 | T | 37.63/37.85 | 0.9588/0.9604 | 33.04/33.22 | 0.9118/0.916 | 31.85/32.05 | 0.8942/0.8982 |
63 | | [DRRN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/drrn_2x.pth) | 2017 | T+B | 37.74/37.76 | 0.9591/0.9599 | 33.23/33.14 | 0.9136/0.9149 | 32.05/31.99 | 0.8973/0.8974 |
64 | | [EDSR](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/edsr_2x.pth) | 2017 | D | 37.99/37.90 | 0.9604/0.9606 | 33.57/33.22 | 0.9175/0.9163 | 32.16/32.10 | 0.8994/0.899 |
65 | | [ESPCN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/espcn_2x.pth) | 2016 | I+T | n.a./36.85 | n.a./0.9559 | n.a./32.31 | n.a./0.9087 | n.a./31.40 | n.a./0.8897 |
66 | | [FSRCNN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/fsrcnn_2x.pth) | 2016 | T+G | 37.00/37.27 | 0.9558/0.958 | 32.63/32.65 | 0.9088/0.9115 | 31.53/31.67 | 0.8920/0.8934 |
67 | | [IDN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/idn_2x.pth) | 2018 | T+B | 37.83/37.84 | 0.96/0.9604 | 33.30/33.12 | 0.9148/0.9155 | 32.08/32.06 | 0.8985/0.8985 |
68 | | [LapSRN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/lapsrn_2x.pth) | 2017 | T+B | 37.52/37.59 | 0.9591/0.9592 | 32.99/32.96 | 0.9124/0.9138 | 31.80/31.89 | 0.8952/0.8961 |
69 | | [SRCNN](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/srcnn_2x.pth) | 2014 | I+T | 36.66/36.88 | 0.9542/0.9561 | 32.45/32.42 | 0.9067/0.9092 | 31.36/31.50 | 0.8879/0.8907 |
70 | | [SRDenseNet](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/srdensenet_2x.pth) | 2017 | I | n.a./37.67 | n.a./0.9596 | n.a./33.05 | n.a./0.9142 | n.a./31.93 | n.a./0.8967 |
71 | | [VDSR](https://github.com/zh320/efficient-image-super-resolution-pytorch/releases/download/v1.0/vdsr_2x.pth) | 2015 | T+B | 37.53/37.74 | 0.9587/0.9598 | 33.03/33.06 | 0.9124/0.9145 | 31.90/31.97 | 0.8960/0.8973 |
72 |
73 | [1 Original training dataset, which are short for B (BSD200), D (DIV2K), G (General100), I (ImageNet), T (T91). In my experiments, the training dataset is T + G + B.]
74 |
75 | # Prepare the dataset
76 |
77 | ```
78 | /train
79 | /T91
80 | /General100
81 | /BSD200
82 | /val
83 | /Set5
84 | /Set14
85 | /BSD100
86 | ```
87 |
88 | # References
--------------------------------------------------------------------------------
/datasets/sr_base_dataset.py:
--------------------------------------------------------------------------------
1 | import os, random
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class SRBaseDataset(Dataset):
8 | IMG_SUFFIX = ['jpg', 'jpeg', 'png', 'bmp']
9 |
10 | def __init__(self, config, data_split, mode='train'):
11 | assert mode in ['train', 'val'], f'Unsupported dataset mode: {mode}.\n'
12 |
13 | data_root = os.path.expanduser(config.data_root)
14 | img_dirs = []
15 | for img_dir in data_split[mode]:
16 | img_dirs.append(os.path.join(data_root, img_dir))
17 |
18 | if len(img_dirs) == 0:
19 | raise RuntimeError('No image directory found.')
20 |
21 | for img_dir in img_dirs:
22 | if not os.path.isdir(img_dir):
23 | raise RuntimeError(f'Image directory: {img_dir} does not exist.')
24 |
25 | self.mode = mode
26 | self.train_y = config.train_y
27 | self.scale = config.upscale
28 | self.patch_size = config.patch_size
29 | self.random_rotate = config.rotate
30 | self.multi_scale = config.multi_scale
31 | self.hflip = config.hflip
32 | self.vflip = config.vflip
33 |
34 | hr_images = []
35 | for img_dir in img_dirs:
36 | for file_name in os.listdir(img_dir):
37 | if file_name.split('.')[-1].lower() in SRBaseDataset.IMG_SUFFIX:
38 | img_path = os.path.join(img_dir, file_name)
39 | hr_images.append(img_path)
40 |
41 | if self.mode == 'train':
42 | dataset_repeat_times = config.max_itrs_per_epoch*config.train_bs*config.gpu_num // len(hr_images)
43 | hr_images *= dataset_repeat_times
44 |
45 | self.hr_images = hr_images
46 |
47 | self.num = len(self.hr_images)
48 |
49 | @staticmethod
50 | def rgb_to_ycbcr(img, y_only=False):
51 | if not isinstance(img, np.ndarray):
52 | raise ValueError(f'\nInput must be np.ndarray, but got type: {type(img)} instead.')
53 | if img.shape[2] != 3:
54 | raise ValueError(f'\nInput should be RGB channel array, but got {img.shape[2]} channel instead.')
55 |
56 | y = (np.dot(img, [65.481, 128.553, 24.966]) / 255. + 16.).astype(np.float32)
57 | if y_only:
58 | return y
59 | else:
60 | cb = (np.dot(img, [-37.945, -74.494, 112.439]) / 255. + 128.).astype(np.float32)
61 | cr = (np.dot(img, [112.439, -94.154, -18.285]) / 255. + 128.).astype(np.float32)
62 | return np.array([y, cb, cr]).transpose([1, 2, 0])
63 |
64 | @staticmethod
65 | def ycbcr_to_rgb(img):
66 | if not isinstance(img, np.ndarray):
67 | raise ValueError(f'\nInput must be np.ndarray, but got type: {type(img)} instead.')
68 | if img.shape[2] != 3:
69 | raise ValueError(f'\nInput should be 3-channel array, but got {img.shape[2]} channel instead.')
70 |
71 | r = (np.dot(img, [298.082, 0, 408.583]) / 255. - 222.921).astype(np.float32)
72 | g = (np.dot(img, [298.082, -100.291, - 208.12]) / 255. + 135.576).astype(np.float32)
73 | b = (np.dot(img, [298.082, 516.412, 0]) / 255. - 276.836).astype(np.float32)
74 | return np.array([r, g, b]).transpose([1, 2, 0])
75 |
76 | def __len__(self):
77 | return len(self.hr_images)
78 |
79 | def __getitem__(self, index):
80 | hr = Image.open(self.hr_images[index]).convert('RGB')
81 |
82 | if self.mode == 'train':
83 | # Perform multiscale augmentation
84 | if self.multi_scale:
85 | # In case the image is too small to perform random crop
86 | min_scale = max(self.patch_size[0]/hr.width, self.patch_size[1]/hr.height)
87 | if min_scale > 1:
88 | scale = min_scale
89 | else:
90 | scale = random.uniform(0.75, 1.0)
91 |
92 | new_w = int(hr.width * scale)
93 | new_h = int(hr.height * scale)
94 | hr = hr.resize((new_w, new_h), resample=Image.BICUBIC)
95 |
96 | # Calculate the random crop location x0
97 | if hr.width > self.patch_size[0]:
98 | x0 = random.randint(0, hr.width-self.patch_size[0])
99 | else:
100 | x0 = 0
101 |
102 | # Calculate the random crop location y0
103 | if hr.height > self.patch_size[1]:
104 | y0 = random.randint(0, hr.height-self.patch_size[1])
105 | else:
106 | y0 = 0
107 |
108 | # Random crop a path using patch_size
109 | hr = hr.crop([x0, y0, x0+self.patch_size[0], y0+self.patch_size[1]])
110 |
111 | # Random rotation
112 | if random.random() < self.random_rotate:
113 | angle = random.randint(0, 3) * 90
114 | hr = hr.rotate(angle, expand=True)
115 |
116 | # Random horizontal flip
117 | if random.random() < self.hflip:
118 | hr = hr.transpose(Image.FLIP_LEFT_RIGHT)
119 |
120 | # Random vertical flip
121 | if random.random() < self.vflip:
122 | hr = hr.transpose(Image.FLIP_TOP_BOTTOM)
123 |
124 | # Resize image to make it compatible for training if needed
125 | hr_width = (hr.width // self.scale) * self.scale
126 | hr_height = (hr.height // self.scale) * self.scale
127 | if hr_width != hr.width or hr_height != hr.height:
128 | hr = hr.resize((hr_width, hr_height), resample=Image.BICUBIC)
129 |
130 | # Generate low resolution image using bicubic interpolation of HR image
131 | lr_width = hr_width // self.scale
132 | lr_height = hr_height // self.scale
133 | lr = hr.resize((lr_width, lr_height), resample=Image.BICUBIC)
134 |
135 | hr = np.array(hr).astype(np.float32)
136 | lr = np.array(lr).astype(np.float32)
137 |
138 | if self.train_y:
139 | # RGB to YCbCr (only need Y channel here)
140 | hr = self.rgb_to_ycbcr(hr, y_only=True)
141 | lr = self.rgb_to_ycbcr(lr, y_only=True)
142 |
143 | # HW to CHW --> normalize
144 | hr = np.expand_dims(hr / 255., 0)
145 | lr = np.expand_dims(lr / 255., 0)
146 | else:
147 | # HWC to CHW --> normalize
148 | hr = hr.transpose((2, 0, 1)) / 255.
149 | lr = lr.transpose((2, 0, 1)) / 255.
150 |
151 | return np.ascontiguousarray(lr), np.ascontiguousarray(hr)
152 |
--------------------------------------------------------------------------------
/core/sr_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from tqdm import tqdm
6 | from torch.cuda import amp
7 | from copy import deepcopy
8 |
9 | from .base_trainer import BaseTrainer
10 | from utils import (get_sr_metrics, sampler_set_epoch, log_config, de_parallel)
11 |
12 |
13 | class SRTrainer(BaseTrainer):
14 | def __init__(self, config):
15 | super().__init__(config)
16 | self.psnr = get_sr_metrics('psnr').to(self.device)
17 | self.ssim = get_sr_metrics('ssim').to(self.device)
18 | if config.metrics not in ['psnr', 'ssim']:
19 | raise ValueError(f'Unsupport metrics type: {config.metrics}')
20 |
21 | def train_one_epoch(self, config):
22 | self.model.train()
23 |
24 | sampler_set_epoch(config, self.train_loader, self.cur_epoch)
25 |
26 | pbar = tqdm(self.train_loader) if self.main_rank else self.train_loader
27 |
28 | for cur_itrs, (images, labels) in enumerate(pbar):
29 | self.cur_itrs = cur_itrs
30 | self.train_itrs += 1
31 |
32 | images = images.to(self.device)
33 | labels = labels.to(self.device)
34 |
35 | self.optimizer.zero_grad()
36 |
37 | # Forward path
38 | with amp.autocast(enabled=config.amp_training):
39 | preds = self.model(images)
40 | loss = self.loss_fn(preds, labels)
41 |
42 | if config.use_tb and self.main_rank:
43 | self.writer.add_scalar('train/loss', loss.detach(), self.train_itrs)
44 |
45 | # Backward path
46 | self.scaler.scale(loss).backward()
47 | self.scaler.step(self.optimizer)
48 | self.scaler.update()
49 | self.scheduler.step()
50 |
51 | if self.cur_epoch >= config.ema_start_epoch:
52 | self.ema_model.update(self.model, self.train_itrs)
53 | else:
54 | self.ema_model.ema = deepcopy(de_parallel(self.model))
55 |
56 | if self.main_rank:
57 | pbar.set_description(('%s'*2) %
58 | (f'Epoch:{self.cur_epoch}/{config.total_epoch}{" "*4}|',
59 | f'Loss:{loss.detach():4.4g}{" "*4}|',)
60 | )
61 |
62 | return
63 |
64 | @torch.no_grad()
65 | def validate(self, config, val_best=False):
66 | pbar = tqdm(self.val_loader) if self.main_rank else self.val_loader
67 | for (images, labels) in pbar:
68 | images = images.to(self.device)
69 | labels = labels.to(self.device)
70 |
71 | preds = self.ema_model.ema(images).clamp(0.0, 1.0)
72 | self.psnr.update(preds.detach(), labels)
73 | self.ssim.update(preds.detach(), labels)
74 |
75 | if self.main_rank:
76 | pbar.set_description(('%s'*1) % (f'Validating:{" "*4}|',))
77 |
78 | psnr = self.psnr.compute()
79 | ssim = self.ssim.compute()
80 | if config.metrics == 'psnr':
81 | score = psnr
82 | elif config.metrics == 'ssim':
83 | score = ssim
84 |
85 | if self.main_rank:
86 | if val_best:
87 | self.logger.info(f'\n\nTrain {config.total_epoch} epochs finished.' +
88 | f'\n\nBest {config.metrics.upper()} is: {psnr:.4f}\n')
89 | else:
90 | self.logger.info(f' Epoch{self.cur_epoch} PSNR: {psnr:.4f} SSIM: {ssim:.4f} | ' +
91 | f'best {config.metrics.upper()} so far: {self.best_score:.4f}\n')
92 |
93 | if config.use_tb and self.cur_epoch < config.total_epoch:
94 | self.writer.add_scalar('val/PSNR', psnr.cpu(), self.cur_epoch+1)
95 | self.writer.add_scalar('val/SSIM', ssim.cpu(), self.cur_epoch+1)
96 | self.psnr.reset()
97 | self.ssim.reset()
98 | return score
99 |
100 | @torch.no_grad()
101 | def benchmark(self, config):
102 | if self.main_rank:
103 | log_config(config, self.logger)
104 |
105 | print(f'{"-"*25} Start benchmarking {"-"*25}\n')
106 | for i, val_loader in enumerate(self.val_loaders):
107 | print(f"\nStart validating dataset: {config.benchmark_datasets[i]}...")
108 |
109 | pbar = tqdm(val_loader) if self.main_rank else val_loader
110 | for (images, labels) in pbar:
111 | images = images.to(self.device)
112 | labels = labels.to(self.device)
113 |
114 | preds = self.model(images).clamp(0.0, 1.0)
115 | self.psnr.update(preds.detach(), labels)
116 | self.ssim.update(preds.detach(), labels)
117 |
118 | if self.main_rank:
119 | pbar.set_description(('%s'*1) % (f'Validating:{" "*4}|',))
120 |
121 | psnr = self.psnr.compute()
122 | ssim = self.ssim.compute()
123 |
124 | if self.main_rank:
125 | self.logger.info(f' PSNR: {psnr:.4f} SSIM: {ssim:.4f}\n')
126 |
127 | self.psnr.reset()
128 | self.ssim.reset()
129 | print(f'{"-"*25} Finish benchmarking {"-"*25}\n')
130 |
131 | @torch.no_grad()
132 | def predict(self, config):
133 | from datasets import SRBaseDataset
134 | if config.DDP:
135 | raise ValueError('Predict mode currently does not support DDP.')
136 |
137 | if config.test_bs != 1:
138 | self.logger.info('Warning: Predict mode only support batch size 1\n')
139 |
140 | self.logger.info('\nStart predicting...\n')
141 |
142 | if config.test_lr:
143 | test_score_path = f'{config.save_dir}/test_score_{config.model}_x{config.upscale}.txt'
144 | if os.path.isfile(test_score_path):
145 | os.remove(test_score_path)
146 |
147 | for (images, img_names) in tqdm(self.test_loader):
148 | hr = images[0]
149 | bicubic = images[1]
150 | img_name = img_names[0]
151 |
152 | hr = hr.to(self.device, dtype=torch.float32)
153 | if config.test_lr:
154 | lr = images[2].to(self.device, dtype=torch.float32)
155 | pred = self.model(lr).clamp(0.0, 1.0)
156 |
157 | # Compute PSNR and SSIM if test lr image
158 | self.psnr.update(pred.detach(), hr)
159 | self.ssim.update(pred.detach(), hr)
160 |
161 | psnr = self.psnr.compute()
162 | ssim = self.ssim.compute()
163 |
164 | self.psnr.reset()
165 | self.ssim.reset()
166 |
167 | with open(test_score_path, 'a+') as f:
168 | f.write(f'{img_name}\tPSNR: {psnr:.2f}\tSSIM: {ssim:.4f}\n')
169 | else:
170 | pred = self.model(hr).clamp(0.0, 1.0)
171 |
172 | # BCHW --> HWC
173 | if config.train_y:
174 | # Need to concatenate [pred_y, bicubic_cb, bicubic_cr] to obtain predicted image
175 | ycbcr = images[-1].numpy().squeeze(0)
176 | pred = pred.mul(255.0).cpu().unsqueeze(-1).numpy().squeeze(0).squeeze(0)
177 | pred = np.concatenate([pred, ycbcr[...,1:]], axis=-1)
178 | pred = SRBaseDataset.ycbcr_to_rgb(pred)
179 | else:
180 | pred = pred.mul(255.0).cpu().numpy().squeeze(0).transpose([1, 2, 0])
181 |
182 | pred = np.clip(pred, 0., 255.).astype(np.uint8)
183 |
184 | # Saving results
185 | img_suffix = img_name.split('.')[-1]
186 | img_prefix = img_name[:-len(img_suffix)-1]
187 | pred_path = f'{config.save_dir}/{img_prefix}_{config.model}_x{config.upscale}.{img_suffix}'
188 | bicubic_path = f'{config.save_dir}/{img_prefix}_bicubic_x{config.upscale}.{img_suffix}'
189 |
190 | pred = Image.fromarray(pred.astype(np.uint8))
191 | pred.save(pred_path)
192 |
193 | bicubic = Image.fromarray(bicubic.numpy().squeeze().astype(np.uint8))
194 | bicubic.save(bicubic_path)
195 |
--------------------------------------------------------------------------------
/configs/parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def load_parser(config):
5 | args = get_parser()
6 |
7 | for k,v in vars(args).items():
8 | if v is not None:
9 | try:
10 | exec(f"config.{k} = v")
11 | except:
12 | raise RuntimeError(f'Unable to assign value to config.{k}')
13 | return config
14 |
15 |
16 | def get_parser():
17 | parser = argparse.ArgumentParser()
18 | # Dataset
19 | parser.add_argument('--dataset', type=str, default=None, choices=['cityscapes'],
20 | help='choose which dataset you want to use')
21 | parser.add_argument('--dataroot', type=str, default=None,
22 | help='path to your dataset')
23 | parser.add_argument('--upscale', type=int, default=None,
24 | help='scale factor for super resolution')
25 | parser.add_argument('--train_y', action='store_false', default=None,
26 | help='whether to train Y channel in YCbCr space or not (default: True)')
27 |
28 | # Model
29 | parser.add_argument('--model', type=str, default=None,
30 | choices=['carn', 'drcn', 'drrn', 'edsr', 'espcn', 'fsrcnn',
31 | 'idn', 'lapsrn', 'srcnn', 'srdensenet', 'vdsr'],
32 | help='choose which model you want to use')
33 | parser.add_argument('--in_channels', type=int, default=None,
34 | help='number of input channel for given model (default: 1 for Y input else 3 for RGB input)')
35 | parser.add_argument('--out_channels', type=int, default=None,
36 | help='number of output channel for given model (default: 1 for Y input else 3 for RGB input)')
37 |
38 | # Training
39 | parser.add_argument('--total_epoch', type=int, default=None,
40 | help='number of total training epochs')
41 | parser.add_argument('--base_lr', type=float, default=None,
42 | help='base learning rate for single GPU, total learning rate *= gpu number')
43 | parser.add_argument('--train_bs', type=int, default=None,
44 | help='training batch size for single GPU, total batch size *= gpu number')
45 | parser.add_argument('--early_stop_epoch', type=int, default=None,
46 | help='epoch number to stop training if validation score does not increase')
47 | parser.add_argument('--max_itrs_per_epoch', type=int, default=None,
48 | help='increase the number of training iterations (suppose there are few training samples)')
49 |
50 | # Validating
51 | parser.add_argument('--val_bs', type=int, default=None,
52 | help='validating batch size for single GPU, total batch size *= gpu number')
53 | parser.add_argument('--begin_val_epoch', type=int, default=None,
54 | help='which epoch to start validating')
55 | parser.add_argument('--val_interval', type=int, default=None,
56 | help='epoch interval between two validations')
57 | parser.add_argument('--metrics', type=str, default=None, choices = ['psnr', 'ssim'],
58 | help='choose which validation metric you want to use (default: psnr)')
59 |
60 | # Testing
61 | parser.add_argument('--is_testing', action='store_true', default=None,
62 | help='whether to perform testing/predicting or not (default: False)')
63 | parser.add_argument('--test_bs', type=int, default=None,
64 | help='testing batch size (currently only support single GPU)')
65 | parser.add_argument('--test_data_folder', type=str, default=None,
66 | help='path to your testing image folder')
67 | parser.add_argument('--test_lr', action='store_false', default=None,
68 | help='whether to test the downscaled/low-resolution image or not (default: True)')
69 |
70 | # Benchmark
71 | parser.add_argument('--benchmark', action='store_true', default=None,
72 | help='whether to perform benchmarking for given datasets or not (default: False)')
73 | parser.add_argument('--benchmark_datasets', type=list, default=None,
74 | help='select which datasets to benchmark (default: [set5, set14, bsd100])')
75 |
76 | # Loss
77 | parser.add_argument('--loss_type', type=str, default=None, choices = ['mse', 'mae', 'charbonnier'],
78 | help='choose which loss you want to use')
79 |
80 | # Scheduler
81 | parser.add_argument('--lr_policy', type=str, default=None,
82 | choices = ['constant', 'step', 'linear', 'cos_warmup'],
83 | help='choose which learning rate policy you want to use (default: constant)')
84 | parser.add_argument('--warmup_epochs', type=int, default=None,
85 | help='warmup epoch number for `cos_warmup` learning rate policy')
86 | parser.add_argument('--step_size', type=int, default=None,
87 | help='step size for `step` learning rate policy')
88 | parser.add_argument('--step_gamma', type=float, default=None,
89 | help='lr reduction factor for `step` learning rate policy (default: 0.1)')
90 |
91 | # Optimizer
92 | parser.add_argument('--optimizer_type', type=str, default=None,
93 | choices = ['sgd', 'adam', 'adamw'],
94 | help='choose which optimizer you want to use (default: adam)')
95 | parser.add_argument('--momentum', type=float, default=None,
96 | help='momentum of SGD optimizer')
97 | parser.add_argument('--weight_decay', type=float, default=None,
98 | help='weight decay rate of SGD optimizer')
99 |
100 | # Monitoring
101 | parser.add_argument('--save_ckpt', action='store_false', default=None,
102 | help='whether to save checkpoint or not (default: True)')
103 | parser.add_argument('--save_dir', type=str, default=None,
104 | help='path to save checkpoints and training configurations etc.')
105 | parser.add_argument('--use_tb', action='store_false', default=None,
106 | help='whether to use tensorboard or not (default: True)')
107 | parser.add_argument('--tb_log_dir', type=str, default=None,
108 | help='path to save tensorboard logs')
109 | parser.add_argument('--ckpt_name', type=str, default=None,
110 | help='given name of the saved checkpoint, otherwise use `last` and `best`')
111 |
112 | # Training setting
113 | parser.add_argument('--amp_training', action='store_true', default=None,
114 | help='whether to use automatic mixed precision training or not (default: False)')
115 | parser.add_argument('--resume_training', action='store_false', default=None,
116 | help='whether to load training state from specific checkpoint or not if present (default: True)')
117 | parser.add_argument('--load_ckpt', action='store_false', default=None,
118 | help='whether to load given checkpoint or not if exist (default: True)')
119 | parser.add_argument('--load_ckpt_path', type=str, default=None,
120 | help='path to load specific checkpoint, otherwise try to load `last.pth`')
121 | parser.add_argument('--base_workers', type=int, default=None,
122 | help='number of workers for single GPU, total workers *= number of GPU')
123 | parser.add_argument('--random_seed', type=int, default=None,
124 | help='random seed')
125 | parser.add_argument('--use_ema', action='store_false', default=None,
126 | help='whether to use exponetial moving average to update weights or not (default: True)')
127 | parser.add_argument('--ema_decay', type=float, default=None,
128 | help='constant decay factor for EMA update, if not given use linear decay instead')
129 | parser.add_argument('--ema_start_epoch', type=int, default=None,
130 | help='epoch number to start EMA update')
131 |
132 | # Augmentation
133 | parser.add_argument('--patch_size', type=int, default=None,
134 | help='crop size for single training patch')
135 | parser.add_argument('--rotate', type=float, default=None,
136 | help='probability to perform rotation')
137 | parser.add_argument('--multi_scale', action='store_false', default=None,
138 | help='whether to perform multi-scale training or not (default: False)')
139 | parser.add_argument('--hflip', type=float, default=None,
140 | help='probability to perform horizontal flip')
141 | parser.add_argument('--vflip', type=float, default=None,
142 | help='probability to perform vertical flip')
143 |
144 | # DDP
145 | parser.add_argument('--synBN', action='store_true', default=None,
146 | help='whether to use SyncBatchNorm or not if trained with DDP (default: False)')
147 | parser.add_argument('--local_rank', type=int, default=None,
148 | help='used for DDP, DO NOT CHANGE')
149 |
150 | args = parser.parse_args()
151 | return args
152 |
--------------------------------------------------------------------------------
/core/base_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.cuda import amp
4 | from copy import deepcopy
5 | from .loss import get_loss_fn
6 | from models import get_model
7 | from datasets import get_loader, get_test_loader, get_val_loader
8 | from utils import (get_optimizer, get_scheduler, parallel_model, de_parallel,
9 | get_ema_model, set_seed, set_device, get_writer, get_logger,
10 | destroy_ddp_process, mkdir, save_config, log_config,)
11 |
12 |
13 | class BaseTrainer:
14 | def __init__(self, config):
15 | super(BaseTrainer, self).__init__()
16 | # DDP parameters, DO NOT CHANGE
17 | self.rank = int(os.getenv('RANK', -1))
18 | self.local_rank = int(os.getenv('LOCAL_RANK', -1))
19 | self.world_size = int(os.getenv('WORLD_SIZE', 1))
20 | config.DDP = self.local_rank != -1
21 | self.main_rank = self.local_rank in [-1, 0]
22 |
23 | # Logger compatible with ddp training
24 | self.logger = get_logger(config, self.main_rank)
25 |
26 | # Select device to train the model
27 | self.device = set_device(config, self.local_rank)
28 |
29 | # Automatic mixed precision training scaler
30 | self.scaler = amp.GradScaler(enabled=config.amp_training)
31 |
32 | # Create directory to save checkpoints and logs
33 | mkdir(config.save_dir)
34 |
35 | # Set random seed to obtain reproducible results
36 | set_seed(config.random_seed)
37 |
38 | # Define model and put it to the selected device
39 | self.model = get_model(config).to(self.device)
40 |
41 | if config.is_testing:
42 | self.test_loader = get_test_loader(config)
43 | elif config.benchmark:
44 | self.val_loaders = get_val_loader(config)
45 | else:
46 | # Tensorboard monitor
47 | self.writer = get_writer(config, self.main_rank)
48 |
49 | # Define loss function
50 | self.loss_fn = get_loss_fn(config, self.device)
51 |
52 | # Get train and validate loader
53 | self.train_loader, self.val_loader = get_loader(config, self.local_rank)
54 |
55 | # Define optimizer
56 | self.optimizer = get_optimizer(config, self.model)
57 |
58 | # Define scheduler to control how learning rate changes
59 | self.scheduler = get_scheduler(config, self.optimizer)
60 |
61 | # Define variables to monitor training process
62 | self.best_score = 0.
63 | self.cur_epoch = 0
64 | self.train_itrs = 0
65 | self.best_epoch = 0
66 |
67 | # Load specific checkpoints if needed
68 | self.load_ckpt(config)
69 |
70 | # Use exponential moving average of checkpoint update if needed
71 | if not config.is_testing and not config.benchmark:
72 | self.ema_model = get_ema_model(config, self.model, self.device)
73 |
74 | def run(self, config):
75 | # Parallel the model using DP or DDP
76 | self.parallel_model(config)
77 |
78 | # Output the training/validating configs (only in rank 0 if DDP)
79 | if self.main_rank:
80 | save_config(config)
81 | log_config(config, self.logger)
82 |
83 | # Start training from the latest epoch or from scratch
84 | start_epoch = self.cur_epoch
85 | for cur_epoch in range(start_epoch, config.total_epoch):
86 | self.cur_epoch = cur_epoch
87 |
88 | self.train_one_epoch(config)
89 |
90 | if cur_epoch >= config.begin_val_epoch and cur_epoch % config.val_interval == 0:
91 | val_score = self.validate(config)
92 |
93 | if self.main_rank and val_score > self.best_score:
94 | # Save best model
95 | self.best_score = val_score
96 | if config.save_ckpt:
97 | self.save_ckpt(config, save_best=True)
98 | self.best_epoch = cur_epoch
99 |
100 | if self.main_rank and config.save_ckpt:
101 | # Save last model
102 | self.save_ckpt(config)
103 |
104 | # Use early stopping if needed
105 | if cur_epoch - self.best_epoch > config.early_stop_epoch:
106 | break
107 |
108 | # Close tensorboard after training
109 | if config.use_tb and self.main_rank:
110 | self.writer.flush()
111 | self.writer.close()
112 |
113 | # Validate for the best model
114 | if config.save_ckpt:
115 | self.val_best(config)
116 |
117 | destroy_ddp_process(config)
118 |
119 | def parallel_model(self, config):
120 | self.model = parallel_model(config, self.model, self.local_rank, self.device)
121 |
122 | def train_one_epoch(self, config):
123 | '''You may implement whatever training process you like here.
124 | '''
125 | raise NotImplementedError()
126 |
127 | def validate(self, config):
128 | raise NotImplementedError()
129 |
130 | def benchmark(self, config):
131 | raise NotImplementedError()
132 |
133 | def predict(self, config):
134 | raise NotImplementedError()
135 |
136 | def load_ckpt(self, config):
137 | if config.load_ckpt and os.path.isfile(config.load_ckpt_path):
138 | checkpoint = torch.load(config.load_ckpt_path, map_location=torch.device(self.device))
139 | self.model.load_state_dict(checkpoint['state_dict'])
140 | if self.main_rank:
141 | self.logger.info(f"Load model state dict from {config.load_ckpt_path}")
142 |
143 | if not config.is_testing and not config.benchmark and config.resume_training:
144 | self.cur_epoch = checkpoint['cur_epoch'] + 1
145 | self.best_score = checkpoint['best_score']
146 | self.best_epoch = checkpoint['best_epoch']
147 | self.optimizer.load_state_dict(checkpoint['optimizer'])
148 | self.scheduler.load_state_dict(checkpoint['scheduler'])
149 | self.train_itrs = self.cur_epoch * config.iters_per_epoch
150 |
151 | if self.main_rank:
152 | self.logger.info(f"Resume training from {config.load_ckpt_path}")
153 |
154 | del checkpoint
155 | else:
156 | if config.is_testing:
157 | raise ValueError(f'Could not find any pretrained checkpoint at path: {config.load_ckpt_path}.')
158 | else:
159 | if self.main_rank:
160 | self.logger.info('[!] Train from scratch')
161 |
162 | def save_ckpt(self, config, save_best=False):
163 | if config.ckpt_name is None:
164 | save_name = 'best.pth' if save_best else 'last.pth'
165 | save_path = f'{config.save_dir}/{save_name}'
166 | state_dict = self.ema_model.ema.state_dict() if save_best else de_parallel(self.model).state_dict()
167 |
168 | torch.save({
169 | 'cur_epoch': self.cur_epoch,
170 | 'best_score': self.best_score,
171 | 'best_epoch': self.best_epoch,
172 | 'state_dict': state_dict,
173 | 'optimizer': self.optimizer.state_dict() if not save_best else None,
174 | 'scheduler': self.scheduler.state_dict() if not save_best else None,
175 | }, save_path)
176 |
177 | def val_best(self, config, ckpt_path=None):
178 | ckpt_path = f"{config.save_dir}/best.pth" if ckpt_path is None else ckpt_path
179 | if not os.path.isfile(ckpt_path):
180 | raise ValueError(f'Best checkpoint does not exist at {ckpt_path}')
181 |
182 | if self.main_rank:
183 | self.logger.info(f"\nTrain {self.cur_epoch+1}/{config.total_epoch} epochs finished!\n")
184 | self.logger.info(f'{"#"*50}\nValidation for the best checkpoint...')
185 |
186 | self.model = de_parallel(self.model)
187 | checkpoint = torch.load(ckpt_path, map_location=torch.device(self.device))
188 | self.model.load_state_dict(checkpoint['state_dict'])
189 |
190 | self.model.to(self.device)
191 | del checkpoint
192 |
193 | self.ema_model.ema = deepcopy(de_parallel(self.model)).eval()
194 |
195 | val_score = self.validate(config, val_best=True)
196 |
197 | if self.main_rank:
198 | self.logger.info(f'Best validation score is {val_score}.\n')
199 |
--------------------------------------------------------------------------------