├── sundry ├── 1f379.gif ├── 1f3ac.gif ├── 1f43c.gif ├── 1f4a1.gif ├── 1f4f0.gif ├── 1f52c.gif ├── 1f9f3.gif ├── 听诊器.gif ├── 小提琴.gif ├── 调色盘.gif ├── cheers.gif └── 1f5c2-fe0f.gif ├── demo_images ├── DSA_1.png └── DSA_2.png ├── model ├── __init__.py ├── warplayer.py ├── refine.py ├── loss.py ├── decoder.py ├── vgg19_losses.py └── encoder.py ├── utils ├── padder.py ├── yuv_frame_io.py └── pytorch_msssim.py ├── config.py ├── GenDSA_env.txt ├── Simple_Interpolator.py ├── README.md ├── LICENSE └── Trainer.py /sundry/1f379.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f379.gif -------------------------------------------------------------------------------- /sundry/1f3ac.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f3ac.gif -------------------------------------------------------------------------------- /sundry/1f43c.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f43c.gif -------------------------------------------------------------------------------- /sundry/1f4a1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f4a1.gif -------------------------------------------------------------------------------- /sundry/1f4f0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f4f0.gif -------------------------------------------------------------------------------- /sundry/1f52c.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f52c.gif -------------------------------------------------------------------------------- /sundry/1f9f3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f9f3.gif -------------------------------------------------------------------------------- /sundry/听诊器.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/听诊器.gif -------------------------------------------------------------------------------- /sundry/小提琴.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/小提琴.gif -------------------------------------------------------------------------------- /sundry/调色盘.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/调色盘.gif -------------------------------------------------------------------------------- /sundry/cheers.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/cheers.gif -------------------------------------------------------------------------------- /demo_images/DSA_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/demo_images/DSA_1.png -------------------------------------------------------------------------------- /demo_images/DSA_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/demo_images/DSA_2.png -------------------------------------------------------------------------------- /sundry/1f5c2-fe0f.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f5c2-fe0f.gif -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import encoder 2 | from .decoder import decoder 3 | 4 | 5 | __all__ = ['encoder', 'decoder'] 6 | -------------------------------------------------------------------------------- /utils/padder.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | class InputPadder: 5 | def __init__(self, dims, divisor = 16): 6 | self.ht, self.wd = dims[-2:] 7 | pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor 8 | pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor 9 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 10 | 11 | def pad(self, *inputs): 12 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 13 | 14 | def unpad(self,x): 15 | ht, wd = x.shape[-2:] 16 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 17 | return x[..., c[0]:c[1], c[2]:c[3]] 18 | 19 | -------------------------------------------------------------------------------- /model/warplayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | backwarp_tenGrid = {} 5 | 6 | def warp(tenInput, tenFlow): 7 | k = (str(tenFlow.device), str(tenFlow.size())) 8 | if k not in backwarp_tenGrid: 9 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 10 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 11 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 12 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 13 | backwarp_tenGrid[k] = torch.cat( 14 | [tenHorizontal, tenVertical], 1).to(device) 15 | 16 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 17 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 18 | 19 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 20 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 21 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch.nn as nn 3 | from model import encoder 4 | from model import decoder 5 | 6 | 7 | def init_model_config(F=32, lambda_range='local', depth=[2, 2, 2, 4]): 8 | return { 9 | 'embed_dims':[F, 2*F, 4*F, 8*F], 10 | 'motion_dims':[0, 0, 0, 8*F//depth[-1]], 11 | 'num_heads':[4], 12 | 'mlp_ratios':[4], 13 | 'lambda_global_or_local': lambda_range, 14 | 'lambda_dim_k':16, 15 | 'lambda_dim_u':1, 16 | 'lambda_n':32, 17 | 'lambda_r':15, 18 | 'norm_layer':partial(nn.LayerNorm, eps=1e-6), 19 | 'depths':depth, 20 | }, { 21 | 'embed_dims':[F, 2*F, 4*F, 8*F], 22 | 'motion_dims':[0, 0, 0, 8*F//depth[-1]], 23 | 'depths':depth, 24 | 'scales':[4, 8], 25 | 'hidden_dims':[4*F], 26 | 'c':F 27 | } 28 | 29 | 30 | MODEL_CONFIG = { 31 | 'LOGNAME': 'GenDSA', 32 | 'MODEL_TYPE': (encoder, decoder), 33 | 'MODEL_ARCH': init_model_config( 34 | F = 32, 35 | lambda_range='local', 36 | depth = [2, 2, 2, 4] 37 | ) 38 | } 39 | -------------------------------------------------------------------------------- /GenDSA_env.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | blessed==1.20.0 3 | cachetools==5.3.1 4 | certifi==2022.12.7 5 | charset-normalizer==3.1.0 6 | contourpy==1.1.1 7 | cycler==0.12.1 8 | DateTime==5.1 9 | einops==0.6.1 10 | filelock==3.10.0 11 | fonttools==4.49.0 12 | google-auth==2.20.0 13 | google-auth-oauthlib==1.0.0 14 | gpustat==1.0.0 15 | grpcio==1.56.0 16 | huggingface-hub==0.13.3 17 | idna==3.4 18 | imageio==2.26.1 19 | importlib-metadata==6.7.0 20 | importlib-resources==6.1.1 21 | kiwisolver==1.4.5 22 | Markdown==3.4.3 23 | MarkupSafe==2.1.3 24 | matplotlib==3.7.5 25 | natsort==8.3.1 26 | networkx==3.0 27 | numpy==1.24.3 28 | nvidia-ml-py==11.495.46 29 | oauthlib==3.2.2 30 | opencv-python==4.6.0.66 31 | packaging==23.0 32 | pandas==2.0.2 33 | Pillow==9.4.0 34 | protobuf==4.23.3 35 | psutil==5.9.4 36 | pyasn1==0.5.0 37 | pyasn1-modules==0.3.0 38 | pydicom==2.4.4 39 | pyparsing==3.1.1 40 | python-dateutil==2.8.2 41 | pytz==2023.3 42 | PyWavelets==1.4.1 43 | PyYAML==6.0 44 | requests==2.28.2 45 | requests-oauthlib==1.3.1 46 | rsa==4.9 47 | scikit-image==0.19.2 48 | scipy==1.10.1 49 | setuptools==58.5.2 50 | six==1.16.0 51 | tensorboard==2.13.0 52 | tensorboard-data-server==0.7.1 53 | tensorboard-plugin-wit==1.8.1 54 | tifffile==2023.3.21 55 | timm==0.6.11 56 | torch==1.9.0+cu111 57 | torch-tb-profiler==0.4.1 58 | torchvision==0.10.0+cu111 59 | tqdm==4.65.0 60 | typing_extensions==4.5.0 61 | tzdata==2023.3 62 | urllib3==1.26.15 63 | wcwidth==0.2.6 64 | Werkzeug==2.3.6 65 | wheel==0.41.2 66 | zipp==3.15.0 67 | zope.interface==6.0 -------------------------------------------------------------------------------- /Simple_Interpolator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import fnmatch 4 | import numpy as np 5 | import cv2 6 | import math 7 | import torch 8 | import argparse 9 | import config as cfg 10 | from Trainer import Model 11 | from utils.padder import InputPadder 12 | 13 | 14 | def run_interpolator(model, Frame1, Frame2, time_list, Output_Frames_list, TTA = True): 15 | I0 = cv2.imread(Frame1) 16 | I2 = cv2.imread(Frame2) 17 | 18 | I0_ = (torch.tensor(I0.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) 19 | I2_ = (torch.tensor(I2.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) 20 | 21 | padder = InputPadder(I0_.shape, divisor=32) 22 | I0_, I2_ = padder.pad(I0_, I2_) 23 | 24 | preds = model.multi_inference(I0_, I2_, TTA = TTA, time_list = time_list, fast_TTA = TTA) 25 | 26 | for pred, Output_Frame in zip(preds, Output_Frames_list): 27 | mid_image = (padder.unpad(pred).detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1] 28 | cv2.imwrite(Output_Frame, mid_image) 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--model_path', type=str, default='./weights/checkpoints/3D-vas-Inf1.pkl') 34 | parser.add_argument('--frame1', type=str, default='./demo_images/DSA_1.png') 35 | parser.add_argument('--frame2', type=str, default='./demo_images/DSA_2.png') 36 | parser.add_argument('--inter_frames', type=int, default=1) 37 | args = parser.parse_args() 38 | 39 | model_path = args.model_path 40 | inf_folder_path = os.path.dirname(args.frame1) 41 | Interframe_num = args.inter_frames 42 | 43 | if Interframe_num == 1: 44 | TimeStepList = [0.5] 45 | Inter_Frames_list = [inf_folder_path + "//" + 'InferImage.png'] 46 | elif Interframe_num == 2: 47 | TimeStepList = [0.3333333333333333, 0.6666666666666667] 48 | Inter_Frames_list = [inf_folder_path + "//" + 'InferImage_1_in_2.png', 49 | inf_folder_path + "//" + 'InferImage_2_in_2.png'] 50 | elif Interframe_num == 3: 51 | TimeStepList = [0.25, 0.50, 0.75] 52 | Inter_Frames_list = [inf_folder_path + "//" + 'InferImage_1_in_3.png', 53 | inf_folder_path + "//" + 'InferImage_2_in_3.png', 54 | inf_folder_path + "//" + 'InferImage_3_in_3.png'] 55 | else: 56 | print("'inter_frames' invalid. Currently, 1, 2, and 3 frames are supported. You can also try training a model that interpolates more frames.") 57 | sys.exit() 58 | 59 | TTA = True 60 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 61 | F = 32, 62 | lambda_range='local', 63 | depth = [2, 2, 2, 4] 64 | ) 65 | 66 | model = Model(-1) 67 | model.load_model(full_path = model_path) 68 | model.eval() 69 | model.device() 70 | 71 | run_interpolator(model, args.frame1, args.frame2, time_list = TimeStepList, Output_Frames_list = Inter_Frames_list, TTA = TTA) 72 | -------------------------------------------------------------------------------- /model/refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from timm.models.layers import trunc_normal_ 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 10 | return nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 12 | padding=padding, dilation=dilation, bias=True), 13 | nn.PReLU(out_planes) 14 | ) 15 | 16 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 17 | return nn.Sequential( 18 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), 19 | nn.PReLU(out_planes) 20 | ) 21 | 22 | class Conv2(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride=2): 24 | super(Conv2, self).__init__() 25 | self.conv1 = conv(in_planes, out_planes, 3, stride, 1) 26 | self.conv2 = conv(out_planes, out_planes, 3, 1, 1) 27 | 28 | def forward(self, x): 29 | x = self.conv1(x) 30 | x = self.conv2(x) 31 | return x 32 | 33 | class Unet(nn.Module): 34 | def __init__(self, c, out=3): 35 | super(Unet, self).__init__() 36 | self.down0 = Conv2(17+c, 2*c) 37 | self.down1 = Conv2(4*c, 4*c) 38 | self.down2 = Conv2(8*c, 8*c) 39 | self.down3 = Conv2(16*c, 16*c) 40 | self.up0 = deconv(32*c, 8*c) 41 | self.supple = Conv2(4*c, 8*c) 42 | 43 | self.up1 = deconv(16*c, 4*c) 44 | self.up2 = deconv(8*c, 2*c) 45 | self.up3 = deconv(4*c, c) 46 | self.conv = nn.Conv2d(c, out, 3, 1, 1) 47 | self.apply(self._init_weights) 48 | 49 | def _init_weights(self, m): 50 | if isinstance(m, nn.Linear): 51 | trunc_normal_(m.weight, std=.02) 52 | if isinstance(m, nn.Linear) and m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | elif isinstance(m, nn.LayerNorm): 55 | nn.init.constant_(m.bias, 0) 56 | nn.init.constant_(m.weight, 1.0) 57 | elif isinstance(m, nn.Conv2d): 58 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 59 | fan_out //= m.groups 60 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 61 | if m.bias is not None: 62 | m.bias.data.zero_() 63 | 64 | def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): 65 | s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow,c0[0], c1[0]), 1)) 66 | s1 = self.down1(torch.cat((s0, c0[1], c1[1]), 1)) 67 | s2 = self.down2(torch.cat((s1, c0[2], c1[2]), 1)) 68 | s3 = self.down3(torch.cat((s2, c0[3], c1[3]), 1)) 69 | c0_4 = self.supple(c0[3]) 70 | c1_4 = self.supple(c1[3]) 71 | x = self.up0(torch.cat((s3, c0_4, c1_4), 1)) 72 | 73 | x = self.up1(torch.cat((x, s2), 1)) 74 | x = self.up2(torch.cat((x, s1), 1)) 75 | x = self.up3(torch.cat((x, s0), 1)) 76 | x = self.conv(x) 77 | return torch.sigmoid(x) 78 | -------------------------------------------------------------------------------- /utils/yuv_frame_io.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import getopt 3 | import math 4 | import numpy 5 | import random 6 | import logging 7 | import numpy as np 8 | from skimage.color import rgb2yuv, yuv2rgb 9 | from PIL import Image 10 | import os 11 | from shutil import copyfile 12 | 13 | 14 | class YUV_Read(): 15 | def __init__(self, filepath, h, w, format='yuv420', toRGB=True): 16 | 17 | self.h = h 18 | self.w = w 19 | 20 | self.fp = open(filepath, 'rb') 21 | 22 | if format == 'yuv420': 23 | self.frame_length = int(1.5 * h * w) 24 | self.Y_length = h * w 25 | self.Uv_length = int(0.25 * h * w) 26 | else: 27 | pass 28 | self.toRGB = toRGB 29 | 30 | def read(self, offset_frame=None): 31 | if not offset_frame == None: 32 | self.fp.seek(offset_frame * self.frame_length, 0) 33 | 34 | Y = np.fromfile(self.fp, np.uint8, count=self.Y_length) 35 | U = np.fromfile(self.fp, np.uint8, count=self.Uv_length) 36 | V = np.fromfile(self.fp, np.uint8, count=self.Uv_length) 37 | if Y.size < self.Y_length or \ 38 | U.size < self.Uv_length or \ 39 | V.size < self.Uv_length: 40 | return None, False 41 | 42 | Y = np.reshape(Y, [self.w, self.h], order='F') 43 | Y = np.transpose(Y) 44 | 45 | U = np.reshape(U, [int(self.w / 2), int(self.h / 2)], order='F') 46 | U = np.transpose(U) 47 | 48 | V = np.reshape(V, [int(self.w / 2), int(self.h / 2)], order='F') 49 | V = np.transpose(V) 50 | 51 | U = np.array(Image.fromarray(U).resize([self.w, self.h])) 52 | V = np.array(Image.fromarray(V).resize([self.w, self.h])) 53 | 54 | if self.toRGB: 55 | Y = Y / 255.0 56 | U = U / 255.0 - 0.5 57 | V = V / 255.0 - 0.5 58 | 59 | self.YUV = np.stack((Y, U, V), axis=-1) 60 | self.RGB = (255.0 * np.clip(yuv2rgb(self.YUV), 0.0, 1.0)).astype('uint8') 61 | 62 | self.YUV = None 63 | return self.RGB, True 64 | else: 65 | self.YUV = np.stack((Y, U, V), axis=-1) 66 | return self.YUV, True 67 | 68 | def close(self): 69 | self.fp.close() 70 | 71 | 72 | class YUV_Write(): 73 | def __init__(self, filepath, fromRGB=True): 74 | if os.path.exists(filepath): 75 | print(filepath) 76 | 77 | self.fp = open(filepath, 'wb') 78 | self.fromRGB = fromRGB 79 | 80 | def write(self, Frame): 81 | 82 | self.h = Frame.shape[0] 83 | self.w = Frame.shape[1] 84 | c = Frame.shape[2] 85 | 86 | assert c == 3 87 | if format == 'yuv420': 88 | self.frame_length = int(1.5 * self.h * self.w) 89 | self.Y_length = self.h * self.w 90 | self.Uv_length = int(0.25 * self.h * self.w) 91 | else: 92 | pass 93 | if self.fromRGB: 94 | Frame = Frame / 255.0 95 | YUV = rgb2yuv(Frame) 96 | Y, U, V = np.dsplit(YUV, 3) 97 | Y = Y[:, :, 0] 98 | U = U[:, :, 0] 99 | V = V[:, :, 0] 100 | U = np.clip(U + 0.5, 0.0, 1.0) 101 | V = np.clip(V + 0.5, 0.0, 1.0) 102 | 103 | U = U[::2, ::2] 104 | V = V[::2, ::2] 105 | Y = (255.0 * Y).astype('uint8') 106 | U = (255.0 * U).astype('uint8') 107 | V = (255.0 * V).astype('uint8') 108 | else: 109 | YUV = Frame 110 | Y = YUV[:, :, 0] 111 | U = YUV[::2, ::2, 1] 112 | V = YUV[::2, ::2, 2] 113 | 114 | Y = Y.flatten() 115 | U = U.flatten() 116 | V = V.flatten() 117 | 118 | Y.tofile(self.fp) 119 | U.tofile(self.fp) 120 | V.tofile(self.fp) 121 | 122 | return True 123 | 124 | def close(self): 125 | self.fp.close() 126 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from . import vgg19_losses as vgg19 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def gauss_kernel(channels=3): 11 | kernel = torch.tensor([[1., 4., 6., 4., 1], 12 | [4., 16., 24., 16., 4.], 13 | [6., 24., 36., 24., 6.], 14 | [4., 16., 24., 16., 4.], 15 | [1., 4., 6., 4., 1.]]) 16 | kernel /= 256. 17 | kernel = kernel.repeat(channels, 1, 1, 1) 18 | kernel = kernel.to(device) 19 | return kernel 20 | 21 | def downsample(x): 22 | return x[:, :, ::2, ::2] 23 | 24 | def upsample(x): 25 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) 26 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) 27 | cc = cc.permute(0,1,3,2) 28 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3) 29 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2) 30 | x_up = cc.permute(0,1,3,2) 31 | return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1])) 32 | 33 | def conv_gauss(img, kernel): 34 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') 35 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) 36 | return out 37 | 38 | def laplacian_pyramid(img, kernel, max_levels=3): 39 | current = img 40 | pyr = [] 41 | for level in range(max_levels): 42 | filtered = conv_gauss(current, kernel) 43 | down = downsample(filtered) 44 | up = upsample(down) 45 | diff = current-up 46 | pyr.append(diff) 47 | current = down 48 | return pyr 49 | 50 | class LapLoss(torch.nn.Module): 51 | def __init__(self, max_levels=5, channels=3): 52 | super(LapLoss, self).__init__() 53 | self.max_levels = max_levels 54 | self.gauss_kernel = gauss_kernel(channels=channels) 55 | 56 | def forward(self, input, target): 57 | pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) 58 | pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) 59 | return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) 60 | 61 | class Ternary(nn.Module): 62 | def __init__(self, device): 63 | super(Ternary, self).__init__() 64 | patch_size = 7 65 | out_channels = patch_size * patch_size 66 | self.w = np.eye(out_channels).reshape( 67 | (patch_size, patch_size, 1, out_channels)) 68 | self.w = np.transpose(self.w, (3, 2, 0, 1)) 69 | self.w = torch.tensor(self.w).float().to(device) 70 | 71 | def transform(self, img): 72 | patches = F.conv2d(img, self.w, padding=3, bias=None) 73 | transf = patches - img 74 | transf_norm = transf / torch.sqrt(0.81 + transf**2) 75 | return transf_norm 76 | 77 | def rgb2gray(self, rgb): 78 | r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] 79 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 80 | return gray 81 | 82 | def hamming(self, t1, t2): 83 | dist = (t1 - t2) ** 2 84 | dist_norm = torch.mean(dist / (0.1 + dist), 1, True) 85 | return dist_norm 86 | 87 | def valid_mask(self, t, padding): 88 | n, _, h, w = t.size() 89 | inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) 90 | mask = F.pad(inner, [padding] * 4) 91 | return mask 92 | 93 | def forward(self, img0, img1): 94 | img0 = self.transform(self.rgb2gray(img0)) 95 | img1 = self.transform(self.rgb2gray(img1)) 96 | return self.hamming(img0, img1) * self.valid_mask(img0, 1) 97 | 98 | class Perceptual_Loss(torch.nn.Module): 99 | def __init__(self): 100 | super(Perceptual_Loss, self).__init__() 101 | 102 | def forward(self, image: torch.Tensor, reference: torch.Tensor, vgg_model_file: str, weights = None,): 103 | return vgg19.perceptual_loss(image, reference, vgg_model_file, weights) 104 | 105 | 106 | class Style_Loss(torch.nn.Module): 107 | def __init__(self): 108 | super(Style_Loss, self).__init__() 109 | 110 | def forward(self, image: torch.Tensor, reference: torch.Tensor, vgg_model_file: str, weights = None,): 111 | return vgg19.style_loss(image, reference, vgg_model_file, weights) 112 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | from .refine import * 7 | 8 | 9 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 10 | return nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 12 | padding=padding, dilation=dilation, bias=True), 13 | nn.PReLU(out_planes) 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self, in_planes, scale, c, in_else=17): 19 | super(Head, self).__init__() 20 | self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2)) 21 | self.scale = scale 22 | self.conv = nn.Sequential( 23 | conv(in_planes*2 // (4*4) + in_else, c), 24 | conv(c, c), 25 | conv(c, 5), 26 | ) 27 | 28 | def forward(self, motion_feature, x, flow): 29 | motion_feature = self.upsample(motion_feature) 30 | 31 | if self.scale != 4: 32 | x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) 33 | 34 | if flow != None: 35 | if self.scale != 4: 36 | flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale 37 | x = torch.cat((x, flow), 1) 38 | 39 | x = self.conv(torch.cat([motion_feature, x], 1)) 40 | 41 | if self.scale != 4: 42 | x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False) 43 | flow = x[:, :4] * (self.scale // 4) 44 | else: 45 | flow = x[:, :4] 46 | 47 | mask = x[:, 4:5] 48 | 49 | return flow, mask 50 | 51 | 52 | class decoder(nn.Module): 53 | def __init__(self, backbone, **kargs): 54 | super(decoder, self).__init__() 55 | self.flow_num_stage = len(kargs['hidden_dims']) 56 | self.feature_bone = backbone 57 | self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i], 58 | kargs['scales'][-1-i], 59 | kargs['hidden_dims'][-1-i], 60 | 6 if i==0 else 17) 61 | for i in range(self.flow_num_stage)]) 62 | self.unet = Unet(kargs['c'] * 2) 63 | 64 | def warp_features(self, xs, flow): 65 | y0 = [] 66 | y1 = [] 67 | B = xs[0].size(0) // 2 68 | for x in xs: 69 | y0.append(warp(x[:B], flow[:, 0:2])) 70 | y1.append(warp(x[B:], flow[:, 2:4])) 71 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 72 | return y0, y1 73 | 74 | def calculate_flow(self, imgs, timestep, af=None, mf=None): 75 | img0, img1 = imgs[:, :3], imgs[:, 3:6] 76 | B = img0.size(0) 77 | flow, mask = None, None 78 | if (af is None) or (mf is None): 79 | af, mf = self.feature_bone(img0, img1) 80 | for i in range(self.flow_num_stage): 81 | t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda() 82 | if flow != None: 83 | warped_img0 = warp(img0, flow[:, :2]) 84 | warped_img1 = warp(img1, flow[:, 2:4]) 85 | flow_, mask_ = self.block[i]( 86 | torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), 87 | torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), 88 | flow 89 | ) 90 | flow = flow + flow_ 91 | mask = mask + mask_ 92 | else: 93 | flow, mask = self.block[i]( 94 | torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), 95 | torch.cat((img0, img1), 1), 96 | None 97 | ) 98 | 99 | return flow, mask 100 | 101 | def coraseWarp_and_Refine(self, imgs, af, flow, mask): 102 | img0, img1 = imgs[:, :3], imgs[:, 3:6] 103 | warped_img0 = warp(img0, flow[:, :2]) 104 | warped_img1 = warp(img1, flow[:, 2:4]) 105 | c0, c1 = self.warp_features(af, flow) 106 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 107 | res = tmp[:, :3] * 2 - 1 108 | mask_ = torch.sigmoid(mask) 109 | merged = warped_img0 * mask_ + warped_img1 * (1 - mask_) 110 | pred = torch.clamp(merged + res, 0, 1) 111 | return pred 112 | 113 | 114 | def forward(self, x, timestep=0.5): 115 | img0, img1 = x[:, :3], x[:, 3:6] 116 | B = x.size(0) 117 | 118 | flow_list = [] 119 | merged = [] 120 | mask_list = [] 121 | warped_img0 = img0 122 | warped_img1 = img1 123 | flow = None 124 | 125 | af, mf = self.feature_bone(img0, img1) 126 | 127 | for i in range(self.flow_num_stage): 128 | t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda() 129 | 130 | if flow != None: 131 | flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), 132 | torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow) 133 | flow = flow + flow_d 134 | mask = mask + mask_d 135 | else: 136 | flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), 137 | torch.cat((img0, img1), 1), None) 138 | 139 | mask_list.append(torch.sigmoid(mask)) 140 | flow_list.append(flow) 141 | 142 | warped_img0 = warp(img0, flow[:, :2]) 143 | warped_img1 = warp(img1, flow[:, 2:4]) 144 | merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i])) 145 | 146 | c0, c1 = self.warp_features(af, flow) 147 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 148 | res = tmp[:, :3] * 2 - 1 149 | pred = torch.clamp(merged[-1] + res, 0, 1) 150 | return flow_list, mask_list, merged, pred -------------------------------------------------------------------------------- /utils/pytorch_msssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel=1): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) 16 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 17 | return window 18 | 19 | 20 | def create_window_3d(window_size, channel=1): 21 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 22 | _2D_window = _1D_window.mm(_1D_window.t()) 23 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) 24 | window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) 25 | return window 26 | 27 | 28 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 29 | if val_range is None: 30 | if torch.max(img1) > 128: 31 | max_val = 255 32 | else: 33 | max_val = 1 34 | 35 | if torch.min(img1) < -0.5: 36 | min_val = -1 37 | else: 38 | min_val = 0 39 | L = max_val - min_val 40 | else: 41 | L = val_range 42 | 43 | padd = 0 44 | (_, channel, height, width) = img1.size() 45 | if window is None: 46 | real_size = min(window_size, height, width) 47 | window = create_window(real_size, channel=channel).to(img1.device) 48 | 49 | mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) 50 | mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) 51 | 52 | mu1_sq = mu1.pow(2) 53 | mu2_sq = mu2.pow(2) 54 | mu1_mu2 = mu1 * mu2 55 | 56 | sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq 57 | sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq 58 | sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2 59 | 60 | C1 = (0.01 * L) ** 2 61 | C2 = (0.03 * L) ** 2 62 | 63 | v1 = 2.0 * sigma12 + C2 64 | v2 = sigma1_sq + sigma2_sq + C2 65 | cs = torch.mean(v1 / v2) 66 | 67 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 68 | 69 | if size_average: 70 | ret = ssim_map.mean() 71 | else: 72 | ret = ssim_map.mean(1).mean(1).mean(1) 73 | 74 | if full: 75 | return ret, cs 76 | return ret 77 | 78 | 79 | def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 80 | if val_range is None: 81 | if torch.max(img1) > 128: 82 | max_val = 255 83 | else: 84 | max_val = 1 85 | 86 | if torch.min(img1) < -0.5: 87 | min_val = -1 88 | else: 89 | min_val = 0 90 | L = max_val - min_val 91 | else: 92 | L = val_range 93 | 94 | padd = 0 95 | (_, _, height, width) = img1.size() 96 | if window is None: 97 | real_size = min(window_size, height, width) 98 | window = create_window_3d(real_size, channel=1).to(img1.device) 99 | 100 | img1 = img1.unsqueeze(1) 101 | img2 = img2.unsqueeze(1) 102 | 103 | mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 104 | mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 105 | 106 | mu1_sq = mu1.pow(2) 107 | mu2_sq = mu2.pow(2) 108 | mu1_mu2 = mu1 * mu2 109 | 110 | sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq 111 | sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq 112 | sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 113 | 114 | C1 = (0.01 * L) ** 2 115 | C2 = (0.03 * L) ** 2 116 | 117 | v1 = 2.0 * sigma12 + C2 118 | v2 = sigma1_sq + sigma2_sq + C2 119 | cs = torch.mean(v1 / v2) 120 | 121 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 122 | 123 | if size_average: 124 | ret = ssim_map.mean() 125 | else: 126 | ret = ssim_map.mean(1).mean(1).mean(1) 127 | 128 | if full: 129 | return ret, cs 130 | return ret 131 | 132 | 133 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 134 | device = img1.device 135 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 136 | levels = weights.size()[0] 137 | mssim = [] 138 | mcs = [] 139 | for _ in range(levels): 140 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 141 | mssim.append(sim) 142 | mcs.append(cs) 143 | 144 | img1 = F.avg_pool2d(img1, (2, 2)) 145 | img2 = F.avg_pool2d(img2, (2, 2)) 146 | 147 | mssim = torch.stack(mssim) 148 | mcs = torch.stack(mcs) 149 | 150 | if normalize: 151 | mssim = (mssim + 1) / 2 152 | mcs = (mcs + 1) / 2 153 | 154 | pow1 = mcs ** weights 155 | pow2 = mssim ** weights 156 | 157 | output = torch.prod(pow1[:-1] * pow2[-1]) 158 | return output 159 | 160 | 161 | 162 | class SSIM(torch.nn.Module): 163 | def __init__(self, window_size=11, size_average=True, val_range=None): 164 | super(SSIM, self).__init__() 165 | self.window_size = window_size 166 | self.size_average = size_average 167 | self.val_range = val_range 168 | 169 | self.channel = 3 170 | self.window = create_window(window_size, channel=self.channel) 171 | 172 | def forward(self, img1, img2): 173 | (_, channel, _, _) = img1.size() 174 | 175 | if channel == self.channel and self.window.dtype == img1.dtype: 176 | window = self.window 177 | else: 178 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 179 | self.window = window 180 | self.channel = channel 181 | 182 | _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 183 | dssim = (1 - _ssim) / 2 184 | return dssim 185 | 186 | class MSSSIM(torch.nn.Module): 187 | def __init__(self, window_size=11, size_average=True, channel=3): 188 | super(MSSSIM, self).__init__() 189 | self.window_size = window_size 190 | self.size_average = size_average 191 | self.channel = channel 192 | 193 | def forward(self, img1, img2): 194 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) -------------------------------------------------------------------------------- /model/vgg19_losses.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple 2 | 3 | import numpy as np 4 | import scipy.io as sio 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def _build_net(layer_type: str, input_tensor: torch.Tensor, 12 | weight_bias: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Callable[[Any], Any]: 13 | if layer_type == 'conv': 14 | return F.relu(F.conv2d(input = input_tensor, weight = weight_bias[0], bias = weight_bias[1], stride=1, padding=0)) 15 | elif layer_type == 'pool': 16 | return F.avg_pool2d(input = input_tensor, kernel_size=2, stride=2) 17 | else: 18 | raise ValueError('Unsupported layer types: %s' % layer_type) 19 | 20 | 21 | def _get_weight_and_bias(vgg_layers: np.ndarray, 22 | index: int) -> Tuple[torch.Tensor, torch.Tensor]: 23 | weights = vgg_layers[index][0][0][2][0][0] 24 | weights = torch.tensor(weights).permute(3, 2, 0, 1).to(device) 25 | bias = vgg_layers[index][0][0][2][0][1] 26 | bias = torch.tensor(np.reshape(bias, bias.size)).to(device) 27 | 28 | return weights, bias 29 | 30 | 31 | def _build_vgg19(image: torch.Tensor, model_filepath: str) -> Dict[str, torch.Tensor]: 32 | net = {} 33 | if not hasattr(_build_vgg19, 'vgg_rawnet'): 34 | _build_vgg19.vgg_rawnet = sio.loadmat(model_filepath) 35 | vgg_layers = _build_vgg19.vgg_rawnet['layers'][0] 36 | imagenet_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape(1, 3, 1, 1).to(device) 37 | net['input'] = image - imagenet_mean 38 | net['conv1_1'] = _build_net( 39 | 'conv', 40 | net['input'], 41 | _get_weight_and_bias(vgg_layers, 0)) 42 | net['conv1_2'] = _build_net( 43 | 'conv', 44 | net['conv1_1'], 45 | _get_weight_and_bias(vgg_layers, 2)) 46 | net['pool1'] = _build_net('pool', net['conv1_2']) 47 | net['conv2_1'] = _build_net( 48 | 'conv', 49 | net['pool1'], 50 | _get_weight_and_bias(vgg_layers, 5)) 51 | net['conv2_2'] = _build_net( 52 | 'conv', 53 | net['conv2_1'], 54 | _get_weight_and_bias(vgg_layers, 7)) 55 | net['pool2'] = _build_net('pool', net['conv2_2']) 56 | net['conv3_1'] = _build_net( 57 | 'conv', 58 | net['pool2'], 59 | _get_weight_and_bias(vgg_layers, 10)) 60 | net['conv3_2'] = _build_net( 61 | 'conv', 62 | net['conv3_1'], 63 | _get_weight_and_bias(vgg_layers, 12)) 64 | net['conv3_3'] = _build_net( 65 | 'conv', 66 | net['conv3_2'], 67 | _get_weight_and_bias(vgg_layers, 14)) 68 | net['conv3_4'] = _build_net( 69 | 'conv', 70 | net['conv3_3'], 71 | _get_weight_and_bias(vgg_layers, 16)) 72 | net['pool3'] = _build_net('pool', net['conv3_4']) 73 | net['conv4_1'] = _build_net( 74 | 'conv', 75 | net['pool3'], 76 | _get_weight_and_bias(vgg_layers, 19)) 77 | net['conv4_2'] = _build_net( 78 | 'conv', 79 | net['conv4_1'], 80 | _get_weight_and_bias(vgg_layers, 21)) 81 | net['conv4_3'] = _build_net( 82 | 'conv', 83 | net['conv4_2'], 84 | _get_weight_and_bias(vgg_layers, 23)) 85 | net['conv4_4'] = _build_net( 86 | 'conv', 87 | net['conv4_3'], 88 | _get_weight_and_bias(vgg_layers, 25)) 89 | net['pool4'] = _build_net('pool', net['conv4_4']) 90 | net['conv5_1'] = _build_net( 91 | 'conv', 92 | net['pool4'], 93 | _get_weight_and_bias(vgg_layers, 28)) 94 | net['conv5_2'] = _build_net( 95 | 'conv', 96 | net['conv5_1'], 97 | _get_weight_and_bias(vgg_layers, 30)) 98 | 99 | return net 100 | 101 | 102 | def _compute_error(fake: torch.Tensor, 103 | real: torch.Tensor, 104 | mask: Optional[torch.Tensor] = None) -> torch.Tensor: 105 | if mask is None: 106 | return torch.mean(torch.abs(fake - real)) 107 | else: 108 | size = (fake.size(2), fake.size(3)) 109 | resized_mask = F.interpolate(mask, size=size, mode='bilinear', align_corners=False) 110 | return torch.mean(torch.abs(fake - real) * resized_mask) 111 | 112 | 113 | def perceptual_loss(image: torch.Tensor, 114 | reference: torch.Tensor, 115 | vgg_model_file: str, 116 | weights: Optional[Sequence[float]] = None, 117 | mask: Optional[torch.Tensor] = None) -> torch.Tensor: 118 | if weights is None: 119 | weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] 120 | 121 | vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file) 122 | vgg_img = _build_vgg19(image * 255.0, vgg_model_file) 123 | 124 | p1 = _compute_error(vgg_ref['conv1_2'], vgg_img['conv1_2'], mask) * weights[0] 125 | p2 = _compute_error(vgg_ref['conv2_2'], vgg_img['conv2_2'], mask) * weights[1] 126 | p3 = _compute_error(vgg_ref['conv3_2'], vgg_img['conv3_2'], mask) * weights[2] 127 | p4 = _compute_error(vgg_ref['conv4_2'], vgg_img['conv4_2'], mask) * weights[3] 128 | p5 = _compute_error(vgg_ref['conv5_2'], vgg_img['conv5_2'], mask) * weights[4] 129 | 130 | final_loss = p1 + p2 + p3 + p4 + p5 131 | final_loss /= 255.0 132 | 133 | return final_loss 134 | 135 | 136 | def _compute_gram_matrix(input_features: torch.Tensor, 137 | mask: torch.Tensor) -> torch.Tensor: 138 | b, c, h, w = input_features.size() 139 | if mask is None: 140 | reshaped_features = input_features.view(b, c, h * w) 141 | else: 142 | resized_mask = F.interpolate( 143 | mask, size=(h, w), mode='bilinear', align_corners=False) 144 | reshaped_features = (input_features * resized_mask).view(b, c, h * w) 145 | return torch.matmul( 146 | reshaped_features, reshaped_features.transpose(1, 2)) / float(h * w) 147 | 148 | 149 | def style_loss(image: torch.Tensor, 150 | reference: torch.Tensor, 151 | vgg_model_file: str, 152 | weights: Optional[Sequence[float]] = None, 153 | mask: Optional[torch.Tensor] = None) -> torch.Tensor: 154 | if not weights: 155 | weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] 156 | 157 | vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file) 158 | vgg_img = _build_vgg19(image * 255.0, vgg_model_file) 159 | 160 | p1 = torch.mean( 161 | torch.square( 162 | _compute_gram_matrix(vgg_ref['conv1_2'] / 255.0, mask) - 163 | _compute_gram_matrix(vgg_img['conv1_2'] / 255.0, mask))) * weights[0] 164 | p2 = torch.mean( 165 | torch.square( 166 | _compute_gram_matrix(vgg_ref['conv2_2'] / 255.0, mask) - 167 | _compute_gram_matrix(vgg_img['conv2_2'] / 255.0, mask))) * weights[1] 168 | p3 = torch.mean( 169 | torch.square( 170 | _compute_gram_matrix(vgg_ref['conv3_2'] / 255.0, mask) - 171 | _compute_gram_matrix(vgg_img['conv3_2'] / 255.0, mask))) * weights[2] 172 | p4 = torch.mean( 173 | torch.square( 174 | _compute_gram_matrix(vgg_ref['conv4_2'] / 255.0, mask) - 175 | _compute_gram_matrix(vgg_img['conv4_2'] / 255.0, mask))) * weights[3] 176 | p5 = torch.mean( 177 | torch.square( 178 | _compute_gram_matrix(vgg_ref['conv5_2'] / 255.0, mask) - 179 | _compute_gram_matrix(vgg_img['conv5_2'] / 255.0, mask))) * weights[4] 180 | 181 | final_loss = p1 + p2 + p3 + p4 + p5 182 | return final_loss 183 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

GenDSA

3 |

Large-scale Pretrained Multi-Frame Generative Model Enables Real-Time Low-Dose DSA Imaging

4 | 5 | [![paper](https://img.shields.io/badge/Paper-Online_Available_(click_here)-orange)](https://www.cell.com/cms/10.1016/j.medj.2024.07.025/attachment/4209f5f8-1a77-4e15-a872-80a6239ec7bd/mmc3.pdf#:~:text=development%20and%20multi-center%20validation%20study,%20Med%20(2024),) [![license](https://img.shields.io/badge/License-Apache_2.0_(click_here)-blue)](LICENSE) 6 | 7 | Huangxuan Zhao1 🏷️ :email:,[Ziyang Xu](https://ziyangxu.top/)2 🏷️,Linxia Wu1 🏷️, Lei Chen1 🏷️, [Ziwei Cui](https://github.com/ziwei-cui)2, Jinqiang Ma1, Tao Sun1, Yu Lei1, Nan Wang3, Hongyao Hu4, Yiqing Tan5, Wei Lu6, Wenzhong Yang7, Kaibing Liao8, Gaojun Teng9, Xiaoyun Liang10, Yi Li10, Congcong Feng11, Tong Nie1, Xiaoyu Han1, P.Matthijs van der Sluijs12, Charles B.L.M. Majoie13, Wim H. van Zwam14, Yun Feng15, Theo van Walsum11, Aad van der Lugt11, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu/)2, Xuefeng Kan1 :email:, Ruisheng Su11 :email:, Weihua Zhang9 :email:, [Xinggang Wang](https://xwcv.github.io/)2 :email:, Chuansheng Zheng1 :email: 8 | 9 | (🏷️) equal contribution, (:email:) corresponding author. 10 | 11 | 1 Department of Radiology, Union Hospital, Tongji Medical College, Huazhong University of Science and Technology, Wuhan, China. 12 | 2 Institute of AI, School of Electronic Information and Communications, Huazhong University of Science and Technology, Wuhan, China. 13 | 3 Department of Radiology, Tongji Hospital, Tongji Medical College, Huazhong University of Science and Technology, Wuhan, China. 14 | 4 Department of Interventional Radiology, Renmin Hospital of Wuhan University, Wuhan, China. 15 | 5 Department of Radiology, Tongren Hospital of Wuhan University (Wuhan Third Hospital), Wuhan University, Wuhan, China. 16 | 6 Department of Interventional Radiology, Zhongnan Hospital of Wuhan University, Wuhan, China. 17 | 7 Department of Radiology, Maternal and Child Health Hospital of Hubei Province, Wuhan, China. 18 | 8 Department of Radiology, Hubei Integrated Traditional Chinese and Western Medicine Hospital, Wuhan, China. 19 | 9 Department of Radiology, Zhongda Hospital, Medical School, Southeast University, Nanjing, China. 20 | 10 Institute of Research and Clinical Innovations, Neusoft Medical Systems, Co., Ltd, Shanghai, China. 21 | 11 CV Systems Research and Development Department, Neusoft Medical Systems, Co., Ltd, Shenyang, China. 22 | 12 Department of Radiology & Nuclear Medicine, Erasmus MC, University Medical Center Rotterdam, The Netherlands. 23 | 13 Department of Radiology and Nuclear Medicine, Amsterdam University Medical Centers, location AMC, Amsterdam, The Netherlands. 24 | 14 Department of Radiology and Nuclear Medicine, Cardiovascular Research Institute Maastricht, Maastricht University Medical Center, Maastricht, The Netherlands. 25 | 15 Center for Biological Imaging, Institute of Biophysics, Chinese Academy of Sciences, Beijing, China. 26 | 27 | 28 |
29 | 30 | ## News 31 | 32 | * **`October 3, 2025`:** 🔥🔥 GenDSA-V2 is accepted by Nature Medicine! For details, please refer [this repo](https://github.com/ZrH42/GenDSA-V2). 33 | 34 | * **`January 25, 2025`:** 🔥🔥🔥 GenDSA-V2 coming soon! 🔥🔥🔥 We further scaled up the data and conducted a large number of clinical trials at multiple centers. The new work is under review, please pay attention! 35 | 36 | * **`September 4, 2024`:** [Paper](https://www.cell.com/cms/10.1016/j.medj.2024.07.025/attachment/4209f5f8-1a77-4e15-a872-80a6239ec7bd/mmc3.pdf#:~:text=development%20and%20multi-center%20validation%20study,%20Med%20(2024),) is online and available now. Enjoy! 37 | 38 | * **`July 5, 2024`:** GenDSA is received by Med (Cell Press journal)! 39 | 40 | * **`June 5, 2024`:** We released a portion of the 3D vascular and 3D non vascular datasets. ([link here](https://drive.google.com/drive/folders/1t-esIdUnVcZdFXmSGhGcwhcpI0KyGKY0?usp=sharing)) 41 | 42 | * **`March 27, 2024`:** We released our inference code. Paper/Project pages are coming soon. Please stay tuned! 43 | 44 | ## Abstract 45 | Digital subtraction angiography (DSA) devices have been commonly used in hundreds of different interventional procedures in various parts of the body, requiring multiple scans of the patient in a single procedure, which was high radiation damage to doctors and patients. Inspired by generative artificial intelligence techniques, this study proposed a large-scale pretrained multi-frame generative model-based real-time and low-dose DSA imaging system (GenDSA). Suitable for most DSA scanning protocols, GenDSA could reduce the DSA frame rate (i.e., radiation dose) to 1/3 and generates video that was virtually identical to clinically available protocols. GenDSA was pre-trained, fine-tuned and tested on ten million of images from 35 hospitals. Objective quantitative metrics (PSNR=36.83, SSIM=0.911, generated times=0.07s/frame) demonstrated that the GenDSA’s performance surpassed that of state-of-the-art algorithms in the field of image frame generation. Subjective ratings and statistical results from five doctors showed that the generated videos reached a comparable level to the full-sampled videos, both in terms of overall quality (4.905 vs 4.935) and lesion assessment (4.825 vs 4.860), which fully demonstrated the potential of GenDSA for clinical applications. 46 | 47 | 48 | ## Environment Setups 49 | 50 | * python 3.8 51 | * cudatoolkit 11.2.1 52 | * cudnn 8.1.0.77 53 | * See 'requirements_Ours.txt' for Python libraries required 54 | 55 | ```shell 56 | conda create -n GenDSA python=3.8 57 | conda activate GenDSA 58 | conda install cudatoolkit=11.2.1 cudnn=8.1.0.77 59 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 60 | # cd /xx/xx/GenDSA 61 | pip install -r GenDSA_env.txt 62 | ``` 63 | 64 | 65 | ## Model Checkpoints 66 | Download the zip of [model checkpoints](https://share.weiyun.com/ze6bOv0i) (key:```mqfd5s```), decompress and put all pkl files into ../GenDSA/weights/checkpoints. 67 | 68 | ## Our Dataset and Inference Cases 69 | We released a portion of the 3D vascular and non vascular datasets, including the results of our model inference. ([Google Drive](https://drive.google.com/drive/folders/1t-esIdUnVcZdFXmSGhGcwhcpI0KyGKY0?usp=sharing)) 70 | 71 | 72 | ## Inference Demo 73 | Run the following commands to generate single/multi-frame interpolation: 74 | 75 | * Single-frame interpolation 76 | ```shell 77 | python Simple_Interpolator.py \ 78 | --model_path ./weights/checkpoints/3D-vas-Inf1.pkl \ 79 | --frame1 ./demo_images/DSA_1.png \ 80 | --frame2 ./demo_images/DSA_2.png \ 81 | --inter_frames 1 82 | ``` 83 | 84 | * Two-frame interpolation 85 | ```shell 86 | python Simple_Interpolator.py \ 87 | --model_path ./weights/checkpoints/3D-vas-Inf2.pkl \ 88 | --frame1 ./demo_images/DSA_1.png \ 89 | --frame2 ./demo_images/DSA_2.png \ 90 | --inter_frames 2 91 | ``` 92 | 93 | * Three-frame interpolation 94 | ```shell 95 | python Simple_Interpolator.py \ 96 | --model_path ./weights/checkpoints/3D-vas-Inf3.pkl \ 97 | --frame1 ./demo_images/DSA_1.png \ 98 | --frame2 ./demo_images/DSA_2.png \ 99 | --inter_frames 3 100 | ``` 101 | 102 | You can also use other checkpoints to generate 1~3 frame interpolation for your 2D/3D - Head/Abdomen/Thorax/Pelvic/Periph images. 103 | 104 | ## 💖 Citation 105 | Please promote and cite our work if you find it helpful. Enjoy! 106 | ```shell 107 | @article{zhao2024large, 108 | title={Large-scale pretrained frame generative model enables real-time low-dose DSA imaging: An AI system development and multi-center validation study}, 109 | author={Zhao, Huangxuan and Xu, Ziyang and Chen, Lei and Wu, Linxia and Cui, Ziwei and Ma, Jinqiang and Sun, Tao and Lei, Yu and Wang, Nan and Hu, Hongyao and others}, 110 | journal={Med}, 111 | year={2024}, 112 | publisher={Elsevier}, 113 | url={https://doi.org/10.1016/j.medj.2024.07.025}, 114 | } 115 | ``` 116 | 117 | 118 | 124 | 125 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /Trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | from torch.optim import AdamW 6 | from model.loss import * 7 | from model.warplayer import warp 8 | from config import * 9 | 10 | 11 | class Model: 12 | def __init__(self, local_rank): 13 | backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE'] 14 | backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH'] 15 | self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg) 16 | self.name = MODEL_CONFIG['LOGNAME'] 17 | self.device() 18 | 19 | self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4) 20 | self.lap = LapLoss() 21 | self.ploss = Perceptual_Loss() 22 | self.styloss = Style_Loss() 23 | if local_rank != -1: 24 | self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank, broadcast_buffers = False) 25 | 26 | def train(self): 27 | self.net.train() 28 | 29 | def eval(self): 30 | self.net.eval() 31 | 32 | def device(self): 33 | self.net.to(torch.device("cuda")) 34 | 35 | def load_model(self, name = None, folder_path = None, full_path = None, rank = 0): 36 | def convert(param): 37 | return { 38 | k.replace("module.", ""): v 39 | for k, v in param.items() 40 | if "module." in k and 'attn_mask' not in k and 'HW' not in k 41 | } 42 | if rank <= 0 : 43 | if full_path is not None: 44 | self.net.load_state_dict(convert(torch.load(full_path))) 45 | else: 46 | if name is None: 47 | name = self.name 48 | if folder_path is None: 49 | self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl'))) 50 | else: 51 | self.net.load_state_dict(convert(torch.load(folder_path + f'/{name}.pkl'))) 52 | 53 | def load_pretrain_weight(self, name = None, folder_path = None, full_path = None, rank = 0): 54 | if rank <= 0 : 55 | if full_path is not None: 56 | self.net.load_state_dict(torch.load(full_path)) 57 | else: 58 | if name is None: 59 | name = self.name 60 | if folder_path is None: 61 | self.net.load_state_dict(torch.load(f'ckpt/{name}.pkl')) 62 | else: 63 | self.net.load_state_dict(torch.load(folder_path + f'/{name}.pkl')) 64 | 65 | def save_model(self, name = None, folder_path = None, rank=0): 66 | if rank == 0: 67 | if name is None: 68 | name = self.name 69 | if folder_path is None: 70 | torch.save(self.net.state_dict(), f'ckpt/{name}.pkl') 71 | else: 72 | torch.save(self.net.state_dict(), folder_path + f'/{name}.pkl') 73 | 74 | @torch.no_grad() 75 | def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False): 76 | ''' 77 | Infer with down_scale flow 78 | Noting: return BxCxHxW 79 | ''' 80 | def infer(imgs): 81 | img0, img1 = imgs[:, :3], imgs[:, 3:6] 82 | imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False) 83 | 84 | flow, mask = self.net.calculate_flow(imgs_down, timestep) 85 | 86 | flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale) 87 | mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) 88 | 89 | af, _ = self.net.feature_bone(img0, img1) 90 | pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask) 91 | return pred 92 | 93 | imgs = torch.cat((img0, img1), 1) 94 | if fast_TTA: 95 | imgs_ = imgs.flip(2).flip(3) 96 | input = torch.cat((imgs, imgs_), 0) 97 | preds = infer(input) 98 | return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2. 99 | 100 | if TTA == False: 101 | return infer(imgs) 102 | else: 103 | return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2 104 | 105 | @torch.no_grad() 106 | def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False): 107 | imgs = torch.cat((img0, img1), 1) 108 | ''' 109 | Noting: return BxCxHxW 110 | ''' 111 | if fast_TTA: 112 | imgs_ = imgs.flip(2).flip(3) 113 | input = torch.cat((imgs, imgs_), 0) 114 | _, _, _, preds = self.net(input, timestep=timestep) 115 | return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2. 116 | 117 | _, _, _, pred = self.net(imgs, timestep=timestep) 118 | if TTA == False: 119 | return pred 120 | else: 121 | _, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep) 122 | return (pred + pred2.flip(2).flip(3)) / 2 123 | 124 | @torch.no_grad() 125 | def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False): 126 | ''' 127 | Run backbone once, get multi frames at different timesteps 128 | Noting: return a list of [CxHxW] 129 | ''' 130 | assert len(time_list) > 0, 'Time_list should not be empty!' 131 | def infer(imgs): 132 | img0, img1 = imgs[:, :3], imgs[:, 3:6] 133 | af, mf = self.net.feature_bone(img0, img1) 134 | imgs_down = None 135 | if down_scale != 1.0: 136 | imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False) 137 | afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6]) 138 | 139 | pred_list = [] 140 | for timestep in time_list: 141 | if imgs_down is None: 142 | flow, mask = self.net.calculate_flow(imgs, timestep, af, mf) 143 | else: 144 | flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd) 145 | flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale) 146 | mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) 147 | 148 | pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask) 149 | pred_list.append(pred) 150 | 151 | return pred_list 152 | 153 | imgs = torch.cat((img0, img1), 1) 154 | if fast_TTA: 155 | imgs_ = imgs.flip(2).flip(3) 156 | input = torch.cat((imgs, imgs_), 0) 157 | preds_lst = infer(input) 158 | return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))] 159 | 160 | preds = infer(imgs) 161 | if TTA is False: 162 | return [preds[i][0] for i in range(len(time_list))] 163 | else: 164 | flip_pred = infer(imgs.flip(2).flip(3)) 165 | return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))] 166 | 167 | def update(self, imgs, gt, learning_rate=0, training=True): 168 | for param_group in self.optimG.param_groups: 169 | param_group['lr'] = learning_rate 170 | if training: 171 | self.train() 172 | else: 173 | self.eval() 174 | 175 | if training: 176 | flow, mask, merged, pred = self.net(imgs) 177 | loss_l1 = (self.lap(pred, gt)).mean() 178 | 179 | factor = 1.0 / len(merged) 180 | for merge in merged: 181 | loss_l1 += (self.lap(merge, gt)).mean() * factor 182 | 183 | self.optimG.zero_grad() 184 | loss_l1.backward() 185 | self.optimG.step() 186 | return pred, loss_l1 187 | else: 188 | with torch.no_grad(): 189 | flow, mask, merged, pred = self.net(imgs) 190 | return pred, 0 191 | 192 | def multi_gts_update(self, imgs, gts, TimeStepList:list, learning_rate=0, training=True): 193 | for param_group in self.optimG.param_groups: 194 | param_group['lr'] = learning_rate 195 | if training: 196 | self.train() 197 | else: 198 | self.eval() 199 | 200 | if training: 201 | loss_l1_all = 0 202 | preds = [] 203 | 204 | for timestep, i in zip(TimeStepList, range(len(TimeStepList))): 205 | flow, mask, merged, pred = self.net(imgs, timestep) 206 | gt_index1 = i * 3 207 | gt_index2 = (i + 1) * 3 208 | loss_l1 = (self.lap(pred, gts[:, gt_index1 : gt_index2])).mean() 209 | 210 | loss_l1_all += loss_l1 211 | preds.append(pred) 212 | 213 | factor = 1.0 / len(merged) 214 | for merge in merged: 215 | loss_l1_all += (self.lap(merge, gts[:, gt_index1 : gt_index2])).mean() * factor 216 | 217 | self.optimG.zero_grad() 218 | loss_l1_all.backward() 219 | self.optimG.step() 220 | return preds, loss_l1_all 221 | else: 222 | with torch.no_grad(): 223 | preds = [] 224 | 225 | for timestep in TimeStepList: 226 | flow, mask, merged, pred = self.net(imgs, timestep) 227 | preds.append(pred) 228 | 229 | return preds, 0 230 | 231 | def multi_gts_losses_update(self, imgs, gts, TimeStepList:list, vgg_model_file:str = '', losses_weight_schedules:list = [], now_epoch:int = 0, now_step:int = 0, learning_rate=0, training=True): 232 | ''' 233 | vgg_model_file: MATLAB format file path for VGG 19 network weights. 234 | losses_weight_schedules: Weight plan for each loss. 235 | 236 | Specific content schematic: 237 | ---------- 238 | losses_weight_schedules = [ 239 | {'boundaries_epoch':[0], 'boundaries_step':[0], 'values':[1.0, 1.0]}, 240 | {'boundaries_epoch':[0], 'boundaries_step':[2400], 'values':[1.0, 0.25]}, 241 | {'boundaries_epoch':[2], 'boundaries_step':[2400], 'values':[0.0, 40.0]}] 242 | ---------- 243 | Prioritize the boundaries specified by the boundaries_epoch. If boundaries_epoch is empty, it is specified by boundaries_step. 244 | Before the epoch specified by boundaries_epoch, the weight of a loss is calculated as values [0], and then as values [1]. Boundaries_step is the same. 245 | ''' 246 | def decide_values(losses_weight_schedules:list, now_epoch:int, now_step:int): 247 | ''' 248 | Based on the current number of epochs, steps (iters), and the stage weight plan, determine the weight of each loss 249 | ''' 250 | l1_epoch = losses_weight_schedules[0]['boundaries_epoch'] 251 | p_epoch = losses_weight_schedules[1]['boundaries_epoch'] 252 | sty_epoch = losses_weight_schedules[2]['boundaries_epoch'] 253 | 254 | l1_step = losses_weight_schedules[0]['boundaries_step'] 255 | p_step = losses_weight_schedules[1]['boundaries_step'] 256 | sty_step = losses_weight_schedules[2]['boundaries_step'] 257 | 258 | if ((len(l1_epoch) != 0) & (len(p_epoch) != 0) & (len(sty_epoch) != 0)): # Prioritize the boundaries specified by the boundaries_epoch. 259 | if now_epoch < l1_epoch[0]: 260 | l1_value = losses_weight_schedules[0]['values'][0] 261 | else: 262 | l1_value = losses_weight_schedules[0]['values'][1] 263 | if now_epoch < p_epoch[0]: 264 | p_value = losses_weight_schedules[1]['values'][0] 265 | else: 266 | p_value = losses_weight_schedules[1]['values'][1] 267 | if now_epoch < sty_epoch[0]: 268 | sty_value = losses_weight_schedules[2]['values'][0] 269 | else: 270 | sty_value = losses_weight_schedules[2]['values'][1] 271 | elif ((len(l1_step) != 0) & (len(p_step) != 0) & (len(sty_step) != 0)): # Otherwise, boundaries specified by boundaries_step. 272 | if now_step < l1_step[0]: 273 | l1_value = losses_weight_schedules[0]['values'][0] 274 | else: 275 | l1_value = losses_weight_schedules[0]['values'][1] 276 | if now_step < p_step[0]: 277 | p_value = losses_weight_schedules[1]['values'][0] 278 | else: 279 | p_value = losses_weight_schedules[1]['values'][1] 280 | if now_step < sty_step[0]: 281 | sty_value = losses_weight_schedules[2]['values'][0] 282 | else: 283 | sty_value = losses_weight_schedules[2]['values'][1] 284 | else: 285 | print("'losses_weight_schedules' is illegal, it needs to meet the following conditions: (1) the boundaries_epoch of each loss is not empty, or (2) the boundaries_step of each loss is not empty!") 286 | sys.exit() 287 | 288 | return l1_value, p_value, sty_value 289 | 290 | 291 | for param_group in self.optimG.param_groups: 292 | param_group['lr'] = learning_rate 293 | if training: 294 | self.train() 295 | else: 296 | self.eval() 297 | 298 | if training: 299 | loss_all = 0 300 | preds = [] 301 | l1_weight, p_weight, sty_weight = decide_values(losses_weight_schedules, now_epoch, now_step) 302 | 303 | for timestep, i in zip(TimeStepList, range(len(TimeStepList))): 304 | flow, mask, merged, pred = self.net(imgs, timestep) 305 | gt_index1 = i * 3 306 | gt_index2 = (i + 1) * 3 307 | if l1_weight != 0: 308 | loss_l1 = (self.lap(pred, gts[:, gt_index1 : gt_index2])).mean() 309 | print("loss_l1", loss_l1) 310 | loss_all += loss_l1 * l1_weight 311 | if p_weight != 0: 312 | loss_p = (self.ploss(pred, gts[:, gt_index1 : gt_index2], vgg_model_file)).mean() 313 | print("loss_p", loss_p) 314 | loss_all += loss_p * p_weight 315 | if sty_weight != 0: 316 | loss_style = (self.styloss(pred, gts[:, gt_index1 : gt_index2], vgg_model_file)).mean() 317 | print("loss_style", loss_style) 318 | loss_all += loss_style * sty_weight 319 | 320 | preds.append(pred) 321 | 322 | factor = 1.0 / len(merged) 323 | for merge in merged: 324 | if l1_weight != 0: 325 | loss_all += (self.lap(merge, gts[:, gt_index1 : gt_index2])).mean() * factor * l1_weight 326 | if p_weight != 0: 327 | loss_all += (self.ploss(merge, gts[:, gt_index1 : gt_index2], vgg_model_file)).mean() * factor * p_weight 328 | if sty_weight != 0: 329 | loss_all += (self.styloss(merge, gts[:, gt_index1 : gt_index2], vgg_model_file)).mean() * factor * sty_weight 330 | 331 | self.optimG.zero_grad() 332 | loss_all.backward() 333 | self.optimG.step() 334 | return preds, loss_all 335 | else: 336 | with torch.no_grad(): 337 | preds = [] 338 | 339 | for timestep in TimeStepList: 340 | flow, mask, merged, pred = self.net(imgs, timestep) 341 | preds.append(pred) 342 | 343 | return preds, 0 344 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import einsum 4 | from einops import rearrange 5 | import math 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def default(val, d): 13 | return val if exists(val) else d 14 | 15 | def calc_rel_pos(n): 16 | pos = torch.meshgrid(torch.arange(n), torch.arange(n)) 17 | pos = rearrange(torch.stack(pos), 'n i j -> (i j) n') 18 | rel_pos = pos[None, :] - pos[:, None] 19 | rel_pos += n - 1 20 | return rel_pos 21 | 22 | 23 | class Mlp(nn.Module): 24 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 25 | super().__init__() 26 | out_features = out_features or in_features 27 | hidden_features = hidden_features or in_features 28 | self.fc1 = nn.Linear(in_features, hidden_features) 29 | self.dwconv = DWConv(hidden_features) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features) 32 | self.drop = nn.Dropout(drop) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.apply(self._init_weights) 35 | 36 | def _init_weights(self, m): 37 | if isinstance(m, nn.Linear): 38 | trunc_normal_(m.weight, std=.02) 39 | if isinstance(m, nn.Linear) and m.bias is not None: 40 | nn.init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.LayerNorm): 42 | nn.init.constant_(m.bias, 0) 43 | nn.init.constant_(m.weight, 1.0) 44 | elif isinstance(m, nn.Conv2d): 45 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 46 | fan_out //= m.groups 47 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 48 | if m.bias is not None: 49 | m.bias.data.zero_() 50 | 51 | def forward(self, x, H, W): 52 | x = self.fc1(x) 53 | x = self.dwconv(x, H, W) 54 | x = self.act(x) 55 | x = self.drop(x) 56 | x = self.fc2(x) 57 | x = self.drop(x) 58 | return x 59 | 60 | 61 | class InterFrameLambda(nn.Module): 62 | def __init__(self, dim, *, dim_k, n = None, r = None, heads = 4, dim_out = None, dim_u = 1, motion_dim): 63 | super().__init__() 64 | dim_out = default(dim_out, dim) 65 | self.u = dim_u 66 | self.heads = heads 67 | 68 | assert (dim_out % heads) == 0, "'dim_out' must be divisible by number of heads for multi-head query" 69 | assert (motion_dim % heads) == 0, "'motion_dim' must be divisible by number of heads for multi-head query" 70 | dim_v = dim_out // heads 71 | 72 | self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False) 73 | self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False) 74 | self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False) 75 | self.cor_embed = nn.Conv2d(2, dim_k * heads, 1, bias = True) 76 | self.motion_proj = nn.Conv2d(dim_k * heads, motion_dim, 1, bias = True) 77 | self.channel_compress = nn.Conv2d(dim_out, dim_k * heads, 1, bias = False) 78 | 79 | self.norm_q = nn.BatchNorm2d(dim_k * heads) 80 | self.norm_v = nn.BatchNorm2d(dim_v * dim_u) 81 | 82 | self.local_contexts = exists(r) 83 | if exists(r): 84 | assert (r % 2) == 1, 'Receptive kernel size should be odd' 85 | self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2)) 86 | else: 87 | assert exists(n), 'You must specify the window size (n=h=w)' 88 | rel_lengths = 2 * n - 1 89 | self.rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k, dim_u)) 90 | self.rel_pos = calc_rel_pos(n) 91 | 92 | def forward(self, x, cor): 93 | x = rearrange(x, 'b h w c -> b c h w') 94 | cor = rearrange(cor, 'b h w c -> b c h w') 95 | b, c, hh, ww, u, h = *x.shape, self.u, self.heads 96 | x_reverse = torch.cat([x[b//2:], x[:b//2]]) 97 | 98 | q = self.to_q(x) 99 | k = self.to_k(x_reverse) 100 | v = self.to_v(x_reverse) 101 | 102 | Q = self.norm_q(q) 103 | V = self.norm_v(v) 104 | 105 | Q = rearrange(Q, 'b (h k) hh ww -> b h k (hh ww)', h = h) 106 | k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u) 107 | V = rearrange(V, 'b (u v) hh ww -> b u v (hh ww)', u = u) 108 | 109 | k = k.softmax(dim=-1) 110 | λc = einsum('b u k m, b u v m -> b k v', k, V) 111 | Yc = einsum('b h k n, b k v -> b h v n', Q, λc) 112 | 113 | if self.local_contexts: 114 | V = rearrange(V, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww) 115 | λp = self.pos_conv(V) 116 | Yp = einsum('b h k n, b k v n -> b h v n', Q, λp.flatten(3)) 117 | else: 118 | n, m = self.rel_pos.unbind(dim = -1) 119 | rel_pos_emb = self.rel_pos_emb[n, m] 120 | λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, V) 121 | Yp = einsum('b h k n, b n k v -> b h v n', Q, λp) 122 | 123 | Y = Yc + Yp 124 | appearance = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww) 125 | 126 | cor_embed_ = self.cor_embed(cor) 127 | cor_embed = rearrange(cor_embed_, 'b (h k) hh ww -> b h k (hh ww)', h = h) 128 | cor_reverse_c = einsum('b h k n, b k v -> b h v n', cor_embed, λc) 129 | if self.local_contexts: 130 | cor_reverse_p = einsum('b h k n, b k v n -> b h v n', cor_embed, λp.flatten(3)) 131 | else: 132 | cor_reverse_p = einsum('b h k n, b n k v -> b h v n', cor_embed, λp) 133 | cor_reverse_ = rearrange(cor_reverse_c + cor_reverse_p, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww) 134 | cor_reverse = self.channel_compress(cor_reverse_) 135 | motion = self.motion_proj(cor_reverse - cor_embed_) 136 | 137 | appearance = rearrange(appearance, 'b c h w -> b h w c') 138 | motion = rearrange(motion, 'b c h w -> b h w c') 139 | 140 | return appearance, motion 141 | 142 | 143 | class StructFormerBlock(nn.Module): 144 | def __init__(self, dim, dim_out, dim_k, dim_u, n, r, motion_dim, heads, range, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 145 | super().__init__() 146 | if range == 'global': 147 | self.Lambda = InterFrameLambda(dim = dim, 148 | dim_k = dim_k, 149 | n = n, 150 | heads = heads, 151 | dim_out = dim_out, 152 | dim_u = dim_u, 153 | motion_dim = motion_dim) 154 | elif range == 'local': 155 | self.Lambda = InterFrameLambda(dim = dim, 156 | dim_k = dim_k, 157 | r = r, 158 | heads = heads, 159 | dim_out = dim_out, 160 | dim_u = dim_u, 161 | motion_dim = motion_dim) 162 | else: 163 | print("range不符合规范, 取'global'或'local'!") 164 | sys.exit() 165 | self.norm1 = norm_layer(dim) 166 | self.norm2 = norm_layer(dim) 167 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 168 | mlp_hidden_dim = int(dim * mlp_ratio) 169 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 170 | self.apply(self._init_weights) 171 | 172 | def _init_weights(self, m): 173 | if isinstance(m, nn.Linear): 174 | trunc_normal_(m.weight, std=.02) 175 | if isinstance(m, nn.Linear) and m.bias is not None: 176 | nn.init.constant_(m.bias, 0) 177 | elif isinstance(m, nn.LayerNorm): 178 | nn.init.constant_(m.bias, 0) 179 | nn.init.constant_(m.weight, 1.0) 180 | elif isinstance(m, nn.Conv2d): 181 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 182 | fan_out //= m.groups 183 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 184 | if m.bias is not None: 185 | m.bias.data.zero_() 186 | 187 | def forward(self, x, cor, H, W, B): 188 | x_norm1 = self.norm1(x) 189 | x = x_norm1.view(2*B, H, W, -1) 190 | x_appearence, x_motion = self.Lambda(x, cor) 191 | x_appearence = rearrange(x_appearence, 'b h w c -> b (h w) c') 192 | x_motion = rearrange(x_motion, 'b h w c -> b (h w) c') 193 | x_appearence = x_norm1 + self.drop_path(x_appearence) 194 | x_appearence = x_appearence + self.drop_path(self.mlp(self.norm2(x_appearence), H, W)) 195 | 196 | return x_appearence, x_motion 197 | 198 | 199 | class ConvBlock(nn.Module): 200 | def __init__(self, in_dim, out_dim, depths=2,act_layer=nn.PReLU): 201 | super().__init__() 202 | layers = [] 203 | for i in range(depths): 204 | if i == 0: 205 | layers.append(nn.Conv2d(in_dim, out_dim, 3,1,1)) 206 | else: 207 | layers.append(nn.Conv2d(out_dim, out_dim, 3,1,1)) 208 | layers.extend([ 209 | act_layer(out_dim), 210 | ]) 211 | self.conv = nn.Sequential(*layers) 212 | 213 | def _init_weights(self, m): 214 | if isinstance(m, nn.Conv2d): 215 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 216 | fan_out //= m.groups 217 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 218 | if m.bias is not None: 219 | m.bias.data.zero_() 220 | 221 | def forward(self, x): 222 | x = self.conv(x) 223 | return x 224 | 225 | 226 | class OverlapPatchEmbed(nn.Module): 227 | def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): 228 | super().__init__() 229 | patch_size = to_2tuple(patch_size) 230 | 231 | self.patch_size = patch_size 232 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 233 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 234 | self.norm = nn.LayerNorm(embed_dim) 235 | 236 | self.apply(self._init_weights) 237 | 238 | def _init_weights(self, m): 239 | if isinstance(m, nn.Linear): 240 | trunc_normal_(m.weight, std=.02) 241 | if isinstance(m, nn.Linear) and m.bias is not None: 242 | nn.init.constant_(m.bias, 0) 243 | elif isinstance(m, nn.LayerNorm): 244 | nn.init.constant_(m.bias, 0) 245 | nn.init.constant_(m.weight, 1.0) 246 | elif isinstance(m, nn.Conv2d): 247 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 248 | fan_out //= m.groups 249 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 250 | if m.bias is not None: 251 | m.bias.data.zero_() 252 | 253 | def forward(self, x): 254 | x = self.proj(x) 255 | _, _, H, W = x.shape 256 | x = x.flatten(2).transpose(1, 2) 257 | x = self.norm(x) 258 | 259 | return x, H, W 260 | 261 | 262 | class CrossScalePatchEmbed(nn.Module): 263 | def __init__(self, in_dims=[16,32,64], embed_dim=768): 264 | super().__init__() 265 | base_dim = in_dims[0] 266 | 267 | layers = [] 268 | for i in range(len(in_dims)): 269 | for j in range(2 ** i): 270 | layers.append(nn.Conv2d(in_dims[-1-i], base_dim, 3, 2**(i+1), 1+j, 1+j)) 271 | self.layers = nn.ModuleList(layers) 272 | self.proj = nn.Conv2d(base_dim * len(layers), embed_dim, 1, 1) 273 | self.norm = nn.LayerNorm(embed_dim) 274 | 275 | self.apply(self._init_weights) 276 | 277 | def _init_weights(self, m): 278 | if isinstance(m, nn.Linear): 279 | trunc_normal_(m.weight, std=.02) 280 | if isinstance(m, nn.Linear) and m.bias is not None: 281 | nn.init.constant_(m.bias, 0) 282 | elif isinstance(m, nn.LayerNorm): 283 | nn.init.constant_(m.bias, 0) 284 | nn.init.constant_(m.weight, 1.0) 285 | elif isinstance(m, nn.Conv2d): 286 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 287 | fan_out //= m.groups 288 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 289 | if m.bias is not None: 290 | m.bias.data.zero_() 291 | 292 | def forward(self, xs): 293 | ys = [] 294 | k = 0 295 | for i in range(len(xs)): 296 | for _ in range(2 ** i): 297 | ys.append(self.layers[k](xs[-1-i])) 298 | k += 1 299 | 300 | x = self.proj(torch.cat(ys,1)) 301 | 302 | _, _, H, W = x.shape 303 | x = x.flatten(2).transpose(1, 2) 304 | 305 | x = self.norm(x) 306 | 307 | return x, H, W 308 | 309 | 310 | class StructFormer(nn.Module): 311 | def __init__(self, in_chans=3, embed_dims=[32, 64, 128, 256, 512], motion_dims=64, num_heads=[8, 16], 312 | mlp_ratios=[4, 4], lambda_global_or_local = 'local', lambda_dim_k = 16, lambda_dim_u = 1, 313 | lambda_n = 32, lambda_r = None, drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 314 | depths=[2, 2, 2, 4, 4], **kwarg): 315 | super().__init__() 316 | self.depths = depths 317 | self.num_stages = len(embed_dims) 318 | 319 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 320 | cur = 0 321 | 322 | self.conv_stages = self.num_stages - len(num_heads) 323 | 324 | for i in range(self.num_stages): 325 | if i == 0: 326 | block = ConvBlock(in_chans,embed_dims[i],depths[i]) 327 | else: 328 | if i < self.conv_stages: 329 | patch_embed = nn.Sequential( 330 | nn.Conv2d(embed_dims[i-1], embed_dims[i], 3,2,1), 331 | nn.PReLU(embed_dims[i]) 332 | ) 333 | block = ConvBlock(embed_dims[i],embed_dims[i],depths[i]) 334 | else: 335 | if i == self.conv_stages: 336 | patch_embed = CrossScalePatchEmbed(embed_dims[:i], 337 | embed_dim=embed_dims[i]) 338 | block = nn.ModuleList([StructFormerBlock( 339 | dim=embed_dims[i], dim_out=embed_dims[i], dim_k=lambda_dim_k, dim_u=lambda_dim_u, n=lambda_n, r=lambda_r, 340 | motion_dim=motion_dims[i], heads=num_heads[i-self.conv_stages], range = lambda_global_or_local, mlp_ratio=mlp_ratios[i-self.conv_stages], 341 | drop=drop_rate, drop_path=dpr[cur + j]) 342 | for j in range(depths[i])]) 343 | else: 344 | patch_embed = OverlapPatchEmbed(patch_size=3, 345 | stride=2, 346 | in_chans=embed_dims[i - 1], 347 | embed_dim=embed_dims[i]) 348 | block = nn.ModuleList([StructFormerBlock( 349 | dim=embed_dims[i], dim_out=embed_dims[i], dim_k=lambda_dim_k, dim_u=lambda_dim_u, n=int(lambda_n/2), r=lambda_r, 350 | motion_dim=motion_dims[i], heads=num_heads[i-self.conv_stages], range = lambda_global_or_local, mlp_ratio=mlp_ratios[i-self.conv_stages], 351 | drop=drop_rate, drop_path=dpr[cur + j]) 352 | for j in range(depths[i])]) 353 | 354 | norm = norm_layer(embed_dims[i]) 355 | setattr(self, f"norm{i + 1}", norm) 356 | 357 | setattr(self, f"patch_embed{i + 1}", patch_embed) 358 | cur += depths[i] 359 | 360 | setattr(self, f"block{i + 1}", block) 361 | 362 | self.cor = {} 363 | 364 | self.apply(self._init_weights) 365 | 366 | def _init_weights(self, m): 367 | if isinstance(m, nn.Linear): 368 | trunc_normal_(m.weight, std=.02) 369 | if isinstance(m, nn.Linear) and m.bias is not None: 370 | nn.init.constant_(m.bias, 0) 371 | elif isinstance(m, nn.LayerNorm): 372 | nn.init.constant_(m.bias, 0) 373 | nn.init.constant_(m.weight, 1.0) 374 | elif isinstance(m, nn.Conv2d): 375 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 376 | fan_out //= m.groups 377 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 378 | if m.bias is not None: 379 | m.bias.data.zero_() 380 | 381 | def get_cor(self, shape, device): 382 | k = (str(shape), str(device)) 383 | if k not in self.cor: 384 | tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view( 385 | 1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1) 386 | tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view( 387 | 1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1) 388 | self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device) 389 | return self.cor[k] 390 | 391 | def forward(self, x1, x2): 392 | B = x1.shape[0] 393 | 394 | x = torch.cat([x1, x2], 0) 395 | 396 | motion_features = [] 397 | appearence_features = [] 398 | xs = [] 399 | 400 | for i in range(self.num_stages): 401 | motion_features.append([]) 402 | patch_embed = getattr(self, f"patch_embed{i + 1}",None) 403 | block = getattr(self, f"block{i + 1}",None) 404 | norm = getattr(self, f"norm{i + 1}",None) 405 | if i < self.conv_stages: 406 | if i > 0: 407 | x = patch_embed(x) 408 | x = block(x) 409 | xs.append(x) 410 | else: 411 | if i == self.conv_stages: 412 | x, H, W = patch_embed(xs) 413 | else: 414 | x, H, W = patch_embed(x) 415 | 416 | cor = self.get_cor((x.shape[0], H, W), x.device) 417 | 418 | for blk in block: 419 | x, x_motion = blk(x, cor, H, W, B) 420 | 421 | motion_features[i].append(x_motion.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous()) 422 | 423 | x = norm(x) 424 | x = x.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous() 425 | motion_features[i] = torch.cat(motion_features[i], 1) 426 | 427 | appearence_features.append(x) 428 | return appearence_features, motion_features 429 | 430 | 431 | class DWConv(nn.Module): 432 | def __init__(self, dim): 433 | super(DWConv, self).__init__() 434 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 435 | 436 | def forward(self, x, H, W): 437 | B, N, C = x.shape 438 | x = x.transpose(1, 2).reshape(B, C, H, W) 439 | x = self.dwconv(x) 440 | x = x.reshape(B, C, -1).transpose(1, 2) 441 | 442 | return x 443 | 444 | 445 | def encoder(**kargs): 446 | model = StructFormer(**kargs) 447 | return model 448 | --------------------------------------------------------------------------------