├── sundry ├── 1f379.gif ├── 1f3ac.gif ├── 1f43c.gif ├── 1f4a1.gif ├── 1f4f0.gif ├── 1f52c.gif ├── 1f9f3.gif ├── 听诊器.gif ├── 小提琴.gif ├── 调色盘.gif ├── cheers.gif └── 1f5c2-fe0f.gif ├── demo_images ├── DSA_1.png └── DSA_2.png ├── model ├── __init__.py ├── warplayer.py ├── refine.py ├── loss.py ├── decoder.py ├── vgg19_losses.py └── encoder.py ├── utils ├── padder.py ├── yuv_frame_io.py └── pytorch_msssim.py ├── config.py ├── GenDSA_env.txt ├── Simple_Interpolator.py ├── README.md ├── LICENSE └── Trainer.py /sundry/1f379.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f379.gif -------------------------------------------------------------------------------- /sundry/1f3ac.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f3ac.gif -------------------------------------------------------------------------------- /sundry/1f43c.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f43c.gif -------------------------------------------------------------------------------- /sundry/1f4a1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f4a1.gif -------------------------------------------------------------------------------- /sundry/1f4f0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f4f0.gif -------------------------------------------------------------------------------- /sundry/1f52c.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f52c.gif -------------------------------------------------------------------------------- /sundry/1f9f3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f9f3.gif -------------------------------------------------------------------------------- /sundry/听诊器.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/听诊器.gif -------------------------------------------------------------------------------- /sundry/小提琴.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/小提琴.gif -------------------------------------------------------------------------------- /sundry/调色盘.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/调色盘.gif -------------------------------------------------------------------------------- /sundry/cheers.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/cheers.gif -------------------------------------------------------------------------------- /demo_images/DSA_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/demo_images/DSA_1.png -------------------------------------------------------------------------------- /demo_images/DSA_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/demo_images/DSA_2.png -------------------------------------------------------------------------------- /sundry/1f5c2-fe0f.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZyoungXu/GenDSA/HEAD/sundry/1f5c2-fe0f.gif -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import encoder 2 | from .decoder import decoder 3 | 4 | 5 | __all__ = ['encoder', 'decoder'] 6 | -------------------------------------------------------------------------------- /utils/padder.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | class InputPadder: 5 | def __init__(self, dims, divisor = 16): 6 | self.ht, self.wd = dims[-2:] 7 | pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor 8 | pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor 9 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 10 | 11 | def pad(self, *inputs): 12 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 13 | 14 | def unpad(self,x): 15 | ht, wd = x.shape[-2:] 16 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 17 | return x[..., c[0]:c[1], c[2]:c[3]] 18 | 19 | -------------------------------------------------------------------------------- /model/warplayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | backwarp_tenGrid = {} 5 | 6 | def warp(tenInput, tenFlow): 7 | k = (str(tenFlow.device), str(tenFlow.size())) 8 | if k not in backwarp_tenGrid: 9 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 10 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 11 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 12 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 13 | backwarp_tenGrid[k] = torch.cat( 14 | [tenHorizontal, tenVertical], 1).to(device) 15 | 16 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 17 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 18 | 19 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 20 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 21 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch.nn as nn 3 | from model import encoder 4 | from model import decoder 5 | 6 | 7 | def init_model_config(F=32, lambda_range='local', depth=[2, 2, 2, 4]): 8 | return { 9 | 'embed_dims':[F, 2*F, 4*F, 8*F], 10 | 'motion_dims':[0, 0, 0, 8*F//depth[-1]], 11 | 'num_heads':[4], 12 | 'mlp_ratios':[4], 13 | 'lambda_global_or_local': lambda_range, 14 | 'lambda_dim_k':16, 15 | 'lambda_dim_u':1, 16 | 'lambda_n':32, 17 | 'lambda_r':15, 18 | 'norm_layer':partial(nn.LayerNorm, eps=1e-6), 19 | 'depths':depth, 20 | }, { 21 | 'embed_dims':[F, 2*F, 4*F, 8*F], 22 | 'motion_dims':[0, 0, 0, 8*F//depth[-1]], 23 | 'depths':depth, 24 | 'scales':[4, 8], 25 | 'hidden_dims':[4*F], 26 | 'c':F 27 | } 28 | 29 | 30 | MODEL_CONFIG = { 31 | 'LOGNAME': 'GenDSA', 32 | 'MODEL_TYPE': (encoder, decoder), 33 | 'MODEL_ARCH': init_model_config( 34 | F = 32, 35 | lambda_range='local', 36 | depth = [2, 2, 2, 4] 37 | ) 38 | } 39 | -------------------------------------------------------------------------------- /GenDSA_env.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | blessed==1.20.0 3 | cachetools==5.3.1 4 | certifi==2022.12.7 5 | charset-normalizer==3.1.0 6 | contourpy==1.1.1 7 | cycler==0.12.1 8 | DateTime==5.1 9 | einops==0.6.1 10 | filelock==3.10.0 11 | fonttools==4.49.0 12 | google-auth==2.20.0 13 | google-auth-oauthlib==1.0.0 14 | gpustat==1.0.0 15 | grpcio==1.56.0 16 | huggingface-hub==0.13.3 17 | idna==3.4 18 | imageio==2.26.1 19 | importlib-metadata==6.7.0 20 | importlib-resources==6.1.1 21 | kiwisolver==1.4.5 22 | Markdown==3.4.3 23 | MarkupSafe==2.1.3 24 | matplotlib==3.7.5 25 | natsort==8.3.1 26 | networkx==3.0 27 | numpy==1.24.3 28 | nvidia-ml-py==11.495.46 29 | oauthlib==3.2.2 30 | opencv-python==4.6.0.66 31 | packaging==23.0 32 | pandas==2.0.2 33 | Pillow==9.4.0 34 | protobuf==4.23.3 35 | psutil==5.9.4 36 | pyasn1==0.5.0 37 | pyasn1-modules==0.3.0 38 | pydicom==2.4.4 39 | pyparsing==3.1.1 40 | python-dateutil==2.8.2 41 | pytz==2023.3 42 | PyWavelets==1.4.1 43 | PyYAML==6.0 44 | requests==2.28.2 45 | requests-oauthlib==1.3.1 46 | rsa==4.9 47 | scikit-image==0.19.2 48 | scipy==1.10.1 49 | setuptools==58.5.2 50 | six==1.16.0 51 | tensorboard==2.13.0 52 | tensorboard-data-server==0.7.1 53 | tensorboard-plugin-wit==1.8.1 54 | tifffile==2023.3.21 55 | timm==0.6.11 56 | torch==1.9.0+cu111 57 | torch-tb-profiler==0.4.1 58 | torchvision==0.10.0+cu111 59 | tqdm==4.65.0 60 | typing_extensions==4.5.0 61 | tzdata==2023.3 62 | urllib3==1.26.15 63 | wcwidth==0.2.6 64 | Werkzeug==2.3.6 65 | wheel==0.41.2 66 | zipp==3.15.0 67 | zope.interface==6.0 -------------------------------------------------------------------------------- /Simple_Interpolator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import fnmatch 4 | import numpy as np 5 | import cv2 6 | import math 7 | import torch 8 | import argparse 9 | import config as cfg 10 | from Trainer import Model 11 | from utils.padder import InputPadder 12 | 13 | 14 | def run_interpolator(model, Frame1, Frame2, time_list, Output_Frames_list, TTA = True): 15 | I0 = cv2.imread(Frame1) 16 | I2 = cv2.imread(Frame2) 17 | 18 | I0_ = (torch.tensor(I0.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) 19 | I2_ = (torch.tensor(I2.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) 20 | 21 | padder = InputPadder(I0_.shape, divisor=32) 22 | I0_, I2_ = padder.pad(I0_, I2_) 23 | 24 | preds = model.multi_inference(I0_, I2_, TTA = TTA, time_list = time_list, fast_TTA = TTA) 25 | 26 | for pred, Output_Frame in zip(preds, Output_Frames_list): 27 | mid_image = (padder.unpad(pred).detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1] 28 | cv2.imwrite(Output_Frame, mid_image) 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--model_path', type=str, default='./weights/checkpoints/3D-vas-Inf1.pkl') 34 | parser.add_argument('--frame1', type=str, default='./demo_images/DSA_1.png') 35 | parser.add_argument('--frame2', type=str, default='./demo_images/DSA_2.png') 36 | parser.add_argument('--inter_frames', type=int, default=1) 37 | args = parser.parse_args() 38 | 39 | model_path = args.model_path 40 | inf_folder_path = os.path.dirname(args.frame1) 41 | Interframe_num = args.inter_frames 42 | 43 | if Interframe_num == 1: 44 | TimeStepList = [0.5] 45 | Inter_Frames_list = [inf_folder_path + "//" + 'InferImage.png'] 46 | elif Interframe_num == 2: 47 | TimeStepList = [0.3333333333333333, 0.6666666666666667] 48 | Inter_Frames_list = [inf_folder_path + "//" + 'InferImage_1_in_2.png', 49 | inf_folder_path + "//" + 'InferImage_2_in_2.png'] 50 | elif Interframe_num == 3: 51 | TimeStepList = [0.25, 0.50, 0.75] 52 | Inter_Frames_list = [inf_folder_path + "//" + 'InferImage_1_in_3.png', 53 | inf_folder_path + "//" + 'InferImage_2_in_3.png', 54 | inf_folder_path + "//" + 'InferImage_3_in_3.png'] 55 | else: 56 | print("'inter_frames' invalid. Currently, 1, 2, and 3 frames are supported. You can also try training a model that interpolates more frames.") 57 | sys.exit() 58 | 59 | TTA = True 60 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 61 | F = 32, 62 | lambda_range='local', 63 | depth = [2, 2, 2, 4] 64 | ) 65 | 66 | model = Model(-1) 67 | model.load_model(full_path = model_path) 68 | model.eval() 69 | model.device() 70 | 71 | run_interpolator(model, args.frame1, args.frame2, time_list = TimeStepList, Output_Frames_list = Inter_Frames_list, TTA = TTA) 72 | -------------------------------------------------------------------------------- /model/refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from timm.models.layers import trunc_normal_ 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | 9 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 10 | return nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 12 | padding=padding, dilation=dilation, bias=True), 13 | nn.PReLU(out_planes) 14 | ) 15 | 16 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 17 | return nn.Sequential( 18 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), 19 | nn.PReLU(out_planes) 20 | ) 21 | 22 | class Conv2(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride=2): 24 | super(Conv2, self).__init__() 25 | self.conv1 = conv(in_planes, out_planes, 3, stride, 1) 26 | self.conv2 = conv(out_planes, out_planes, 3, 1, 1) 27 | 28 | def forward(self, x): 29 | x = self.conv1(x) 30 | x = self.conv2(x) 31 | return x 32 | 33 | class Unet(nn.Module): 34 | def __init__(self, c, out=3): 35 | super(Unet, self).__init__() 36 | self.down0 = Conv2(17+c, 2*c) 37 | self.down1 = Conv2(4*c, 4*c) 38 | self.down2 = Conv2(8*c, 8*c) 39 | self.down3 = Conv2(16*c, 16*c) 40 | self.up0 = deconv(32*c, 8*c) 41 | self.supple = Conv2(4*c, 8*c) 42 | 43 | self.up1 = deconv(16*c, 4*c) 44 | self.up2 = deconv(8*c, 2*c) 45 | self.up3 = deconv(4*c, c) 46 | self.conv = nn.Conv2d(c, out, 3, 1, 1) 47 | self.apply(self._init_weights) 48 | 49 | def _init_weights(self, m): 50 | if isinstance(m, nn.Linear): 51 | trunc_normal_(m.weight, std=.02) 52 | if isinstance(m, nn.Linear) and m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | elif isinstance(m, nn.LayerNorm): 55 | nn.init.constant_(m.bias, 0) 56 | nn.init.constant_(m.weight, 1.0) 57 | elif isinstance(m, nn.Conv2d): 58 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 59 | fan_out //= m.groups 60 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 61 | if m.bias is not None: 62 | m.bias.data.zero_() 63 | 64 | def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): 65 | s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow,c0[0], c1[0]), 1)) 66 | s1 = self.down1(torch.cat((s0, c0[1], c1[1]), 1)) 67 | s2 = self.down2(torch.cat((s1, c0[2], c1[2]), 1)) 68 | s3 = self.down3(torch.cat((s2, c0[3], c1[3]), 1)) 69 | c0_4 = self.supple(c0[3]) 70 | c1_4 = self.supple(c1[3]) 71 | x = self.up0(torch.cat((s3, c0_4, c1_4), 1)) 72 | 73 | x = self.up1(torch.cat((x, s2), 1)) 74 | x = self.up2(torch.cat((x, s1), 1)) 75 | x = self.up3(torch.cat((x, s0), 1)) 76 | x = self.conv(x) 77 | return torch.sigmoid(x) 78 | -------------------------------------------------------------------------------- /utils/yuv_frame_io.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import getopt 3 | import math 4 | import numpy 5 | import random 6 | import logging 7 | import numpy as np 8 | from skimage.color import rgb2yuv, yuv2rgb 9 | from PIL import Image 10 | import os 11 | from shutil import copyfile 12 | 13 | 14 | class YUV_Read(): 15 | def __init__(self, filepath, h, w, format='yuv420', toRGB=True): 16 | 17 | self.h = h 18 | self.w = w 19 | 20 | self.fp = open(filepath, 'rb') 21 | 22 | if format == 'yuv420': 23 | self.frame_length = int(1.5 * h * w) 24 | self.Y_length = h * w 25 | self.Uv_length = int(0.25 * h * w) 26 | else: 27 | pass 28 | self.toRGB = toRGB 29 | 30 | def read(self, offset_frame=None): 31 | if not offset_frame == None: 32 | self.fp.seek(offset_frame * self.frame_length, 0) 33 | 34 | Y = np.fromfile(self.fp, np.uint8, count=self.Y_length) 35 | U = np.fromfile(self.fp, np.uint8, count=self.Uv_length) 36 | V = np.fromfile(self.fp, np.uint8, count=self.Uv_length) 37 | if Y.size < self.Y_length or \ 38 | U.size < self.Uv_length or \ 39 | V.size < self.Uv_length: 40 | return None, False 41 | 42 | Y = np.reshape(Y, [self.w, self.h], order='F') 43 | Y = np.transpose(Y) 44 | 45 | U = np.reshape(U, [int(self.w / 2), int(self.h / 2)], order='F') 46 | U = np.transpose(U) 47 | 48 | V = np.reshape(V, [int(self.w / 2), int(self.h / 2)], order='F') 49 | V = np.transpose(V) 50 | 51 | U = np.array(Image.fromarray(U).resize([self.w, self.h])) 52 | V = np.array(Image.fromarray(V).resize([self.w, self.h])) 53 | 54 | if self.toRGB: 55 | Y = Y / 255.0 56 | U = U / 255.0 - 0.5 57 | V = V / 255.0 - 0.5 58 | 59 | self.YUV = np.stack((Y, U, V), axis=-1) 60 | self.RGB = (255.0 * np.clip(yuv2rgb(self.YUV), 0.0, 1.0)).astype('uint8') 61 | 62 | self.YUV = None 63 | return self.RGB, True 64 | else: 65 | self.YUV = np.stack((Y, U, V), axis=-1) 66 | return self.YUV, True 67 | 68 | def close(self): 69 | self.fp.close() 70 | 71 | 72 | class YUV_Write(): 73 | def __init__(self, filepath, fromRGB=True): 74 | if os.path.exists(filepath): 75 | print(filepath) 76 | 77 | self.fp = open(filepath, 'wb') 78 | self.fromRGB = fromRGB 79 | 80 | def write(self, Frame): 81 | 82 | self.h = Frame.shape[0] 83 | self.w = Frame.shape[1] 84 | c = Frame.shape[2] 85 | 86 | assert c == 3 87 | if format == 'yuv420': 88 | self.frame_length = int(1.5 * self.h * self.w) 89 | self.Y_length = self.h * self.w 90 | self.Uv_length = int(0.25 * self.h * self.w) 91 | else: 92 | pass 93 | if self.fromRGB: 94 | Frame = Frame / 255.0 95 | YUV = rgb2yuv(Frame) 96 | Y, U, V = np.dsplit(YUV, 3) 97 | Y = Y[:, :, 0] 98 | U = U[:, :, 0] 99 | V = V[:, :, 0] 100 | U = np.clip(U + 0.5, 0.0, 1.0) 101 | V = np.clip(V + 0.5, 0.0, 1.0) 102 | 103 | U = U[::2, ::2] 104 | V = V[::2, ::2] 105 | Y = (255.0 * Y).astype('uint8') 106 | U = (255.0 * U).astype('uint8') 107 | V = (255.0 * V).astype('uint8') 108 | else: 109 | YUV = Frame 110 | Y = YUV[:, :, 0] 111 | U = YUV[::2, ::2, 1] 112 | V = YUV[::2, ::2, 2] 113 | 114 | Y = Y.flatten() 115 | U = U.flatten() 116 | V = V.flatten() 117 | 118 | Y.tofile(self.fp) 119 | U.tofile(self.fp) 120 | V.tofile(self.fp) 121 | 122 | return True 123 | 124 | def close(self): 125 | self.fp.close() 126 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from . import vgg19_losses as vgg19 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def gauss_kernel(channels=3): 11 | kernel = torch.tensor([[1., 4., 6., 4., 1], 12 | [4., 16., 24., 16., 4.], 13 | [6., 24., 36., 24., 6.], 14 | [4., 16., 24., 16., 4.], 15 | [1., 4., 6., 4., 1.]]) 16 | kernel /= 256. 17 | kernel = kernel.repeat(channels, 1, 1, 1) 18 | kernel = kernel.to(device) 19 | return kernel 20 | 21 | def downsample(x): 22 | return x[:, :, ::2, ::2] 23 | 24 | def upsample(x): 25 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) 26 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) 27 | cc = cc.permute(0,1,3,2) 28 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3) 29 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2) 30 | x_up = cc.permute(0,1,3,2) 31 | return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1])) 32 | 33 | def conv_gauss(img, kernel): 34 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') 35 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) 36 | return out 37 | 38 | def laplacian_pyramid(img, kernel, max_levels=3): 39 | current = img 40 | pyr = [] 41 | for level in range(max_levels): 42 | filtered = conv_gauss(current, kernel) 43 | down = downsample(filtered) 44 | up = upsample(down) 45 | diff = current-up 46 | pyr.append(diff) 47 | current = down 48 | return pyr 49 | 50 | class LapLoss(torch.nn.Module): 51 | def __init__(self, max_levels=5, channels=3): 52 | super(LapLoss, self).__init__() 53 | self.max_levels = max_levels 54 | self.gauss_kernel = gauss_kernel(channels=channels) 55 | 56 | def forward(self, input, target): 57 | pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) 58 | pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) 59 | return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) 60 | 61 | class Ternary(nn.Module): 62 | def __init__(self, device): 63 | super(Ternary, self).__init__() 64 | patch_size = 7 65 | out_channels = patch_size * patch_size 66 | self.w = np.eye(out_channels).reshape( 67 | (patch_size, patch_size, 1, out_channels)) 68 | self.w = np.transpose(self.w, (3, 2, 0, 1)) 69 | self.w = torch.tensor(self.w).float().to(device) 70 | 71 | def transform(self, img): 72 | patches = F.conv2d(img, self.w, padding=3, bias=None) 73 | transf = patches - img 74 | transf_norm = transf / torch.sqrt(0.81 + transf**2) 75 | return transf_norm 76 | 77 | def rgb2gray(self, rgb): 78 | r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] 79 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 80 | return gray 81 | 82 | def hamming(self, t1, t2): 83 | dist = (t1 - t2) ** 2 84 | dist_norm = torch.mean(dist / (0.1 + dist), 1, True) 85 | return dist_norm 86 | 87 | def valid_mask(self, t, padding): 88 | n, _, h, w = t.size() 89 | inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) 90 | mask = F.pad(inner, [padding] * 4) 91 | return mask 92 | 93 | def forward(self, img0, img1): 94 | img0 = self.transform(self.rgb2gray(img0)) 95 | img1 = self.transform(self.rgb2gray(img1)) 96 | return self.hamming(img0, img1) * self.valid_mask(img0, 1) 97 | 98 | class Perceptual_Loss(torch.nn.Module): 99 | def __init__(self): 100 | super(Perceptual_Loss, self).__init__() 101 | 102 | def forward(self, image: torch.Tensor, reference: torch.Tensor, vgg_model_file: str, weights = None,): 103 | return vgg19.perceptual_loss(image, reference, vgg_model_file, weights) 104 | 105 | 106 | class Style_Loss(torch.nn.Module): 107 | def __init__(self): 108 | super(Style_Loss, self).__init__() 109 | 110 | def forward(self, image: torch.Tensor, reference: torch.Tensor, vgg_model_file: str, weights = None,): 111 | return vgg19.style_loss(image, reference, vgg_model_file, weights) 112 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | from .refine import * 7 | 8 | 9 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 10 | return nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 12 | padding=padding, dilation=dilation, bias=True), 13 | nn.PReLU(out_planes) 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self, in_planes, scale, c, in_else=17): 19 | super(Head, self).__init__() 20 | self.upsample = nn.Sequential(nn.PixelShuffle(2), nn.PixelShuffle(2)) 21 | self.scale = scale 22 | self.conv = nn.Sequential( 23 | conv(in_planes*2 // (4*4) + in_else, c), 24 | conv(c, c), 25 | conv(c, 5), 26 | ) 27 | 28 | def forward(self, motion_feature, x, flow): 29 | motion_feature = self.upsample(motion_feature) 30 | 31 | if self.scale != 4: 32 | x = F.interpolate(x, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) 33 | 34 | if flow != None: 35 | if self.scale != 4: 36 | flow = F.interpolate(flow, scale_factor = 4. / self.scale, mode="bilinear", align_corners=False) * 4. / self.scale 37 | x = torch.cat((x, flow), 1) 38 | 39 | x = self.conv(torch.cat([motion_feature, x], 1)) 40 | 41 | if self.scale != 4: 42 | x = F.interpolate(x, scale_factor = self.scale // 4, mode="bilinear", align_corners=False) 43 | flow = x[:, :4] * (self.scale // 4) 44 | else: 45 | flow = x[:, :4] 46 | 47 | mask = x[:, 4:5] 48 | 49 | return flow, mask 50 | 51 | 52 | class decoder(nn.Module): 53 | def __init__(self, backbone, **kargs): 54 | super(decoder, self).__init__() 55 | self.flow_num_stage = len(kargs['hidden_dims']) 56 | self.feature_bone = backbone 57 | self.block = nn.ModuleList([Head( kargs['motion_dims'][-1-i] * kargs['depths'][-1-i] + kargs['embed_dims'][-1-i], 58 | kargs['scales'][-1-i], 59 | kargs['hidden_dims'][-1-i], 60 | 6 if i==0 else 17) 61 | for i in range(self.flow_num_stage)]) 62 | self.unet = Unet(kargs['c'] * 2) 63 | 64 | def warp_features(self, xs, flow): 65 | y0 = [] 66 | y1 = [] 67 | B = xs[0].size(0) // 2 68 | for x in xs: 69 | y0.append(warp(x[:B], flow[:, 0:2])) 70 | y1.append(warp(x[B:], flow[:, 2:4])) 71 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 72 | return y0, y1 73 | 74 | def calculate_flow(self, imgs, timestep, af=None, mf=None): 75 | img0, img1 = imgs[:, :3], imgs[:, 3:6] 76 | B = img0.size(0) 77 | flow, mask = None, None 78 | if (af is None) or (mf is None): 79 | af, mf = self.feature_bone(img0, img1) 80 | for i in range(self.flow_num_stage): 81 | t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda() 82 | if flow != None: 83 | warped_img0 = warp(img0, flow[:, :2]) 84 | warped_img1 = warp(img1, flow[:, 2:4]) 85 | flow_, mask_ = self.block[i]( 86 | torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), 87 | torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), 88 | flow 89 | ) 90 | flow = flow + flow_ 91 | mask = mask + mask_ 92 | else: 93 | flow, mask = self.block[i]( 94 | torch.cat([t*mf[-1-i][:B],(1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), 95 | torch.cat((img0, img1), 1), 96 | None 97 | ) 98 | 99 | return flow, mask 100 | 101 | def coraseWarp_and_Refine(self, imgs, af, flow, mask): 102 | img0, img1 = imgs[:, :3], imgs[:, 3:6] 103 | warped_img0 = warp(img0, flow[:, :2]) 104 | warped_img1 = warp(img1, flow[:, 2:4]) 105 | c0, c1 = self.warp_features(af, flow) 106 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 107 | res = tmp[:, :3] * 2 - 1 108 | mask_ = torch.sigmoid(mask) 109 | merged = warped_img0 * mask_ + warped_img1 * (1 - mask_) 110 | pred = torch.clamp(merged + res, 0, 1) 111 | return pred 112 | 113 | 114 | def forward(self, x, timestep=0.5): 115 | img0, img1 = x[:, :3], x[:, 3:6] 116 | B = x.size(0) 117 | 118 | flow_list = [] 119 | merged = [] 120 | mask_list = [] 121 | warped_img0 = img0 122 | warped_img1 = img1 123 | flow = None 124 | 125 | af, mf = self.feature_bone(img0, img1) 126 | 127 | for i in range(self.flow_num_stage): 128 | t = torch.full(mf[-1-i][:B].shape, timestep, dtype=torch.float).cuda() 129 | 130 | if flow != None: 131 | flow_d, mask_d = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), 132 | torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow) 133 | flow = flow + flow_d 134 | mask = mask + mask_d 135 | else: 136 | flow, mask = self.block[i]( torch.cat([t*mf[-1-i][:B], (1-t)*mf[-1-i][B:],af[-1-i][:B],af[-1-i][B:]],1), 137 | torch.cat((img0, img1), 1), None) 138 | 139 | mask_list.append(torch.sigmoid(mask)) 140 | flow_list.append(flow) 141 | 142 | warped_img0 = warp(img0, flow[:, :2]) 143 | warped_img1 = warp(img1, flow[:, 2:4]) 144 | merged.append(warped_img0 * mask_list[i] + warped_img1 * (1 - mask_list[i])) 145 | 146 | c0, c1 = self.warp_features(af, flow) 147 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 148 | res = tmp[:, :3] * 2 - 1 149 | pred = torch.clamp(merged[-1] + res, 0, 1) 150 | return flow_list, mask_list, merged, pred -------------------------------------------------------------------------------- /utils/pytorch_msssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel=1): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device) 16 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 17 | return window 18 | 19 | 20 | def create_window_3d(window_size, channel=1): 21 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 22 | _2D_window = _1D_window.mm(_1D_window.t()) 23 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) 24 | window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) 25 | return window 26 | 27 | 28 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 29 | if val_range is None: 30 | if torch.max(img1) > 128: 31 | max_val = 255 32 | else: 33 | max_val = 1 34 | 35 | if torch.min(img1) < -0.5: 36 | min_val = -1 37 | else: 38 | min_val = 0 39 | L = max_val - min_val 40 | else: 41 | L = val_range 42 | 43 | padd = 0 44 | (_, channel, height, width) = img1.size() 45 | if window is None: 46 | real_size = min(window_size, height, width) 47 | window = create_window(real_size, channel=channel).to(img1.device) 48 | 49 | mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) 50 | mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=channel) 51 | 52 | mu1_sq = mu1.pow(2) 53 | mu2_sq = mu2.pow(2) 54 | mu1_mu2 = mu1 * mu2 55 | 56 | sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_sq 57 | sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu2_sq 58 | sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), 'replicate'), window, padding=padd, groups=channel) - mu1_mu2 59 | 60 | C1 = (0.01 * L) ** 2 61 | C2 = (0.03 * L) ** 2 62 | 63 | v1 = 2.0 * sigma12 + C2 64 | v2 = sigma1_sq + sigma2_sq + C2 65 | cs = torch.mean(v1 / v2) 66 | 67 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 68 | 69 | if size_average: 70 | ret = ssim_map.mean() 71 | else: 72 | ret = ssim_map.mean(1).mean(1).mean(1) 73 | 74 | if full: 75 | return ret, cs 76 | return ret 77 | 78 | 79 | def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 80 | if val_range is None: 81 | if torch.max(img1) > 128: 82 | max_val = 255 83 | else: 84 | max_val = 1 85 | 86 | if torch.min(img1) < -0.5: 87 | min_val = -1 88 | else: 89 | min_val = 0 90 | L = max_val - min_val 91 | else: 92 | L = val_range 93 | 94 | padd = 0 95 | (_, _, height, width) = img1.size() 96 | if window is None: 97 | real_size = min(window_size, height, width) 98 | window = create_window_3d(real_size, channel=1).to(img1.device) 99 | 100 | img1 = img1.unsqueeze(1) 101 | img2 = img2.unsqueeze(1) 102 | 103 | mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 104 | mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) 105 | 106 | mu1_sq = mu1.pow(2) 107 | mu2_sq = mu2.pow(2) 108 | mu1_mu2 = mu1 * mu2 109 | 110 | sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq 111 | sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq 112 | sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 113 | 114 | C1 = (0.01 * L) ** 2 115 | C2 = (0.03 * L) ** 2 116 | 117 | v1 = 2.0 * sigma12 + C2 118 | v2 = sigma1_sq + sigma2_sq + C2 119 | cs = torch.mean(v1 / v2) 120 | 121 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 122 | 123 | if size_average: 124 | ret = ssim_map.mean() 125 | else: 126 | ret = ssim_map.mean(1).mean(1).mean(1) 127 | 128 | if full: 129 | return ret, cs 130 | return ret 131 | 132 | 133 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 134 | device = img1.device 135 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 136 | levels = weights.size()[0] 137 | mssim = [] 138 | mcs = [] 139 | for _ in range(levels): 140 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 141 | mssim.append(sim) 142 | mcs.append(cs) 143 | 144 | img1 = F.avg_pool2d(img1, (2, 2)) 145 | img2 = F.avg_pool2d(img2, (2, 2)) 146 | 147 | mssim = torch.stack(mssim) 148 | mcs = torch.stack(mcs) 149 | 150 | if normalize: 151 | mssim = (mssim + 1) / 2 152 | mcs = (mcs + 1) / 2 153 | 154 | pow1 = mcs ** weights 155 | pow2 = mssim ** weights 156 | 157 | output = torch.prod(pow1[:-1] * pow2[-1]) 158 | return output 159 | 160 | 161 | 162 | class SSIM(torch.nn.Module): 163 | def __init__(self, window_size=11, size_average=True, val_range=None): 164 | super(SSIM, self).__init__() 165 | self.window_size = window_size 166 | self.size_average = size_average 167 | self.val_range = val_range 168 | 169 | self.channel = 3 170 | self.window = create_window(window_size, channel=self.channel) 171 | 172 | def forward(self, img1, img2): 173 | (_, channel, _, _) = img1.size() 174 | 175 | if channel == self.channel and self.window.dtype == img1.dtype: 176 | window = self.window 177 | else: 178 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 179 | self.window = window 180 | self.channel = channel 181 | 182 | _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 183 | dssim = (1 - _ssim) / 2 184 | return dssim 185 | 186 | class MSSSIM(torch.nn.Module): 187 | def __init__(self, window_size=11, size_average=True, channel=3): 188 | super(MSSSIM, self).__init__() 189 | self.window_size = window_size 190 | self.size_average = size_average 191 | self.channel = channel 192 | 193 | def forward(self, img1, img2): 194 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) -------------------------------------------------------------------------------- /model/vgg19_losses.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple 2 | 3 | import numpy as np 4 | import scipy.io as sio 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def _build_net(layer_type: str, input_tensor: torch.Tensor, 12 | weight_bias: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Callable[[Any], Any]: 13 | if layer_type == 'conv': 14 | return F.relu(F.conv2d(input = input_tensor, weight = weight_bias[0], bias = weight_bias[1], stride=1, padding=0)) 15 | elif layer_type == 'pool': 16 | return F.avg_pool2d(input = input_tensor, kernel_size=2, stride=2) 17 | else: 18 | raise ValueError('Unsupported layer types: %s' % layer_type) 19 | 20 | 21 | def _get_weight_and_bias(vgg_layers: np.ndarray, 22 | index: int) -> Tuple[torch.Tensor, torch.Tensor]: 23 | weights = vgg_layers[index][0][0][2][0][0] 24 | weights = torch.tensor(weights).permute(3, 2, 0, 1).to(device) 25 | bias = vgg_layers[index][0][0][2][0][1] 26 | bias = torch.tensor(np.reshape(bias, bias.size)).to(device) 27 | 28 | return weights, bias 29 | 30 | 31 | def _build_vgg19(image: torch.Tensor, model_filepath: str) -> Dict[str, torch.Tensor]: 32 | net = {} 33 | if not hasattr(_build_vgg19, 'vgg_rawnet'): 34 | _build_vgg19.vgg_rawnet = sio.loadmat(model_filepath) 35 | vgg_layers = _build_vgg19.vgg_rawnet['layers'][0] 36 | imagenet_mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape(1, 3, 1, 1).to(device) 37 | net['input'] = image - imagenet_mean 38 | net['conv1_1'] = _build_net( 39 | 'conv', 40 | net['input'], 41 | _get_weight_and_bias(vgg_layers, 0)) 42 | net['conv1_2'] = _build_net( 43 | 'conv', 44 | net['conv1_1'], 45 | _get_weight_and_bias(vgg_layers, 2)) 46 | net['pool1'] = _build_net('pool', net['conv1_2']) 47 | net['conv2_1'] = _build_net( 48 | 'conv', 49 | net['pool1'], 50 | _get_weight_and_bias(vgg_layers, 5)) 51 | net['conv2_2'] = _build_net( 52 | 'conv', 53 | net['conv2_1'], 54 | _get_weight_and_bias(vgg_layers, 7)) 55 | net['pool2'] = _build_net('pool', net['conv2_2']) 56 | net['conv3_1'] = _build_net( 57 | 'conv', 58 | net['pool2'], 59 | _get_weight_and_bias(vgg_layers, 10)) 60 | net['conv3_2'] = _build_net( 61 | 'conv', 62 | net['conv3_1'], 63 | _get_weight_and_bias(vgg_layers, 12)) 64 | net['conv3_3'] = _build_net( 65 | 'conv', 66 | net['conv3_2'], 67 | _get_weight_and_bias(vgg_layers, 14)) 68 | net['conv3_4'] = _build_net( 69 | 'conv', 70 | net['conv3_3'], 71 | _get_weight_and_bias(vgg_layers, 16)) 72 | net['pool3'] = _build_net('pool', net['conv3_4']) 73 | net['conv4_1'] = _build_net( 74 | 'conv', 75 | net['pool3'], 76 | _get_weight_and_bias(vgg_layers, 19)) 77 | net['conv4_2'] = _build_net( 78 | 'conv', 79 | net['conv4_1'], 80 | _get_weight_and_bias(vgg_layers, 21)) 81 | net['conv4_3'] = _build_net( 82 | 'conv', 83 | net['conv4_2'], 84 | _get_weight_and_bias(vgg_layers, 23)) 85 | net['conv4_4'] = _build_net( 86 | 'conv', 87 | net['conv4_3'], 88 | _get_weight_and_bias(vgg_layers, 25)) 89 | net['pool4'] = _build_net('pool', net['conv4_4']) 90 | net['conv5_1'] = _build_net( 91 | 'conv', 92 | net['pool4'], 93 | _get_weight_and_bias(vgg_layers, 28)) 94 | net['conv5_2'] = _build_net( 95 | 'conv', 96 | net['conv5_1'], 97 | _get_weight_and_bias(vgg_layers, 30)) 98 | 99 | return net 100 | 101 | 102 | def _compute_error(fake: torch.Tensor, 103 | real: torch.Tensor, 104 | mask: Optional[torch.Tensor] = None) -> torch.Tensor: 105 | if mask is None: 106 | return torch.mean(torch.abs(fake - real)) 107 | else: 108 | size = (fake.size(2), fake.size(3)) 109 | resized_mask = F.interpolate(mask, size=size, mode='bilinear', align_corners=False) 110 | return torch.mean(torch.abs(fake - real) * resized_mask) 111 | 112 | 113 | def perceptual_loss(image: torch.Tensor, 114 | reference: torch.Tensor, 115 | vgg_model_file: str, 116 | weights: Optional[Sequence[float]] = None, 117 | mask: Optional[torch.Tensor] = None) -> torch.Tensor: 118 | if weights is None: 119 | weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] 120 | 121 | vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file) 122 | vgg_img = _build_vgg19(image * 255.0, vgg_model_file) 123 | 124 | p1 = _compute_error(vgg_ref['conv1_2'], vgg_img['conv1_2'], mask) * weights[0] 125 | p2 = _compute_error(vgg_ref['conv2_2'], vgg_img['conv2_2'], mask) * weights[1] 126 | p3 = _compute_error(vgg_ref['conv3_2'], vgg_img['conv3_2'], mask) * weights[2] 127 | p4 = _compute_error(vgg_ref['conv4_2'], vgg_img['conv4_2'], mask) * weights[3] 128 | p5 = _compute_error(vgg_ref['conv5_2'], vgg_img['conv5_2'], mask) * weights[4] 129 | 130 | final_loss = p1 + p2 + p3 + p4 + p5 131 | final_loss /= 255.0 132 | 133 | return final_loss 134 | 135 | 136 | def _compute_gram_matrix(input_features: torch.Tensor, 137 | mask: torch.Tensor) -> torch.Tensor: 138 | b, c, h, w = input_features.size() 139 | if mask is None: 140 | reshaped_features = input_features.view(b, c, h * w) 141 | else: 142 | resized_mask = F.interpolate( 143 | mask, size=(h, w), mode='bilinear', align_corners=False) 144 | reshaped_features = (input_features * resized_mask).view(b, c, h * w) 145 | return torch.matmul( 146 | reshaped_features, reshaped_features.transpose(1, 2)) / float(h * w) 147 | 148 | 149 | def style_loss(image: torch.Tensor, 150 | reference: torch.Tensor, 151 | vgg_model_file: str, 152 | weights: Optional[Sequence[float]] = None, 153 | mask: Optional[torch.Tensor] = None) -> torch.Tensor: 154 | if not weights: 155 | weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10.0 / 1.5] 156 | 157 | vgg_ref = _build_vgg19(reference * 255.0, vgg_model_file) 158 | vgg_img = _build_vgg19(image * 255.0, vgg_model_file) 159 | 160 | p1 = torch.mean( 161 | torch.square( 162 | _compute_gram_matrix(vgg_ref['conv1_2'] / 255.0, mask) - 163 | _compute_gram_matrix(vgg_img['conv1_2'] / 255.0, mask))) * weights[0] 164 | p2 = torch.mean( 165 | torch.square( 166 | _compute_gram_matrix(vgg_ref['conv2_2'] / 255.0, mask) - 167 | _compute_gram_matrix(vgg_img['conv2_2'] / 255.0, mask))) * weights[1] 168 | p3 = torch.mean( 169 | torch.square( 170 | _compute_gram_matrix(vgg_ref['conv3_2'] / 255.0, mask) - 171 | _compute_gram_matrix(vgg_img['conv3_2'] / 255.0, mask))) * weights[2] 172 | p4 = torch.mean( 173 | torch.square( 174 | _compute_gram_matrix(vgg_ref['conv4_2'] / 255.0, mask) - 175 | _compute_gram_matrix(vgg_img['conv4_2'] / 255.0, mask))) * weights[3] 176 | p5 = torch.mean( 177 | torch.square( 178 | _compute_gram_matrix(vgg_ref['conv5_2'] / 255.0, mask) - 179 | _compute_gram_matrix(vgg_img['conv5_2'] / 255.0, mask))) * weights[4] 180 | 181 | final_loss = p1 + p2 + p3 + p4 + p5 182 | return final_loss 183 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
GenDSA 
News
31 |
32 | * **`October 3, 2025`:** 🔥🔥 GenDSA-V2 is accepted by Nature Medicine!
For details, please refer [this repo](https://github.com/ZrH42/GenDSA-V2).
33 |
34 | * **`January 25, 2025`:** 🔥🔥🔥 GenDSA-V2 coming soon! 🔥🔥🔥 We further scaled up the data and conducted a large number of clinical trials at multiple centers. The new work is under review, please pay attention!
35 |
36 | * **`September 4, 2024`:** [Paper](https://www.cell.com/cms/10.1016/j.medj.2024.07.025/attachment/4209f5f8-1a77-4e15-a872-80a6239ec7bd/mmc3.pdf#:~:text=development%20and%20multi-center%20validation%20study,%20Med%20(2024),) is online and available now. Enjoy!
37 |
38 | * **`July 5, 2024`:** GenDSA is received by Med (Cell Press journal)!
39 |
40 | * **`June 5, 2024`:** We released a portion of the 3D vascular and 3D non vascular datasets. ([link here](https://drive.google.com/drive/folders/1t-esIdUnVcZdFXmSGhGcwhcpI0KyGKY0?usp=sharing))
41 |
42 | * **`March 27, 2024`:** We released our inference code. Paper/Project pages are coming soon. Please stay tuned!
43 |
44 | ##
Abstract
45 | Digital subtraction angiography (DSA) devices have been commonly used in hundreds of different interventional procedures in various parts of the body, requiring multiple scans of the patient in a single procedure, which was high radiation damage to doctors and patients. Inspired by generative artificial intelligence techniques, this study proposed a large-scale pretrained multi-frame generative model-based real-time and low-dose DSA imaging system (GenDSA). Suitable for most DSA scanning protocols, GenDSA could reduce the DSA frame rate (i.e., radiation dose) to 1/3 and generates video that was virtually identical to clinically available protocols. GenDSA was pre-trained, fine-tuned and tested on ten million of images from 35 hospitals. Objective quantitative metrics (PSNR=36.83, SSIM=0.911, generated times=0.07s/frame) demonstrated that the GenDSA’s performance surpassed that of state-of-the-art algorithms in the field of image frame generation. Subjective ratings and statistical results from five doctors showed that the generated videos reached a comparable level to the full-sampled videos, both in terms of overall quality (4.905 vs 4.935) and lesion assessment (4.825 vs 4.860), which fully demonstrated the potential of GenDSA for clinical applications.
46 |
47 |
48 | ##
Environment Setups
49 |
50 | * python 3.8
51 | * cudatoolkit 11.2.1
52 | * cudnn 8.1.0.77
53 | * See 'requirements_Ours.txt' for Python libraries required
54 |
55 | ```shell
56 | conda create -n GenDSA python=3.8
57 | conda activate GenDSA
58 | conda install cudatoolkit=11.2.1 cudnn=8.1.0.77
59 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
60 | # cd /xx/xx/GenDSA
61 | pip install -r GenDSA_env.txt
62 | ```
63 |
64 |
65 | ##
Model Checkpoints
66 | Download the zip of [model checkpoints](https://share.weiyun.com/ze6bOv0i) (key:```mqfd5s```), decompress and put all pkl files into ../GenDSA/weights/checkpoints.
67 |
68 | ##
Our Dataset and Inference Cases
69 | We released a portion of the 3D vascular and non vascular datasets, including the results of our model inference. ([Google Drive](https://drive.google.com/drive/folders/1t-esIdUnVcZdFXmSGhGcwhcpI0KyGKY0?usp=sharing))
70 |
71 |
72 | ##
Inference Demo
73 | Run the following commands to generate single/multi-frame interpolation:
74 |
75 | * Single-frame interpolation
76 | ```shell
77 | python Simple_Interpolator.py \
78 | --model_path ./weights/checkpoints/3D-vas-Inf1.pkl \
79 | --frame1 ./demo_images/DSA_1.png \
80 | --frame2 ./demo_images/DSA_2.png \
81 | --inter_frames 1
82 | ```
83 |
84 | * Two-frame interpolation
85 | ```shell
86 | python Simple_Interpolator.py \
87 | --model_path ./weights/checkpoints/3D-vas-Inf2.pkl \
88 | --frame1 ./demo_images/DSA_1.png \
89 | --frame2 ./demo_images/DSA_2.png \
90 | --inter_frames 2
91 | ```
92 |
93 | * Three-frame interpolation
94 | ```shell
95 | python Simple_Interpolator.py \
96 | --model_path ./weights/checkpoints/3D-vas-Inf3.pkl \
97 | --frame1 ./demo_images/DSA_1.png \
98 | --frame2 ./demo_images/DSA_2.png \
99 | --inter_frames 3
100 | ```
101 |
102 | You can also use other checkpoints to generate 1~3 frame interpolation for your 2D/3D - Head/Abdomen/Thorax/Pelvic/Periph images.
103 |
104 | ## 💖 Citation
105 | Please promote and cite our work if you find it helpful. Enjoy!
106 | ```shell
107 | @article{zhao2024large,
108 | title={Large-scale pretrained frame generative model enables real-time low-dose DSA imaging: An AI system development and multi-center validation study},
109 | author={Zhao, Huangxuan and Xu, Ziyang and Chen, Lei and Wu, Linxia and Cui, Ziwei and Ma, Jinqiang and Sun, Tao and Lei, Yu and Wang, Nan and Hu, Hongyao and others},
110 | journal={Med},
111 | year={2024},
112 | publisher={Elsevier},
113 | url={https://doi.org/10.1016/j.medj.2024.07.025},
114 | }
115 | ```
116 |
117 |
118 |
124 |
125 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
203 |
--------------------------------------------------------------------------------
/Trainer.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.nn.parallel import DistributedDataParallel as DDP
5 | from torch.optim import AdamW
6 | from model.loss import *
7 | from model.warplayer import warp
8 | from config import *
9 |
10 |
11 | class Model:
12 | def __init__(self, local_rank):
13 | backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
14 | backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
15 | self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
16 | self.name = MODEL_CONFIG['LOGNAME']
17 | self.device()
18 |
19 | self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
20 | self.lap = LapLoss()
21 | self.ploss = Perceptual_Loss()
22 | self.styloss = Style_Loss()
23 | if local_rank != -1:
24 | self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank, broadcast_buffers = False)
25 |
26 | def train(self):
27 | self.net.train()
28 |
29 | def eval(self):
30 | self.net.eval()
31 |
32 | def device(self):
33 | self.net.to(torch.device("cuda"))
34 |
35 | def load_model(self, name = None, folder_path = None, full_path = None, rank = 0):
36 | def convert(param):
37 | return {
38 | k.replace("module.", ""): v
39 | for k, v in param.items()
40 | if "module." in k and 'attn_mask' not in k and 'HW' not in k
41 | }
42 | if rank <= 0 :
43 | if full_path is not None:
44 | self.net.load_state_dict(convert(torch.load(full_path)))
45 | else:
46 | if name is None:
47 | name = self.name
48 | if folder_path is None:
49 | self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))
50 | else:
51 | self.net.load_state_dict(convert(torch.load(folder_path + f'/{name}.pkl')))
52 |
53 | def load_pretrain_weight(self, name = None, folder_path = None, full_path = None, rank = 0):
54 | if rank <= 0 :
55 | if full_path is not None:
56 | self.net.load_state_dict(torch.load(full_path))
57 | else:
58 | if name is None:
59 | name = self.name
60 | if folder_path is None:
61 | self.net.load_state_dict(torch.load(f'ckpt/{name}.pkl'))
62 | else:
63 | self.net.load_state_dict(torch.load(folder_path + f'/{name}.pkl'))
64 |
65 | def save_model(self, name = None, folder_path = None, rank=0):
66 | if rank == 0:
67 | if name is None:
68 | name = self.name
69 | if folder_path is None:
70 | torch.save(self.net.state_dict(), f'ckpt/{name}.pkl')
71 | else:
72 | torch.save(self.net.state_dict(), folder_path + f'/{name}.pkl')
73 |
74 | @torch.no_grad()
75 | def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
76 | '''
77 | Infer with down_scale flow
78 | Noting: return BxCxHxW
79 | '''
80 | def infer(imgs):
81 | img0, img1 = imgs[:, :3], imgs[:, 3:6]
82 | imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
83 |
84 | flow, mask = self.net.calculate_flow(imgs_down, timestep)
85 |
86 | flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
87 | mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
88 |
89 | af, _ = self.net.feature_bone(img0, img1)
90 | pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
91 | return pred
92 |
93 | imgs = torch.cat((img0, img1), 1)
94 | if fast_TTA:
95 | imgs_ = imgs.flip(2).flip(3)
96 | input = torch.cat((imgs, imgs_), 0)
97 | preds = infer(input)
98 | return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
99 |
100 | if TTA == False:
101 | return infer(imgs)
102 | else:
103 | return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2
104 |
105 | @torch.no_grad()
106 | def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False):
107 | imgs = torch.cat((img0, img1), 1)
108 | '''
109 | Noting: return BxCxHxW
110 | '''
111 | if fast_TTA:
112 | imgs_ = imgs.flip(2).flip(3)
113 | input = torch.cat((imgs, imgs_), 0)
114 | _, _, _, preds = self.net(input, timestep=timestep)
115 | return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
116 |
117 | _, _, _, pred = self.net(imgs, timestep=timestep)
118 | if TTA == False:
119 | return pred
120 | else:
121 | _, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep)
122 | return (pred + pred2.flip(2).flip(3)) / 2
123 |
124 | @torch.no_grad()
125 | def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False):
126 | '''
127 | Run backbone once, get multi frames at different timesteps
128 | Noting: return a list of [CxHxW]
129 | '''
130 | assert len(time_list) > 0, 'Time_list should not be empty!'
131 | def infer(imgs):
132 | img0, img1 = imgs[:, :3], imgs[:, 3:6]
133 | af, mf = self.net.feature_bone(img0, img1)
134 | imgs_down = None
135 | if down_scale != 1.0:
136 | imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
137 | afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6])
138 |
139 | pred_list = []
140 | for timestep in time_list:
141 | if imgs_down is None:
142 | flow, mask = self.net.calculate_flow(imgs, timestep, af, mf)
143 | else:
144 | flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd)
145 | flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
146 | mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
147 |
148 | pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
149 | pred_list.append(pred)
150 |
151 | return pred_list
152 |
153 | imgs = torch.cat((img0, img1), 1)
154 | if fast_TTA:
155 | imgs_ = imgs.flip(2).flip(3)
156 | input = torch.cat((imgs, imgs_), 0)
157 | preds_lst = infer(input)
158 | return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))]
159 |
160 | preds = infer(imgs)
161 | if TTA is False:
162 | return [preds[i][0] for i in range(len(time_list))]
163 | else:
164 | flip_pred = infer(imgs.flip(2).flip(3))
165 | return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))]
166 |
167 | def update(self, imgs, gt, learning_rate=0, training=True):
168 | for param_group in self.optimG.param_groups:
169 | param_group['lr'] = learning_rate
170 | if training:
171 | self.train()
172 | else:
173 | self.eval()
174 |
175 | if training:
176 | flow, mask, merged, pred = self.net(imgs)
177 | loss_l1 = (self.lap(pred, gt)).mean()
178 |
179 | factor = 1.0 / len(merged)
180 | for merge in merged:
181 | loss_l1 += (self.lap(merge, gt)).mean() * factor
182 |
183 | self.optimG.zero_grad()
184 | loss_l1.backward()
185 | self.optimG.step()
186 | return pred, loss_l1
187 | else:
188 | with torch.no_grad():
189 | flow, mask, merged, pred = self.net(imgs)
190 | return pred, 0
191 |
192 | def multi_gts_update(self, imgs, gts, TimeStepList:list, learning_rate=0, training=True):
193 | for param_group in self.optimG.param_groups:
194 | param_group['lr'] = learning_rate
195 | if training:
196 | self.train()
197 | else:
198 | self.eval()
199 |
200 | if training:
201 | loss_l1_all = 0
202 | preds = []
203 |
204 | for timestep, i in zip(TimeStepList, range(len(TimeStepList))):
205 | flow, mask, merged, pred = self.net(imgs, timestep)
206 | gt_index1 = i * 3
207 | gt_index2 = (i + 1) * 3
208 | loss_l1 = (self.lap(pred, gts[:, gt_index1 : gt_index2])).mean()
209 |
210 | loss_l1_all += loss_l1
211 | preds.append(pred)
212 |
213 | factor = 1.0 / len(merged)
214 | for merge in merged:
215 | loss_l1_all += (self.lap(merge, gts[:, gt_index1 : gt_index2])).mean() * factor
216 |
217 | self.optimG.zero_grad()
218 | loss_l1_all.backward()
219 | self.optimG.step()
220 | return preds, loss_l1_all
221 | else:
222 | with torch.no_grad():
223 | preds = []
224 |
225 | for timestep in TimeStepList:
226 | flow, mask, merged, pred = self.net(imgs, timestep)
227 | preds.append(pred)
228 |
229 | return preds, 0
230 |
231 | def multi_gts_losses_update(self, imgs, gts, TimeStepList:list, vgg_model_file:str = '', losses_weight_schedules:list = [], now_epoch:int = 0, now_step:int = 0, learning_rate=0, training=True):
232 | '''
233 | vgg_model_file: MATLAB format file path for VGG 19 network weights.
234 | losses_weight_schedules: Weight plan for each loss.
235 |
236 | Specific content schematic:
237 | ----------
238 | losses_weight_schedules = [
239 | {'boundaries_epoch':[0], 'boundaries_step':[0], 'values':[1.0, 1.0]},
240 | {'boundaries_epoch':[0], 'boundaries_step':[2400], 'values':[1.0, 0.25]},
241 | {'boundaries_epoch':[2], 'boundaries_step':[2400], 'values':[0.0, 40.0]}]
242 | ----------
243 | Prioritize the boundaries specified by the boundaries_epoch. If boundaries_epoch is empty, it is specified by boundaries_step.
244 | Before the epoch specified by boundaries_epoch, the weight of a loss is calculated as values [0], and then as values [1]. Boundaries_step is the same.
245 | '''
246 | def decide_values(losses_weight_schedules:list, now_epoch:int, now_step:int):
247 | '''
248 | Based on the current number of epochs, steps (iters), and the stage weight plan, determine the weight of each loss
249 | '''
250 | l1_epoch = losses_weight_schedules[0]['boundaries_epoch']
251 | p_epoch = losses_weight_schedules[1]['boundaries_epoch']
252 | sty_epoch = losses_weight_schedules[2]['boundaries_epoch']
253 |
254 | l1_step = losses_weight_schedules[0]['boundaries_step']
255 | p_step = losses_weight_schedules[1]['boundaries_step']
256 | sty_step = losses_weight_schedules[2]['boundaries_step']
257 |
258 | if ((len(l1_epoch) != 0) & (len(p_epoch) != 0) & (len(sty_epoch) != 0)): # Prioritize the boundaries specified by the boundaries_epoch.
259 | if now_epoch < l1_epoch[0]:
260 | l1_value = losses_weight_schedules[0]['values'][0]
261 | else:
262 | l1_value = losses_weight_schedules[0]['values'][1]
263 | if now_epoch < p_epoch[0]:
264 | p_value = losses_weight_schedules[1]['values'][0]
265 | else:
266 | p_value = losses_weight_schedules[1]['values'][1]
267 | if now_epoch < sty_epoch[0]:
268 | sty_value = losses_weight_schedules[2]['values'][0]
269 | else:
270 | sty_value = losses_weight_schedules[2]['values'][1]
271 | elif ((len(l1_step) != 0) & (len(p_step) != 0) & (len(sty_step) != 0)): # Otherwise, boundaries specified by boundaries_step.
272 | if now_step < l1_step[0]:
273 | l1_value = losses_weight_schedules[0]['values'][0]
274 | else:
275 | l1_value = losses_weight_schedules[0]['values'][1]
276 | if now_step < p_step[0]:
277 | p_value = losses_weight_schedules[1]['values'][0]
278 | else:
279 | p_value = losses_weight_schedules[1]['values'][1]
280 | if now_step < sty_step[0]:
281 | sty_value = losses_weight_schedules[2]['values'][0]
282 | else:
283 | sty_value = losses_weight_schedules[2]['values'][1]
284 | else:
285 | print("'losses_weight_schedules' is illegal, it needs to meet the following conditions: (1) the boundaries_epoch of each loss is not empty, or (2) the boundaries_step of each loss is not empty!")
286 | sys.exit()
287 |
288 | return l1_value, p_value, sty_value
289 |
290 |
291 | for param_group in self.optimG.param_groups:
292 | param_group['lr'] = learning_rate
293 | if training:
294 | self.train()
295 | else:
296 | self.eval()
297 |
298 | if training:
299 | loss_all = 0
300 | preds = []
301 | l1_weight, p_weight, sty_weight = decide_values(losses_weight_schedules, now_epoch, now_step)
302 |
303 | for timestep, i in zip(TimeStepList, range(len(TimeStepList))):
304 | flow, mask, merged, pred = self.net(imgs, timestep)
305 | gt_index1 = i * 3
306 | gt_index2 = (i + 1) * 3
307 | if l1_weight != 0:
308 | loss_l1 = (self.lap(pred, gts[:, gt_index1 : gt_index2])).mean()
309 | print("loss_l1", loss_l1)
310 | loss_all += loss_l1 * l1_weight
311 | if p_weight != 0:
312 | loss_p = (self.ploss(pred, gts[:, gt_index1 : gt_index2], vgg_model_file)).mean()
313 | print("loss_p", loss_p)
314 | loss_all += loss_p * p_weight
315 | if sty_weight != 0:
316 | loss_style = (self.styloss(pred, gts[:, gt_index1 : gt_index2], vgg_model_file)).mean()
317 | print("loss_style", loss_style)
318 | loss_all += loss_style * sty_weight
319 |
320 | preds.append(pred)
321 |
322 | factor = 1.0 / len(merged)
323 | for merge in merged:
324 | if l1_weight != 0:
325 | loss_all += (self.lap(merge, gts[:, gt_index1 : gt_index2])).mean() * factor * l1_weight
326 | if p_weight != 0:
327 | loss_all += (self.ploss(merge, gts[:, gt_index1 : gt_index2], vgg_model_file)).mean() * factor * p_weight
328 | if sty_weight != 0:
329 | loss_all += (self.styloss(merge, gts[:, gt_index1 : gt_index2], vgg_model_file)).mean() * factor * sty_weight
330 |
331 | self.optimG.zero_grad()
332 | loss_all.backward()
333 | self.optimG.step()
334 | return preds, loss_all
335 | else:
336 | with torch.no_grad():
337 | preds = []
338 |
339 | for timestep in TimeStepList:
340 | flow, mask, merged, pred = self.net(imgs, timestep)
341 | preds.append(pred)
342 |
343 | return preds, 0
344 |
--------------------------------------------------------------------------------
/model/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import einsum
4 | from einops import rearrange
5 | import math
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 |
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def default(val, d):
13 | return val if exists(val) else d
14 |
15 | def calc_rel_pos(n):
16 | pos = torch.meshgrid(torch.arange(n), torch.arange(n))
17 | pos = rearrange(torch.stack(pos), 'n i j -> (i j) n')
18 | rel_pos = pos[None, :] - pos[:, None]
19 | rel_pos += n - 1
20 | return rel_pos
21 |
22 |
23 | class Mlp(nn.Module):
24 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
25 | super().__init__()
26 | out_features = out_features or in_features
27 | hidden_features = hidden_features or in_features
28 | self.fc1 = nn.Linear(in_features, hidden_features)
29 | self.dwconv = DWConv(hidden_features)
30 | self.act = act_layer()
31 | self.fc2 = nn.Linear(hidden_features, out_features)
32 | self.drop = nn.Dropout(drop)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.apply(self._init_weights)
35 |
36 | def _init_weights(self, m):
37 | if isinstance(m, nn.Linear):
38 | trunc_normal_(m.weight, std=.02)
39 | if isinstance(m, nn.Linear) and m.bias is not None:
40 | nn.init.constant_(m.bias, 0)
41 | elif isinstance(m, nn.LayerNorm):
42 | nn.init.constant_(m.bias, 0)
43 | nn.init.constant_(m.weight, 1.0)
44 | elif isinstance(m, nn.Conv2d):
45 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
46 | fan_out //= m.groups
47 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
48 | if m.bias is not None:
49 | m.bias.data.zero_()
50 |
51 | def forward(self, x, H, W):
52 | x = self.fc1(x)
53 | x = self.dwconv(x, H, W)
54 | x = self.act(x)
55 | x = self.drop(x)
56 | x = self.fc2(x)
57 | x = self.drop(x)
58 | return x
59 |
60 |
61 | class InterFrameLambda(nn.Module):
62 | def __init__(self, dim, *, dim_k, n = None, r = None, heads = 4, dim_out = None, dim_u = 1, motion_dim):
63 | super().__init__()
64 | dim_out = default(dim_out, dim)
65 | self.u = dim_u
66 | self.heads = heads
67 |
68 | assert (dim_out % heads) == 0, "'dim_out' must be divisible by number of heads for multi-head query"
69 | assert (motion_dim % heads) == 0, "'motion_dim' must be divisible by number of heads for multi-head query"
70 | dim_v = dim_out // heads
71 |
72 | self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
73 | self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
74 | self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
75 | self.cor_embed = nn.Conv2d(2, dim_k * heads, 1, bias = True)
76 | self.motion_proj = nn.Conv2d(dim_k * heads, motion_dim, 1, bias = True)
77 | self.channel_compress = nn.Conv2d(dim_out, dim_k * heads, 1, bias = False)
78 |
79 | self.norm_q = nn.BatchNorm2d(dim_k * heads)
80 | self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
81 |
82 | self.local_contexts = exists(r)
83 | if exists(r):
84 | assert (r % 2) == 1, 'Receptive kernel size should be odd'
85 | self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
86 | else:
87 | assert exists(n), 'You must specify the window size (n=h=w)'
88 | rel_lengths = 2 * n - 1
89 | self.rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k, dim_u))
90 | self.rel_pos = calc_rel_pos(n)
91 |
92 | def forward(self, x, cor):
93 | x = rearrange(x, 'b h w c -> b c h w')
94 | cor = rearrange(cor, 'b h w c -> b c h w')
95 | b, c, hh, ww, u, h = *x.shape, self.u, self.heads
96 | x_reverse = torch.cat([x[b//2:], x[:b//2]])
97 |
98 | q = self.to_q(x)
99 | k = self.to_k(x_reverse)
100 | v = self.to_v(x_reverse)
101 |
102 | Q = self.norm_q(q)
103 | V = self.norm_v(v)
104 |
105 | Q = rearrange(Q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
106 | k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
107 | V = rearrange(V, 'b (u v) hh ww -> b u v (hh ww)', u = u)
108 |
109 | k = k.softmax(dim=-1)
110 | λc = einsum('b u k m, b u v m -> b k v', k, V)
111 | Yc = einsum('b h k n, b k v -> b h v n', Q, λc)
112 |
113 | if self.local_contexts:
114 | V = rearrange(V, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
115 | λp = self.pos_conv(V)
116 | Yp = einsum('b h k n, b k v n -> b h v n', Q, λp.flatten(3))
117 | else:
118 | n, m = self.rel_pos.unbind(dim = -1)
119 | rel_pos_emb = self.rel_pos_emb[n, m]
120 | λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, V)
121 | Yp = einsum('b h k n, b n k v -> b h v n', Q, λp)
122 |
123 | Y = Yc + Yp
124 | appearance = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
125 |
126 | cor_embed_ = self.cor_embed(cor)
127 | cor_embed = rearrange(cor_embed_, 'b (h k) hh ww -> b h k (hh ww)', h = h)
128 | cor_reverse_c = einsum('b h k n, b k v -> b h v n', cor_embed, λc)
129 | if self.local_contexts:
130 | cor_reverse_p = einsum('b h k n, b k v n -> b h v n', cor_embed, λp.flatten(3))
131 | else:
132 | cor_reverse_p = einsum('b h k n, b n k v -> b h v n', cor_embed, λp)
133 | cor_reverse_ = rearrange(cor_reverse_c + cor_reverse_p, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
134 | cor_reverse = self.channel_compress(cor_reverse_)
135 | motion = self.motion_proj(cor_reverse - cor_embed_)
136 |
137 | appearance = rearrange(appearance, 'b c h w -> b h w c')
138 | motion = rearrange(motion, 'b c h w -> b h w c')
139 |
140 | return appearance, motion
141 |
142 |
143 | class StructFormerBlock(nn.Module):
144 | def __init__(self, dim, dim_out, dim_k, dim_u, n, r, motion_dim, heads, range, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
145 | super().__init__()
146 | if range == 'global':
147 | self.Lambda = InterFrameLambda(dim = dim,
148 | dim_k = dim_k,
149 | n = n,
150 | heads = heads,
151 | dim_out = dim_out,
152 | dim_u = dim_u,
153 | motion_dim = motion_dim)
154 | elif range == 'local':
155 | self.Lambda = InterFrameLambda(dim = dim,
156 | dim_k = dim_k,
157 | r = r,
158 | heads = heads,
159 | dim_out = dim_out,
160 | dim_u = dim_u,
161 | motion_dim = motion_dim)
162 | else:
163 | print("range不符合规范, 取'global'或'local'!")
164 | sys.exit()
165 | self.norm1 = norm_layer(dim)
166 | self.norm2 = norm_layer(dim)
167 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
168 | mlp_hidden_dim = int(dim * mlp_ratio)
169 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
170 | self.apply(self._init_weights)
171 |
172 | def _init_weights(self, m):
173 | if isinstance(m, nn.Linear):
174 | trunc_normal_(m.weight, std=.02)
175 | if isinstance(m, nn.Linear) and m.bias is not None:
176 | nn.init.constant_(m.bias, 0)
177 | elif isinstance(m, nn.LayerNorm):
178 | nn.init.constant_(m.bias, 0)
179 | nn.init.constant_(m.weight, 1.0)
180 | elif isinstance(m, nn.Conv2d):
181 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
182 | fan_out //= m.groups
183 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
184 | if m.bias is not None:
185 | m.bias.data.zero_()
186 |
187 | def forward(self, x, cor, H, W, B):
188 | x_norm1 = self.norm1(x)
189 | x = x_norm1.view(2*B, H, W, -1)
190 | x_appearence, x_motion = self.Lambda(x, cor)
191 | x_appearence = rearrange(x_appearence, 'b h w c -> b (h w) c')
192 | x_motion = rearrange(x_motion, 'b h w c -> b (h w) c')
193 | x_appearence = x_norm1 + self.drop_path(x_appearence)
194 | x_appearence = x_appearence + self.drop_path(self.mlp(self.norm2(x_appearence), H, W))
195 |
196 | return x_appearence, x_motion
197 |
198 |
199 | class ConvBlock(nn.Module):
200 | def __init__(self, in_dim, out_dim, depths=2,act_layer=nn.PReLU):
201 | super().__init__()
202 | layers = []
203 | for i in range(depths):
204 | if i == 0:
205 | layers.append(nn.Conv2d(in_dim, out_dim, 3,1,1))
206 | else:
207 | layers.append(nn.Conv2d(out_dim, out_dim, 3,1,1))
208 | layers.extend([
209 | act_layer(out_dim),
210 | ])
211 | self.conv = nn.Sequential(*layers)
212 |
213 | def _init_weights(self, m):
214 | if isinstance(m, nn.Conv2d):
215 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
216 | fan_out //= m.groups
217 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
218 | if m.bias is not None:
219 | m.bias.data.zero_()
220 |
221 | def forward(self, x):
222 | x = self.conv(x)
223 | return x
224 |
225 |
226 | class OverlapPatchEmbed(nn.Module):
227 | def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
228 | super().__init__()
229 | patch_size = to_2tuple(patch_size)
230 |
231 | self.patch_size = patch_size
232 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
233 | padding=(patch_size[0] // 2, patch_size[1] // 2))
234 | self.norm = nn.LayerNorm(embed_dim)
235 |
236 | self.apply(self._init_weights)
237 |
238 | def _init_weights(self, m):
239 | if isinstance(m, nn.Linear):
240 | trunc_normal_(m.weight, std=.02)
241 | if isinstance(m, nn.Linear) and m.bias is not None:
242 | nn.init.constant_(m.bias, 0)
243 | elif isinstance(m, nn.LayerNorm):
244 | nn.init.constant_(m.bias, 0)
245 | nn.init.constant_(m.weight, 1.0)
246 | elif isinstance(m, nn.Conv2d):
247 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
248 | fan_out //= m.groups
249 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
250 | if m.bias is not None:
251 | m.bias.data.zero_()
252 |
253 | def forward(self, x):
254 | x = self.proj(x)
255 | _, _, H, W = x.shape
256 | x = x.flatten(2).transpose(1, 2)
257 | x = self.norm(x)
258 |
259 | return x, H, W
260 |
261 |
262 | class CrossScalePatchEmbed(nn.Module):
263 | def __init__(self, in_dims=[16,32,64], embed_dim=768):
264 | super().__init__()
265 | base_dim = in_dims[0]
266 |
267 | layers = []
268 | for i in range(len(in_dims)):
269 | for j in range(2 ** i):
270 | layers.append(nn.Conv2d(in_dims[-1-i], base_dim, 3, 2**(i+1), 1+j, 1+j))
271 | self.layers = nn.ModuleList(layers)
272 | self.proj = nn.Conv2d(base_dim * len(layers), embed_dim, 1, 1)
273 | self.norm = nn.LayerNorm(embed_dim)
274 |
275 | self.apply(self._init_weights)
276 |
277 | def _init_weights(self, m):
278 | if isinstance(m, nn.Linear):
279 | trunc_normal_(m.weight, std=.02)
280 | if isinstance(m, nn.Linear) and m.bias is not None:
281 | nn.init.constant_(m.bias, 0)
282 | elif isinstance(m, nn.LayerNorm):
283 | nn.init.constant_(m.bias, 0)
284 | nn.init.constant_(m.weight, 1.0)
285 | elif isinstance(m, nn.Conv2d):
286 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
287 | fan_out //= m.groups
288 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
289 | if m.bias is not None:
290 | m.bias.data.zero_()
291 |
292 | def forward(self, xs):
293 | ys = []
294 | k = 0
295 | for i in range(len(xs)):
296 | for _ in range(2 ** i):
297 | ys.append(self.layers[k](xs[-1-i]))
298 | k += 1
299 |
300 | x = self.proj(torch.cat(ys,1))
301 |
302 | _, _, H, W = x.shape
303 | x = x.flatten(2).transpose(1, 2)
304 |
305 | x = self.norm(x)
306 |
307 | return x, H, W
308 |
309 |
310 | class StructFormer(nn.Module):
311 | def __init__(self, in_chans=3, embed_dims=[32, 64, 128, 256, 512], motion_dims=64, num_heads=[8, 16],
312 | mlp_ratios=[4, 4], lambda_global_or_local = 'local', lambda_dim_k = 16, lambda_dim_u = 1,
313 | lambda_n = 32, lambda_r = None, drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
314 | depths=[2, 2, 2, 4, 4], **kwarg):
315 | super().__init__()
316 | self.depths = depths
317 | self.num_stages = len(embed_dims)
318 |
319 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
320 | cur = 0
321 |
322 | self.conv_stages = self.num_stages - len(num_heads)
323 |
324 | for i in range(self.num_stages):
325 | if i == 0:
326 | block = ConvBlock(in_chans,embed_dims[i],depths[i])
327 | else:
328 | if i < self.conv_stages:
329 | patch_embed = nn.Sequential(
330 | nn.Conv2d(embed_dims[i-1], embed_dims[i], 3,2,1),
331 | nn.PReLU(embed_dims[i])
332 | )
333 | block = ConvBlock(embed_dims[i],embed_dims[i],depths[i])
334 | else:
335 | if i == self.conv_stages:
336 | patch_embed = CrossScalePatchEmbed(embed_dims[:i],
337 | embed_dim=embed_dims[i])
338 | block = nn.ModuleList([StructFormerBlock(
339 | dim=embed_dims[i], dim_out=embed_dims[i], dim_k=lambda_dim_k, dim_u=lambda_dim_u, n=lambda_n, r=lambda_r,
340 | motion_dim=motion_dims[i], heads=num_heads[i-self.conv_stages], range = lambda_global_or_local, mlp_ratio=mlp_ratios[i-self.conv_stages],
341 | drop=drop_rate, drop_path=dpr[cur + j])
342 | for j in range(depths[i])])
343 | else:
344 | patch_embed = OverlapPatchEmbed(patch_size=3,
345 | stride=2,
346 | in_chans=embed_dims[i - 1],
347 | embed_dim=embed_dims[i])
348 | block = nn.ModuleList([StructFormerBlock(
349 | dim=embed_dims[i], dim_out=embed_dims[i], dim_k=lambda_dim_k, dim_u=lambda_dim_u, n=int(lambda_n/2), r=lambda_r,
350 | motion_dim=motion_dims[i], heads=num_heads[i-self.conv_stages], range = lambda_global_or_local, mlp_ratio=mlp_ratios[i-self.conv_stages],
351 | drop=drop_rate, drop_path=dpr[cur + j])
352 | for j in range(depths[i])])
353 |
354 | norm = norm_layer(embed_dims[i])
355 | setattr(self, f"norm{i + 1}", norm)
356 |
357 | setattr(self, f"patch_embed{i + 1}", patch_embed)
358 | cur += depths[i]
359 |
360 | setattr(self, f"block{i + 1}", block)
361 |
362 | self.cor = {}
363 |
364 | self.apply(self._init_weights)
365 |
366 | def _init_weights(self, m):
367 | if isinstance(m, nn.Linear):
368 | trunc_normal_(m.weight, std=.02)
369 | if isinstance(m, nn.Linear) and m.bias is not None:
370 | nn.init.constant_(m.bias, 0)
371 | elif isinstance(m, nn.LayerNorm):
372 | nn.init.constant_(m.bias, 0)
373 | nn.init.constant_(m.weight, 1.0)
374 | elif isinstance(m, nn.Conv2d):
375 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
376 | fan_out //= m.groups
377 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
378 | if m.bias is not None:
379 | m.bias.data.zero_()
380 |
381 | def get_cor(self, shape, device):
382 | k = (str(shape), str(device))
383 | if k not in self.cor:
384 | tenHorizontal = torch.linspace(-1.0, 1.0, shape[2], device=device).view(
385 | 1, 1, 1, shape[2]).expand(shape[0], -1, shape[1], -1).permute(0, 2, 3, 1)
386 | tenVertical = torch.linspace(-1.0, 1.0, shape[1], device=device).view(
387 | 1, 1, shape[1], 1).expand(shape[0], -1, -1, shape[2]).permute(0, 2, 3, 1)
388 | self.cor[k] = torch.cat([tenHorizontal, tenVertical], -1).to(device)
389 | return self.cor[k]
390 |
391 | def forward(self, x1, x2):
392 | B = x1.shape[0]
393 |
394 | x = torch.cat([x1, x2], 0)
395 |
396 | motion_features = []
397 | appearence_features = []
398 | xs = []
399 |
400 | for i in range(self.num_stages):
401 | motion_features.append([])
402 | patch_embed = getattr(self, f"patch_embed{i + 1}",None)
403 | block = getattr(self, f"block{i + 1}",None)
404 | norm = getattr(self, f"norm{i + 1}",None)
405 | if i < self.conv_stages:
406 | if i > 0:
407 | x = patch_embed(x)
408 | x = block(x)
409 | xs.append(x)
410 | else:
411 | if i == self.conv_stages:
412 | x, H, W = patch_embed(xs)
413 | else:
414 | x, H, W = patch_embed(x)
415 |
416 | cor = self.get_cor((x.shape[0], H, W), x.device)
417 |
418 | for blk in block:
419 | x, x_motion = blk(x, cor, H, W, B)
420 |
421 | motion_features[i].append(x_motion.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous())
422 |
423 | x = norm(x)
424 | x = x.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous()
425 | motion_features[i] = torch.cat(motion_features[i], 1)
426 |
427 | appearence_features.append(x)
428 | return appearence_features, motion_features
429 |
430 |
431 | class DWConv(nn.Module):
432 | def __init__(self, dim):
433 | super(DWConv, self).__init__()
434 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
435 |
436 | def forward(self, x, H, W):
437 | B, N, C = x.shape
438 | x = x.transpose(1, 2).reshape(B, C, H, W)
439 | x = self.dwconv(x)
440 | x = x.reshape(B, C, -1).transpose(1, 2)
441 |
442 | return x
443 |
444 |
445 | def encoder(**kargs):
446 | model = StructFormer(**kargs)
447 | return model
448 |
--------------------------------------------------------------------------------