├── LICENSE ├── README.md ├── model.py ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 hzwer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MFSR-TSM 2 | Multi-Frame Super-Resolution based on Temporal Shift Module 3 | 4 | We took part in the [Megvii 1st Open-Source Super Resolution Competition](https://studio.brainpp.com/competition) and got PSNR31.07 in Round 1. We need 10 hours for training our model in a V100 GPU card. 5 | 6 | We use TSM modules to fuse information from multiple frames, and the backbone of the network is a simplified DenseNet. 7 | 8 | # Dependencies 9 | opencv-python 10 | 11 | tensorboardX 12 | 13 | [megengine](https://megengine.org.cn/install/) 14 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import megengine as mge 2 | import megengine.module as M 3 | import megengine.functional as F 4 | from megengine.core import Parameter 5 | from utils import * 6 | 7 | def addLeakyRelu(x): 8 | return M.Sequential(x, M.LeakyReLU(0.1)) 9 | 10 | def addSig(x): 11 | return M.Sequential(x, M.Sigmoid()) 12 | 13 | def up_block(x, ic, oc): 14 | return M.ConvTranspose2d(ic, oc, 4, stride=2, padding=1) 15 | 16 | def down_block(x, ic, oc): 17 | return M.Conv2d(ic, oc, 3, padding=1, stride=2) 18 | 19 | class BasicBlock(M.Module): 20 | expansion = 1 21 | 22 | def __init__( 23 | self, 24 | in_channels, 25 | channels, 26 | stride=1, 27 | groups=1, 28 | base_width=64, 29 | dilation=1, 30 | norm=M.BatchNorm2d, 31 | ): 32 | super().__init__() 33 | if groups != 1 or base_width != 64: 34 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | self.conv1 = M.Conv2d( 38 | in_channels, channels, 3, stride, padding=dilation, bias=True 39 | ) 40 | self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=True) 41 | if in_channels == channels and stride == 1: 42 | self.downsample = M.Identity() 43 | elif stride == 1: 44 | self.downsample = M.Conv2d(in_channels, channels, 1, stride, bias=False) 45 | else: 46 | self.downsample = M.Sequential( 47 | M.AvgPool2d(kernel_size=stride, stride=stride), 48 | M.Conv2d(in_channels, channels, 1, 1, bias=False) 49 | ) 50 | self.fc1 = M.Conv2d(channels, 16, kernel_size=1) 51 | self.fc2 = M.Conv2d(16, channels, kernel_size=1) 52 | self.relu1 = M.LeakyReLU(0.1) 53 | self.relu2 = M.LeakyReLU(0.1) 54 | self.relu3 = M.LeakyReLU(0.1) 55 | 56 | def forward(self, x): 57 | identity = x 58 | x = self.conv1(x) 59 | x = self.relu1(x) 60 | x = self.conv2(x) 61 | identity = self.downsample(identity) 62 | w = x.mean(3, True).mean(2, True) 63 | w = self.relu2(self.fc1(w)) 64 | w = F.sigmoid(self.fc2(w)) 65 | x = x * w + identity 66 | x = self.relu3(x) 67 | return x 68 | 69 | def subpixel(x): 70 | shape = x.shape 71 | x = x.reshape(shape[0], shape[1] // 4, 2, 2, shape[2], shape[3]) 72 | x = F.dimshuffle(x, (0, 1, 4, 2, 5, 3)) 73 | return x.reshape(shape[0], shape[1] // 4, shape[2]*2, shape[3]*2) 74 | 75 | c = 64 76 | class SimpleUNet(M.Module): 77 | def __init__(self): 78 | super().__init__() 79 | 80 | self.conv0_ = (BasicBlock(3, 32, stride=2)) 81 | self.conv1_ = (BasicBlock(32, c, stride=2)) 82 | self.conv0 = (BasicBlock(15, 32, stride=2)) 83 | self.conv1 = (BasicBlock(32, c, stride=2)) 84 | self.conv2 = (BasicBlock(c, 2*c, stride=1)) 85 | self.conv3 = (BasicBlock(2*c, 2*c, stride=1)) 86 | self.conv4 = (BasicBlock(4*c, 2*c, stride=1)) 87 | self.conv5 = (BasicBlock(4*c, 2*c, stride=1)) 88 | self.conv6 = (BasicBlock(6*c, 2*c, stride=1)) 89 | self.conv7 = (BasicBlock(6*c, 2*c, stride=1)) 90 | self.conv8 = (BasicBlock(6*c, 2*c, stride=1)) 91 | self.conv9 = (BasicBlock(6*c, 2*c, stride=1)) 92 | self.conv10 = (BasicBlock(3*c, 4*c, stride=1)) 93 | self.conv11 = addSig(M.Conv2d(c+32, 12, 1)) 94 | 95 | def forward(self, x): 96 | size = x.shape 97 | x = x.reshape((size[0] * 5, 3) + size[2:]) 98 | conv0 = tsm(self.conv0_(x)) 99 | conv1 = tsm(self.conv1_(conv0)) 100 | # 101 | x = (x.reshape((size[0], 15) + x.shape[2:])) 102 | conv0_ = (conv0.reshape((size[0], 5) + conv0.shape[1:]))[:, 2] 103 | conv1_ = (conv1.reshape((size[0], 5) + conv1.shape[1:]))[:, 2] 104 | conv0 = self.conv0(x) 105 | conv1 = self.conv1(conv0) 106 | conv0 += conv0_ 107 | conv1 += conv1_ 108 | 109 | conv2 = (self.conv2(conv1)) 110 | conv3 = (self.conv3(conv2)) 111 | conv4 = (self.conv4(F.concat((conv3, conv2), 1))) 112 | conv5 = (self.conv5(F.concat((conv4, conv3), 1))) 113 | conv6 = (self.conv6(F.concat((conv5, conv4, conv2), 1))) 114 | conv7 = (self.conv7(F.concat((conv6, conv5, conv3), 1))) 115 | conv8 = (self.conv8(F.concat((conv7, conv6, conv4), 1))) 116 | conv9 = (self.conv9(F.concat((conv8, conv7, conv5), 1))) 117 | conv10 = subpixel(self.conv10(F.concat((conv9, conv1), 1))) 118 | conv11 = subpixel(self.conv11(F.concat((conv10, conv0), 1))) 119 | conv11 = conv11 * 2 - 1 # sigmoid to [-1, 1] 120 | 121 | return F.minimum(F.maximum(conv11 + x[:, 6:9], 0), 1) 122 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from IPython.display import Image, display 3 | import tarfile 4 | import cv2 5 | import numpy as np 6 | import os 7 | 8 | MODEL_PATH = "model.mge.state" 9 | TEST_RAW_DATA = "../../../dataset/game1/test.tar" 10 | 11 | net = SimpleUNet() 12 | 13 | with open(MODEL_PATH, 'rb') as f: 14 | net.load_state_dict(mge.load(f)['net']) 15 | 16 | @mge.jit.trace 17 | def inference(inp): 18 | return net(inp) 19 | 20 | for i in range(90, 100): 21 | cur_dir = '/home/megstudio/workspace/input/{}/'.format(i) 22 | save_dir = '/home/megstudio/workspace/test/{}/'.format(i) 23 | if not os.path.exists(save_dir): 24 | os.mkdir(save_dir) 25 | l = os.listdir(cur_dir) 26 | for j in l: 27 | if 'png' in j: 28 | img_dir = cur_dir + j 29 | img = cv2.imread(img_dir) 30 | img = cv2.resize(img, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_CUBIC) 31 | img = np.float32(img / 255.) 32 | h, w, _ = img.shape 33 | ph = ((h - 1) // 64 + 1) * 64 34 | pw = ((w - 1) // 64 + 1) * 64 35 | img = np.pad(img, ((0,ph-h),(0,pw-w),(0,0)), 'constant') 36 | print(img_dir, img.shape) 37 | img = img.transpose((2, 0, 1))[None, :, :, :] 38 | img_out = inference(img).numpy()[0] 39 | print(np.abs(img - img_out).mean()) 40 | img_out = ((img_out * 255).clip(0, 255).transpose((1, 2, 0)).copy()[:h, :w]).astype('uint8') 41 | cv2.imwrite(save_dir + j, img_out) 42 | 43 | ''' 44 | with tarfile.open(TEST_RAW_DATA, mode='r') as tar: 45 | tinfo = tar.getmember("test/90/0045.png") 46 | content = tar.extractfile(tinfo).read() 47 | img = cv2.imdecode(np.frombuffer(content, dtype='uint8'), 1) 48 | img = cv2.resize(img, (0, 0), fx=4, fy=4) 49 | img = (np.float32(img) / 256).transpose((2, 0, 1))[None, :, :, :] 50 | img_out = inference(img) 51 | img_out = (img_out.numpy() * 256).clip(0, 255)[0].transpose((1, 2, 0)).copy() 52 | content_out = cv2.imencode('.png', img_out)[1] 53 | 54 | display(Image(data=content_out, width=400)) 55 | ''' 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import math 4 | import random 5 | import os 6 | import cv2 7 | import tarfile 8 | import io 9 | import av 10 | import time 11 | from utils import * 12 | import tensorboardX 13 | import imgaug.augmenters as iaa 14 | from functools import lru_cache 15 | from megengine.optimizer import Adam 16 | from model import * 17 | from megengine.data.dataset import Dataset 18 | from megengine.data import RandomSampler, SequentialSampler 19 | from megengine.data import DataLoader 20 | from tensorboardX import * 21 | from megengine.jit import trace, SublinearMemoryConfig 22 | 23 | train_epoch = 50 24 | batch_size = 64 25 | 26 | net = SimpleUNet() 27 | optimizer = Adam(net.parameters(), lr=3e-4) 28 | 29 | train_patches = [] 30 | gt_patches = [] 31 | for i in range(5): 32 | train_patches.append(TRAIN_DATA_STORAGE + str(i)) 33 | image_list = [] 34 | config = SublinearMemoryConfig(genetic_nr_iter=20) 35 | 36 | @trace(symbolic=True, sublinear_memory_config=config) 37 | def train_iter(batch_train, batch_gt): 38 | pred = net(batch_train) 39 | loss = (((batch_gt - pred)**2 + 1e-6)**0.4).mean() 40 | optimizer.backward(loss) 41 | return loss, pred 42 | 43 | @trace(symbolic=True, sublinear_memory_config=config) 44 | def train_iterv2(batch_train, batch_gt): 45 | pred = net(batch_train) 46 | loss = ((batch_gt - pred)**2).mean() 47 | optimizer.backward(loss) 48 | return loss, pred 49 | 50 | def validate(): 51 | img_list = [] 52 | gt_list = [] 53 | l2_list = [] 54 | cubic_list = [] 55 | for i in range(537): 56 | gt = cv2.imread('/data/validate/{}.png'.format(i + 1))[256:512, 256:512, :] 57 | img = cv2.imread('/data/validate/{}_down.png'.format(i + 1))[64:128, 64:128, :] 58 | img = cv2.resize(img, (0, 0), fx=4, fy=4, interpolation=cv2.INTER_CUBIC) / 255. 59 | img = np.float32(img) 60 | h, w, _ = img.shape 61 | ph = ((h - 1) // 32 + 1) * 32 62 | pw = ((w - 1) // 32 + 1) * 32 63 | img = np.pad(img, ((0,ph-h),(0,pw-w),(0,0)), 'constant') 64 | if len(img_list) != 0 and img_list[-1].shape != img.shape: 65 | img_list = [] 66 | gt_list = [] 67 | img_list.append(img) 68 | gt_list.append(gt) 69 | if(len(img_list) > 5): 70 | img_list = img_list[-5:] 71 | gt_list = gt_list[-5:] 72 | if(len(img_list) == 5): 73 | inp = np.zeros([5, ph, pw, 3]) 74 | for k in range(5): 75 | inp[k] = img_list[k] 76 | inp = np.float32(inp) 77 | for i in range(4): 78 | if np.abs(inp[4-i] - inp[4-i-1]).mean() > 0.2: 79 | inp[4-i-1] = inp[4-i] 80 | if np.abs(inp[4+i] - inp[4+i+1]).mean() > 0.2: 81 | inp[4+i+1] = inp[4+i] 82 | inp = inp.transpose((0, 3, 1, 2)).reshape(1, 15, ph, pw) 83 | img_out = inference(inp).numpy()[0] 84 | img_out = ((img_out * 255).transpose((1, 2, 0)).copy()[:h, :w]).astype('uint8') 85 | l2_list.append(((gt_list[4][:h, :w]/255. - img_out/255.)**2).mean()) 86 | cubic_list.append(((gt_list[4][:h, :w]/255. - img_list[4][:h, :w])**2).mean()) 87 | print(10 * np.log10(1 / np.array(l2_list).mean()), 10 * np.log10(1 / np.array(cubic_list).mean())) 88 | return 10 * np.log10(1 / np.array(l2_list).mean()) 89 | 90 | @trace(symbolic=True) 91 | def inference(inp): 92 | inp = net(inp) 93 | return inp 94 | 95 | class ImageDataSet(Dataset): 96 | def __init__(self, now_dataset): 97 | super().__init__() 98 | self.data = [] 99 | cnt = 0 100 | 101 | def load(self): 102 | TRAIN_RAW_DATA='/data/train_png/' 103 | img_num = 0 104 | self.data = [] 105 | from tqdm import tqdm 106 | tasks = sorted([i for i in os.listdir(TRAIN_RAW_DATA) if 'down4x' in i]) 107 | for task in tqdm(tasks): 108 | num = task[:2] 109 | if num[0] == '0': 110 | num = num[1] 111 | if(eval(num) < 10): # the first 10 video are used for validation 112 | continue 113 | task = TRAIN_RAW_DATA + task 114 | down4x_list = [] 115 | task_origin = task.replace('_down4x.mp4','') 116 | frames_origin = sorted([os.path.join(task_origin,i) for i in os.listdir(task_origin)]) 117 | frames_down4x = sorted([os.path.join(task,i) for i in os.listdir(task)]) 118 | for k, (frame_down4x) in enumerate(frames_down4x): 119 | img_down4x = cv2.imread(frame_down4x) 120 | down4x_list.append(img_down4x) 121 | assert(len(frames_origin) == len(down4x_list)) 122 | for k, (frame_origin) in enumerate(frames_origin): 123 | if k < 4 or k + 4 >= len(down4x_list): 124 | continue 125 | if np.random.uniform(0, 1) < 0.95: 126 | continue 127 | img_origin = cv2.imread(frame_origin) 128 | x0 = 0 129 | tmp = np.array(down4x_list[k-4:k+5]) 130 | while x0 < img_origin.shape[0]: 131 | if x0 + 128 > img_origin.shape[0]: 132 | x0 = img_origin.shape[0] - 128 133 | y0 = 0 134 | while y0 < img_origin.shape[1]: 135 | if y0 + 128 > img_origin.shape[1]: 136 | y0 = img_origin.shape[1] - 128; 137 | img0 = tmp[:, x0//4:x0//4 + 32, y0//4:y0//4 + 32].copy() 138 | img1 = img_origin[x0:x0 + 128, y0:y0 + 128].copy() 139 | if img1.mean() > 10: 140 | self.data.append((img0, img1)) 141 | img_num += 1 142 | y0 += 128 143 | x0 += 128 144 | self.len = len(self.data) 145 | print(self.len) 146 | 147 | def __getitem__(self, index): 148 | img0 = (self.data[index][0].copy() / 255., self.data[index][1].copy() / 255.) 149 | aug = iaa.Sequential([ 150 | iaa.ChannelShuffle(0.5), 151 | iaa.Fliplr(0.5), 152 | iaa.Flipud(0.5), 153 | iaa.Rot90([0, 3]), 154 | ]).to_deterministic() 155 | if np.random.rand() < 0.5: 156 | inp = img0[0] 157 | gt = img0[1] 158 | else: 159 | p = np.random.uniform(0.1, 0.9) 160 | index2 = np.random.randint(0, self.len - 1) 161 | img1 = (self.data[index2][0].copy() / 255., self.data[index2][1].copy() / 255.) 162 | inp = img0[0] * p + img1[0] * (1-p) 163 | gt = img0[1] * p + img1[1] * (1-p) 164 | if np.random.rand() < 0.5: 165 | inp = inp[::-1] 166 | base = [] 167 | for i in range(9): 168 | inp[i] = aug(image=inp[i]) 169 | gt = aug(image=gt) 170 | for i in range(9): 171 | base.append(cv2.resize((inp[i]*255).astype('uint8'), (128, 128), interpolation=cv2.INTER_CUBIC) / 255.) 172 | for i in range(4): 173 | if np.abs(base[4-i] - base[4-i-1]).mean() > 0.2: 174 | base[4-i-1] = base[4-i] 175 | if np.abs(base[4+i] - base[4+i+1]).mean() > 0.2: 176 | base[4+i+1] = base[4+i] 177 | base = np.transpose(np.array(base), (0, 3, 1, 2)).reshape(15, 128, 128) 178 | gt = np.transpose(gt, (2, 0, 1)) 179 | return np.float32(base), np.float32(gt) 180 | 181 | def shuffle(self): 182 | np.random.shuffle(self.data) 183 | 184 | def __len__(self): 185 | return self.len 186 | 187 | train_dataset = ImageDataSet(train_patches[1:]) 188 | 189 | loss_acc = 0 190 | loss_acc0 = 0 191 | 192 | cnt = 0 193 | writer = SummaryWriter('./log') 194 | 195 | state = { 196 | 'net': net.state_dict(), 197 | 'opt': optimizer.state_dict(), 198 | } 199 | for epoch in range(train_epoch): 200 | if epoch % 2 == 0: 201 | train_dataset.load() 202 | random_sampler = RandomSampler(dataset=train_dataset, batch_size=batch_size, seed=epoch) 203 | image_dataloader = DataLoader( 204 | dataset=train_dataset, 205 | sampler=random_sampler, 206 | num_workers=8, 207 | ) 208 | begin = time.time() 209 | for idx, (img, label) in enumerate(image_dataloader): 210 | if idx % 100 == 0: 211 | print('{} / {}'.format(idx, train_dataset.__len__() // batch_size)) 212 | cosine_decay = 0.5 * (1 + math.cos(math.pi * cnt / (train_epoch * 2700) ) ) 213 | lr = 3e-4 * cosine_decay 214 | for g in optimizer.param_groups: 215 | g['lr'] = lr 216 | train_begin = time.time() 217 | 218 | optimizer.zero_grad() 219 | loss, pred = train_iter(img, label) 220 | optimizer.step() 221 | loss_acc = loss_acc * 0.99 + loss 222 | loss_acc0 = loss_acc0 * 0.99 + 1 223 | end = time.time() 224 | 225 | total_time = end - begin 226 | data_load_time = total_time - (end - train_begin) 227 | 228 | begin = time.time() 229 | if idx % 100 == 0: 230 | writer.add_scalar("loss",(loss_acc / loss_acc0).numpy(), cnt) 231 | writer.add_scalar("learning_rate", lr, cnt) 232 | cnt += 1 233 | 234 | print( 235 | "{}: loss: {}, speed: {:.2f}it/sec, tot: {:.4f}s, data: {:.4f}s, data/tot: {:.4f}" 236 | .format(epoch, loss_acc / (loss_acc0+1e-6), 1 / (total_time+1e-6), total_time, 237 | data_load_time, data_load_time / (total_time+1e-6))) 238 | with open('log.txt','a') as f: 239 | print( 240 | "{}: loss: {}, speed: {:.2f}it/sec, tot: {:.4f}s, data: {:.4f}s, data/tot: {:.4f}" 241 | .format(epoch, loss_acc / (loss_acc0+1e-6), 1 / (total_time+1e-6), total_time, 242 | data_load_time, data_load_time / (total_time+1e-6)), file=f) 243 | 244 | if (epoch+1) % 1 == 0: 245 | val_res = validate() 246 | val_psnr = val_res 247 | net.train() 248 | writer.add_scalar("psnr", val_psnr, cnt) 249 | print(val_res) 250 | with open('log.txt','a') as f: 251 | print(val_res, file=f) 252 | 253 | # save our model 254 | state = { 255 | 'net': net.state_dict(), 256 | 'opt': optimizer.state_dict(), 257 | } 258 | with open('model.mge.state', 'wb') as fout: 259 | mge.save(state, fout) 260 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import megengine as mge 2 | from megengine.module import Module 3 | import megengine.functional as F 4 | class Swish(Module): 5 | def __init__(self): 6 | super().__init__() 7 | self.weight = Parameter(np.ones(1, dtype=np.float32)) 8 | 9 | def forward(self, inp): 10 | return F.sigmoid(inp * self.weight) * inp 11 | 12 | def frame_count(container, video_stream=0): 13 | def count(generator): 14 | res = 0 15 | for _ in generator: 16 | res += 1 17 | return res 18 | 19 | frames = container.streams.video[video_stream].frames 20 | if frames != 0: 21 | return frames 22 | frame_series = container.decode(video=video_stream) 23 | frames = count(frame_series) 24 | container.seek(0) 25 | return frames 26 | 27 | def tsm(x): 28 | # tensor [N*T, C, H, W] 29 | size = x.shape 30 | tensor = x.reshape((-1, 5) + size[1:]) 31 | # tensor [N, T, C, H, W] 32 | p = size[1] // 4 33 | pre_tensor = tensor[:, :, :p] 34 | post_tensor = tensor[:, :, p:2*p] 35 | peri_tensor = tensor[:, :, 2*p:] 36 | pre_tensor_ = F.concat((mge.zeros(pre_tensor[:, -1: ], dtype=tensor.dtype), 37 | pre_tensor [:, :-1]), 1) 38 | post_tensor_ = F.concat((post_tensor[:, 1: ], 39 | mge.zeros(post_tensor[:, :1], dtype=tensor.dtype)), 1) 40 | output = F.concat((pre_tensor_, post_tensor_, peri_tensor), 2).reshape(size) 41 | output = tensor.reshape(size) 42 | return output 43 | --------------------------------------------------------------------------------