├── README.md ├── base_networks.py ├── data.py ├── dataset.py ├── drcn.py ├── edsr.py ├── espcn.py ├── fsrcnn.py ├── lapsrn.py ├── logger.py ├── main.py ├── model.py ├── srcnn.py ├── srgan.py ├── utils.py └── vdsr.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-super-resolution-model-collection 2 | Collection of Super-Resolution models via PyTorch 3 | -------------------------------------------------------------------------------- /base_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DenseBlock(torch.nn.Module): 5 | def __init__(self, input_size, output_size, bias=True, activation='relu', norm='batch'): 6 | super(DenseBlock, self).__init__() 7 | self.fc = torch.nn.Linear(input_size, output_size, bias=bias) 8 | 9 | self.norm = norm 10 | if self.norm =='batch': 11 | self.bn = torch.nn.BatchNorm1d(output_size) 12 | elif self.norm == 'instance': 13 | self.bn = torch.nn.InstanceNorm1d(output_size) 14 | 15 | self.activation = activation 16 | if self.activation == 'relu': 17 | self.act = torch.nn.ReLU(True) 18 | elif self.activation == 'prelu': 19 | self.act = torch.nn.PReLU() 20 | elif self.activation == 'lrelu': 21 | self.act = torch.nn.LeakyReLU(0.2, True) 22 | elif self.activation == 'tanh': 23 | self.act = torch.nn.Tanh() 24 | elif self.activation == 'sigmoid': 25 | self.act = torch.nn.Sigmoid() 26 | 27 | def forward(self, x): 28 | if self.norm is not None: 29 | out = self.bn(self.fc(x)) 30 | else: 31 | out = self.fc(x) 32 | 33 | if self.activation is not None: 34 | return self.act(out) 35 | else: 36 | return out 37 | 38 | 39 | class ConvBlock(torch.nn.Module): 40 | def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='relu', norm='batch'): 41 | super(ConvBlock, self).__init__() 42 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 43 | 44 | self.norm = norm 45 | if self.norm =='batch': 46 | self.bn = torch.nn.BatchNorm2d(output_size) 47 | elif self.norm == 'instance': 48 | self.bn = torch.nn.InstanceNorm2d(output_size) 49 | 50 | self.activation = activation 51 | if self.activation == 'relu': 52 | self.act = torch.nn.ReLU(True) 53 | elif self.activation == 'prelu': 54 | self.act = torch.nn.PReLU() 55 | elif self.activation == 'lrelu': 56 | self.act = torch.nn.LeakyReLU(0.2, True) 57 | elif self.activation == 'tanh': 58 | self.act = torch.nn.Tanh() 59 | elif self.activation == 'sigmoid': 60 | self.act = torch.nn.Sigmoid() 61 | 62 | def forward(self, x): 63 | if self.norm is not None: 64 | out = self.bn(self.conv(x)) 65 | else: 66 | out = self.conv(x) 67 | 68 | if self.activation is not None: 69 | return self.act(out) 70 | else: 71 | return out 72 | 73 | 74 | class DeconvBlock(torch.nn.Module): 75 | def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='relu', norm='batch'): 76 | super(DeconvBlock, self).__init__() 77 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 78 | 79 | self.norm = norm 80 | if self.norm == 'batch': 81 | self.bn = torch.nn.BatchNorm2d(output_size) 82 | elif self.norm == 'instance': 83 | self.bn = torch.nn.InstanceNorm2d(output_size) 84 | 85 | self.activation = activation 86 | if self.activation == 'relu': 87 | self.act = torch.nn.ReLU(True) 88 | elif self.activation == 'prelu': 89 | self.act = torch.nn.PReLU() 90 | elif self.activation == 'lrelu': 91 | self.act = torch.nn.LeakyReLU(0.2, True) 92 | elif self.activation == 'tanh': 93 | self.act = torch.nn.Tanh() 94 | elif self.activation == 'sigmoid': 95 | self.act = torch.nn.Sigmoid() 96 | 97 | def forward(self, x): 98 | if self.norm is not None: 99 | out = self.bn(self.deconv(x)) 100 | else: 101 | out = self.deconv(x) 102 | 103 | if self.activation is not None: 104 | return self.act(out) 105 | else: 106 | return out 107 | 108 | 109 | class ResnetBlock(torch.nn.Module): 110 | def __init__(self, num_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='relu', norm='batch'): 111 | super(ResnetBlock, self).__init__() 112 | self.conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias) 113 | self.conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias) 114 | 115 | self.norm = norm 116 | if self.norm == 'batch': 117 | self.bn = torch.nn.BatchNorm2d(num_filter) 118 | elif norm == 'instance': 119 | self.bn = torch.nn.InstanceNorm2d(num_filter) 120 | 121 | self.activation = activation 122 | if self.activation == 'relu': 123 | self.act = torch.nn.ReLU(True) 124 | elif self.activation == 'prelu': 125 | self.act = torch.nn.PReLU() 126 | elif self.activation == 'lrelu': 127 | self.act = torch.nn.LeakyReLU(0.2, True) 128 | elif self.activation == 'tanh': 129 | self.act = torch.nn.Tanh() 130 | elif self.activation == 'sigmoid': 131 | self.act = torch.nn.Sigmoid() 132 | 133 | 134 | def forward(self, x): 135 | residual = x 136 | if self.norm is not None: 137 | out = self.bn(self.conv1(x)) 138 | else: 139 | out = self.conv1(x) 140 | 141 | if self.activation is not None: 142 | out = self.act(out) 143 | 144 | if self.norm is not None: 145 | out = self.bn(self.conv2(out)) 146 | else: 147 | out = self.conv2(out) 148 | 149 | out = torch.add(out, residual) 150 | return out 151 | 152 | 153 | class PSBlock(torch.nn.Module): 154 | def __init__(self, input_size, output_size, scale_factor, kernel_size=3, stride=1, padding=1, bias=True, activation='relu', norm='batch'): 155 | super(PSBlock, self).__init__() 156 | self.conv = torch.nn.Conv2d(input_size, output_size * scale_factor**2, kernel_size, stride, padding, bias=bias) 157 | self.ps = torch.nn.PixelShuffle(scale_factor) 158 | 159 | self.norm = norm 160 | if self.norm == 'batch': 161 | self.bn = torch.nn.BatchNorm2d(output_size) 162 | elif norm == 'instance': 163 | self.bn = torch.nn.InstanceNorm2d(output_size) 164 | 165 | self.activation = activation 166 | if self.activation == 'relu': 167 | self.act = torch.nn.ReLU(True) 168 | elif self.activation == 'prelu': 169 | self.act = torch.nn.PReLU() 170 | elif self.activation == 'lrelu': 171 | self.act = torch.nn.LeakyReLU(0.2, True) 172 | elif self.activation == 'tanh': 173 | self.act = torch.nn.Tanh() 174 | elif self.activation == 'sigmoid': 175 | self.act = torch.nn.Sigmoid() 176 | 177 | def forward(self, x): 178 | if self.norm is not None: 179 | out = self.bn(self.ps(self.conv(x))) 180 | else: 181 | out = self.ps(self.conv(x)) 182 | 183 | if self.activation is not None: 184 | out = self.act(out) 185 | return out 186 | 187 | 188 | class Upsample2xBlock(torch.nn.Module): 189 | def __init__(self, input_size, output_size, bias=True, upsample='deconv', activation='relu', norm='batch'): 190 | super(Upsample2xBlock, self).__init__() 191 | scale_factor = 2 192 | # 1. Deconvolution (Transposed convolution) 193 | if upsample == 'deconv': 194 | self.upsample = DeconvBlock(input_size, output_size, 195 | kernel_size=4, stride=2, padding=1, 196 | bias=bias, activation=activation, norm=norm) 197 | 198 | # 2. Sub-pixel convolution (Pixel shuffler) 199 | elif upsample == 'ps': 200 | self.upsample = PSBlock(input_size, output_size, scale_factor=scale_factor, 201 | bias=bias, activation=activation, norm=norm) 202 | 203 | # 3. Resize and Convolution 204 | elif upsample == 'rnc': 205 | self.upsample = torch.nn.Sequential( 206 | torch.nn.Upsample(scale_factor=scale_factor, mode='nearest'), 207 | ConvBlock(input_size, output_size, 208 | kernel_size=3, stride=1, padding=1, 209 | bias=bias, activation=activation, norm=norm) 210 | ) 211 | 212 | def forward(self, x): 213 | out = self.upsample(x) 214 | return out 215 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from os.path import exists, join, basename 2 | from os import makedirs, remove 3 | from six.moves import urllib 4 | import tarfile 5 | from dataset import * 6 | 7 | 8 | def download_bsds300(dest="dataset"): 9 | output_image_dir = join(dest, "BSDS300/images") 10 | 11 | if not exists(output_image_dir): 12 | if not exists(dest): 13 | makedirs(dest) 14 | url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300-images.tgz" 15 | print("downloading url ", url) 16 | 17 | data = urllib.request.urlopen(url) 18 | 19 | file_path = join(dest, basename(url)) 20 | with open(file_path, 'wb') as f: 21 | f.write(data.read()) 22 | 23 | print("Extracting data") 24 | with tarfile.open(file_path) as tar: 25 | for item in tar: 26 | tar.extract(item, dest) 27 | 28 | remove(file_path) 29 | 30 | return output_image_dir 31 | 32 | 33 | def get_training_set(data_dir, datasets, crop_size, scale_factor, is_gray=False): 34 | train_dir = [] 35 | for dataset in datasets: 36 | if dataset == 'bsds300': 37 | root_dir = download_bsds300(data_dir) 38 | train_dir.append(join(root_dir, "train")) 39 | elif dataset == 'DIV2K': 40 | train_dir.append(join(data_dir, dataset, 'DIV2K_train_LR_bicubic/X4')) 41 | else: 42 | train_dir.append(join(data_dir, dataset)) 43 | 44 | return TrainDatasetFromFolder(train_dir, 45 | is_gray=is_gray, 46 | random_scale=True, # random scaling 47 | crop_size=crop_size, # random crop 48 | rotate=True, # random rotate 49 | fliplr=True, # random flip 50 | fliptb=True, 51 | scale_factor=scale_factor) 52 | 53 | 54 | def get_test_set(data_dir, dataset, scale_factor, is_gray=False): 55 | if dataset == 'bsds300': 56 | root_dir = download_bsds300(data_dir) 57 | test_dir = join(root_dir, "test") 58 | elif dataset == 'DIV2K': 59 | test_dir = join(data_dir, dataset, 'DIV2K_test_LR_bicubic/X4') 60 | else: 61 | test_dir = join(data_dir, dataset) 62 | 63 | return TestDatasetFromFolder(test_dir, 64 | is_gray=is_gray, 65 | scale_factor=scale_factor) 66 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torchvision.transforms import * 3 | from os import listdir 4 | from os.path import join 5 | from PIL import Image 6 | import random 7 | 8 | 9 | def is_image_file(filename): 10 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".bmp"]) 11 | 12 | 13 | def load_img(filepath): 14 | img = Image.open(filepath).convert('RGB') 15 | return img 16 | 17 | 18 | def calculate_valid_crop_size(crop_size, scale_factor): 19 | return crop_size - (crop_size % scale_factor) 20 | 21 | 22 | class TrainDatasetFromFolder(data.Dataset): 23 | def __init__(self, image_dirs, is_gray=False, random_scale=True, crop_size=128, rotate=True, fliplr=True, 24 | fliptb=True, scale_factor=4): 25 | super(TrainDatasetFromFolder, self).__init__() 26 | 27 | self.image_filenames = [] 28 | for image_dir in image_dirs: 29 | self.image_filenames.extend(join(image_dir, x) for x in sorted(listdir(image_dir)) if is_image_file(x)) 30 | self.is_gray = is_gray 31 | self.random_scale = random_scale 32 | self.crop_size = crop_size 33 | self.rotate = rotate 34 | self.fliplr = fliplr 35 | self.fliptb = fliptb 36 | self.scale_factor = scale_factor 37 | 38 | def __getitem__(self, index): 39 | # load image 40 | img = load_img(self.image_filenames[index]) 41 | 42 | # determine valid HR image size with scale factor 43 | self.crop_size = calculate_valid_crop_size(self.crop_size, self.scale_factor) 44 | hr_img_w = self.crop_size 45 | hr_img_h = self.crop_size 46 | 47 | # determine LR image size 48 | lr_img_w = hr_img_w // self.scale_factor 49 | lr_img_h = hr_img_h // self.scale_factor 50 | 51 | # random scaling between [0.5, 1.0] 52 | if self.random_scale: 53 | eps = 1e-3 54 | ratio = random.randint(5, 10) * 0.1 55 | if hr_img_w * ratio < self.crop_size: 56 | ratio = self.crop_size / hr_img_w + eps 57 | if hr_img_h * ratio < self.crop_size: 58 | ratio = self.crop_size / hr_img_h + eps 59 | 60 | scale_w = int(hr_img_w * ratio) 61 | scale_h = int(hr_img_h * ratio) 62 | transform = Scale((scale_w, scale_h), interpolation=Image.BICUBIC) 63 | img = transform(img) 64 | 65 | # random crop 66 | transform = RandomCrop(self.crop_size) 67 | img = transform(img) 68 | 69 | # random rotation between [90, 180, 270] degrees 70 | if self.rotate: 71 | rv = random.randint(1, 3) 72 | img = img.rotate(90 * rv, expand=True) 73 | 74 | # random horizontal flip 75 | if self.fliplr: 76 | transform = RandomHorizontalFlip() 77 | img = transform(img) 78 | 79 | # random vertical flip 80 | if self.fliptb: 81 | if random.random() < 0.5: 82 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 83 | 84 | # only Y-channel is super-resolved 85 | if self.is_gray: 86 | img = img.convert('YCbCr') 87 | # img, _, _ = img.split() 88 | 89 | # hr_img HR image 90 | hr_transform = Compose([Scale((hr_img_w, hr_img_h), interpolation=Image.BICUBIC), ToTensor()]) 91 | hr_img = hr_transform(img) 92 | 93 | # lr_img LR image 94 | lr_transform = Compose([Scale((lr_img_w, lr_img_h), interpolation=Image.BICUBIC), ToTensor()]) 95 | lr_img = lr_transform(img) 96 | 97 | # Bicubic interpolated image 98 | bc_transform = Compose([ToPILImage(), Scale((hr_img_w, hr_img_h), interpolation=Image.BICUBIC), ToTensor()]) 99 | bc_img = bc_transform(lr_img) 100 | 101 | return lr_img, hr_img, bc_img 102 | 103 | def __len__(self): 104 | return len(self.image_filenames) 105 | 106 | 107 | class TestDatasetFromFolder(data.Dataset): 108 | def __init__(self, image_dir, is_gray=False, scale_factor=4): 109 | super(TestDatasetFromFolder, self).__init__() 110 | 111 | self.image_filenames = [join(image_dir, x) for x in sorted(listdir(image_dir)) if is_image_file(x)] 112 | self.is_gray = is_gray 113 | self.scale_factor = scale_factor 114 | 115 | def __getitem__(self, index): 116 | # load image 117 | img = load_img(self.image_filenames[index]) 118 | 119 | # original HR image size 120 | w = img.size[0] 121 | h = img.size[1] 122 | 123 | # determine valid HR image size with scale factor 124 | hr_img_w = calculate_valid_crop_size(w, self.scale_factor) 125 | hr_img_h = calculate_valid_crop_size(h, self.scale_factor) 126 | 127 | # determine lr_img LR image size 128 | lr_img_w = hr_img_w // self.scale_factor 129 | lr_img_h = hr_img_h // self.scale_factor 130 | 131 | # only Y-channel is super-resolved 132 | if self.is_gray: 133 | img = img.convert('YCbCr') 134 | # img, _, _ = lr_img.split() 135 | 136 | # hr_img HR image 137 | hr_transform = Compose([Scale((hr_img_w, hr_img_h), interpolation=Image.BICUBIC), ToTensor()]) 138 | hr_img = hr_transform(img) 139 | 140 | # lr_img LR image 141 | lr_transform = Compose([Scale((lr_img_w, lr_img_h), interpolation=Image.BICUBIC), ToTensor()]) 142 | lr_img = lr_transform(img) 143 | 144 | # Bicubic interpolated image 145 | bc_transform = Compose([ToPILImage(), Scale((hr_img_w, hr_img_h), interpolation=Image.BICUBIC), ToTensor()]) 146 | bc_img = bc_transform(lr_img) 147 | 148 | return lr_img, hr_img, bc_img 149 | 150 | def __len__(self): 151 | return len(self.image_filenames) -------------------------------------------------------------------------------- /drcn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from base_networks import * 6 | from torch.utils.data import DataLoader 7 | from data import get_training_set, get_test_set 8 | import utils 9 | from logger import Logger 10 | from torchvision.transforms import * 11 | 12 | 13 | class Net(torch.nn.Module): 14 | def __init__(self, num_channels, base_filter, num_recursions): 15 | super(Net, self).__init__() 16 | self.num_recursions = num_recursions 17 | # embedding layer 18 | self.embedding_layer = nn.Sequential( 19 | ConvBlock(num_channels, base_filter, 3, 1, 1, norm=None), 20 | ConvBlock(base_filter, base_filter, 3, 1, 1, norm=None) 21 | ) 22 | 23 | # conv block of inference layer 24 | self.conv_block = ConvBlock(base_filter, base_filter, 3, 1, 1, norm=None) 25 | 26 | # reconstruction layer 27 | self.reconstruction_layer = nn.Sequential( 28 | ConvBlock(base_filter, base_filter, 3, 1, 1, activation=None, norm=None), 29 | ConvBlock(base_filter, num_channels, 3, 1, 1, activation=None, norm=None) 30 | ) 31 | 32 | # initial w 33 | self.w_init = torch.ones(self.num_recursions) / self.num_recursions 34 | self.w = Variable(self.w_init.cuda(), requires_grad=True) 35 | 36 | def forward(self, x): 37 | # embedding layer 38 | h0 = self.embedding_layer(x) 39 | 40 | # recursions 41 | h = [h0] 42 | for d in range(self.num_recursions): 43 | h.append(self.conv_block(h[d])) 44 | 45 | y_d_ = [] 46 | out_sum = 0 47 | for d in range(self.num_recursions): 48 | y_d_.append(self.reconstruction_layer(h[d+1])) 49 | out_sum += torch.mul(y_d_[d], self.w[d]) 50 | out_sum = torch.mul(out_sum, 1.0 / (torch.sum(self.w))) 51 | 52 | # skip connection 53 | final_out = torch.add(out_sum, x) 54 | 55 | return y_d_, final_out 56 | 57 | def weight_init(self): 58 | for m in self.modules(): 59 | utils.weights_init_kaming(m) 60 | 61 | 62 | class DRCN(object): 63 | def __init__(self, args): 64 | # parameters 65 | self.model_name = args.model_name 66 | self.train_dataset = args.train_dataset 67 | self.test_dataset = args.test_dataset 68 | self.crop_size = args.crop_size 69 | self.num_threads = args.num_threads 70 | self.num_channels = args.num_channels 71 | self.scale_factor = args.scale_factor 72 | self.num_epochs = args.num_epochs 73 | self.save_epochs = args.save_epochs 74 | self.batch_size = args.batch_size 75 | self.test_batch_size = args.test_batch_size 76 | self.lr = args.lr 77 | self.data_dir = args.data_dir 78 | self.save_dir = args.save_dir 79 | self.gpu_mode = args.gpu_mode 80 | 81 | def load_dataset(self, dataset='train'): 82 | if self.num_channels == 1: 83 | is_gray = True 84 | else: 85 | is_gray = False 86 | 87 | if dataset == 'train': 88 | print('Loading train datasets...') 89 | train_set = get_training_set(self.data_dir, self.train_dataset, self.crop_size, self.scale_factor, is_gray=is_gray, 90 | normalize=False) 91 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 92 | shuffle=True) 93 | elif dataset == 'test': 94 | print('Loading test datasets...') 95 | test_set = get_test_set(self.data_dir, self.test_dataset, self.scale_factor, is_gray=is_gray, 96 | normalize=False) 97 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 98 | batch_size=self.test_batch_size, 99 | shuffle=False) 100 | 101 | def train(self): 102 | # networks 103 | self.num_recursions = 16 104 | self.model = Net(num_channels=self.num_channels, base_filter=256, num_recursions=self.num_recursions) 105 | 106 | # weigh initialization 107 | self.model.weight_init() 108 | 109 | # optimizer 110 | self.momentum = 0.9 111 | self.weight_decay = 0.0001 112 | self.loss_alpha = 1.0 113 | self.loss_alpha_zero_epoch = 25 114 | self.loss_alpha_decay = self.loss_alpha / self.loss_alpha_zero_epoch 115 | self.loss_beta = 0.001 116 | 117 | # learnable parameters 118 | param_groups = list(self.model.parameters()) 119 | param_groups = [{'params': param_groups}] 120 | param_groups += [{'params': [self.model.w]}] 121 | self.optimizer = optim.Adam(param_groups, lr=self.lr) 122 | 123 | # loss function 124 | if self.gpu_mode: 125 | self.model.cuda() 126 | self.MSE_loss = nn.MSELoss().cuda() 127 | else: 128 | self.MSE_loss = nn.MSELoss() 129 | 130 | print('---------- Networks architecture -------------') 131 | utils.print_network(self.model) 132 | print('----------------------------------------------') 133 | 134 | # load dataset 135 | train_data_loader = self.load_dataset(dataset='train') 136 | test_data_loader = self.load_dataset(dataset='test') 137 | 138 | # set the logger 139 | log_dir = os.path.join(self.save_dir, 'logs') 140 | if not os.path.exists(log_dir): 141 | os.mkdir(log_dir) 142 | logger = Logger(log_dir) 143 | 144 | ################# Train ################# 145 | print('Training is started.') 146 | avg_loss = [] 147 | step = 0 148 | 149 | # test image 150 | test_input, test_target = test_data_loader.dataset.__getitem__(2) 151 | test_input = test_input.unsqueeze(0) 152 | test_target = test_target.unsqueeze(0) 153 | 154 | self.model.train() 155 | for epoch in range(self.num_epochs): 156 | 157 | # learning rate is decayed by a factor of 10 every 20 epochs 158 | if (epoch + 1) % 20 == 0: 159 | for param_group in self.optimizer.param_groups: 160 | param_group["lr"] /= 10.0 161 | print("Learning rate decay: lr={}".format(self.optimizer.param_groups[0]["lr"])) 162 | 163 | # loss_alpha decayed to zero after 25 epochs 164 | self.loss_alpha = max(0.0, self.loss_alpha - self.loss_alpha_decay) 165 | 166 | epoch_loss = 0 167 | for iter, (input, target) in enumerate(train_data_loader): 168 | # input data (bicubic interpolated image) 169 | if self.gpu_mode: 170 | y = Variable(target.cuda()) 171 | x = Variable(utils.img_interp(input, self.scale_factor).cuda()) 172 | else: 173 | y = Variable(target) 174 | x = Variable(utils.img_interp(input, self.scale_factor)) 175 | 176 | # update network 177 | self.optimizer.zero_grad() 178 | y_d_, y_ = self.model(x) 179 | 180 | # loss1 181 | loss1 = 0 182 | for d in range(self.num_recursions): 183 | loss1 += (self.MSE_loss(y_d_[d], y) / self.num_recursions) 184 | 185 | # loss2 186 | loss2 = self.MSE_loss(y_, y) 187 | 188 | # regularization 189 | reg_term = 0 190 | for theta in self.model.parameters(): 191 | reg_term += torch.mean(torch.sum(theta ** 2)) 192 | 193 | # total loss 194 | 195 | loss = self.loss_alpha * loss1 + (1-self.loss_alpha) * loss2 + self.loss_beta * reg_term 196 | loss.backward() 197 | self.optimizer.step() 198 | 199 | # log 200 | epoch_loss += loss.data[0] 201 | print("Epoch: [%2d] [%4d/%4d] loss: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), loss.data[0])) 202 | 203 | # tensorboard logging 204 | logger.scalar_summary('loss', loss.data[0], step + 1) 205 | step += 1 206 | 207 | # avg. loss per epoch 208 | avg_loss.append(epoch_loss / len(train_data_loader)) 209 | 210 | # prediction 211 | _, recon_imgs = self.model(Variable(utils.img_interp(test_input, self.scale_factor).cuda())) 212 | recon_img = recon_imgs[0].cpu().data 213 | gt_img = test_target[0] 214 | lr_img = test_input[0] 215 | bc_img = utils.img_interp(test_input[0], self.scale_factor) 216 | 217 | # calculate psnrs 218 | bc_psnr = utils.PSNR(bc_img, gt_img) 219 | recon_psnr = utils.PSNR(recon_img, gt_img) 220 | 221 | # save result images 222 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 223 | psnrs = [None, None, bc_psnr, recon_psnr] 224 | utils.plot_test_result(result_imgs, psnrs, epoch + 1, save_dir=self.save_dir, is_training=True) 225 | 226 | print("Saving training result images at epoch %d" % (epoch + 1)) 227 | 228 | # Save trained parameters of model 229 | if (epoch + 1) % self.save_epochs == 0: 230 | self.save_model(epoch + 1) 231 | 232 | # Plot avg. loss 233 | utils.plot_loss([avg_loss], self.num_epochs, save_dir=self.save_dir) 234 | print("Training is finished.") 235 | 236 | # Save final trained parameters of model 237 | self.save_model(epoch=None) 238 | 239 | def test(self): 240 | # networks 241 | self.num_recursions = 16 242 | self.model = Net(num_channels=self.num_channels, base_filter=256, num_recursions=self.num_recursions) 243 | 244 | if self.gpu_mode: 245 | self.model.cuda() 246 | 247 | # load model 248 | self.load_model() 249 | 250 | # load dataset 251 | test_data_loader = self.load_dataset(dataset='test') 252 | 253 | # Test 254 | print('Test is started.') 255 | img_num = 0 256 | self.model.eval() 257 | for input, target in test_data_loader: 258 | # input data (bicubic interpolated image) 259 | if self.gpu_mode: 260 | y_ = Variable(utils.img_interp(input, self.scale_factor).cuda()) 261 | else: 262 | y_ = Variable(utils.img_interp(input, self.scale_factor)) 263 | 264 | # prediction 265 | _, recon_imgs = self.model(y_) 266 | for i, recon_img in enumerate(recon_imgs): 267 | img_num += 1 268 | recon_img = recon_imgs[i].cpu().data 269 | gt_img = target[i] 270 | lr_img = input[i] 271 | bc_img = utils.img_interp(input[i], self.scale_factor) 272 | 273 | # calculate psnrs 274 | bc_psnr = utils.PSNR(bc_img, gt_img) 275 | recon_psnr = utils.PSNR(recon_img, gt_img) 276 | 277 | # save result images 278 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 279 | psnrs = [None, None, bc_psnr, recon_psnr] 280 | utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=self.save_dir) 281 | 282 | print("Saving %d test result images..." % img_num) 283 | 284 | def test_single(self, img_fn): 285 | # networks 286 | self.model = Net(num_channels=self.num_channels, base_filter=256, num_recursions=self.num_recursions) 287 | 288 | if self.gpu_mode: 289 | self.model.cuda() 290 | 291 | # load model 292 | self.load_model() 293 | 294 | # load data 295 | img = Image.open(img_fn) 296 | img = img.convert('YCbCr') 297 | y, cb, cr = img.split() 298 | 299 | input = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0]) 300 | if self.gpu_mode: 301 | input = input.cuda() 302 | 303 | self.model.eval() 304 | recon_img = self.model(input) 305 | 306 | # save result images 307 | utils.save_img(recon_img.cpu().data, 1, save_dir=self.save_dir) 308 | 309 | out = recon_img.cpu() 310 | out_img_y = out.data[0] 311 | out_img_y = (((out_img_y - out_img_y.min()) * 255) / (out_img_y.max() - out_img_y.min())).numpy() 312 | # out_img_y *= 255.0 313 | # out_img_y = out_img_y.clip(0, 255) 314 | out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L') 315 | 316 | out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) 317 | out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) 318 | out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB') 319 | 320 | # save img 321 | result_dir = os.path.join(self.save_dir, 'result') 322 | if not os.path.exists(result_dir): 323 | os.mkdir(result_dir) 324 | save_fn = result_dir + '/SR_result.png' 325 | out_img.save(save_fn) 326 | 327 | def save_model(self, epoch=None): 328 | model_dir = os.path.join(self.save_dir, 'model') 329 | if not os.path.exists(model_dir): 330 | os.mkdir(model_dir) 331 | if epoch is not None: 332 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param_epoch_%d.pkl' % epoch) 333 | torch.save(self.model.w, model_dir + '/' + self.model_name + '_w_epoch_%d.pkl' % epoch) 334 | else: 335 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param.pkl') 336 | torch.save(self.model.w, model_dir + '/' + self.model_name + '_w.pkl') 337 | 338 | print('Trained model is saved.') 339 | 340 | def load_model(self): 341 | model_dir = os.path.join(self.save_dir, 'model') 342 | 343 | model_name = model_dir + '/' + self.model_name + '_param.pkl' 344 | if os.path.exists(model_name): 345 | self.model.load_state_dict(torch.load(model_name)) 346 | print('Trained model is loaded.') 347 | else: 348 | print('No model exists to load.') 349 | 350 | w_name = model_dir + '/' + self.model_name + '_w.pkl' 351 | if os.path.exists(w_name): 352 | self.model.w = torch.load(w_name) 353 | print('Trained weight is loaded.') 354 | else: 355 | print('No weight exists to load.') 356 | -------------------------------------------------------------------------------- /edsr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from base_networks import * 6 | from torch.utils.data import DataLoader 7 | from data import get_training_set, get_test_set 8 | import utils 9 | from logger import Logger 10 | from torchvision.transforms import * 11 | 12 | 13 | class Net(torch.nn.Module): 14 | def __init__(self, num_channels, base_filter, num_residuals): 15 | super(Net, self).__init__() 16 | 17 | self.input_conv = ConvBlock(num_channels, base_filter, 3, 1, 1, activation=None, norm=None) 18 | 19 | resnet_blocks = [] 20 | for _ in range(num_residuals): 21 | resnet_blocks.append(ResnetBlock(base_filter, norm=None)) 22 | self.residual_layers = nn.Sequential(*resnet_blocks) 23 | 24 | self.mid_conv = ConvBlock(base_filter, base_filter, 3, 1, 1, activation=None, norm=None) 25 | 26 | self.upscale4x = nn.Sequential( 27 | Upsample2xBlock(base_filter, base_filter, upsample='ps', activation=None, norm=None), 28 | Upsample2xBlock(base_filter, base_filter, upsample='ps', activation=None, norm=None), 29 | ) 30 | 31 | self.output_conv = ConvBlock(base_filter, num_channels, 3, 1, 1, activation=None, norm=None) 32 | 33 | def weight_init(self, mean=0.0, std=0.02): 34 | for m in self.modules(): 35 | utils.weights_init_normal(m, mean=mean, std=std) 36 | 37 | def forward(self, x): 38 | out = self.input_conv(x) 39 | residual = out 40 | out = self.residual_layers(out) 41 | out = self.mid_conv(out) 42 | out = torch.add(out, residual) 43 | out = self.upscale4x(out) 44 | out = self.output_conv(out) 45 | return out 46 | 47 | 48 | class EDSR(object): 49 | def __init__(self, args): 50 | # parameters 51 | self.model_name = args.model_name 52 | self.train_dataset = args.train_dataset 53 | self.test_dataset = args.test_dataset 54 | self.crop_size = args.crop_size 55 | self.num_threads = args.num_threads 56 | self.num_channels = args.num_channels 57 | self.scale_factor = args.scale_factor 58 | self.num_epochs = args.num_epochs 59 | self.save_epochs = args.save_epochs 60 | self.batch_size = args.batch_size 61 | self.test_batch_size = args.test_batch_size 62 | self.lr = args.lr 63 | self.data_dir = args.data_dir 64 | self.save_dir = args.save_dir 65 | self.gpu_mode = args.gpu_mode 66 | 67 | def load_dataset(self, dataset, is_train=True): 68 | if self.num_channels == 1: 69 | is_gray = True 70 | else: 71 | is_gray = False 72 | 73 | if is_train: 74 | print('Loading train datasets...') 75 | train_set = get_training_set(self.data_dir, dataset, self.crop_size, self.scale_factor, is_gray=is_gray) 76 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 77 | shuffle=True) 78 | else: 79 | print('Loading test datasets...') 80 | test_set = get_test_set(self.data_dir, dataset, self.scale_factor, is_gray=is_gray) 81 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 82 | batch_size=self.test_batch_size, 83 | shuffle=False) 84 | 85 | def train(self): 86 | # networks 87 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_residuals=16) 88 | 89 | # weigh initialization 90 | self.model.weight_init() 91 | 92 | # optimizer 93 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-8) 94 | 95 | # loss function 96 | if self.gpu_mode: 97 | self.model.cuda() 98 | self.L1_loss = nn.L1Loss().cuda() 99 | else: 100 | self.L1_loss = nn.L1Loss() 101 | 102 | print('---------- Networks architecture -------------') 103 | utils.print_network(self.model) 104 | print('----------------------------------------------') 105 | 106 | # load dataset 107 | train_data_loader = self.load_dataset(dataset=self.train_dataset, is_train=True) 108 | test_data_loader = self.load_dataset(dataset=self.test_dataset[0], is_train=False) 109 | 110 | # set the logger 111 | log_dir = os.path.join(self.save_dir, 'logs') 112 | if not os.path.exists(log_dir): 113 | os.makedirs(log_dir) 114 | logger = Logger(log_dir) 115 | 116 | ################# Train ################# 117 | print('Training is started.') 118 | avg_loss = [] 119 | step = 0 120 | 121 | # test image 122 | test_lr, test_hr, test_bc = test_data_loader.dataset.__getitem__(2) 123 | test_lr = test_lr.unsqueeze(0) 124 | test_hr = test_hr.unsqueeze(0) 125 | test_bc = test_bc.unsqueeze(0) 126 | 127 | self.model.train() 128 | for epoch in range(self.num_epochs): 129 | 130 | # learning rate is decayed by a factor of 2 every 40 epochs 131 | if (epoch+1) % 40 == 0: 132 | for param_group in self.optimizer.param_groups: 133 | param_group['lr'] /= 2.0 134 | print('Learning rate decay: lr={}'.format(self.optimizer.param_groups[0]['lr'])) 135 | 136 | epoch_loss = 0 137 | for iter, (lr, hr, _) in enumerate(train_data_loader): 138 | # input data (low resolution image) 139 | if self.num_channels == 1: 140 | x_ = Variable(hr[:, 0].unsqueeze(1)) 141 | y_ = Variable(lr[:, 0].unsqueeze(1)) 142 | else: 143 | x_ = Variable(hr) 144 | y_ = Variable(lr) 145 | 146 | if self.gpu_mode: 147 | x_ = x_.cuda() 148 | y_ = y_.cuda() 149 | 150 | # update network 151 | self.optimizer.zero_grad() 152 | recon_image = self.model(y_) 153 | loss = self.L1_loss(recon_image, x_) 154 | loss.backward() 155 | self.optimizer.step() 156 | 157 | # log 158 | epoch_loss += loss.data[0] 159 | print('Epoch: [%2d] [%4d/%4d] loss: %.8f' % ((epoch + 1), (iter + 1), len(train_data_loader), loss.data[0])) 160 | 161 | # tensorboard logging 162 | logger.scalar_summary('loss', loss.data[0], step + 1) 163 | step += 1 164 | 165 | # avg. loss per epoch 166 | avg_loss.append(epoch_loss / len(train_data_loader)) 167 | 168 | # prediction 169 | if self.num_channels == 1: 170 | y_ = Variable(test_lr[:, 0].unsqueeze(1)) 171 | else: 172 | y_ = Variable(test_lr) 173 | 174 | if self.gpu_mode: 175 | y_ = y_.cuda() 176 | 177 | recon_img = self.model(y_) 178 | sr_img = recon_img[0].cpu().data 179 | 180 | # save result image 181 | save_dir = os.path.join(self.save_dir, 'train_result') 182 | utils.save_img(sr_img, epoch + 1, save_dir=save_dir, is_training=True) 183 | print('Result image at epoch %d is saved.' % (epoch + 1)) 184 | 185 | # Save trained parameters of model 186 | if (epoch + 1) % self.save_epochs == 0: 187 | self.save_model(epoch + 1) 188 | 189 | # calculate psnrs 190 | if self.num_channels == 1: 191 | gt_img = test_hr[0][0].unsqueeze(0) 192 | lr_img = test_lr[0][0].unsqueeze(0) 193 | bc_img = test_bc[0][0].unsqueeze(0) 194 | else: 195 | gt_img = test_hr[0] 196 | lr_img = test_lr[0] 197 | bc_img = test_bc[0] 198 | 199 | bc_psnr = utils.PSNR(bc_img, gt_img) 200 | recon_psnr = utils.PSNR(sr_img, gt_img) 201 | 202 | # plot result images 203 | result_imgs = [gt_img, lr_img, bc_img, sr_img] 204 | psnrs = [None, None, bc_psnr, recon_psnr] 205 | utils.plot_test_result(result_imgs, psnrs, self.num_epochs, save_dir=save_dir, is_training=True) 206 | print('Training result image is saved.') 207 | 208 | # Plot avg. loss 209 | utils.plot_loss([avg_loss], self.num_epochs, save_dir=save_dir) 210 | print('Training is finished.') 211 | 212 | # Save final trained parameters of model 213 | self.save_model(epoch=None) 214 | 215 | def test(self): 216 | # networks 217 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_residuals=16) 218 | 219 | if self.gpu_mode: 220 | self.model.cuda() 221 | 222 | # load model 223 | self.load_model() 224 | 225 | # load dataset 226 | for test_dataset in self.test_dataset: 227 | test_data_loader = self.load_dataset(dataset=test_dataset, is_train=False) 228 | 229 | # Test 230 | print('Test is started.') 231 | img_num = 0 232 | total_img_num = len(test_data_loader) 233 | self.model.eval() 234 | for lr, hr, bc in test_data_loader: 235 | # input data (low resolution image) 236 | if self.num_channels == 1: 237 | y_ = Variable(lr[:, 0].unsqueeze(1)) 238 | else: 239 | y_ = Variable(lr) 240 | 241 | if self.gpu_mode: 242 | y_ = y_.cuda() 243 | 244 | # prediction 245 | recon_imgs = self.model(y_) 246 | for i, recon_img in enumerate(recon_imgs): 247 | img_num += 1 248 | sr_img = recon_img.cpu().data 249 | 250 | # save result image 251 | save_dir = os.path.join(self.save_dir, 'test_result', test_dataset) 252 | utils.save_img(sr_img, img_num, save_dir=save_dir) 253 | 254 | # calculate psnrs 255 | if self.num_channels == 1: 256 | gt_img = hr[i][0].unsqueeze(0) 257 | lr_img = lr[i][0].unsqueeze(0) 258 | bc_img = bc[i][0].unsqueeze(0) 259 | else: 260 | gt_img = hr[i] 261 | lr_img = lr[i] 262 | bc_img = bc[i] 263 | 264 | bc_psnr = utils.PSNR(bc_img, gt_img) 265 | recon_psnr = utils.PSNR(sr_img, gt_img) 266 | 267 | # plot result images 268 | result_imgs = [gt_img, lr_img, bc_img, sr_img] 269 | psnrs = [None, None, bc_psnr, recon_psnr] 270 | utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=save_dir) 271 | 272 | print('Test DB: %s, Saving result images...[%d/%d]' % (test_dataset, img_num, total_img_num)) 273 | 274 | print('Test is finishied.') 275 | 276 | def test_single(self, img_fn): 277 | # networks 278 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_residuals=16) 279 | 280 | if self.gpu_mode: 281 | self.model.cuda() 282 | 283 | # load model 284 | self.load_model() 285 | 286 | # load data 287 | img = Image.open(img_fn).convert('RGB') 288 | 289 | if self.num_channels == 1: 290 | img = img.convert('YCbCr') 291 | img_y, img_cb, img_cr = img.split() 292 | 293 | input = ToTensor()(img_y) 294 | y_ = Variable(input.unsqueeze(1)) 295 | else: 296 | input = ToTensor()(img).view(1, -1, img.height, img.width) 297 | y_ = Variable(input) 298 | 299 | if self.gpu_mode: 300 | y_ = y_.cuda() 301 | 302 | # prediction 303 | self.model.eval() 304 | recon_img = self.model(y_) 305 | recon_img = recon_img.cpu().data[0].clamp(0, 1) 306 | recon_img = ToPILImage()(recon_img) 307 | 308 | if self.num_channels == 1: 309 | # merge color channels with super-resolved Y-channel 310 | recon_y = recon_img 311 | recon_cb = img_cb.resize(recon_y.size, Image.BICUBIC) 312 | recon_cr = img_cr.resize(recon_y.size, Image.BICUBIC) 313 | recon_img = Image.merge('YCbCr', [recon_y, recon_cb, recon_cr]).convert('RGB') 314 | 315 | # save img 316 | result_dir = os.path.join(self.save_dir, 'test_result') 317 | if not os.path.exists(result_dir): 318 | os.makedirs(result_dir) 319 | save_fn = result_dir + '/SR_result.png' 320 | recon_img.save(save_fn) 321 | 322 | print('Single test result image is saved.') 323 | 324 | def save_model(self, epoch=None): 325 | model_dir = os.path.join(self.save_dir, 'model') 326 | if not os.path.exists(model_dir): 327 | os.makedirs(model_dir) 328 | if epoch is not None: 329 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + 330 | '_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 331 | % (self.num_channels, self.batch_size, epoch, self.lr)) 332 | else: 333 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + 334 | '_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 335 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr)) 336 | 337 | print('Trained model is saved.') 338 | 339 | def load_model(self): 340 | model_dir = os.path.join(self.save_dir, 'model') 341 | 342 | model_name = model_dir + '/' + self.model_name +\ 343 | '_param_ch%d_batch%d_epoch%d_lr%.g.pkl'\ 344 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr) 345 | if os.path.exists(model_name): 346 | self.model.load_state_dict(torch.load(model_name)) 347 | print('Trained model is loaded.') 348 | return True 349 | else: 350 | print('No model exists to load.') 351 | return False 352 | -------------------------------------------------------------------------------- /espcn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from base_networks import * 6 | from torch.utils.data import DataLoader 7 | from data import get_training_set, get_test_set 8 | import utils 9 | from logger import Logger 10 | from torchvision.transforms import * 11 | 12 | 13 | class Net(torch.nn.Module): 14 | def __init__(self, num_channels, base_filter, scale_factor): 15 | super(Net, self).__init__() 16 | 17 | self.layers = torch.nn.Sequential( 18 | ConvBlock(num_channels, base_filter, 5, 1, 0, activation='relu', norm=None), 19 | ConvBlock(base_filter, base_filter // 2, 3, 1, 0, activation='relu', norm=None), 20 | PSBlock(base_filter // 2, num_channels, scale_factor, 3, 1, 0, activation=None, norm=None) 21 | ) 22 | 23 | def forward(self, x): 24 | out = self.layers(x) 25 | return out 26 | 27 | def weight_init(self): 28 | for m in self.modules(): 29 | utils.weights_init_normal(m) 30 | 31 | 32 | class ESPCN(object): 33 | def __init__(self, args): 34 | # parameters 35 | self.model_name = args.model_name 36 | self.train_dataset = args.train_dataset 37 | self.test_dataset = args.test_dataset 38 | self.crop_size = args.crop_size 39 | self.num_threads = args.num_threads 40 | self.num_channels = args.num_channels 41 | self.scale_factor = args.scale_factor 42 | self.num_epochs = args.num_epochs 43 | self.save_epochs = args.save_epochs 44 | self.batch_size = args.batch_size 45 | self.test_batch_size = args.test_batch_size 46 | self.lr = args.lr 47 | self.data_dir = args.data_dir 48 | self.save_dir = args.save_dir 49 | self.gpu_mode = args.gpu_mode 50 | 51 | def load_dataset(self, dataset='train'): 52 | if self.num_channels == 1: 53 | is_gray = True 54 | else: 55 | is_gray = False 56 | 57 | if dataset == 'train': 58 | print('Loading train datasets...') 59 | train_set = get_training_set(self.data_dir, self.train_dataset, self.crop_size, self.scale_factor, is_gray=is_gray, 60 | normalize=False) 61 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 62 | shuffle=True) 63 | elif dataset == 'test': 64 | print('Loading test datasets...') 65 | test_set = get_test_set(self.data_dir, self.test_dataset, self.scale_factor, is_gray=is_gray, 66 | normalize=False) 67 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 68 | batch_size=self.test_batch_size, 69 | shuffle=False) 70 | 71 | def train(self): 72 | # networks 73 | self.model = Net(num_channels=self.num_channels, base_filter=64, scale_factor=self.scale_factor) 74 | 75 | # weigh initialization 76 | self.model.weight_init() 77 | 78 | # optimizer 79 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 80 | 81 | # loss function 82 | if self.gpu_mode: 83 | self.model.cuda() 84 | self.MSE_loss = nn.MSELoss().cuda() 85 | else: 86 | self.MSE_loss = nn.MSELoss() 87 | 88 | print('---------- Networks architecture -------------') 89 | utils.print_network(self.model) 90 | print('----------------------------------------------') 91 | 92 | # load dataset 93 | train_data_loader = self.load_dataset(dataset='train') 94 | test_data_loader = self.load_dataset(dataset='test') 95 | 96 | # set the logger 97 | log_dir = os.path.join(self.save_dir, 'logs') 98 | if not os.path.exists(log_dir): 99 | os.mkdir(log_dir) 100 | logger = Logger(log_dir) 101 | 102 | ################# Train ################# 103 | print('Training is started.') 104 | avg_loss = [] 105 | step = 0 106 | 107 | # test image 108 | test_input, test_target = test_data_loader.dataset.__getitem__(2) 109 | test_input = test_input.unsqueeze(0) 110 | test_target = test_target.unsqueeze(0) 111 | 112 | self.model.train() 113 | for epoch in range(self.num_epochs): 114 | 115 | epoch_loss = 0 116 | for iter, (input, target) in enumerate(train_data_loader): 117 | # input data (low resolution image) 118 | if self.gpu_mode: 119 | # exclude border pixels from loss computation 120 | x_ = Variable(utils.shave(target, border_size=8*self.scale_factor).cuda()) 121 | y_ = Variable(input.cuda()) 122 | else: 123 | x_ = Variable(utils.shave(target, border_size=8*self.scale_factor)) 124 | y_ = Variable(input) 125 | 126 | # update network 127 | self.optimizer.zero_grad() 128 | recon_image = self.model(y_) 129 | loss = self.MSE_loss(recon_image, x_) 130 | loss.backward() 131 | self.optimizer.step() 132 | 133 | # log 134 | epoch_loss += loss.data[0] 135 | print("Epoch: [%2d] [%4d/%4d] loss: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), loss.data[0])) 136 | 137 | # tensorboard logging 138 | logger.scalar_summary('loss', loss.data[0], step + 1) 139 | step += 1 140 | 141 | # avg. loss per epoch 142 | avg_loss.append(epoch_loss / len(train_data_loader)) 143 | 144 | # prediction 145 | recon_imgs = self.model(Variable(test_input.cuda())) 146 | recon_img = recon_imgs[0].cpu().data 147 | gt_img = utils.shave(test_target[0], border_size=8*self.scale_factor) 148 | lr_img = utils.shave(test_input[0], border_size=8) 149 | bc_img = utils.shave(utils.img_interp(test_input[0], self.scale_factor), border_size=8*self.scale_factor) 150 | 151 | # calculate psnrs 152 | bc_psnr = utils.PSNR(bc_img, gt_img) 153 | recon_psnr = utils.PSNR(recon_img, gt_img) 154 | 155 | # save result images 156 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 157 | psnrs = [None, None, bc_psnr, recon_psnr] 158 | utils.plot_test_result(result_imgs, psnrs, epoch + 1, save_dir=self.save_dir, is_training=True) 159 | 160 | print("Saving training result images at epoch %d" % (epoch + 1)) 161 | 162 | # Save trained parameters of model 163 | if (epoch + 1) % self.save_epochs == 0: 164 | self.save_model(epoch + 1) 165 | 166 | # Plot avg. loss 167 | utils.plot_loss([avg_loss], self.num_epochs, save_dir=self.save_dir) 168 | print("Training is finished.") 169 | 170 | # Save final trained parameters of model 171 | self.save_model(epoch=None) 172 | 173 | def test(self): 174 | # networks 175 | self.model = Net(num_channels=self.num_channels, base_filter=64, scale_factor=self.scale_factor) 176 | 177 | if self.gpu_mode: 178 | self.model.cuda() 179 | 180 | # load model 181 | self.load_model() 182 | 183 | # load dataset 184 | test_data_loader = self.load_dataset(dataset='test') 185 | 186 | # Test 187 | print('Test is started.') 188 | img_num = 0 189 | self.model.eval() 190 | for input, target in test_data_loader: 191 | # input data (low resolution image) 192 | if self.gpu_mode: 193 | y_ = Variable(input.cuda()) 194 | else: 195 | y_ = Variable(input) 196 | 197 | # prediction 198 | recon_imgs = self.model(y_) 199 | for i in range(self.test_batch_size): 200 | img_num += 1 201 | recon_img = recon_imgs[i].cpu().data 202 | gt_img = utils.shave(target[i], border_size=8 * self.scale_factor) 203 | lr_img = utils.shave(input[i], border_size=8) 204 | bc_img = utils.shave(utils.img_interp(input[i], self.scale_factor), border_size=8 * self.scale_factor) 205 | 206 | # calculate psnrs 207 | bc_psnr = utils.PSNR(bc_img, gt_img) 208 | recon_psnr = utils.PSNR(recon_img, gt_img) 209 | 210 | # save result images 211 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 212 | psnrs = [None, None, bc_psnr, recon_psnr] 213 | utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=self.save_dir) 214 | 215 | print("Saving %d test result images..." % img_num) 216 | 217 | def test_single(self, img_fn): 218 | # networks 219 | self.model = Net(num_channels=self.num_channels, base_filter=64, scale_factor=self.scale_factor) 220 | 221 | if self.gpu_mode: 222 | self.model.cuda() 223 | 224 | # load model 225 | self.load_model() 226 | 227 | # load data 228 | img = Image.open(img_fn) 229 | img = img.convert('YCbCr') 230 | y, cb, cr = img.split() 231 | 232 | input = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0]) 233 | if self.gpu_mode: 234 | input = input.cuda() 235 | 236 | self.model.eval() 237 | recon_img = self.model(input) 238 | 239 | # save result images 240 | utils.save_img(recon_img.cpu().data, 1, save_dir=self.save_dir) 241 | 242 | out = recon_img.cpu() 243 | out_img_y = out.data[0] 244 | out_img_y = (((out_img_y - out_img_y.min()) * 255) / (out_img_y.max() - out_img_y.min())).numpy() 245 | # out_img_y *= 255.0 246 | # out_img_y = out_img_y.clip(0, 255) 247 | out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L') 248 | 249 | out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) 250 | out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) 251 | out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB') 252 | 253 | # save img 254 | result_dir = os.path.join(self.save_dir, 'result') 255 | if not os.path.exists(result_dir): 256 | os.mkdir(result_dir) 257 | save_fn = result_dir + '/SR_result.png' 258 | out_img.save(save_fn) 259 | 260 | def save_model(self, epoch=None): 261 | model_dir = os.path.join(self.save_dir, 'model') 262 | if not os.path.exists(model_dir): 263 | os.mkdir(model_dir) 264 | if epoch is not None: 265 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param_epoch_%d.pkl' % epoch) 266 | else: 267 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param.pkl') 268 | 269 | print('Trained model is saved.') 270 | 271 | def load_model(self): 272 | model_dir = os.path.join(self.save_dir, 'model') 273 | 274 | model_name = model_dir + '/' + self.model_name + '_param.pkl' 275 | if os.path.exists(model_name): 276 | self.model.load_state_dict(torch.load(model_name)) 277 | print('Trained model is loaded.') 278 | return True 279 | else: 280 | print('No model exists to load.') 281 | return False 282 | -------------------------------------------------------------------------------- /fsrcnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from base_networks import * 6 | from torch.utils.data import DataLoader 7 | from data import get_training_set, get_test_set 8 | import utils 9 | from logger import Logger 10 | from torchvision.transforms import * 11 | 12 | 13 | class Net(torch.nn.Module): 14 | def __init__(self, num_channels, scale_factor, d, s, m): 15 | super(Net, self).__init__() 16 | 17 | # Feature extraction 18 | self.first_part = ConvBlock(num_channels, d, 5, 1, 0, activation='prelu', norm=None) 19 | 20 | self.layers = [] 21 | # Shrinking 22 | self.layers.append(ConvBlock(d, s, 1, 1, 0, activation='prelu', norm=None)) 23 | # Non-linear Mapping 24 | for _ in range(m): 25 | self.layers.append(ConvBlock(s, s, 3, 1, 1, activation=None, norm=None)) 26 | self.layers.append(nn.PReLU()) 27 | # Expanding 28 | self.layers.append(ConvBlock(s, d, 1, 1, 0, activation='prelu', norm=None)) 29 | 30 | self.mid_part = torch.nn.Sequential(*self.layers) 31 | 32 | # Deconvolution 33 | self.last_part = nn.ConvTranspose2d(d, num_channels, 9, scale_factor, 3, output_padding=1) 34 | # self.last_part = torch.nn.Sequential( 35 | # Upsample2xBlock(d, d, upsample='rnc', activation=None, norm=None), 36 | # Upsample2xBlock(d, num_channels, upsample='rnc', activation=None, norm=None) 37 | # ) 38 | 39 | def forward(self, x): 40 | out = self.first_part(x) 41 | out = self.mid_part(out) 42 | out = self.last_part(out) 43 | return out 44 | 45 | def weight_init(self, mean=0.0, std=0.02): 46 | for m in self.modules(): 47 | # utils.weights_init_normal(m, mean=mean, std=std) 48 | if isinstance(m, nn.Conv2d): 49 | m.weight.data.normal_(mean, std) 50 | if m.bias is not None: 51 | m.bias.data.zero_() 52 | if isinstance(m, nn.ConvTranspose2d): 53 | m.weight.data.normal_(0.0, 0.0001) 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | 57 | 58 | class FSRCNN(object): 59 | def __init__(self, args): 60 | # parameters 61 | self.model_name = args.model_name 62 | self.train_dataset = args.train_dataset 63 | self.test_dataset = args.test_dataset 64 | self.crop_size = args.crop_size 65 | self.num_threads = args.num_threads 66 | self.num_channels = args.num_channels 67 | self.scale_factor = args.scale_factor 68 | self.num_epochs = args.num_epochs 69 | self.save_epochs = args.save_epochs 70 | self.batch_size = args.batch_size 71 | self.test_batch_size = args.test_batch_size 72 | self.lr = args.lr 73 | self.data_dir = args.data_dir 74 | self.save_dir = args.save_dir 75 | self.gpu_mode = args.gpu_mode 76 | 77 | def load_dataset(self, dataset='train'): 78 | if self.num_channels == 1: 79 | is_gray = True 80 | else: 81 | is_gray = False 82 | 83 | if dataset == 'train': 84 | print('Loading train datasets...') 85 | train_set = get_training_set(self.data_dir, self.train_dataset, self.crop_size, self.scale_factor, is_gray=is_gray, 86 | normalize=False) 87 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 88 | shuffle=True) 89 | elif dataset == 'test': 90 | print('Loading test datasets...') 91 | test_set = get_test_set(self.data_dir, self.test_dataset, self.scale_factor, is_gray=is_gray, 92 | normalize=False) 93 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 94 | batch_size=self.test_batch_size, 95 | shuffle=False) 96 | 97 | def train(self): 98 | # networks 99 | self.model = Net(num_channels=self.num_channels, scale_factor=self.scale_factor, d=56, s=12, m=4) 100 | 101 | # weigh initialization 102 | self.model.weight_init(mean=0.0, std=0.02) 103 | 104 | # optimizer 105 | self.momentum = 0.9 106 | self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum) 107 | 108 | # loss function 109 | if self.gpu_mode: 110 | self.model.cuda() 111 | self.MSE_loss = nn.MSELoss().cuda() 112 | else: 113 | self.MSE_loss = nn.MSELoss() 114 | 115 | print('---------- Networks architecture -------------') 116 | utils.print_network(self.model) 117 | print('----------------------------------------------') 118 | 119 | # load dataset 120 | train_data_loader = self.load_dataset(dataset='train') 121 | test_data_loader = self.load_dataset(dataset='test') 122 | 123 | # set the logger 124 | log_dir = os.path.join(self.save_dir, 'logs') 125 | if not os.path.exists(log_dir): 126 | os.mkdir(log_dir) 127 | logger = Logger(log_dir) 128 | 129 | ################# Train ################# 130 | print('Training is started.') 131 | avg_loss = [] 132 | step = 0 133 | 134 | # test image 135 | test_input, test_target = test_data_loader.dataset.__getitem__(2) 136 | test_input = test_input.unsqueeze(0) 137 | test_target = test_target.unsqueeze(0) 138 | 139 | self.model.train() 140 | for epoch in range(self.num_epochs): 141 | 142 | epoch_loss = 0 143 | for iter, (input, target) in enumerate(train_data_loader): 144 | # input data (low resolution image) 145 | if self.gpu_mode: 146 | x_ = Variable(utils.shave(target, border_size=2*self.scale_factor).cuda()) 147 | y_ = Variable(input.cuda()) 148 | else: 149 | x_ = Variable(utils.shave(target, border_size=2*self.scale_factor)) 150 | y_ = Variable(input) 151 | 152 | # update network 153 | self.optimizer.zero_grad() 154 | recon_image = self.model(y_) 155 | loss = self.MSE_loss(recon_image, x_) 156 | loss.backward() 157 | self.optimizer.step() 158 | 159 | # log 160 | epoch_loss += loss.data[0] 161 | print("Epoch: [%2d] [%4d/%4d] loss: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), loss.data[0])) 162 | 163 | # tensorboard logging 164 | logger.scalar_summary('loss', loss.data[0], step + 1) 165 | step += 1 166 | 167 | # avg. loss per epoch 168 | avg_loss.append(epoch_loss / len(train_data_loader)) 169 | 170 | # prediction 171 | recon_imgs = self.model(Variable(test_input.cuda())) 172 | recon_img = recon_imgs[0].cpu().data 173 | gt_img = utils.shave(test_target[0], border_size=2 * self.scale_factor) 174 | lr_img = utils.shave(test_input[0], border_size=2) 175 | bc_img = utils.shave(utils.img_interp(test_input[0], self.scale_factor), border_size=2 * self.scale_factor) 176 | 177 | # calculate psnrs 178 | bc_psnr = utils.PSNR(bc_img, gt_img) 179 | recon_psnr = utils.PSNR(recon_img, gt_img) 180 | 181 | # save result images 182 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 183 | psnrs = [None, None, bc_psnr, recon_psnr] 184 | utils.plot_test_result(result_imgs, psnrs, epoch + 1, save_dir=self.save_dir, is_training=True) 185 | 186 | print("Saving training result images at epoch %d" % (epoch + 1)) 187 | 188 | # Save trained parameters of model 189 | if (epoch + 1) % self.save_epochs == 0: 190 | self.save_model(epoch + 1) 191 | 192 | # Plot avg. loss 193 | utils.plot_loss([avg_loss], self.num_epochs, save_dir=self.save_dir) 194 | print("Training is finished.") 195 | 196 | # Save final trained parameters of model 197 | self.save_model(epoch=None) 198 | 199 | def test(self): 200 | # networks 201 | self.model = Net(num_channels=self.num_channels, scale_factor=self.scale_factor, d=56, s=12, m=4) 202 | 203 | if self.gpu_mode: 204 | self.model.cuda() 205 | 206 | # load model 207 | self.load_model() 208 | 209 | # load dataset 210 | test_data_loader = self.load_dataset(dataset='test') 211 | 212 | # Test 213 | print('Test is started.') 214 | img_num = 0 215 | self.model.eval() 216 | for input, target in test_data_loader: 217 | # input data (low resolution image) 218 | if self.gpu_mode: 219 | y_ = Variable(input.cuda()) 220 | else: 221 | y_ = Variable(input) 222 | 223 | # prediction 224 | recon_imgs = self.model(y_) 225 | for i, recon_img in enumerate(recon_imgs): 226 | img_num += 1 227 | recon_img = recon_imgs[i].cpu().data 228 | gt_img = utils.shave(target[i], border_size=2 * self.scale_factor) 229 | lr_img = utils.shave(input[i], border_size=2) 230 | bc_img = utils.shave(utils.img_interp(input[i], self.scale_factor), border_size=2 * self.scale_factor) 231 | 232 | # calculate psnrs 233 | bc_psnr = utils.PSNR(bc_img, gt_img) 234 | recon_psnr = utils.PSNR(recon_img, gt_img) 235 | 236 | # save result images 237 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 238 | psnrs = [None, None, bc_psnr, recon_psnr] 239 | utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=self.save_dir) 240 | 241 | print("Saving %d test result images..." % img_num) 242 | 243 | def test_single(self, img_fn): 244 | # networks 245 | self.model = Net(num_channels=self.num_channels, scale_factor=self.scale_factor, d=56, s=12, m=4) 246 | 247 | if self.gpu_mode: 248 | self.model.cuda() 249 | 250 | # load model 251 | self.load_model() 252 | 253 | # load data 254 | img = Image.open(img_fn) 255 | img = img.convert('YCbCr') 256 | y, cb, cr = img.split() 257 | 258 | input = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0]) 259 | if self.gpu_mode: 260 | input = input.cuda() 261 | 262 | self.model.eval() 263 | recon_img = self.model(input) 264 | 265 | # save result images 266 | utils.save_img(recon_img.cpu().data, 1, save_dir=self.save_dir) 267 | 268 | out = recon_img.cpu() 269 | out_img_y = out.data[0] 270 | out_img_y = (((out_img_y - out_img_y.min()) * 255) / (out_img_y.max() - out_img_y.min())).numpy() 271 | # out_img_y *= 255.0 272 | # out_img_y = out_img_y.clip(0, 255) 273 | out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L') 274 | 275 | out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) 276 | out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) 277 | out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB') 278 | 279 | # save img 280 | result_dir = os.path.join(self.save_dir, 'result') 281 | if not os.path.exists(result_dir): 282 | os.mkdir(result_dir) 283 | save_fn = result_dir + '/SR_result.png' 284 | out_img.save(save_fn) 285 | 286 | def save_model(self, epoch=None): 287 | model_dir = os.path.join(self.save_dir, 'model') 288 | if not os.path.exists(model_dir): 289 | os.mkdir(model_dir) 290 | if epoch is not None: 291 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param_epoch_%d.pkl' % epoch) 292 | else: 293 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param.pkl') 294 | 295 | print('Trained model is saved.') 296 | 297 | def load_model(self): 298 | model_dir = os.path.join(self.save_dir, 'model') 299 | 300 | model_name = model_dir + '/' + self.model_name + '_param.pkl' 301 | if os.path.exists(model_name): 302 | self.model.load_state_dict(torch.load(model_name)) 303 | print('Trained generator model is loaded.') 304 | return True 305 | else: 306 | print('No model exists to load.') 307 | return False 308 | -------------------------------------------------------------------------------- /lapsrn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from base_networks import * 6 | from torch.utils.data import DataLoader 7 | from torchvision.transforms import * 8 | from data import get_training_set, get_test_set 9 | import utils 10 | from logger import Logger 11 | from torchvision.transforms import * 12 | 13 | 14 | def get_upsample_filter(size): 15 | """Make a 2D bilinear kernel suitable for upsampling""" 16 | factor = (size + 1) // 2 17 | if size % 2 == 1: 18 | center = factor - 1 19 | else: 20 | center = factor - 0.5 21 | og = np.ogrid[:size, :size] 22 | filter = (1 - abs(og[0] - center) / factor) * \ 23 | (1 - abs(og[1] - center) / factor) 24 | return torch.from_numpy(filter).float() 25 | 26 | 27 | class Net(torch.nn.Module): 28 | def __init__(self, num_channels, base_filter, num_convs): 29 | super(Net, self).__init__() 30 | 31 | self.input_conv = ConvBlock(num_channels, base_filter, 3, 1, 1, activation='lrelu', norm=None, bias=False) 32 | 33 | conv_blocks = [] 34 | for _ in range(num_convs): 35 | conv_blocks.append(ConvBlock(base_filter, base_filter, 3, 1, 1, activation='lrelu', norm=None, bias=False)) 36 | conv_blocks.append(DeconvBlock(base_filter, base_filter, 4, 2, 1, activation='lrelu', norm=None, bias=False)) 37 | 38 | self.convt_I1 = DeconvBlock(num_channels, num_channels, 4, 2, 1, activation=None, norm=None, bias=False) 39 | self.convt_R1 = ConvBlock(base_filter, num_channels, 3, 1, 1, activation=None, norm=None, bias=False) 40 | self.convt_F1 = nn.Sequential(*conv_blocks) 41 | 42 | self.convt_I2 = DeconvBlock(num_channels, num_channels, 4, 2, 1, activation=None, norm=None, bias=False) 43 | self.convt_R2 = ConvBlock(base_filter, num_channels, 3, 1, 1, activation=None, norm=None, bias=False) 44 | self.convt_F2 = nn.Sequential(*conv_blocks) 45 | 46 | def weight_init(self): 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 50 | m.weight.data.normal_(0, math.sqrt(2. / n)) 51 | if m.bias is not None: 52 | m.bias.data.zero_() 53 | if isinstance(m, nn.ConvTranspose2d): 54 | c1, c2, h, w = m.weight.data.size() 55 | weight = get_upsample_filter(h) 56 | m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1) 57 | if m.bias is not None: 58 | m.bias.data.zero_() 59 | 60 | def forward(self, x): 61 | out = self.input_conv(x) 62 | convt_F1 = self.convt_F1(out) 63 | convt_I1 = self.convt_I1(x) 64 | convt_R1 = self.convt_R1(convt_F1) 65 | x_coarse_ = convt_I1 + convt_R1 66 | 67 | convt_F2 = self.convt_F2(convt_F1) 68 | convt_I2 = self.convt_I2(x_coarse_) 69 | convt_R2 = self.convt_R2(convt_F2) 70 | x_finer_ = convt_I2 + convt_R2 71 | 72 | return x_coarse_, x_finer_ 73 | 74 | 75 | class L1_Charbonnier_loss(torch.nn.Module): 76 | """L1 Charbonnierloss.""" 77 | def __init__(self): 78 | super(L1_Charbonnier_loss, self).__init__() 79 | self.eps = 1e-6 80 | 81 | def forward(self, x, y): 82 | diff = torch.add(x, -y) 83 | error = torch.sqrt(diff * diff + self.eps) 84 | loss = torch.mean(error) 85 | return loss 86 | 87 | 88 | class LapSRN(object): 89 | def __init__(self, args): 90 | # parameters 91 | self.model_name = args.model_name 92 | self.train_dataset = args.train_dataset 93 | self.test_dataset = args.test_dataset 94 | self.crop_size = args.crop_size 95 | self.num_threads = args.num_threads 96 | self.num_channels = args.num_channels 97 | self.scale_factor = args.scale_factor 98 | self.num_epochs = args.num_epochs 99 | self.save_epochs = args.save_epochs 100 | self.batch_size = args.batch_size 101 | self.test_batch_size = args.test_batch_size 102 | self.lr = args.lr 103 | self.data_dir = args.data_dir 104 | self.save_dir = args.save_dir 105 | self.gpu_mode = args.gpu_mode 106 | 107 | def load_dataset(self, dataset='train'): 108 | if self.num_channels == 1: 109 | is_gray = True 110 | else: 111 | is_gray = False 112 | 113 | if dataset == 'train': 114 | print('Loading train datasets...') 115 | train_set = get_training_set(self.data_dir, self.train_dataset, self.crop_size, self.scale_factor, is_gray=is_gray, 116 | normalize=False) 117 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 118 | shuffle=True) 119 | elif dataset == 'test': 120 | print('Loading test datasets...') 121 | test_set = get_test_set(self.data_dir, self.test_dataset, self.scale_factor, is_gray=is_gray, 122 | normalize=False) 123 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 124 | batch_size=self.test_batch_size, 125 | shuffle=False) 126 | 127 | def train(self): 128 | # networks 129 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_convs=10) 130 | 131 | # weigh initialization 132 | self.model.weight_init() 133 | 134 | # optimizer 135 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) 136 | 137 | # loss function 138 | if self.gpu_mode: 139 | self.model.cuda() 140 | self.loss = L1_Charbonnier_loss().cuda() 141 | # self.loss = nn.L1Loss().cuda() 142 | else: 143 | self.loss = L1_Charbonnier_loss() 144 | 145 | print('---------- Networks architecture -------------') 146 | utils.print_network(self.model) 147 | print('----------------------------------------------') 148 | 149 | # load dataset 150 | train_data_loader = self.load_dataset(dataset='train') 151 | test_data_loader = self.load_dataset(dataset='test') 152 | 153 | # set the logger 154 | log_dir = os.path.join(self.save_dir, 'logs') 155 | if not os.path.exists(log_dir): 156 | os.mkdir(log_dir) 157 | logger = Logger(log_dir) 158 | 159 | ################# Train ################# 160 | print('Training is started.') 161 | avg_loss = [] 162 | step = 0 163 | 164 | # test image 165 | test_input, test_target = test_data_loader.dataset.__getitem__(2) 166 | test_input = test_input.unsqueeze(0) 167 | test_target = test_target.unsqueeze(0) 168 | 169 | self.model.train() 170 | for epoch in range(self.num_epochs): 171 | 172 | # learning rate is decayed by a factor of 10 every 10 epochs 173 | if (epoch+1) % 100 == 0: 174 | for param_group in self.optimizer.param_groups: 175 | param_group["lr"] /= 10.0 176 | print("Learning rate decay: lr={}".format(self.optimizer.param_groups[0]["lr"])) 177 | 178 | epoch_loss = 0 179 | for iter, (input, target) in enumerate(train_data_loader): 180 | # input data (low resolution image) 181 | if self.gpu_mode: 182 | x_finer_ = Variable(target.cuda()) 183 | x_coarse_ = Variable(utils.img_interp(target, 1/self.scale_factor*2).cuda()) 184 | y_ = Variable(input.cuda()) 185 | else: 186 | x_finer_ = Variable(target) 187 | x_coarse_ = Variable(utils.img_interp(target, 1/self.scale_factor*2)) 188 | y_ = Variable(input) 189 | 190 | # update network 191 | self.optimizer.zero_grad() 192 | recon_coarse_, recon_finer_ = self.model(y_) 193 | loss_coarse = self.loss(recon_coarse_, x_coarse_) 194 | loss_finer = self.loss(recon_finer_, x_finer_) 195 | 196 | loss = loss_coarse + loss_finer 197 | loss_coarse.backward(retain_variables=True) 198 | loss_finer.backward() 199 | self.optimizer.step() 200 | 201 | # log 202 | epoch_loss += loss.data[0] 203 | print("Epoch: [%2d] [%4d/%4d] loss: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), loss.data[0])) 204 | 205 | # tensorboard logging 206 | logger.scalar_summary('loss', loss.data[0], step + 1) 207 | step += 1 208 | 209 | # avg. loss per epoch 210 | avg_loss.append(epoch_loss / len(train_data_loader)) 211 | 212 | # prediction 213 | _, recon_imgs = self.model(Variable(test_input.cuda())) 214 | recon_img = recon_imgs[0].cpu().data 215 | gt_img = test_target[0] 216 | lr_img = test_input[0] 217 | bc_img = utils.img_interp(test_input[0], self.scale_factor) 218 | 219 | # calculate psnrs 220 | bc_psnr = utils.PSNR(bc_img, gt_img) 221 | recon_psnr = utils.PSNR(recon_img, gt_img) 222 | 223 | # save result images 224 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 225 | psnrs = [None, None, bc_psnr, recon_psnr] 226 | utils.plot_test_result(result_imgs, psnrs, epoch + 1, save_dir=self.save_dir, is_training=True) 227 | 228 | print("Saving training result images at epoch %d" % (epoch + 1)) 229 | 230 | # Save trained parameters of model 231 | if (epoch + 1) % self.save_epochs == 0: 232 | self.save_model(epoch + 1) 233 | 234 | # Plot avg. loss 235 | utils.plot_loss([avg_loss], self.num_epochs, save_dir=self.save_dir) 236 | print("Training is finished.") 237 | 238 | # Save final trained parameters of model 239 | self.save_model(epoch=None) 240 | 241 | def test(self): 242 | # networks 243 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_convs=10) 244 | 245 | if self.gpu_mode: 246 | self.model.cuda() 247 | 248 | # load model 249 | self.load_model() 250 | 251 | # load dataset 252 | test_data_loader = self.load_dataset(dataset='test') 253 | 254 | # Test 255 | print('Test is started.') 256 | img_num = 0 257 | self.model.eval() 258 | for input, target in test_data_loader: 259 | # input data (low resolution image) 260 | if self.gpu_mode: 261 | y_ = Variable(input.cuda()) 262 | else: 263 | y_ = Variable(input) 264 | 265 | # prediction 266 | _, recon_imgs = self.model(y_) 267 | for i, recon_img in enumerate(recon_imgs): 268 | img_num += 1 269 | recon_img = recon_imgs[i].cpu().data 270 | gt_img = target[i] 271 | lr_img = input[i] 272 | bc_img = utils.img_interp(input[i], self.scale_factor) 273 | 274 | # calculate psnrs 275 | bc_psnr = utils.PSNR(bc_img, gt_img) 276 | recon_psnr = utils.PSNR(recon_img, gt_img) 277 | 278 | # save result images 279 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 280 | psnrs = [None, None, bc_psnr, recon_psnr] 281 | utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=self.save_dir) 282 | 283 | print("Saving %d test result images..." % img_num) 284 | 285 | def test_single(self, img_fn): 286 | # networks 287 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_convs=10) 288 | 289 | if self.gpu_mode: 290 | self.model.cuda() 291 | 292 | # load model 293 | self.load_model() 294 | 295 | # load data 296 | img = Image.open(img_fn) 297 | img = img.convert('YCbCr') 298 | y, cb, cr = img.split() 299 | 300 | input = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0]) 301 | if self.gpu_mode: 302 | input = input.cuda() 303 | 304 | self.model.eval() 305 | recon_img = self.model(input) 306 | 307 | # save result images 308 | utils.save_img(recon_img.cpu().data, 1, save_dir=self.save_dir) 309 | 310 | out = recon_img.cpu() 311 | out_img_y = out.data[0] 312 | out_img_y = (((out_img_y - out_img_y.min()) * 255) / (out_img_y.max() - out_img_y.min())).numpy() 313 | # out_img_y *= 255.0 314 | # out_img_y = out_img_y.clip(0, 255) 315 | out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L') 316 | 317 | out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) 318 | out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) 319 | out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB') 320 | 321 | # save img 322 | result_dir = os.path.join(self.save_dir, 'result') 323 | if not os.path.exists(result_dir): 324 | os.mkdir(result_dir) 325 | save_fn = result_dir + '/SR_result.png' 326 | out_img.save(save_fn) 327 | 328 | def save_model(self, epoch=None): 329 | model_dir = os.path.join(self.save_dir, 'model') 330 | if not os.path.exists(model_dir): 331 | os.mkdir(model_dir) 332 | if epoch is not None: 333 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param_epoch_%d.pkl' % epoch) 334 | else: 335 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param.pkl') 336 | 337 | print('Trained model is saved.') 338 | 339 | def load_model(self): 340 | model_dir = os.path.join(self.save_dir, 'model') 341 | 342 | model_name = model_dir + '/' + self.model_name + '_param.pkl' 343 | if os.path.exists(model_name): 344 | self.model.load_state_dict(torch.load(model_name)) 345 | print('Trained model is loaded.') 346 | return True 347 | else: 348 | print('No model exists to load.') 349 | return False 350 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | # import scipy.misc 5 | import matplotlib.pyplot as plt 6 | 7 | try: 8 | from StringIO import StringIO # Python 2.7 9 | except ImportError: 10 | from io import BytesIO # Python 3.x 11 | 12 | 13 | class Logger(object): 14 | def __init__(self, log_dir): 15 | """Create a summary writer logging to log_dir.""" 16 | self.writer = tf.summary.FileWriter(log_dir) 17 | 18 | def scalar_summary(self, tag, value, step): 19 | """Log a scalar variable.""" 20 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 21 | self.writer.add_summary(summary, step) 22 | self.writer.flush() 23 | 24 | def image_summary(self, tag, images, step): 25 | """Log a list of images.""" 26 | 27 | img_summaries = [] 28 | for i, img in enumerate(images): 29 | # Write the image to a string 30 | try: 31 | s = StringIO() 32 | except: 33 | s = BytesIO() 34 | # scipy.misc.toimage(img).save(s, format="png") 35 | plt.imsave(s, img, format='png') 36 | 37 | # Create an Image object 38 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 39 | height=img.shape[0], 40 | width=img.shape[1]) 41 | # Create a Summary value 42 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 43 | 44 | # Create and write Summary 45 | summary = tf.Summary(value=img_summaries) 46 | self.writer.add_summary(summary, step) 47 | self.writer.flush() 48 | 49 | def histo_summary(self, tag, values, step, bins=1000): 50 | """Log a histogram of the tensor of values.""" 51 | 52 | # Create a histogram using numpy 53 | counts, bin_edges = np.histogram(values, bins=bins) 54 | 55 | # Fill the fields of the histogram proto 56 | hist = tf.HistogramProto() 57 | hist.min = float(np.min(values)) 58 | hist.max = float(np.max(values)) 59 | hist.num = int(np.prod(values.shape)) 60 | hist.sum = float(np.sum(values)) 61 | hist.sum_squares = float(np.sum(values ** 2)) 62 | 63 | # Drop the start of the first bin 64 | bin_edges = bin_edges[1:] 65 | 66 | # Add bin edges and counts 67 | for edge in bin_edges: 68 | hist.bucket_limit.append(edge) 69 | for c in counts: 70 | hist.bucket.append(c) 71 | 72 | # Create and write Summary 73 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 74 | self.writer.add_summary(summary, step) 75 | self.writer.flush() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, argparse 3 | from srcnn import SRCNN 4 | from vdsr import VDSR 5 | from fsrcnn import FSRCNN 6 | from srgan import SRGAN 7 | from drcn import DRCN 8 | from espcn import ESPCN 9 | from edsr import EDSR 10 | from lapsrn import LapSRN 11 | 12 | """parsing and configuration""" 13 | def parse_args(): 14 | desc = "PyTorch implementation of SR collections" 15 | parser = argparse.ArgumentParser(description=desc) 16 | parser.add_argument('--model_name', type=str, default='SRGAN', 17 | choices=['SRCNN', 'VDSR', 'DRCN', 'ESPCN', 'FastNeuralStyle', 'FSRCNN', 'SRGAN', 'LapSRN', 18 | 'EnhanceNet', 'EDSR', 'EnhanceGAN'], help='The type of model') 19 | parser.add_argument('--data_dir', type=str, default='../Data') 20 | parser.add_argument('--train_dataset', type=list, default=['DIV2K'], choices=['bsds300', 'General100', 'T91'], 21 | help='The name of training dataset') 22 | parser.add_argument('--test_dataset', type=list, default=['Set5', 'Set14', 'Urban100'], choices=['Set5', 'Set14', 'Urban100'], 23 | help='The name of test dataset') 24 | parser.add_argument('--crop_size', type=int, default=128, help='Size of cropped HR image') 25 | parser.add_argument('--num_threads', type=int, default=4, help='number of threads for data loader to use') 26 | parser.add_argument('--num_channels', type=int, default=3, help='The number of channels to super-resolve') 27 | parser.add_argument('--scale_factor', type=int, default=4, help='Size of scale factor') 28 | parser.add_argument('--num_epochs', type=int, default=100, help='The number of epochs to run') 29 | parser.add_argument('--save_epochs', type=int, default=10, help='Save trained model every this epochs') 30 | parser.add_argument('--batch_size', type=int, default=16, help='training batch size') 31 | parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size') 32 | parser.add_argument('--save_dir', type=str, default='Result_DIV2K', help='Directory name to save the results') 33 | parser.add_argument('--lr', type=float, default=0.00001) 34 | parser.add_argument('--gpu_mode', type=bool, default=True) 35 | 36 | return check_args(parser.parse_args()) 37 | 38 | """checking arguments""" 39 | def check_args(args): 40 | # --save_dir 41 | args.save_dir = os.path.join(args.save_dir, args.model_name) 42 | if not os.path.exists(args.save_dir): 43 | os.makedirs(args.save_dir) 44 | 45 | # --epoch 46 | try: 47 | assert args.num_epochs >= 1 48 | except: 49 | print('number of epochs must be larger than or equal to one') 50 | 51 | # --batch_size 52 | try: 53 | assert args.batch_size >= 1 54 | except: 55 | print('batch size must be larger than or equal to one') 56 | 57 | return args 58 | 59 | """main""" 60 | def main(): 61 | # parse arguments 62 | args = parse_args() 63 | if args is None: 64 | exit() 65 | 66 | if args.gpu_mode and not torch.cuda.is_available(): 67 | raise Exception("No GPU found, please run without --gpu_mode=False") 68 | 69 | # model 70 | if args.model_name == 'SRCNN': 71 | net = SRCNN(args) 72 | elif args.model_name == 'VDSR': 73 | net = VDSR(args) 74 | elif args.model_name == 'DRCN': 75 | net = DRCN(args) 76 | elif args.model_name == 'ESPCN': 77 | net = ESPCN(args) 78 | # elif args.model_name == 'FastNeuralStyle': 79 | # net = FastNeuralStyle(args) 80 | elif args.model_name == 'FSRCNN': 81 | net = FSRCNN(args) 82 | elif args.model_name == 'SRGAN': 83 | net = SRGAN(args) 84 | elif args.model_name == 'LapSRN': 85 | net = LapSRN(args) 86 | # elif args.model_name == 'EnhanceNet': 87 | # net = EnhanceNet(args) 88 | elif args.model_name == 'EDSR': 89 | net = EDSR(args) 90 | # elif args.model_name == 'EnhanceGAN': 91 | # net = EnhanceGAN(args) 92 | else: 93 | raise Exception("[!] There is no option for " + args.model_name) 94 | 95 | # train 96 | net.train() 97 | 98 | # test 99 | net.test() 100 | # net.test_single('getchu_full.jpg') 101 | 102 | if __name__ == '__main__': 103 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # 1. SRCNN (Image Super-Resolution Using Deep Convolutional Networks) - http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.html, https://github.com/tegg89/SRCNN-Tensorflow, https://github.com/nagadomi/waifu2x 2 | # 2. VDSR (Accurate Image Super-Resolution Using Very Deep Convolutional Networks) - https://github.com/twtygqyy/pytorch-vdsr, https://github.com/pytorch/examples/tree/master/super_resolution, https://github.com/Jongchan/tensorflow-vdsr 3 | # 3. DRCN (Deeply-Recursive Convolutional Network For Image Super-Resolution) - https://github.com/jiny2001/deeply-recursive-cnn-tf 4 | # 4. ESPCN (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network) - https://github.com/Tetrachrome/subpixel, https://github.com/rampage644/super-resolution 5 | # 5. FastNeuralStyle (Perceptual Losses for Real-Time Style Transfer and Super-Resolution) - https://github.com/jcjohnson/fast-neural-style, https://github.com/bengxy/FastNeuralStyle, https://github.com/abhiskk/fast-neural-style, https://github.com/ceshine/fast-neural-style, https://github.com/vishal1796/pytorch-fast-neural-style, https://github.com/bguisard/SuperResolution 6 | # 6. FSRCNN (Accelerating the Super-Resolution Convolutional Neural Network) - http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html, https://github.com/drakelevy/FSRCNN-TensorFlow 7 | # 7. SRResNet (SRGAN) (Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network) - https://github.com/twtygqyy/pytorch-SRResNet, https://github.com/buriburisuri/SRGAN, https://github.com/junhocho/SRGAN, https://github.com/zsdonghao/SRGAN, https://github.com/tadax/srgan 8 | # 8. LapSRN (Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution) - https://github.com/twtygqyy/pytorch-LapSRN, https://github.com/phoenix104104/LapSRN 9 | # 9. EnhanceNet (EnhanceNet: Single Image Super-Resolution Through Automated Texture Synthesis) - https://github.com/alexjc/neural-enhance 10 | # 10. EDSR (Enhanced Deep Residual Networks for Single Image Super-Resolution) - https://github.com/LimBee/NTIRE2017, https://github.com/twtygqyy/pytorch-edsr, https://github.com/jmiller656/EDSR-Tensorflow 11 | # 11. EnhanceGAN (Pixel Recursive Super Resolution) - https://github.com/nilboy/pixel-recursive-super-resolution 12 | # 13 | # 14 | # 12. CSCN (Deep Networks for Image Super-Resolution with Sparse Prior) - https://github.com/huangzehao/SCN_Matlab 15 | # 13. ShCNN (Shepard Convolutional Neural Networks) - https://github.com/jimmy-ren/vcnn_double-bladed/tree/master/applications/Shepard_CNN 16 | # 14. AffGAN (Amortised MAP Inference for Image Super-resolution) - 17 | # 15. DRRN (Image Super-Resolution via Deep Recursive Residual Network) - https://github.com/tyshiwo/DRRN_CVPR17 18 | # 16. IRCNN (Learning Deep CNN Denoiser Prior for Image Restoration) - https://github.com/cszn/ircnn 19 | # 17. DCSCN (Fast and Accurate Image Super Resolution by Deep CNN with Skip Connection and Network in Network) - https://github.com/jiny2001/dcscn-super-resolution, http://cv.snu.ac.kr/research/DRCN/ 20 | # 18. Lab402 (Beyond Deep Residual Learning for Image Restoration : Persistent Homology-Guided Manifold Simplification) - https://github.com/iorism/CNN -------------------------------------------------------------------------------- /srcnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from base_networks import * 6 | from torch.utils.data import DataLoader 7 | from data import get_training_set, get_test_set 8 | import utils 9 | from logger import Logger 10 | from torchvision.transforms import * 11 | 12 | 13 | class Net(torch.nn.Module): 14 | def __init__(self, num_channels, base_filter): 15 | super(Net, self).__init__() 16 | 17 | self.layers = torch.nn.Sequential( 18 | ConvBlock(num_channels, base_filter, 9, 1, 0, norm=None), 19 | ConvBlock(base_filter, base_filter // 2, 5, 1, 0, norm=None), 20 | ConvBlock(base_filter // 2, num_channels, 5, 1, 0, activation=None, norm=None) 21 | ) 22 | 23 | def forward(self, x): 24 | out = self.layers(x) 25 | return out 26 | 27 | def weight_init(self, mean=0.0, std=0.001): 28 | for m in self.modules(): 29 | utils.weights_init_normal(m, mean=mean, std=std) 30 | 31 | 32 | class SRCNN(object): 33 | def __init__(self, args): 34 | # parameters 35 | self.model_name = args.model_name 36 | self.train_dataset = args.train_dataset 37 | self.test_dataset = args.test_dataset 38 | self.crop_size = args.crop_size 39 | self.num_threads = args.num_threads 40 | self.num_channels = args.num_channels 41 | self.scale_factor = args.scale_factor 42 | self.num_epochs = args.num_epochs 43 | self.save_epochs = args.save_epochs 44 | self.batch_size = args.batch_size 45 | self.test_batch_size = args.test_batch_size 46 | self.lr = args.lr 47 | self.data_dir = args.data_dir 48 | self.save_dir = args.save_dir 49 | self.gpu_mode = args.gpu_mode 50 | 51 | def load_dataset(self, dataset='train'): 52 | if self.num_channels == 1: 53 | is_gray = True 54 | else: 55 | is_gray = False 56 | 57 | if dataset == 'train': 58 | print('Loading train datasets...') 59 | train_set = get_training_set(self.data_dir, self.train_dataset, self.crop_size, self.scale_factor, is_gray=is_gray, 60 | normalize=False) 61 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 62 | shuffle=True) 63 | elif dataset == 'test': 64 | print('Loading test datasets...') 65 | test_set = get_test_set(self.data_dir, self.test_dataset, self.scale_factor, is_gray=is_gray, 66 | normalize=False) 67 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 68 | batch_size=self.test_batch_size, 69 | shuffle=False) 70 | 71 | def train(self): 72 | # networks 73 | self.model = Net(num_channels=self.num_channels, base_filter=64) 74 | 75 | # weigh initialization 76 | self.model.weight_init(mean=0.0, std=0.001) 77 | 78 | # optimizer 79 | self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr) 80 | 81 | # loss function 82 | if self.gpu_mode: 83 | self.model.cuda() 84 | self.MSE_loss = nn.MSELoss().cuda() 85 | else: 86 | self.MSE_loss = nn.MSELoss() 87 | 88 | print('---------- Networks architecture -------------') 89 | utils.print_network(self.model) 90 | print('----------------------------------------------') 91 | 92 | # load dataset 93 | train_data_loader = self.load_dataset(dataset='train') 94 | test_data_loader = self.load_dataset(dataset='test') 95 | 96 | # set the logger 97 | log_dir = os.path.join(self.save_dir, 'logs') 98 | if not os.path.exists(log_dir): 99 | os.mkdir(log_dir) 100 | logger = Logger(log_dir) 101 | 102 | ################# Train ################# 103 | print('Training is started.') 104 | avg_loss = [] 105 | step = 0 106 | 107 | # test image 108 | test_input, test_target = test_data_loader.dataset.__getitem__(2) 109 | test_input = test_input.unsqueeze(0) 110 | test_target = test_target.unsqueeze(0) 111 | 112 | self.model.train() 113 | for epoch in range(self.num_epochs): 114 | 115 | epoch_loss = 0 116 | for iter, (input, target) in enumerate(train_data_loader): 117 | # input data (bicubic interpolated image) 118 | if self.gpu_mode: 119 | # exclude border pixels from loss computation 120 | x_ = Variable(utils.shave(target, border_size=8).cuda()) 121 | y_ = Variable(utils.img_interp(input, self.scale_factor).cuda()) 122 | else: 123 | x_ = Variable(utils.shave(target, border_size=8)) 124 | y_ = Variable(utils.img_interp(input, self.scale_factor)) 125 | 126 | # update network 127 | self.optimizer.zero_grad() 128 | recon_image = self.model(y_) 129 | loss = self.MSE_loss(recon_image, x_) 130 | loss.backward() 131 | self.optimizer.step() 132 | 133 | # log 134 | epoch_loss += loss.data[0] 135 | print("Epoch: [%2d] [%4d/%4d] loss: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), loss.data[0])) 136 | 137 | # tensorboard logging 138 | logger.scalar_summary('loss', loss.data[0], step + 1) 139 | step += 1 140 | 141 | # avg. loss per epoch 142 | avg_loss.append(epoch_loss / len(train_data_loader)) 143 | 144 | # prediction 145 | recon_imgs = self.model(Variable(utils.img_interp(test_input, self.scale_factor).cuda())) 146 | recon_img = recon_imgs[0].cpu().data 147 | gt_img = utils.shave(test_target[0], border_size=8) 148 | lr_img = test_input[0] 149 | bc_img = utils.shave(utils.img_interp(test_input[0], self.scale_factor), border_size=8) 150 | 151 | # calculate psnrs 152 | bc_psnr = utils.PSNR(bc_img, gt_img) 153 | recon_psnr = utils.PSNR(recon_img, gt_img) 154 | 155 | # save result images 156 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 157 | psnrs = [None, None, bc_psnr, recon_psnr] 158 | utils.plot_test_result(result_imgs, psnrs, epoch + 1, save_dir=self.save_dir, is_training=True) 159 | 160 | print("Saving training result images at epoch %d" % (epoch + 1)) 161 | 162 | # Save trained parameters of model 163 | if (epoch + 1) % self.save_epochs == 0: 164 | self.save_model(epoch + 1) 165 | 166 | # Plot avg. loss 167 | utils.plot_loss([avg_loss], self.num_epochs, save_dir=self.save_dir) 168 | print("Training is finished.") 169 | 170 | # Save final trained parameters of model 171 | self.save_model(epoch=None) 172 | 173 | def test(self): 174 | # networks 175 | self.model = Net(num_channels=self.num_channels, base_filter=64) 176 | 177 | if self.gpu_mode: 178 | self.model.cuda() 179 | 180 | # load model 181 | self.load_model() 182 | 183 | # load dataset 184 | test_data_loader = self.load_dataset(dataset='test') 185 | 186 | # Test 187 | print('Test is started.') 188 | img_num = 0 189 | self.model.eval() 190 | for input, target in test_data_loader: 191 | # input data (bicubic interpolated image) 192 | if self.gpu_mode: 193 | y_ = Variable(utils.img_interp(input, self.scale_factor).cuda()) 194 | else: 195 | y_ = Variable(utils.img_interp(input, self.scale_factor)) 196 | 197 | # prediction 198 | recon_imgs = self.model(y_) 199 | for i in range(self.test_batch_size): 200 | img_num += 1 201 | recon_img = recon_imgs[i].cpu().data 202 | gt_img = utils.shave(target[i], border_size=8) 203 | lr_img = input[i] 204 | bc_img = utils.shave(utils.img_interp(input[i], self.scale_factor), border_size=8) 205 | 206 | # calculate psnrs 207 | bc_psnr = utils.PSNR(bc_img, gt_img) 208 | recon_psnr = utils.PSNR(recon_img, gt_img) 209 | 210 | # save result images 211 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 212 | psnrs = [None, None, bc_psnr, recon_psnr] 213 | utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=self.save_dir) 214 | 215 | print("Saving %d test result images..." % img_num) 216 | 217 | def test_single(self, img_fn): 218 | # networks 219 | self.model = Net(num_channels=self.num_channels, base_filter=64) 220 | 221 | if self.gpu_mode: 222 | self.model.cuda() 223 | 224 | # load model 225 | self.load_model() 226 | 227 | # load data 228 | img = Image.open(img_fn) 229 | img = img.convert('YCbCr') 230 | y, cb, cr = img.split() 231 | 232 | input = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0]) 233 | if self.gpu_mode: 234 | input = input.cuda() 235 | 236 | self.model.eval() 237 | recon_img = self.model(input) 238 | 239 | # save result images 240 | utils.save_img(recon_img.cpu().data, 1, save_dir=self.save_dir) 241 | 242 | out = recon_img.cpu() 243 | out_img_y = out.data[0] 244 | out_img_y = (((out_img_y - out_img_y.min()) * 255) / (out_img_y.max() - out_img_y.min())).numpy() 245 | # out_img_y *= 255.0 246 | # out_img_y = out_img_y.clip(0, 255) 247 | out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L') 248 | 249 | out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) 250 | out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) 251 | out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB') 252 | 253 | # save img 254 | result_dir = os.path.join(self.save_dir, 'result') 255 | if not os.path.exists(result_dir): 256 | os.mkdir(result_dir) 257 | save_fn = result_dir + '/SR_result.png' 258 | out_img.save(save_fn) 259 | 260 | def save_model(self, epoch=None): 261 | model_dir = os.path.join(self.save_dir, 'model') 262 | if not os.path.exists(model_dir): 263 | os.mkdir(model_dir) 264 | if epoch is not None: 265 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param_epoch_%d.pkl' % epoch) 266 | else: 267 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param.pkl') 268 | 269 | print('Trained model is saved.') 270 | 271 | def load_model(self): 272 | model_dir = os.path.join(self.save_dir, 'model') 273 | 274 | model_name = model_dir + '/' + self.model_name + '_param.pkl' 275 | if os.path.exists(model_name): 276 | self.model.load_state_dict(torch.load(model_name)) 277 | print('Trained model is loaded.') 278 | return True 279 | else: 280 | print('No model exists to load.') 281 | return False 282 | -------------------------------------------------------------------------------- /srgan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from torchvision import models 6 | from base_networks import * 7 | from torch.utils.data import DataLoader 8 | from data import get_training_set, get_test_set 9 | import utils 10 | from logger import Logger 11 | from torchvision.transforms import * 12 | 13 | 14 | class Generator(torch.nn.Module): 15 | def __init__(self, num_channels, base_filter, num_residuals): 16 | super(Generator, self).__init__() 17 | 18 | self.input_conv = ConvBlock(num_channels, base_filter, 9, 1, 4, activation='prelu', norm=None) 19 | 20 | resnet_blocks = [] 21 | for _ in range(num_residuals): 22 | resnet_blocks.append(ResnetBlock(base_filter, activation='prelu')) 23 | self.residual_layers = nn.Sequential(*resnet_blocks) 24 | 25 | self.mid_conv = ConvBlock(base_filter, base_filter, 3, 1, 1, activation=None) 26 | 27 | self.upscale4x = nn.Sequential( 28 | Upsample2xBlock(base_filter, base_filter, upsample='ps', activation='prelu', norm=None), 29 | Upsample2xBlock(base_filter, base_filter, upsample='ps', activation='prelu', norm=None) 30 | ) 31 | 32 | self.output_conv = ConvBlock(base_filter, num_channels, 9, 1, 4, activation=None, norm=None) 33 | 34 | def forward(self, x): 35 | out = self.input_conv(x) 36 | residual = out 37 | out = self.residual_layers(out) 38 | out = self.mid_conv(out) 39 | out = torch.add(out, residual) 40 | out = self.upscale4x(out) 41 | out = self.output_conv(out) 42 | return out 43 | 44 | def weight_init(self, mean=0.0, std=0.02): 45 | for m in self.modules(): 46 | utils.weights_init_normal(m, mean=mean, std=std) 47 | 48 | 49 | class Discriminator(torch.nn.Module): 50 | def __init__(self, num_channels, base_filter, image_size): 51 | super(Discriminator, self).__init__() 52 | self.image_size = image_size 53 | 54 | self.input_conv = ConvBlock(num_channels, base_filter, 3, 1, 1, activation='lrelu', norm=None) 55 | 56 | self.conv_blocks = nn.Sequential( 57 | ConvBlock(base_filter, base_filter, 3, 2, 1, activation='lrelu'), 58 | ConvBlock(base_filter, base_filter * 2, 3, 1, 1, activation='lrelu'), 59 | ConvBlock(base_filter * 2, base_filter * 2, 3, 2, 1, activation='lrelu'), 60 | ConvBlock(base_filter * 2, base_filter * 4, 3, 1, 1, activation='lrelu'), 61 | ConvBlock(base_filter * 4, base_filter * 4, 3, 2, 1, activation='lrelu'), 62 | ConvBlock(base_filter * 4, base_filter * 8, 3, 1, 1, activation='lrelu'), 63 | ConvBlock(base_filter * 8, base_filter * 8, 3, 2, 1, activation='lrelu'), 64 | ) 65 | 66 | self.dense_layers = nn.Sequential( 67 | DenseBlock(base_filter * 8 * image_size // 16 * image_size // 16, base_filter * 16, activation='lrelu', 68 | norm=None), 69 | DenseBlock(base_filter * 16, 1, activation='sigmoid', norm=None) 70 | ) 71 | 72 | def forward(self, x): 73 | out = self.input_conv(x) 74 | out = self.conv_blocks(out) 75 | out = out.view(out.size()[0], -1) 76 | out = self.dense_layers(out) 77 | return out 78 | 79 | def weight_init(self, mean=0.0, std=0.02): 80 | for m in self.modules(): 81 | utils.weights_init_normal(m, mean=mean, std=std) 82 | 83 | 84 | class FeatureExtractor(torch.nn.Module): 85 | def __init__(self, netVGG, feature_layer=8): 86 | super(FeatureExtractor, self).__init__() 87 | self.features = nn.Sequential(*list(netVGG.features.children())[:(feature_layer + 1)]) 88 | 89 | def forward(self, x): 90 | return self.features(x) 91 | 92 | 93 | class SRGAN(object): 94 | def __init__(self, args): 95 | # parameters 96 | self.model_name = args.model_name 97 | self.train_dataset = args.train_dataset 98 | self.test_dataset = args.test_dataset 99 | self.crop_size = args.crop_size 100 | self.num_threads = args.num_threads 101 | self.num_channels = args.num_channels 102 | self.scale_factor = args.scale_factor 103 | self.num_epochs = args.num_epochs 104 | self.save_epochs = args.save_epochs 105 | self.batch_size = args.batch_size 106 | self.test_batch_size = args.test_batch_size 107 | self.lr = args.lr 108 | self.data_dir = args.data_dir 109 | self.save_dir = args.save_dir 110 | self.gpu_mode = args.gpu_mode 111 | 112 | def load_dataset(self, dataset, is_train=True): 113 | if self.num_channels == 1: 114 | is_gray = True 115 | else: 116 | is_gray = False 117 | 118 | if is_train: 119 | print('Loading train datasets...') 120 | train_set = get_training_set(self.data_dir, dataset, self.crop_size, self.scale_factor, is_gray=is_gray) 121 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 122 | shuffle=True) 123 | else: 124 | print('Loading test datasets...') 125 | test_set = get_test_set(self.data_dir, dataset, self.scale_factor, is_gray=is_gray) 126 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 127 | batch_size=self.test_batch_size, 128 | shuffle=False) 129 | 130 | def train(self): 131 | # load dataset 132 | train_data_loader = self.load_dataset(dataset=self.train_dataset, is_train=True) 133 | test_data_loader = self.load_dataset(dataset=self.test_dataset[0], is_train=False) 134 | 135 | # networks 136 | self.G = Generator(num_channels=self.num_channels, base_filter=64, num_residuals=16) 137 | self.D = Discriminator(num_channels=self.num_channels, base_filter=64, image_size=self.crop_size) 138 | 139 | # weigh initialization 140 | self.G.weight_init() 141 | self.D.weight_init() 142 | 143 | # For the content loss 144 | self.feature_extractor = FeatureExtractor(models.vgg19(pretrained=True)) 145 | 146 | # optimizer 147 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=self.lr, betas=(0.9, 0.999)) 148 | # self.D_optimizer = optim.Adam(self.D.parameters(), lr=self.lr, betas=(0.9, 0.999)) 149 | self.D_optimizer = optim.SGD(self.D.parameters(), lr=self.lr/100, momentum=0.9, nesterov=True) 150 | 151 | # loss function 152 | if self.gpu_mode: 153 | self.G.cuda() 154 | self.D.cuda() 155 | self.feature_extractor.cuda() 156 | self.MSE_loss = nn.MSELoss().cuda() 157 | self.BCE_loss = nn.BCELoss().cuda() 158 | else: 159 | self.MSE_loss = nn.MSELoss() 160 | self.BCE_loss = nn.BCELoss() 161 | 162 | print('---------- Networks architecture -------------') 163 | utils.print_network(self.G) 164 | utils.print_network(self.D) 165 | print('----------------------------------------------') 166 | 167 | # set the logger 168 | G_log_dir = os.path.join(self.save_dir, 'G_logs') 169 | if not os.path.exists(G_log_dir): 170 | os.mkdir(G_log_dir) 171 | G_logger = Logger(G_log_dir) 172 | 173 | D_log_dir = os.path.join(self.save_dir, 'D_logs') 174 | if not os.path.exists(D_log_dir): 175 | os.mkdir(D_log_dir) 176 | D_logger = Logger(D_log_dir) 177 | 178 | ################# Pre-train generator ################# 179 | self.epoch_pretrain = 50 180 | 181 | # Load pre-trained parameters of generator 182 | if not self.load_model(is_pretrain=True): 183 | # Pre-training generator for 50 epochs 184 | print('Pre-training is started.') 185 | self.G.train() 186 | for epoch in range(self.epoch_pretrain): 187 | for iter, (lr, hr, _) in enumerate(train_data_loader): 188 | # input data (low resolution image) 189 | if self.num_channels == 1: 190 | x_ = Variable(utils.norm(hr[:, 0].unsqueeze(1), vgg=True)) 191 | y_ = Variable(utils.norm(lr[:, 0].unsqueeze(1), vgg=True)) 192 | else: 193 | x_ = Variable(utils.norm(hr, vgg=True)) 194 | y_ = Variable(utils.norm(lr, vgg=True)) 195 | 196 | if self.gpu_mode: 197 | x_ = x_.cuda() 198 | y_ = y_.cuda() 199 | 200 | # Train generator 201 | self.G_optimizer.zero_grad() 202 | recon_image = self.G(y_) 203 | 204 | # Content losses 205 | content_loss = self.MSE_loss(recon_image, x_) 206 | 207 | # Back propagation 208 | G_loss_pretrain = content_loss 209 | G_loss_pretrain.backward() 210 | self.G_optimizer.step() 211 | 212 | # log 213 | print("Epoch: [%2d] [%4d/%4d] G_loss_pretrain: %.8f" 214 | % ((epoch + 1), (iter + 1), len(train_data_loader), G_loss_pretrain.data[0])) 215 | 216 | print('Pre-training is finished.') 217 | 218 | # Save pre-trained parameters of generator 219 | self.save_model(is_pretrain=True) 220 | 221 | ################# Adversarial train ################# 222 | print('Training is started.') 223 | # Avg. losses 224 | G_avg_loss = [] 225 | D_avg_loss = [] 226 | step = 0 227 | 228 | # test image 229 | test_lr, test_hr, test_bc = test_data_loader.dataset.__getitem__(2) 230 | test_lr = test_lr.unsqueeze(0) 231 | test_hr = test_hr.unsqueeze(0) 232 | test_bc = test_bc.unsqueeze(0) 233 | 234 | self.G.train() 235 | self.D.train() 236 | for epoch in range(self.num_epochs): 237 | 238 | # learning rate is decayed by a factor of 10 every 20 epoch 239 | if (epoch + 1) % 20 == 0: 240 | for param_group in self.G_optimizer.param_groups: 241 | param_group["lr"] /= 10.0 242 | print("Learning rate decay for G: lr={}".format(self.G_optimizer.param_groups[0]["lr"])) 243 | for param_group in self.D_optimizer.param_groups: 244 | param_group["lr"] /= 10.0 245 | print("Learning rate decay for D: lr={}".format(self.D_optimizer.param_groups[0]["lr"])) 246 | 247 | G_epoch_loss = 0 248 | D_epoch_loss = 0 249 | for iter, (lr, hr, _) in enumerate(train_data_loader): 250 | # input data (low resolution image) 251 | mini_batch = lr.size()[0] 252 | 253 | if self.num_channels == 1: 254 | x_ = Variable(utils.norm(hr[:, 0].unsqueeze(1), vgg=True)) 255 | y_ = Variable(utils.norm(lr[:, 0].unsqueeze(1), vgg=True)) 256 | else: 257 | x_ = Variable(utils.norm(hr, vgg=True)) 258 | y_ = Variable(utils.norm(lr, vgg=True)) 259 | 260 | if self.gpu_mode: 261 | x_ = x_.cuda() 262 | y_ = y_.cuda() 263 | # labels 264 | real_label = Variable(torch.ones(mini_batch).cuda()) 265 | fake_label = Variable(torch.zeros(mini_batch).cuda()) 266 | else: 267 | # labels 268 | real_label = Variable(torch.ones(mini_batch)) 269 | fake_label = Variable(torch.zeros(mini_batch)) 270 | 271 | # Reset gradient 272 | self.D_optimizer.zero_grad() 273 | 274 | # Train discriminator with real data 275 | D_real_decision = self.D(x_) 276 | D_real_loss = self.BCE_loss(D_real_decision, real_label) 277 | 278 | # Train discriminator with fake data 279 | recon_image = self.G(y_) 280 | D_fake_decision = self.D(recon_image) 281 | D_fake_loss = self.BCE_loss(D_fake_decision, fake_label) 282 | 283 | D_loss = D_real_loss + D_fake_loss 284 | 285 | # Back propagation 286 | D_loss.backward() 287 | self.D_optimizer.step() 288 | 289 | # Reset gradient 290 | self.G_optimizer.zero_grad() 291 | 292 | # Train generator 293 | recon_image = self.G(y_) 294 | D_fake_decision = self.D(recon_image) 295 | 296 | # Adversarial loss 297 | GAN_loss = self.BCE_loss(D_fake_decision, real_label) 298 | 299 | # Content losses 300 | mse_loss = self.MSE_loss(recon_image, x_) 301 | x_VGG = Variable(utils.norm(hr, vgg=True).cuda()) 302 | recon_VGG = Variable(utils.norm(recon_image.data, vgg=True).cuda()) 303 | real_feature = self.feature_extractor(x_VGG) 304 | fake_feature = self.feature_extractor(recon_VGG) 305 | vgg_loss = self.MSE_loss(fake_feature, real_feature.detach()) 306 | 307 | # Back propagation 308 | G_loss = mse_loss + 6e-3 * vgg_loss + 1e-3 * GAN_loss 309 | G_loss.backward() 310 | self.G_optimizer.step() 311 | 312 | # log 313 | G_epoch_loss += G_loss.data[0] 314 | D_epoch_loss += D_loss.data[0] 315 | print("Epoch: [%2d] [%4d/%4d] G_loss: %.8f, D_loss: %.8f" 316 | % ((epoch + 1), (iter + 1), len(train_data_loader), G_loss.data[0], D_loss.data[0])) 317 | 318 | # tensorboard logging 319 | G_logger.scalar_summary('losses', G_loss.data[0], step + 1) 320 | D_logger.scalar_summary('losses', D_loss.data[0], step + 1) 321 | step += 1 322 | 323 | # avg. loss per epoch 324 | G_avg_loss.append(G_epoch_loss / len(train_data_loader)) 325 | D_avg_loss.append(D_epoch_loss / len(train_data_loader)) 326 | 327 | # prediction 328 | if self.num_channels == 1: 329 | y_ = Variable(utils.norm(test_lr[:, 0].unsqueeze(1), vgg=True)) 330 | else: 331 | y_ = Variable(utils.norm(test_lr, vgg=True)) 332 | 333 | if self.gpu_mode: 334 | y_ = y_.cuda() 335 | 336 | recon_img = self.G(y_) 337 | sr_img = utils.denorm(recon_img[0].cpu().data, vgg=True) 338 | 339 | # save result image 340 | save_dir = os.path.join(self.save_dir, 'train_result') 341 | utils.save_img(sr_img, epoch + 1, save_dir=save_dir, is_training=True) 342 | print('Result image at epoch %d is saved.' % (epoch + 1)) 343 | 344 | # Save trained parameters of model 345 | if (epoch + 1) % self.save_epochs == 0: 346 | self.save_model(epoch + 1) 347 | 348 | # calculate psnrs 349 | if self.num_channels == 1: 350 | gt_img = test_hr[0][0].unsqueeze(0) 351 | lr_img = test_lr[0][0].unsqueeze(0) 352 | bc_img = test_bc[0][0].unsqueeze(0) 353 | else: 354 | gt_img = test_hr[0] 355 | lr_img = test_lr[0] 356 | bc_img = test_bc[0] 357 | 358 | bc_psnr = utils.PSNR(bc_img, gt_img) 359 | recon_psnr = utils.PSNR(sr_img, gt_img) 360 | 361 | # plot result images 362 | result_imgs = [gt_img, lr_img, bc_img, sr_img] 363 | psnrs = [None, None, bc_psnr, recon_psnr] 364 | utils.plot_test_result(result_imgs, psnrs, self.num_epochs, save_dir=save_dir, is_training=True) 365 | print('Training result image is saved.') 366 | 367 | # Plot avg. loss 368 | utils.plot_loss([G_avg_loss, D_avg_loss], self.num_epochs, save_dir=self.save_dir) 369 | print("Training is finished.") 370 | 371 | # Save final trained parameters of model 372 | self.save_model(epoch=None) 373 | 374 | def test(self): 375 | # networks 376 | self.G = Generator(num_channels=self.num_channels, base_filter=64, num_residuals=16) 377 | 378 | if self.gpu_mode: 379 | self.G.cuda() 380 | 381 | # load model 382 | self.load_model() 383 | 384 | # load dataset 385 | for test_dataset in self.test_dataset: 386 | test_data_loader = self.load_dataset(dataset=test_dataset, is_train=False) 387 | 388 | # Test 389 | print('Test is started.') 390 | img_num = 0 391 | total_img_num = len(test_data_loader) 392 | self.G.eval() 393 | for lr, hr, bc in test_data_loader: 394 | # input data (low resolution image) 395 | if self.num_channels == 1: 396 | y_ = Variable(utils.norm(lr[:, 0].unsqueeze(1), vgg=True)) 397 | else: 398 | y_ = Variable(utils.norm(lr, vgg=True)) 399 | 400 | if self.gpu_mode: 401 | y_ = y_.cuda() 402 | 403 | # prediction 404 | recon_imgs = self.G(y_) 405 | for i, recon_img in enumerate(recon_imgs): 406 | img_num += 1 407 | sr_img = utils.denorm(recon_img.cpu().data, vgg=True) 408 | 409 | # save result image 410 | save_dir = os.path.join(self.save_dir, 'test_result', test_dataset) 411 | utils.save_img(sr_img, img_num, save_dir=save_dir) 412 | 413 | # calculate psnrs 414 | if self.num_channels == 1: 415 | gt_img = hr[i][0].unsqueeze(0) 416 | lr_img = lr[i][0].unsqueeze(0) 417 | bc_img = bc[i][0].unsqueeze(0) 418 | else: 419 | gt_img = hr[i] 420 | lr_img = lr[i] 421 | bc_img = bc[i] 422 | 423 | bc_psnr = utils.PSNR(bc_img, gt_img) 424 | recon_psnr = utils.PSNR(sr_img, gt_img) 425 | 426 | # plot result images 427 | result_imgs = [gt_img, lr_img, bc_img, sr_img] 428 | psnrs = [None, None, bc_psnr, recon_psnr] 429 | utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=save_dir) 430 | 431 | print('Test DB: %s, Saving result images...[%d/%d]' % (test_dataset, img_num, total_img_num)) 432 | 433 | print('Test is finishied.') 434 | 435 | def test_single(self, img_fn): 436 | # networks 437 | self.G = Generator(num_channels=self.num_channels, base_filter=64, num_residuals=16) 438 | 439 | if self.gpu_mode: 440 | self.G.cuda() 441 | 442 | # load model 443 | self.load_model() 444 | 445 | # load data 446 | img = Image.open(img_fn).convert('RGB') 447 | 448 | if self.num_channels == 1: 449 | img = img.convert('YCbCr') 450 | img_y, img_cb, img_cr = img.split() 451 | 452 | input = ToTensor()(img_y) 453 | y_ = Variable(utils.norm(input.unsqueeze(1), vgg=True)) 454 | else: 455 | input = ToTensor()(img).view(1, -1, img.height, img.width) 456 | y_ = Variable(utils.norm(input, vgg=True)) 457 | 458 | if self.gpu_mode: 459 | y_ = y_.cuda() 460 | 461 | # prediction 462 | self.G.eval() 463 | recon_img = self.G(y_) 464 | recon_img = utils.denorm(recon_img.cpu().data[0].clamp(0, 1), vgg=True) 465 | recon_img = ToPILImage()(recon_img) 466 | 467 | if self.num_channels == 1: 468 | # merge color channels with super-resolved Y-channel 469 | recon_y = recon_img 470 | recon_cb = img_cb.resize(recon_y.size, Image.BICUBIC) 471 | recon_cr = img_cr.resize(recon_y.size, Image.BICUBIC) 472 | recon_img = Image.merge('YCbCr', [recon_y, recon_cb, recon_cr]).convert('RGB') 473 | 474 | # save img 475 | result_dir = os.path.join(self.save_dir, 'test_result') 476 | if not os.path.exists(result_dir): 477 | os.makedirs(result_dir) 478 | save_fn = result_dir + '/SR_result.png' 479 | recon_img.save(save_fn) 480 | 481 | print('Single test result image is saved.') 482 | 483 | def save_model(self, epoch=None, is_pretrain=False): 484 | model_dir = os.path.join(self.save_dir, 'model') 485 | if not os.path.exists(model_dir): 486 | os.mkdir(model_dir) 487 | 488 | if is_pretrain: 489 | torch.save(self.G.state_dict(), model_dir + '/' + self.model_name + '_G_param_pretrain.pkl') 490 | print('Pre-trained generator model is saved.') 491 | else: 492 | if epoch is not None: 493 | torch.save(self.G.state_dict(), model_dir + '/' + self.model_name + 494 | '_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 495 | % (self.num_channels, self.batch_size, epoch, self.lr)) 496 | torch.save(self.D.state_dict(), model_dir + '/' + self.model_name + 497 | '_D_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 498 | % (self.num_channels, self.batch_size, epoch, self.lr)) 499 | else: 500 | torch.save(self.G.state_dict(), model_dir + '/' + self.model_name + 501 | '_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 502 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr)) 503 | torch.save(self.D.state_dict(), model_dir + '/' + self.model_name + 504 | '_D_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 505 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr)) 506 | print('Trained models are saved.') 507 | 508 | def load_model(self, is_pretrain=False): 509 | model_dir = os.path.join(self.save_dir, 'model') 510 | 511 | if is_pretrain: 512 | model_name = model_dir + '/' + self.model_name + '_G_param_pretrain.pkl' 513 | if os.path.exists(model_name): 514 | self.G.load_state_dict(torch.load(model_name)) 515 | print('Pre-trained generator model is loaded.') 516 | return True 517 | else: 518 | model_name = model_dir + '/' + self.model_name + \ 519 | '_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' \ 520 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr) 521 | if os.path.exists(model_name): 522 | self.G.load_state_dict(torch.load(model_name)) 523 | print('Trained generator model is loaded.') 524 | return True 525 | 526 | return False 527 | 528 | 529 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | from math import log10 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import os 9 | import imageio 10 | from scipy.misc import imsave 11 | # from edge_detector import edge_detect 12 | 13 | 14 | def print_network(net): 15 | num_params = 0 16 | for param in net.parameters(): 17 | num_params += param.numel() 18 | print(net) 19 | print('Total number of parameters: %d' % num_params) 20 | 21 | 22 | # For logger 23 | def to_np(x): 24 | return x.data.cpu().numpy() 25 | 26 | 27 | def to_var(x): 28 | if torch.cuda.is_available(): 29 | x = torch.from_numpy(x).cuda() 30 | return Variable(x) 31 | 32 | 33 | # Plot losses 34 | def plot_loss(avg_losses, num_epochs, save_dir='', show=False): 35 | fig, ax = plt.subplots() 36 | ax.set_xlim(0, num_epochs) 37 | temp = 0.0 38 | for i in range(len(avg_losses)): 39 | temp = max(np.max(avg_losses[i]), temp) 40 | ax.set_ylim(0, temp*1.1) 41 | plt.xlabel('# of Epochs') 42 | plt.ylabel('Loss values') 43 | 44 | if len(avg_losses) == 1: 45 | plt.plot(avg_losses[0], label='loss') 46 | else: 47 | plt.plot(avg_losses[0], label='G_loss') 48 | plt.plot(avg_losses[1], label='D_loss') 49 | plt.legend() 50 | 51 | # save figure 52 | if not os.path.exists(save_dir): 53 | os.makedirs(save_dir) 54 | 55 | save_fn = 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png' 56 | save_fn = os.path.join(save_dir, save_fn) 57 | plt.savefig(save_fn) 58 | 59 | if show: 60 | plt.show() 61 | else: 62 | plt.close() 63 | 64 | 65 | # Make gif 66 | def make_gif(dataset, num_epochs, save_dir='results/'): 67 | gen_image_plots = [] 68 | for epoch in range(num_epochs): 69 | # plot for generating gif 70 | save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png' 71 | gen_image_plots.append(imageio.imread(save_fn)) 72 | 73 | imageio.mimsave(save_dir + dataset + '_result_epochs_{:d}'.format(num_epochs) + '.gif', gen_image_plots, fps=5) 74 | 75 | 76 | def weights_init_normal(m, mean=0.0, std=0.02): 77 | classname = m.__class__.__name__ 78 | if classname.find('Linear') != -1: 79 | m.weight.data.normal_(mean, std) 80 | if m.bias is not None: 81 | m.bias.data.zero_() 82 | elif classname.find('Conv2d') != -1: 83 | m.weight.data.normal_(mean, std) 84 | if m.bias is not None: 85 | m.bias.data.zero_() 86 | elif classname.find('ConvTranspose2d') != -1: 87 | m.weight.data.normal_(mean, std) 88 | if m.bias is not None: 89 | m.bias.data.zero_() 90 | elif classname.find('Norm') != -1: 91 | m.weight.data.normal_(1.0, 0.02) 92 | if m.bias is not None: 93 | m.bias.data.zero_() 94 | 95 | 96 | def weights_init_kaming(m): 97 | classname = m.__class__.__name__ 98 | if classname.find('Linear') != -1: 99 | torch.nn.init.kaiming_normal(m.weight) 100 | if m.bias is not None: 101 | m.bias.data.zero_() 102 | elif classname.find('Conv2d') != -1: 103 | torch.nn.init.kaiming_normal(m.weight) 104 | if m.bias is not None: 105 | m.bias.data.zero_() 106 | elif classname.find('ConvTranspose2d') != -1: 107 | torch.nn.init.kaiming_normal(m.weight) 108 | if m.bias is not None: 109 | m.bias.data.zero_() 110 | elif classname.find('Norm') != -1: 111 | m.weight.data.normal_(1.0, 0.02) 112 | if m.bias is not None: 113 | m.bias.data.zero_() 114 | 115 | 116 | def save_img(img, img_num, save_dir='', is_training=False): 117 | # img.clamp(0, 1) 118 | if list(img.shape)[0] == 3: 119 | save_img = img*255.0 120 | save_img = save_img.clamp(0, 255).numpy().transpose(1, 2, 0).astype(np.uint8) 121 | # img = (((img - img.min()) * 255) / (img.max() - img.min())).numpy().transpose(1, 2, 0).astype(np.uint8) 122 | else: 123 | save_img = img.squeeze().clamp(0, 1).numpy() 124 | 125 | # save img 126 | if not os.path.exists(save_dir): 127 | os.makedirs(save_dir) 128 | if is_training: 129 | save_fn = save_dir + '/SR_result_epoch_{:d}'.format(img_num) + '.png' 130 | else: 131 | save_fn = save_dir + '/SR_result_{:d}'.format(img_num) + '.png' 132 | imsave(save_fn, save_img) 133 | 134 | 135 | def plot_test_result(imgs, psnrs, img_num, save_dir='', is_training=False, show_label=True, show=False): 136 | size = list(imgs[0].shape) 137 | if show_label: 138 | h = 3 139 | w = h * len(imgs) 140 | else: 141 | h = size[2] / 100 142 | w = size[1] * len(imgs) / 100 143 | 144 | fig, axes = plt.subplots(1, len(imgs), figsize=(w, h)) 145 | # axes.axis('off') 146 | for i, (ax, img, psnr) in enumerate(zip(axes.flatten(), imgs, psnrs)): 147 | ax.axis('off') 148 | ax.set_adjustable('box-forced') 149 | if list(img.shape)[0] == 3: 150 | # Scale to 0-255 151 | # img = (((img - img.min()) * 255) / (img.max() - img.min())).numpy().transpose(1, 2, 0).astype(np.uint8) 152 | img *= 255.0 153 | img = img.clamp(0, 255).numpy().transpose(1, 2, 0).astype(np.uint8) 154 | 155 | ax.imshow(img, cmap=None, aspect='equal') 156 | else: 157 | # img = ((img - img.min()) / (img.max() - img.min())).numpy().transpose(1, 2, 0) 158 | img = img.squeeze().clamp(0, 1).numpy() 159 | ax.imshow(img, cmap='gray', aspect='equal') 160 | 161 | if show_label: 162 | ax.axis('on') 163 | if i == 0: 164 | ax.set_xlabel('HR image') 165 | elif i == 1: 166 | ax.set_xlabel('LR image') 167 | elif i == 2: 168 | ax.set_xlabel('Bicubic (PSNR: %.2fdB)' % psnr) 169 | elif i == 3: 170 | ax.set_xlabel('SR image (PSNR: %.2fdB)' % psnr) 171 | 172 | if show_label: 173 | plt.tight_layout() 174 | else: 175 | plt.subplots_adjust(wspace=0, hspace=0) 176 | plt.subplots_adjust(bottom=0) 177 | plt.subplots_adjust(top=1) 178 | plt.subplots_adjust(right=1) 179 | plt.subplots_adjust(left=0) 180 | 181 | # save figure 182 | result_dir = os.path.join(save_dir, 'plot') 183 | if not os.path.exists(result_dir): 184 | os.makedirs(result_dir) 185 | if is_training: 186 | save_fn = result_dir + '/Train_result_epoch_{:d}'.format(img_num) + '.png' 187 | else: 188 | save_fn = result_dir + '/Test_result_{:d}'.format(img_num) + '.png' 189 | plt.savefig(save_fn) 190 | 191 | if show: 192 | plt.show() 193 | else: 194 | plt.close() 195 | 196 | 197 | def shave(imgs, border_size=0): 198 | size = list(imgs.shape) 199 | if len(size) == 4: 200 | shave_imgs = torch.FloatTensor(size[0], size[1], size[2]-border_size*2, size[3]-border_size*2) 201 | for i, img in enumerate(imgs): 202 | shave_imgs[i, :, :, :] = img[:, border_size:-border_size, border_size:-border_size] 203 | return shave_imgs 204 | else: 205 | return imgs[:, border_size:-border_size, border_size:-border_size] 206 | 207 | 208 | def PSNR(pred, gt): 209 | pred = pred.clamp(0, 1) 210 | # pred = (pred - pred.min()) / (pred.max() - pred.min()) 211 | 212 | diff = pred - gt 213 | mse = np.mean(diff.numpy() ** 2) 214 | if mse == 0: 215 | return 100 216 | return 10 * log10(1.0 / mse) 217 | 218 | 219 | def norm(img, vgg=False): 220 | if vgg: 221 | # normalize for pre-trained vgg model 222 | # https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101 223 | transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], 224 | std=[0.229, 0.224, 0.225]) 225 | else: 226 | # normalize [-1, 1] 227 | transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], 228 | std=[0.5, 0.5, 0.5]) 229 | return transform(img) 230 | 231 | 232 | def denorm(img, vgg=False): 233 | if vgg: 234 | transform = transforms.Normalize(mean=[-2.118, -2.036, -1.804], 235 | std=[4.367, 4.464, 4.444]) 236 | return transform(img) 237 | else: 238 | out = (img + 1) / 2 239 | return out.clamp(0, 1) 240 | 241 | 242 | def img_interp(imgs, scale_factor, interpolation='bicubic'): 243 | if interpolation == 'bicubic': 244 | interpolation = Image.BICUBIC 245 | elif interpolation == 'bilinear': 246 | interpolation = Image.BILINEAR 247 | elif interpolation == 'nearest': 248 | interpolation = Image.NEAREST 249 | 250 | size = list(imgs.shape) 251 | 252 | if len(size) == 4: 253 | target_height = int(size[2] * scale_factor) 254 | target_width = int(size[3] * scale_factor) 255 | interp_imgs = torch.FloatTensor(size[0], size[1], target_height, target_width) 256 | for i, img in enumerate(imgs): 257 | transform = transforms.Compose([transforms.ToPILImage(), 258 | transforms.Scale((target_width, target_height), interpolation=interpolation), 259 | transforms.ToTensor()]) 260 | 261 | interp_imgs[i, :, :, :] = transform(img) 262 | return interp_imgs 263 | else: 264 | target_height = int(size[1] * scale_factor) 265 | target_width = int(size[2] * scale_factor) 266 | transform = transforms.Compose([transforms.ToPILImage(), 267 | transforms.Scale((target_width, target_height), interpolation=interpolation), 268 | transforms.ToTensor()]) 269 | return transform(imgs) 270 | -------------------------------------------------------------------------------- /vdsr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from base_networks import * 6 | from torch.utils.data import DataLoader 7 | from data import get_training_set, get_test_set 8 | import utils 9 | from logger import Logger 10 | from torchvision.transforms import * 11 | 12 | 13 | class Net(torch.nn.Module): 14 | def __init__(self, num_channels, base_filter, num_residuals): 15 | super(Net, self).__init__() 16 | 17 | self.input_conv = ConvBlock(num_channels, base_filter, 3, 1, 1, norm=None, bias=False) 18 | 19 | conv_blocks = [] 20 | for _ in range(num_residuals): 21 | conv_blocks.append(ConvBlock(base_filter, base_filter, 3, 1, 1, norm=None, bias=False)) 22 | self.residual_layers = nn.Sequential(*conv_blocks) 23 | 24 | self.output_conv = ConvBlock(base_filter, num_channels, 3, 1, 1, activation=None, norm=None, bias=False) 25 | 26 | def forward(self, x): 27 | residual = x 28 | out = self.input_conv(x) 29 | out = self.residual_layers(out) 30 | out = self.output_conv(out) 31 | out = torch.add(out, residual) 32 | return out 33 | 34 | def weight_init(self): 35 | for m in self.modules(): 36 | utils.weights_init_kaming(m) 37 | 38 | 39 | class VDSR(object): 40 | def __init__(self, args): 41 | # parameters 42 | self.model_name = args.model_name 43 | self.train_dataset = args.train_dataset 44 | self.test_dataset = args.test_dataset 45 | self.crop_size = args.crop_size 46 | self.num_threads = args.num_threads 47 | self.num_channels = args.num_channels 48 | self.scale_factor = args.scale_factor 49 | self.num_epochs = args.num_epochs 50 | self.save_epochs = args.save_epochs 51 | self.batch_size = args.batch_size 52 | self.test_batch_size = args.test_batch_size 53 | self.lr = args.lr 54 | self.data_dir = args.data_dir 55 | self.save_dir = args.save_dir 56 | self.gpu_mode = args.gpu_mode 57 | 58 | def load_dataset(self, dataset='train'): 59 | if self.num_channels == 1: 60 | is_gray = True 61 | else: 62 | is_gray = False 63 | 64 | if dataset == 'train': 65 | print('Loading train datasets...') 66 | train_set = get_training_set(self.data_dir, self.train_dataset, self.crop_size, self.scale_factor, is_gray=is_gray, 67 | normalize=False) 68 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 69 | shuffle=True) 70 | elif dataset == 'test': 71 | print('Loading test datasets...') 72 | test_set = get_test_set(self.data_dir, self.test_dataset, self.scale_factor, is_gray=is_gray, 73 | normalize=False) 74 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 75 | batch_size=self.test_batch_size, 76 | shuffle=False) 77 | 78 | def train(self): 79 | # networks 80 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_residuals=18) 81 | 82 | # weigh initialization 83 | self.model.weight_init() 84 | 85 | # optimizer 86 | self.momentum = 0.9 87 | self.weight_decay = 0.0001 88 | self.clip = 0.4 89 | self.optimizer = optim.SGD(self.model.parameters(), 90 | lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) 91 | 92 | # loss function 93 | if self.gpu_mode: 94 | self.model.cuda() 95 | self.MSE_loss = nn.MSELoss().cuda() 96 | else: 97 | self.MSE_loss = nn.MSELoss() 98 | 99 | print('---------- Networks architecture -------------') 100 | utils.print_network(self.model) 101 | print('----------------------------------------------') 102 | 103 | # load dataset 104 | train_data_loader = self.load_dataset(dataset='train') 105 | test_data_loader = self.load_dataset(dataset='test') 106 | 107 | # set the logger 108 | log_dir = os.path.join(self.save_dir, 'logs') 109 | if not os.path.exists(log_dir): 110 | os.mkdir(log_dir) 111 | logger = Logger(log_dir) 112 | 113 | ################# Train ################# 114 | print('Training is started.') 115 | avg_loss = [] 116 | step = 0 117 | 118 | # test image 119 | test_input, test_target = test_data_loader.dataset.__getitem__(2) 120 | test_input = test_input.unsqueeze(0) 121 | test_target = test_target.unsqueeze(0) 122 | 123 | self.model.train() 124 | for epoch in range(self.num_epochs): 125 | 126 | # learning rate is decayed by a factor of 10 every 20 epochs 127 | if (epoch+1) % 20 == 0: 128 | for param_group in self.optimizer.param_groups: 129 | param_group["lr"] /= 10.0 130 | print("Learning rate decay: lr={}".format(self.optimizer.param_groups[0]["lr"])) 131 | 132 | epoch_loss = 0 133 | for iter, (input, target) in enumerate(train_data_loader): 134 | # input data (bicubic interpolated image) 135 | if self.gpu_mode: 136 | x_ = Variable(target.cuda()) 137 | y_ = Variable(utils.img_interp(input, self.scale_factor).cuda()) 138 | else: 139 | x_ = Variable(target) 140 | y_ = Variable(utils.img_interp(input, self.scale_factor)) 141 | 142 | # update network 143 | self.optimizer.zero_grad() 144 | recon_image = self.model(y_) 145 | loss = self.MSE_loss(recon_image, x_) 146 | loss.backward() 147 | 148 | # gradient clipping 149 | nn.utils.clip_grad_norm(self.model.parameters(), self.clip) 150 | self.optimizer.step() 151 | 152 | # log 153 | epoch_loss += loss.data[0] 154 | print("Epoch: [%2d] [%4d/%4d] loss: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), loss.data[0])) 155 | 156 | # tensorboard logging 157 | logger.scalar_summary('loss', loss.data[0], step + 1) 158 | step += 1 159 | 160 | # avg. loss per epoch 161 | avg_loss.append(epoch_loss / len(train_data_loader)) 162 | 163 | # prediction 164 | recon_imgs = self.model(Variable(utils.img_interp(test_input, self.scale_factor).cuda())) 165 | recon_img = recon_imgs[0].cpu().data 166 | gt_img = test_target[0] 167 | lr_img = test_input[0] 168 | bc_img = utils.img_interp(test_input[0], self.scale_factor) 169 | 170 | # calculate psnrs 171 | bc_psnr = utils.PSNR(bc_img, gt_img) 172 | recon_psnr = utils.PSNR(recon_img, gt_img) 173 | 174 | # save result images 175 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 176 | psnrs = [None, None, bc_psnr, recon_psnr] 177 | utils.plot_test_result(result_imgs, psnrs, epoch + 1, save_dir=self.save_dir, is_training=True) 178 | 179 | print("Saving training result images at epoch %d" % (epoch + 1)) 180 | 181 | # Save trained parameters of model 182 | if (epoch + 1) % self.save_epochs == 0: 183 | self.save_model(epoch + 1) 184 | 185 | # Plot avg. loss 186 | utils.plot_loss([avg_loss], self.num_epochs, save_dir=self.save_dir) 187 | print("Training is finished.") 188 | 189 | # Save final trained parameters of model 190 | self.save_model(epoch=None) 191 | 192 | def test(self): 193 | # networks 194 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_residuals=18) 195 | 196 | if self.gpu_mode: 197 | self.model.cuda() 198 | 199 | # load model 200 | self.load_model() 201 | 202 | # load dataset 203 | test_data_loader = self.load_dataset(dataset='test') 204 | 205 | # Test 206 | print('Test is started.') 207 | img_num = 0 208 | self.model.eval() 209 | for input, target in test_data_loader: 210 | # input data (bicubic interpolated image) 211 | if self.gpu_mode: 212 | y_ = Variable(utils.img_interp(input, self.scale_factor).cuda()) 213 | else: 214 | y_ = Variable(utils.img_interp(input, self.scale_factor)) 215 | 216 | # prediction 217 | recon_imgs = self.model(y_) 218 | for i, recon_img in enumerate(recon_imgs): 219 | img_num += 1 220 | recon_img = recon_imgs[i].cpu().data 221 | gt_img = target[i] 222 | lr_img = input[i] 223 | bc_img = utils.img_interp(input[i], self.scale_factor) 224 | 225 | # calculate psnrs 226 | bc_psnr = utils.PSNR(bc_img, gt_img) 227 | recon_psnr = utils.PSNR(recon_img, gt_img) 228 | 229 | # save result images 230 | result_imgs = [gt_img, lr_img, bc_img, recon_img] 231 | psnrs = [None, None, bc_psnr, recon_psnr] 232 | utils.plot_test_result(result_imgs, psnrs, img_num, save_dir=self.save_dir) 233 | 234 | print("Saving %d test result images..." % img_num) 235 | 236 | def test_single(self, img_fn): 237 | # networks 238 | self.model = Net(num_channels=self.num_channels, base_filter=64, num_residuals=18) 239 | 240 | if self.gpu_mode: 241 | self.model.cuda() 242 | 243 | # load model 244 | self.load_model() 245 | 246 | # load data 247 | img = Image.open(img_fn) 248 | img = img.convert('YCbCr') 249 | y, cb, cr = img.split() 250 | y = y.resize((y.size[0] * self.scale_factor, y.size[1] * self.scale_factor), Image.BICUBIC) 251 | 252 | input = Variable(ToTensor()(y)).view(1, -1, y.size[1], y.size[0]) 253 | if self.gpu_mode: 254 | input = input.cuda() 255 | 256 | self.model.eval() 257 | recon_img = self.model(input) 258 | 259 | # save result images 260 | utils.save_img(recon_img.cpu().data, 1, save_dir=self.save_dir) 261 | 262 | out = recon_img.cpu() 263 | out_img_y = out.data[0] 264 | out_img_y = (((out_img_y - out_img_y.min()) * 255) / (out_img_y.max() - out_img_y.min())).numpy() 265 | # out_img_y *= 255.0 266 | # out_img_y = out_img_y.clip(0, 255) 267 | out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L') 268 | 269 | out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) 270 | out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) 271 | out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB') 272 | 273 | # save img 274 | result_dir = os.path.join(self.save_dir, 'result') 275 | if not os.path.exists(result_dir): 276 | os.mkdir(result_dir) 277 | save_fn = result_dir + '/SR_result.png' 278 | out_img.save(save_fn) 279 | 280 | def save_model(self, epoch=None): 281 | model_dir = os.path.join(self.save_dir, 'model') 282 | if not os.path.exists(model_dir): 283 | os.mkdir(model_dir) 284 | if epoch is not None: 285 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param_epoch_%d.pkl' % epoch) 286 | else: 287 | torch.save(self.model.state_dict(), model_dir + '/' + self.model_name + '_param.pkl') 288 | 289 | print('Trained model is saved.') 290 | 291 | def load_model(self): 292 | model_dir = os.path.join(self.save_dir, 'model') 293 | 294 | model_name = model_dir + '/' + self.model_name + '_param.pkl' 295 | if os.path.exists(model_name): 296 | self.model.load_state_dict(torch.load(model_name)) 297 | print('Trained model is loaded.') 298 | return True 299 | else: 300 | print('No model exists to load.') 301 | return False 302 | --------------------------------------------------------------------------------