├── .gitignore ├── README.md ├── LICENSE ├── dataloader.py ├── train.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | /samples/* 2 | /__pycache__/* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Super-SloMo 2 | 3 | This repository contains an Unofficial PyTorch implimentation of "Super SloMo: High Quality Estimation of Multiple Intermediate Frames for Video Interpolation", Jiang et. al. CVPR 2018 4 | 5 | The repo is in it's initial stages. This can be used as a reference to build upon. I have tried to stick to the maths in the paper, but please let me know if any changes should be made. 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 SMonk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | import cv2 5 | from PIL import Image 6 | import glob 7 | import os 8 | import random 9 | 10 | 11 | def populateTrainList(folderPath): 12 | folderList_pre = [x[0] for x in os.walk(folderPath)] 13 | folderList = [] 14 | trainList = [] 15 | 16 | for folder in folderList_pre: 17 | if folder[-3:] == '240': 18 | folderList.append(folder + "/" + folder.split("/")[-2]) 19 | 20 | 21 | for folder in folderList: 22 | imageList = sorted(glob.glob(folder + '/' + '*.jpg')) 23 | for i in range(0, len(imageList), 12): 24 | tmp = imageList[i:i+12] 25 | if len(tmp) == 12: 26 | trainList.append(imageList[i:i+12]) 27 | 28 | 29 | return trainList 30 | 31 | def populateTrainList2(folderPath): 32 | folderList = [x[0] for x in os.walk(folderPath)] 33 | trainList = [] 34 | 35 | for folder in folderList: 36 | imageList = sorted(glob.glob(folder + '/' + '*.jpg')) 37 | for i in range(0, len(imageList), 12): 38 | tmp = imageList[i:i+12] 39 | if len(tmp) == 12: 40 | trainList.append(imageList[i:i+12]) 41 | return trainList 42 | 43 | 44 | 45 | 46 | 47 | def randomCropOnList(image_list, output_size): 48 | 49 | cropped_img_list = [] 50 | 51 | h,w = output_size 52 | height, width, _ = image_list[0].shape 53 | 54 | #print(h,w,height,width) 55 | 56 | i = random.randint(0, height - h) 57 | j = random.randint(0, width - w) 58 | 59 | st_y = 0 60 | ed_y = w 61 | st_x = 0 62 | ed_x = h 63 | 64 | or_st_y = i 65 | or_ed_y = i + w 66 | or_st_x = j 67 | or_ed_x = j + h 68 | 69 | #print(st_x, ed_x, st_y, ed_y) 70 | #print(or_st_x, or_ed_x, or_st_y, or_ed_y) 71 | 72 | 73 | for img in image_list: 74 | new_img = np.empty((h,w,3), dtype=np.float32) 75 | new_img.fill(128) 76 | new_img[st_y: ed_y, st_x: ed_x, :] = img[or_st_y: or_ed_y, or_st_x: or_ed_x, :].copy() 77 | cropped_img_list.append(np.ascontiguousarray(new_img)) 78 | 79 | 80 | return cropped_img_list 81 | 82 | 83 | 84 | #print(len(populateTrainList('/home/user/data/nfs/'))) 85 | 86 | class expansionLoader(data.Dataset): 87 | 88 | def __init__(self, folderPath): 89 | 90 | self.trainList = populateTrainList2(folderPath) 91 | print("# of training samples:", len(self.trainList)) 92 | 93 | 94 | def __getitem__(self, index): 95 | 96 | img_path_list = self.trainList[index] 97 | start = random.randint(0,3) 98 | h,w,c = cv2.imread(img_path_list[0]).shape 99 | 100 | image = cv2.cv2.imread(img_path_list[0]) 101 | 102 | #print(h,w,c) 103 | 104 | if h > w: 105 | scaleX = int(360*(h/w)) 106 | scaleY = 360 107 | elif h <= w: 108 | scaleX = 360 109 | scaleY = int(360*(w/h)) 110 | 111 | 112 | 113 | img_list = [] 114 | 115 | flip = random.randint(0,1) 116 | if flip: 117 | for img_path in img_path_list[start:start+9]: 118 | tmp = cv2.resize(cv2.imread(img_path), (scaleX,scaleY))[:,:,(2,1,0)] 119 | img_list.append(np.array(cv2.flip(tmp,1), dtype=np.float32)) 120 | else: 121 | for img_path in img_path_list[start:start+9]: 122 | tmp = cv2.resize(cv2.imread(img_path), (scaleX, scaleY))[:,:,(2,1,0)] 123 | img_list.append(np.array(tmp,dtype=np.float32)) 124 | #cv2.imshow("j",tmp) 125 | #cv2.waitKey(0) & 0xff 126 | #brak 127 | for i in range(len(img_list)): 128 | #print(img_list[i].shape) 129 | #brak 130 | img_list[i] /= 255 131 | img_list[i][:,:,0] -= 0.485#(img_list[i]/127.5) - 1 132 | img_list[i][:,:,1] -= 0.456 133 | img_list[i][:,:,2] -= 0.406 134 | 135 | img_list[i][:,:,0] /= 0.229 136 | img_list[i][:,:,1] /= 0.224 137 | img_list[i][:,:,2] /= 0.225 138 | 139 | cropped_img_list = randomCropOnList(img_list,(352,352)) 140 | for i in range(len(cropped_img_list)): 141 | cropped_img_list[i] = torch.from_numpy(cropped_img_list[i].transpose((2, 0, 1))) 142 | 143 | 144 | return cropped_img_list 145 | 146 | 147 | def __len__(self): 148 | return len(self.trainList) 149 | 150 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim 6 | import os 7 | import sys 8 | import argparse 9 | import time 10 | import dataloader 11 | import model 12 | import numpy as np 13 | 14 | class FlowWarper(nn.Module): 15 | def __init__(self, w, h): 16 | super(FlowWarper, self).__init__() 17 | x = np.arange(0,w) 18 | y = np.arange(0,h) 19 | gx, gy = np.meshgrid(x,y) 20 | self.w = w 21 | self.h = h 22 | self.grid_x = torch.autograd.Variable(torch.Tensor(gx), requires_grad=False).cuda() 23 | self.grid_y = torch.autograd.Variable(torch.Tensor(gy), requires_grad=False).cuda() 24 | 25 | def forward(self, img, uv): 26 | u = uv[:,0,:,:] 27 | v = uv[:,1,:,:] 28 | X = self.grid_x.unsqueeze(0).expand_as(u) + u 29 | Y = self.grid_y.unsqueeze(0).expand_as(v) + v 30 | X = 2*(X/self.w - 0.5) 31 | Y = 2*(Y/self.h - 0.5) 32 | grid_tf = torch.stack((X,Y), dim=3) 33 | img_tf = torch.nn.functional.grid_sample(img, grid_tf) 34 | return img_tf 35 | 36 | 37 | def train_val(): 38 | 39 | #cudnn.benchmark = True 40 | flowModel = model.UNet_flow().cuda() 41 | interpolationModel = model.UNet_refine().cuda() 42 | 43 | ### ResNet for Perceptual Loss 44 | res50_model = torchvision.models.resnet18(pretrained=True) 45 | res50_conv = nn.Sequential(*list(res50_model.children())[:-2]) 46 | res50_conv.cuda() 47 | 48 | for param in res50_conv.parameters(): 49 | param.requires_grad = False 50 | 51 | 52 | #dataFeeder = dataloader.expansionLoader('/home/user/data/nfs') 53 | dataFeeder = dataloader.expansionLoader('/home/user/data/original_high_fps_videos') 54 | train_loader = torch.utils.data.DataLoader(dataFeeder, batch_size=2, 55 | shuffle=True, num_workers=1, 56 | pin_memory=True) 57 | criterion = nn.L1Loss().cuda() 58 | criterionMSE = nn.MSELoss().cuda() 59 | 60 | optimizer = torch.optim.Adam(list(flowModel.parameters()) + list(interpolationModel.parameters()), lr=0.0001) 61 | 62 | flowModel.train() 63 | interpolationModel.train() 64 | 65 | warper = FlowWarper(352,352) 66 | 67 | for epoch in range(5): 68 | for i, (imageList) in enumerate(train_loader): 69 | 70 | I0_var = torch.autograd.Variable(imageList[0]).cuda() 71 | I1_var = torch.autograd.Variable(imageList[-1]).cuda() 72 | #torchvision.utils.save_image((I0_var),'samples/'+ str(i+1) +'1.jpg',normalize=True) 73 | #brak 74 | 75 | 76 | flow_out_var = flowModel(I0_var, I1_var) 77 | 78 | F_0_1 = flow_out_var[:,:2,:,:] 79 | F_1_0 = flow_out_var[:,2:,:,:] 80 | 81 | loss_vector = [] 82 | perceptual_loss_collector = [] 83 | warping_loss_collector = [] 84 | 85 | image_collector = [] 86 | for t_ in range(1,8): 87 | 88 | t = t_/8 89 | It_var = torch.autograd.Variable(imageList[t_]).cuda() 90 | 91 | F_t_0 = -(1-t)*t*F_0_1 + t*t*F_1_0 92 | 93 | F_t_1 = (1-t)*(1-t)*F_0_1 - t*(1-t)*(F_1_0) 94 | 95 | 96 | g_I0_F_t_0 = warper(I0_var, F_t_0) 97 | g_I1_F_t_1 = warper(I1_var, F_t_1) 98 | 99 | interp_out_var = interpolationModel(I0_var, I1_var, F_0_1, F_1_0, F_t_0, F_t_1, g_I0_F_t_0, g_I1_F_t_1) 100 | F_t_0_final = interp_out_var[:,:2,:,:] + F_t_0 101 | F_t_1_final = interp_out_var[:,2:4,:,:] + F_t_1 102 | V_t_0 = torch.unsqueeze(interp_out_var[:,4,:,:],1) 103 | V_t_1 = 1 - V_t_0 104 | 105 | g_I0_F_t_0_final = warper(I0_var, F_t_0_final) 106 | g_I0_F_t_1_final = warper(I1_var, F_t_1_final) 107 | 108 | normalization = (1-t)*V_t_0 + t*V_t_1 109 | interpolated_image_t_pre = (1-t)*V_t_0*g_I0_F_t_0_final + t*V_t_1*g_I0_F_t_1_final 110 | interpolated_image_t = interpolated_image_t_pre / normalization 111 | image_collector.append(interpolated_image_t) 112 | 113 | ### Reconstruction Loss Collector ### 114 | loss_reconstruction_t = criterion(interpolated_image_t, It_var) 115 | loss_vector.append(loss_reconstruction_t) 116 | 117 | ### Perceptual Loss Collector ### 118 | feat_pred = res50_conv(interpolated_image_t) 119 | feat_gt = res50_conv(It_var) 120 | loss_perceptual_t = criterionMSE(feat_pred, feat_gt) 121 | perceptual_loss_collector.append(loss_perceptual_t) 122 | 123 | ### Warping Loss Collector ### 124 | g_I0_F_t_0_i = warper(I0_var, F_t_0) 125 | g_I1_F_t_1_i = warper(I1_var, F_t_1) 126 | loss_warping_t = criterion(g_I0_F_t_0_i, It_var) + criterion(g_I1_F_t_1_i, It_var) 127 | warping_loss_collector.append(loss_warping_t) 128 | 129 | ### Reconstruction Loss Computation ### 130 | loss_reconstruction = sum(loss_vector)/len(loss_vector) 131 | 132 | ### Perceptual Loss Computation ### 133 | loss_perceptual = sum(perceptual_loss_collector)/len(perceptual_loss_collector) 134 | 135 | ### Warping Loss Computation ### 136 | g_I0_F_1_0 = warper(I0_var, F_1_0) 137 | g_I1_F_0_1 = warper(I1_var, F_0_1) 138 | loss_warping = (criterion(g_I0_F_1_0, I1_var) + criterion(g_I1_F_0_1, I0_var)) + sum(warping_loss_collector)/len(warping_loss_collector) 139 | 140 | ### Smoothness Loss Computation ### 141 | loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) 142 | loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) 143 | loss_smooth = loss_smooth_1_0 + loss_smooth_0_1 144 | 145 | 146 | ### Overall Loss 147 | loss = 0.8*loss_reconstruction + 0.005*loss_perceptual + 0.4*loss_warping + loss_smooth 148 | 149 | ### Optimization 150 | optimizer.zero_grad() 151 | loss.backward() 152 | optimizer.step() 153 | 154 | if ((i+1) % 10) == 0: 155 | print("Loss at iteration", i+1, "/", len(train_loader), ":", loss.item()) 156 | 157 | if ((i+1) % 100) == 0: 158 | torchvision.utils.save_image((I0_var),'samples/'+ str(i+1) +'1.jpg',normalize=True) 159 | for jj,image in enumerate(image_collector): 160 | torchvision.utils.save_image((image),'samples/'+ str(i+1) + str(jj+1)+'.jpg',normalize=True) 161 | torchvision.utils.save_image((I1_var),'samples/'+str(i+1)+'9.jpg',normalize=True) 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | if __name__ == '__main__': 171 | train_val() 172 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class UNet_flow(nn.Module): 7 | 8 | def __init__(self): 9 | super(UNet_flow, self).__init__() 10 | 11 | ## Encoder 12 | self.conv1 = nn.Conv2d(6,32,7,1,3,bias=False) 13 | self.conv2 = nn.Conv2d(32,32,7,1,3,bias=False) 14 | self.relu = nn.LeakyReLU(0.1,True)#nn.ReLU(inplace=True) 15 | self.avgpool1 = nn.AvgPool2d(kernel_size=7,stride=2, padding=3) 16 | 17 | self.conv3 = nn.Conv2d(32,64,5,1,2,bias=False) 18 | self.conv4 = nn.Conv2d(64,64,5,1,2,bias=False) 19 | self.avgpool2 = nn.AvgPool2d(kernel_size=5,stride=2, padding=2) 20 | 21 | self.conv5 = nn.Conv2d(64,128,3,1,1,bias=False) 22 | self.conv6 = nn.Conv2d(128,128,3,1,1,bias=False) 23 | #self.relu = nn.ReLU(inplace=True) 24 | self.avgpool3 = nn.AvgPool2d(kernel_size=3,stride=2, padding=1) 25 | 26 | self.conv7 = nn.Conv2d(128,256,3,1,1,bias=False) 27 | self.conv8 = nn.Conv2d(256,256,3,1,1,bias=False) 28 | #self.relu = nn.ReLU(inplace=True) 29 | self.avgpool4 = nn.AvgPool2d(kernel_size=3,stride=2, padding=1) 30 | 31 | self.conv9 = nn.Conv2d(256,512,3,1,1,bias=False) 32 | self.conv10 = nn.Conv2d(512,512,3,1,1,bias=False) 33 | #self.relu = nn.ReLU(inplace=True) 34 | self.avgpool5 = nn.AvgPool2d(kernel_size=3,stride=2, padding=1) 35 | 36 | self.conv11 = nn.Conv2d(512,512,3,1,1,bias=False) 37 | self.conv12 = nn.Conv2d(512,512,3,1,1,bias=False) 38 | #self.relu = nn.ReLU(inplace=True) 39 | 40 | 41 | ## Decoder 42 | self.upsample2D = nn.Upsample(scale_factor=2, mode='bilinear') 43 | 44 | self.conv13 = nn.Conv2d(512,512,3,1,1,bias=False) 45 | self.conv14 = nn.Conv2d(512,512,3,1,1,bias=False) 46 | 47 | self.conv15 = nn.Conv2d(512,256,3,1,1,bias=False) 48 | self.conv16 = nn.Conv2d(256,256,3,1,1,bias=False) 49 | 50 | self.conv17 = nn.Conv2d(256,128,3,1,1,bias=False) 51 | self.conv18 = nn.Conv2d(128,128,3,1,1,bias=False) 52 | 53 | self.conv19 = nn.Conv2d(128,64,3,1,1,bias=False) 54 | self.conv20 = nn.Conv2d(64,64,3,1,1,bias=False) 55 | 56 | self.conv21 = nn.Conv2d(64,32,3,1,1,bias=False) 57 | self.conv22 = nn.Conv2d(32,32,3,1,1,bias=False) 58 | 59 | self.conv23 = nn.Conv2d(32,4,3,1,1,bias=False) 60 | 61 | #self.tanh = nn.Tanh() 62 | 63 | 64 | 65 | 66 | def forward(self, I0, I1): 67 | sources = [] 68 | 69 | 70 | X = torch.cat([I0, I1], 1) 71 | 72 | ## Encoder 73 | X = self.conv1(X) 74 | X = self.relu(X) 75 | X = self.conv2(X) 76 | X = self.relu(X) 77 | ##print(X.size()) 78 | sources.append(X) 79 | 80 | X = self.avgpool1(X) 81 | X = self.conv3(X) 82 | X = self.relu(X) 83 | X = self.conv4(X) 84 | X = self.relu(X) 85 | ##print(X.size()) 86 | sources.append(X) 87 | 88 | X = self.avgpool2(X) 89 | X = self.conv5(X) 90 | X = self.relu(X) 91 | X = self.conv6(X) 92 | X = self.relu(X) 93 | ##print(X.size()) 94 | sources.append(X) 95 | 96 | X = self.avgpool3(X) 97 | X = self.conv7(X) 98 | X = self.relu(X) 99 | X = self.conv8(X) 100 | X = self.relu(X) 101 | ##print(X.size()) 102 | sources.append(X) 103 | 104 | X = self.avgpool4(X) 105 | X = self.conv9(X) 106 | X = self.relu(X) 107 | X = self.conv10(X) 108 | X = self.relu(X) 109 | #print(X.size()) 110 | sources.append(X) 111 | 112 | X = self.avgpool5(X) 113 | X = self.conv11(X) 114 | X = self.relu(X) 115 | X = self.conv12(X) 116 | X = self.relu(X) 117 | #print(X.size()) 118 | 119 | ## Decoder 120 | X = self.upsample2D(X) 121 | X = self.conv13(X) 122 | X = self.relu(X) 123 | #print(X.size()) 124 | X = X + sources[-1] 125 | X = self.conv14(X) 126 | X = self.relu(X) 127 | 128 | X = self.upsample2D(X) 129 | X = self.conv15(X) 130 | X = self.relu(X) 131 | #print(X.size()) 132 | X = X + sources[-2] 133 | X = self.conv16(X) 134 | X = self.relu(X) 135 | 136 | X = self.upsample2D(X) 137 | X = self.conv17(X) 138 | X = self.relu(X) 139 | #print(X.size()) 140 | X = X + sources[-3] 141 | X = self.conv18(X) 142 | X = self.relu(X) 143 | 144 | X = self.upsample2D(X) 145 | X = self.conv19(X) 146 | X = self.relu(X) 147 | #print(X.size()) 148 | X = X + sources[-4] 149 | X = self.conv20(X) 150 | X = self.relu(X) 151 | 152 | X = self.upsample2D(X) 153 | X = self.conv21(X) 154 | X = self.relu(X) 155 | #print(X.size()) 156 | X = X + sources[-5] 157 | X = self.conv22(X) 158 | X = self.relu(X) 159 | 160 | X = self.conv23(X) 161 | X = self.relu(X) 162 | #print(X.size()) 163 | out = X#self.tanh(X) 164 | 165 | return out 166 | 167 | 168 | class UNet_refine(nn.Module): 169 | 170 | def __init__(self): 171 | super(UNet_refine, self).__init__() 172 | 173 | ## Encoder 174 | self.conv1 = nn.Conv2d(20,32,7,1,3,bias=False) 175 | self.conv2 = nn.Conv2d(32,32,7,1,3,bias=False) 176 | self.relu = nn.LeakyReLU(0.1,True)#nn.ReLU(inplace=True) 177 | self.avgpool1 = nn.AvgPool2d(kernel_size=7,stride=2, padding=3) 178 | 179 | self.conv3 = nn.Conv2d(32,64,5,1,2,bias=False) 180 | self.conv4 = nn.Conv2d(64,64,5,1,2,bias=False) 181 | self.avgpool2 = nn.AvgPool2d(kernel_size=5,stride=2, padding=2) 182 | 183 | self.conv5 = nn.Conv2d(64,128,3,1,1,bias=False) 184 | self.conv6 = nn.Conv2d(128,128,3,1,1,bias=False) 185 | #self.relu = nn.ReLU(inplace=True) 186 | self.avgpool3 = nn.AvgPool2d(kernel_size=3,stride=2, padding=1) 187 | 188 | self.conv7 = nn.Conv2d(128,256,3,1,1,bias=False) 189 | self.conv8 = nn.Conv2d(256,256,3,1,1,bias=False) 190 | #self.relu = nn.ReLU(inplace=True) 191 | self.avgpool4 = nn.AvgPool2d(kernel_size=3,stride=2, padding=1) 192 | 193 | self.conv9 = nn.Conv2d(256,512,3,1,1,bias=False) 194 | self.conv10 = nn.Conv2d(512,512,3,1,1,bias=False) 195 | #self.relu = nn.ReLU(inplace=True) 196 | self.avgpool5 = nn.AvgPool2d(kernel_size=3,stride=2, padding=1) 197 | 198 | self.conv11 = nn.Conv2d(512,512,3,1,1,bias=False) 199 | self.conv12 = nn.Conv2d(512,512,3,1,1,bias=False) 200 | #self.relu = nn.ReLU(inplace=True) 201 | 202 | 203 | ## Decoder 204 | self.upsample2D = nn.Upsample(scale_factor=2, mode='bilinear') 205 | 206 | self.conv13 = nn.Conv2d(512,512,3,1,1,bias=False) 207 | self.conv14 = nn.Conv2d(512,512,3,1,1,bias=False) 208 | 209 | self.conv15 = nn.Conv2d(512,256,3,1,1,bias=False) 210 | self.conv16 = nn.Conv2d(256,256,3,1,1,bias=False) 211 | 212 | self.conv17 = nn.Conv2d(256,128,3,1,1,bias=False) 213 | self.conv18 = nn.Conv2d(128,128,3,1,1,bias=False) 214 | 215 | self.conv19 = nn.Conv2d(128,64,3,1,1,bias=False) 216 | self.conv20 = nn.Conv2d(64,64,3,1,1,bias=False) 217 | 218 | self.conv21 = nn.Conv2d(64,32,3,1,1,bias=False) 219 | self.conv22 = nn.Conv2d(32,32,3,1,1,bias=False) 220 | 221 | self.conv23 = nn.Conv2d(32,5,3,1,1,bias=False) 222 | 223 | #self.tanh = nn.Tanh() 224 | self.sigmoid = nn.Sigmoid() 225 | 226 | 227 | 228 | 229 | def forward(self, I0, I1, F_0_1, F_1_0, F_t_0, F_t_1, g_I0_F_t_0, g_I1_F_t_1): 230 | sources = [] 231 | 232 | 233 | X = torch.cat([I0, I1, F_0_1, F_1_0, F_t_0, F_t_1, g_I0_F_t_0, g_I1_F_t_1], 1) 234 | 235 | ## Encoder 236 | X = self.conv1(X) 237 | X = self.relu(X) 238 | X = self.conv2(X) 239 | X = self.relu(X) 240 | #print(X.size()) 241 | sources.append(X) 242 | 243 | X = self.avgpool1(X) 244 | X = self.conv3(X) 245 | X = self.relu(X) 246 | X = self.conv4(X) 247 | X = self.relu(X) 248 | #print(X.size()) 249 | sources.append(X) 250 | 251 | X = self.avgpool2(X) 252 | X = self.conv5(X) 253 | X = self.relu(X) 254 | X = self.conv6(X) 255 | X = self.relu(X) 256 | #print(X.size()) 257 | sources.append(X) 258 | 259 | X = self.avgpool3(X) 260 | X = self.conv7(X) 261 | X = self.relu(X) 262 | X = self.conv8(X) 263 | X = self.relu(X) 264 | #print(X.size()) 265 | sources.append(X) 266 | 267 | X = self.avgpool4(X) 268 | X = self.conv9(X) 269 | X = self.relu(X) 270 | X = self.conv10(X) 271 | X = self.relu(X) 272 | #print(X.size()) 273 | sources.append(X) 274 | 275 | X = self.avgpool5(X) 276 | X = self.conv11(X) 277 | X = self.relu(X) 278 | X = self.conv12(X) 279 | X = self.relu(X) 280 | #print(X.size()) 281 | 282 | ## Decoder 283 | X = self.upsample2D(X) 284 | X = self.conv13(X) 285 | X = self.relu(X) 286 | #print(X.size()) 287 | X = X + sources[-1] 288 | X = self.conv14(X) 289 | X = self.relu(X) 290 | 291 | X = self.upsample2D(X) 292 | X = self.conv15(X) 293 | X = self.relu(X) 294 | #print(X.size()) 295 | X = X + sources[-2] 296 | X = self.conv16(X) 297 | X = self.relu(X) 298 | 299 | X = self.upsample2D(X) 300 | X = self.conv17(X) 301 | X = self.relu(X) 302 | #print(X.size()) 303 | X = X + sources[-3] 304 | X = self.conv18(X) 305 | X = self.relu(X) 306 | 307 | X = self.upsample2D(X) 308 | X = self.conv19(X) 309 | X = self.relu(X) 310 | #print(X.size()) 311 | X = X + sources[-4] 312 | X = self.conv20(X) 313 | X = self.relu(X) 314 | 315 | X = self.upsample2D(X) 316 | X = self.conv21(X) 317 | X = self.relu(X) 318 | #print(X.size()) 319 | X = X + sources[-5] 320 | X = self.conv22(X) 321 | X = self.relu(X) 322 | 323 | X = self.conv23(X) 324 | #X = self.relu(X) 325 | #print(X.size()) 326 | out = X#self.tanh(X) 327 | out_processed = torch.cat((self.relu(X[:,:4,:,:]),self.sigmoid(torch.unsqueeze(out[:,4,:,:],1))),1) 328 | 329 | return out_processed 330 | 331 | 332 | 333 | 334 | --------------------------------------------------------------------------------