├── .gitignore ├── LICENSE ├── README.md ├── demo ├── 001.png └── 002.png ├── inference_video.py ├── model ├── FusionNet.py ├── GMFupSS.py ├── MetricNet.py ├── gmflow │ ├── __init__.py │ ├── backbone.py │ ├── geometry.py │ ├── gmflow.py │ ├── matching.py │ ├── position.py │ ├── transformer.py │ ├── trident_conv.py │ └── utils.py └── softsplat.py ├── requirements.txt └── train_log ├── flownet.pkl ├── fusionnet.pkl └── metric.pkl /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 98mxr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GMFupSS 2 | 3 | A faster [GMFSS](https://github.com/YiWeiHuang-stack/GMFSS) 4 | 5 | --- 6 | 7 | **2023-04-03: We now provide [GMFSS_Fortuna](https://github.com/98mxr/GMFSS_Fortuna) as a factual basis for training in GMFSS. Please use it. This item will not be updated!** 8 | 9 | **2022-11-23: We now provide GMFSS_union as the next generation of GMFSS with better performance. Welcome to try.** 10 | 11 | --- 12 | 13 | * Cupy is required as a running environment, please follow the [link](https://docs.cupy.dev/en/stable/install.html) to install. 14 | 15 | ## Run Video Frame Interpolation 16 | 17 | ``` 18 | python3 inference_video.py --img=demo/ --scale=1.0 --multi=2 19 | ``` 20 | 21 | ## Acknowledgment 22 | This project is supported by [SVFI](https://steamcommunity.com/app/1692080) [Development Team](https://github.com/Justin62628/Squirrel-RIFE) 23 | -------------------------------------------------------------------------------- /demo/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/98mxr/GMFupSS/5b5e7f0f5927e35aef4e807d4f1358df22c365b0/demo/001.png -------------------------------------------------------------------------------- /demo/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/98mxr/GMFupSS/5b5e7f0f5927e35aef4e807d4f1358df22c365b0/demo/002.png -------------------------------------------------------------------------------- /inference_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.nn import functional as F 8 | import warnings 9 | import _thread 10 | import skvideo.io 11 | from queue import Queue, Empty 12 | 13 | warnings.filterwarnings("ignore") 14 | 15 | def transferAudio(sourceVideo, targetVideo): 16 | import shutil 17 | import moviepy.editor 18 | tempAudioFileName = "./temp/audio.mkv" 19 | 20 | # split audio from original video file and store in "temp" directory 21 | if True: 22 | 23 | # clear old "temp" directory if it exits 24 | if os.path.isdir("temp"): 25 | # remove temp directory 26 | shutil.rmtree("temp") 27 | # create new "temp" directory 28 | os.makedirs("temp") 29 | # extract audio from video 30 | os.system('ffmpeg -y -i "{}" -c:a copy -vn {}'.format(sourceVideo, tempAudioFileName)) 31 | 32 | targetNoAudio = os.path.splitext(targetVideo)[0] + "_noaudio" + os.path.splitext(targetVideo)[1] 33 | os.rename(targetVideo, targetNoAudio) 34 | # combine audio file and new video file 35 | os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) 36 | 37 | if os.path.getsize(targetVideo) == 0: # if ffmpeg failed to merge the video and audio together try converting the audio to aac 38 | tempAudioFileName = "./temp/audio.m4a" 39 | os.system('ffmpeg -y -i "{}" -c:a aac -b:a 160k -vn {}'.format(sourceVideo, tempAudioFileName)) 40 | os.system('ffmpeg -y -i "{}" -i {} -c copy "{}"'.format(targetNoAudio, tempAudioFileName, targetVideo)) 41 | if (os.path.getsize(targetVideo) == 0): # if aac is not supported by selected format 42 | os.rename(targetNoAudio, targetVideo) 43 | print("Audio transfer failed. Interpolated video will have no audio") 44 | else: 45 | print("Lossless audio transfer failed. Audio was transcoded to AAC (M4A) instead.") 46 | 47 | # remove audio-less video 48 | os.remove(targetNoAudio) 49 | else: 50 | os.remove(targetNoAudio) 51 | 52 | # remove temp directory 53 | shutil.rmtree("temp") 54 | 55 | parser = argparse.ArgumentParser(description='Interpolation for a pair of images') 56 | parser.add_argument('--video', dest='video', type=str, default=None) 57 | parser.add_argument('--output', dest='output', type=str, default=None) 58 | parser.add_argument('--img', dest='img', type=str, default=None) 59 | parser.add_argument('--montage', dest='montage', action='store_true', help='montage origin video') 60 | parser.add_argument('--model', dest='modelDir', type=str, default='train_log', help='directory with trained model files') 61 | parser.add_argument('--fp16', dest='fp16', action='store_true', help='fp16 mode for faster and more lightweight inference on cards with Tensor Cores') 62 | parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video') 63 | parser.add_argument('--scale', dest='scale', type=float, default=1.0, help='Try scale=0.5 for 4k video') 64 | parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing') 65 | parser.add_argument('--fps', dest='fps', type=int, default=None) 66 | parser.add_argument('--png', dest='png', action='store_true', help='whether to vid_out png format vid_outs') 67 | parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='vid_out video extension') 68 | parser.add_argument('--exp', dest='exp', type=int, default=1) 69 | parser.add_argument('--multi', dest='multi', type=int, default=2) 70 | 71 | args = parser.parse_args() 72 | if args.exp != 1: 73 | args.multi = (2 ** args.exp) 74 | assert (not args.video is None or not args.img is None) 75 | if args.skip: 76 | print("skip flag is abandoned, please refer to issue #207.") 77 | if args.UHD and args.scale==1.0: 78 | args.scale = 0.5 79 | assert args.scale in [0.25, 0.5, 1.0, 2.0, 4.0] 80 | if not args.img is None: 81 | args.png = True 82 | 83 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 84 | torch.set_grad_enabled(False) 85 | if torch.cuda.is_available(): 86 | torch.backends.cudnn.enabled = True 87 | torch.backends.cudnn.benchmark = True 88 | if(args.fp16): 89 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 90 | 91 | try: 92 | from model.GMFupSS import Model 93 | except: 94 | print("Please download our model from model list") 95 | model = Model() 96 | if not hasattr(model, 'version'): 97 | model.version = 0 98 | model.load_model(args.modelDir, -1) 99 | print("Loaded model") 100 | model.eval() 101 | model.device() 102 | 103 | if not args.video is None: 104 | videoCapture = cv2.VideoCapture(args.video) 105 | fps = videoCapture.get(cv2.CAP_PROP_FPS) 106 | tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) 107 | videoCapture.release() 108 | if args.fps is None: 109 | fpsNotAssigned = True 110 | args.fps = fps * args.multi 111 | else: 112 | fpsNotAssigned = False 113 | videogen = skvideo.io.vreader(args.video) 114 | lastframe = next(videogen) 115 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 116 | video_path_wo_ext, ext = os.path.splitext(args.video) 117 | print('{}.{}, {} frames in total, {}FPS to {}FPS'.format(video_path_wo_ext, args.ext, tot_frame, fps, args.fps)) 118 | if args.png == False and fpsNotAssigned == True: 119 | print("The audio will be merged after interpolation process") 120 | else: 121 | print("Will not merge audio because using png or fps flag!") 122 | else: 123 | videogen = [] 124 | for f in os.listdir(args.img): 125 | if 'png' in f: 126 | videogen.append(f) 127 | tot_frame = len(videogen) 128 | videogen.sort(key= lambda x:int(x[:-4])) 129 | lastframe = cv2.imread(os.path.join(args.img, videogen[0]), cv2.IMREAD_UNCHANGED)[:, :, ::-1].copy() 130 | videogen = videogen[1:] 131 | h, w, _ = lastframe.shape 132 | vid_out_name = None 133 | vid_out = None 134 | if args.png: 135 | if not os.path.exists('vid_out'): 136 | os.mkdir('vid_out') 137 | else: 138 | if args.output is not None: 139 | vid_out_name = args.output 140 | else: 141 | vid_out_name = '{}_{}X_{}fps.{}'.format(video_path_wo_ext, args.multi, int(np.round(args.fps)), args.ext) 142 | vid_out = cv2.VideoWriter(vid_out_name, fourcc, args.fps, (w, h)) 143 | 144 | def clear_write_buffer(user_args, write_buffer): 145 | cnt = 0 146 | while True: 147 | item = write_buffer.get() 148 | if item is None: 149 | break 150 | if user_args.png: 151 | cv2.imwrite('vid_out/{:0>7d}.png'.format(cnt), item[:, :, ::-1]) 152 | cnt += 1 153 | else: 154 | vid_out.write(item[:, :, ::-1]) 155 | 156 | def build_read_buffer(user_args, read_buffer, videogen): 157 | try: 158 | for frame in videogen: 159 | if not user_args.img is None: 160 | frame = cv2.imread(os.path.join(user_args.img, frame))[:, :, ::-1].copy() 161 | if user_args.montage: 162 | frame = frame[:, left: left + w] 163 | read_buffer.put(frame) 164 | except: 165 | pass 166 | read_buffer.put(None) 167 | 168 | def make_inference(I0, I1, reuse_things, n): 169 | global model 170 | if model.version >= 3.9: 171 | res = [] 172 | for i in range(n): 173 | res.append(model.inference(I0, I1, reuse_things, (i+1) * 1. / (n+1))) 174 | return res 175 | else: 176 | middle = model.inference(I0, I1, args.scale) 177 | if n == 1: 178 | return [middle] 179 | first_half = make_inference(I0, middle, n=n//2) 180 | second_half = make_inference(middle, I1, n=n//2) 181 | if n%2: 182 | return [*first_half, middle, *second_half] 183 | else: 184 | return [*first_half, *second_half] 185 | 186 | def pad_image(img): 187 | if(args.fp16): 188 | return F.pad(img, padding).half() 189 | else: 190 | return F.pad(img, padding) 191 | 192 | if args.montage: 193 | left = w // 4 194 | w = w // 2 195 | tmp = max(64, int(64 / args.scale)) 196 | ph = ((h - 1) // tmp + 1) * tmp 197 | pw = ((w - 1) // tmp + 1) * tmp 198 | padding = (0, pw - w, 0, ph - h) 199 | pbar = tqdm(total=tot_frame) 200 | if args.montage: 201 | lastframe = lastframe[:, left: left + w] 202 | write_buffer = Queue(maxsize=500) 203 | read_buffer = Queue(maxsize=500) 204 | _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen)) 205 | _thread.start_new_thread(clear_write_buffer, (args, write_buffer)) 206 | 207 | I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. 208 | I1 = F.interpolate(I1, (ph, pw), mode='bilinear', align_corners=False) 209 | temp = None # save lastframe when processing static frame 210 | 211 | while True: 212 | if temp is not None: 213 | frame = temp 214 | temp = None 215 | else: 216 | frame = read_buffer.get() 217 | if frame is None: 218 | break 219 | I0 = I1 220 | I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. 221 | I1 = F.interpolate(I1, (ph, pw), mode='bilinear', align_corners=False) 222 | 223 | reuse_things = model.reuse(I0, I1, args.scale) 224 | output = make_inference(I0, I1, reuse_things, args.multi-1) 225 | 226 | if args.montage: 227 | write_buffer.put(np.concatenate((lastframe, lastframe), 1)) 228 | for mid in output: 229 | mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) 230 | write_buffer.put(np.concatenate((lastframe, mid[:h, :w]), 1)) 231 | else: 232 | write_buffer.put(lastframe) 233 | for mid in output: 234 | mid = F.interpolate(mid, (h, w), mode='bilinear', align_corners=False) 235 | mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) 236 | write_buffer.put(mid) 237 | pbar.update(1) 238 | lastframe = frame 239 | 240 | if args.montage: 241 | write_buffer.put(np.concatenate((lastframe, lastframe), 1)) 242 | else: 243 | write_buffer.put(lastframe) 244 | import time 245 | while(not write_buffer.empty()): 246 | time.sleep(0.1) 247 | pbar.close() 248 | if not vid_out is None: 249 | vid_out.release() 250 | 251 | # move audio to new video file if appropriate 252 | if args.png == False and fpsNotAssigned == True and not args.video is None: 253 | try: 254 | transferAudio(args.video, vid_out_name) 255 | except: 256 | print("Audio transfer failed. Interpolated video will have no audio") 257 | targetNoAudio = os.path.splitext(vid_out_name)[0] + "_noaudio" + os.path.splitext(vid_out_name)[1] 258 | os.rename(targetNoAudio, vid_out_name) 259 | -------------------------------------------------------------------------------- /model/FusionNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.softsplat import softsplat as warp 6 | 7 | 8 | # Residual Block 9 | def ResidualBlock(in_channels, out_channels, stride=1): 10 | return torch.nn.Sequential( 11 | nn.PReLU(), 12 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), 13 | nn.PReLU(), 14 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True) 15 | ) 16 | 17 | 18 | # downsample block 19 | def DownsampleBlock(in_channels, out_channels, stride=2): 20 | return torch.nn.Sequential( 21 | nn.PReLU(), 22 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), 23 | nn.PReLU(), 24 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True) 25 | ) 26 | 27 | 28 | # upsample block 29 | def UpsampleBlock(in_channels, out_channels, stride=1): 30 | return torch.nn.Sequential( 31 | nn.Upsample(scale_factor=2, mode='bilinear'), 32 | nn.PReLU(), 33 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), 34 | nn.PReLU(), 35 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), 36 | ) 37 | 38 | 39 | class PixelShuffleBlcok(nn.Module): 40 | def __init__(self, in_feat, num_feat, num_out_ch): 41 | super(PixelShuffleBlcok, self).__init__() 42 | self.conv_before_upsample = nn.Sequential( 43 | nn.Conv2d(in_feat, num_feat, 3, 1, 1), 44 | nn.LeakyReLU(inplace=True) 45 | ) 46 | self.upsample = nn.Sequential( 47 | nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), 48 | nn.PixelShuffle(2) 49 | ) 50 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 51 | 52 | def forward(self, x): 53 | x = self.conv_before_upsample(x) 54 | x = self.conv_last(self.upsample(x)) 55 | return x 56 | 57 | # grid network 58 | class GridNet(nn.Module): 59 | def __init__(self, in_channels, in_channels1, in_channels2, in_channels3, out_channels): 60 | super(GridNet, self).__init__() 61 | 62 | self.residual_model_head = ResidualBlock(in_channels, 32, stride=1) 63 | self.residual_model_head1 = ResidualBlock(in_channels1, 32, stride=1) 64 | self.residual_model_head2 = ResidualBlock(in_channels2, 64, stride=1) 65 | self.residual_model_head3 = ResidualBlock(in_channels3, 96, stride=1) 66 | 67 | self.residual_model_01=ResidualBlock(32, 32, stride=1) 68 | #self.residual_model_02=ResidualBlock(32, 32, stride=1) 69 | #self.residual_model_03=ResidualBlock(32, 32, stride=1) 70 | self.residual_model_04=ResidualBlock(32, 32, stride=1) 71 | self.residual_model_05=ResidualBlock(32, 32, stride=1) 72 | self.residual_model_tail=PixelShuffleBlcok(32, 32, out_channels) 73 | 74 | 75 | self.residual_model_11=ResidualBlock(64, 64, stride=1) 76 | #self.residual_model_12=ResidualBlock(64, 64, stride=1) 77 | #self.residual_model_13=ResidualBlock(64, 64, stride=1) 78 | self.residual_model_14=ResidualBlock(64, 64, stride=1) 79 | self.residual_model_15=ResidualBlock(64, 64, stride=1) 80 | 81 | self.residual_model_21=ResidualBlock(96, 96, stride=1) 82 | #self.residual_model_22=ResidualBlock(96, 96, stride=1) 83 | #self.residual_model_23=ResidualBlock(96, 96, stride=1) 84 | self.residual_model_24=ResidualBlock(96, 96, stride=1) 85 | self.residual_model_25=ResidualBlock(96, 96, stride=1) 86 | 87 | # 88 | 89 | self.downsample_model_10=DownsampleBlock(32, 64, stride=2) 90 | self.downsample_model_20=DownsampleBlock(64, 96, stride=2) 91 | 92 | self.downsample_model_11=DownsampleBlock(32, 64, stride=2) 93 | self.downsample_model_21=DownsampleBlock(64, 96, stride=2) 94 | 95 | #self.downsample_model_12=DownsampleBlock(32, 64, stride=2) 96 | #self.downsample_model_22=DownsampleBlock(64, 96, stride=2) 97 | 98 | # 99 | 100 | #self.upsample_model_03=UpsampleBlock(64, 32, stride=1) 101 | #self.upsample_model_13=UpsampleBlock(96, 64, stride=1) 102 | 103 | self.upsample_model_04=UpsampleBlock(64, 32, stride=1) 104 | self.upsample_model_14=UpsampleBlock(96, 64, stride=1) 105 | 106 | self.upsample_model_05=UpsampleBlock(64, 32, stride=1) 107 | self.upsample_model_15=UpsampleBlock(96, 64, stride=1) 108 | 109 | def forward(self, x, x1, x2, x3): 110 | X00=self.residual_model_head(x) + self.residual_model_head1(x1) #--- 182 ~ 185 111 | # X10 = self.residual_model_head1(x1) 112 | 113 | X01=self.residual_model_01(X00) + X00#--- 208 ~ 211 ,AddBackward1213 114 | 115 | X10=self.downsample_model_10(X00) + self.residual_model_head2(x2) #--- 186 ~ 189 116 | X20=self.downsample_model_20(X10) + self.residual_model_head3(x3) #--- 190 ~ 193 117 | 118 | residual_11=self.residual_model_11(X10) + X10 #201 ~ 204 , sum AddBackward1206 119 | downsample_11=self.downsample_model_11(X01) #214 ~ 217 120 | X11=residual_11 + downsample_11 #--- AddBackward1218 121 | 122 | residual_21=self.residual_model_21(X20) + X20 #194 ~ 197 , sum AddBackward1199 123 | downsample_21=self.downsample_model_21(X11) #219 ~ 222 124 | X21=residual_21 + downsample_21 # AddBackward1223 125 | 126 | 127 | X24=self.residual_model_24(X21) + X21 #--- 224 ~ 227 , AddBackward1229 128 | X25=self.residual_model_25(X24) + X24 #--- 230 ~ 233 , AddBackward1235 129 | 130 | 131 | upsample_14=self.upsample_model_14(X24) #242 ~ 246 132 | residual_14=self.residual_model_14(X11) + X11 #248 ~ 251, AddBackward1253 133 | X14=upsample_14 + residual_14 #--- AddBackward1254 134 | 135 | upsample_04=self.upsample_model_04(X14) #268 ~ 272 136 | residual_04=self.residual_model_04(X01) + X01 #274 ~ 277, AddBackward1279 137 | X04=upsample_04 + residual_04 #--- AddBackward1280 138 | 139 | upsample_15=self.upsample_model_15(X25) #236 ~ 240 140 | residual_15=self.residual_model_15(X14) + X14 #255 ~ 258, AddBackward1260 141 | X15=upsample_15 + residual_15 # AddBackward1261 142 | 143 | upsample_05=self.upsample_model_05(X15) # 262 ~ 266 144 | residual_05=self.residual_model_05(X04) + X04 #281 ~ 284,AddBackward1286 145 | X05=upsample_05 + residual_05 # AddBackward1287 146 | 147 | X_tail=self.residual_model_tail(X05) #288 ~ 291 148 | 149 | return X_tail 150 | 151 | 152 | class FeatureExtractor(nn.Module): 153 | """The quadratic model""" 154 | def __init__(self): 155 | super(FeatureExtractor, self).__init__() 156 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1) 157 | self.prelu1 = nn.PReLU() 158 | self.conv2 = nn.Conv2d(32, 32, 3, padding=1) 159 | self.prelu2 = nn.PReLU() 160 | self.conv3 = nn.Conv2d(32, 64, 3, stride=2, padding=1) 161 | self.prelu3 = nn.PReLU() 162 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1) 163 | self.prelu4 = nn.PReLU() 164 | self.conv5 = nn.Conv2d(64, 96, 3, stride=2, padding=1) 165 | self.prelu5 = nn.PReLU() 166 | self.conv6 = nn.Conv2d(96, 96, 3, padding=1) 167 | self.prelu6 = nn.PReLU() 168 | 169 | def forward(self, x): 170 | x = self.prelu1(self.conv1(x)) 171 | x1 = self.prelu2(self.conv2(x)) 172 | x = self.prelu3(self.conv3(x1)) 173 | x2 = self.prelu4(self.conv4(x)) 174 | x = self.prelu5(self.conv5(x2)) 175 | x3 = self.prelu6(self.conv6(x)) 176 | 177 | return x1, x2, x3 178 | 179 | 180 | class AnimeInterp(nn.Module): 181 | """The quadratic model""" 182 | def __init__(self): 183 | super(AnimeInterp, self).__init__() 184 | self.feat_ext = FeatureExtractor() 185 | self.synnet = GridNet(6, 64, 128, 96*2, 3) 186 | 187 | def dflow(self, flo, target): 188 | tmp = F.interpolate(flo, target.size()[2:4]) 189 | tmp[:, :1] = tmp[:, :1].clone() * tmp.size()[3] / flo.size()[3] 190 | tmp[:, 1:] = tmp[:, 1:].clone() * tmp.size()[2] / flo.size()[2] 191 | 192 | return tmp 193 | 194 | def dmetric(self, metric, target): 195 | tmp = F.interpolate(metric, target.size()[2:4]) 196 | 197 | return tmp 198 | 199 | def forward(self, I1, I2, reuse_things, t): 200 | F12, Z1, feat11, feat12, feat13 = reuse_things[0], reuse_things[2], reuse_things[4][0], reuse_things[4][1], reuse_things[4][2] 201 | F21, Z2, feat21, feat22, feat23 = reuse_things[1], reuse_things[3], reuse_things[5][0], reuse_things[5][1], reuse_things[5][2] 202 | 203 | F1t = t * F12 204 | F2t = (1-t) * F21 205 | 206 | I1 = F.interpolate(I1, scale_factor = 0.5, mode="bilinear", align_corners=False) 207 | I1t = warp(I1, F1t, Z1, strMode='soft') 208 | I2 = F.interpolate(I2, scale_factor = 0.5, mode="bilinear", align_corners=False) 209 | I2t = warp(I2, F2t, Z2, strMode='soft') 210 | 211 | feat1t1 = warp(feat11, F1t, Z1, strMode='soft') 212 | feat2t1 = warp(feat21, F2t, Z2, strMode='soft') 213 | 214 | F1tdd = self.dflow(F1t, feat12) 215 | F2tdd = self.dflow(F2t, feat22) 216 | Z1dd = self.dmetric(Z1, feat12) 217 | Z2dd = self.dmetric(Z2, feat22) 218 | feat1t2 = warp(feat12, F1tdd, Z1dd, strMode='soft') 219 | feat2t2 = warp(feat22, F2tdd, Z2dd, strMode='soft') 220 | 221 | F1tddd = self.dflow(F1t, feat13) 222 | F2tddd = self.dflow(F2t, feat23) 223 | Z1ddd = self.dmetric(Z1, feat13) 224 | Z2ddd = self.dmetric(Z2, feat23) 225 | feat1t3 = warp(feat13, F1tddd, Z1ddd, strMode='soft') 226 | feat2t3 = warp(feat23, F2tddd, Z2ddd, strMode='soft') 227 | 228 | It_warp = self.synnet(torch.cat([I1t, I2t], dim=1), torch.cat([feat1t1, feat2t1], dim=1), torch.cat([feat1t2, feat2t2], dim=1), torch.cat([feat1t3, feat2t3], dim=1)) 229 | 230 | return torch.clamp(It_warp, 0, 1) 231 | 232 | -------------------------------------------------------------------------------- /model/GMFupSS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from model.gmflow.gmflow import GMFlow 7 | from model.MetricNet import MetricNet 8 | from model.FusionNet import AnimeInterp 9 | 10 | device = torch.device("cuda") 11 | 12 | class Model: 13 | def __init__(self): 14 | self.flownet = GMFlow() 15 | self.metricnet = MetricNet() 16 | self.fusionnet = AnimeInterp() 17 | self.version = 3.9 18 | 19 | def eval(self): 20 | self.flownet.eval() 21 | self.metricnet.eval() 22 | self.fusionnet.eval() 23 | 24 | def device(self): 25 | self.flownet.to(device) 26 | self.metricnet.to(device) 27 | self.fusionnet.to(device) 28 | 29 | def load_model(self, path, rank): 30 | def convert(param): 31 | if rank == -1: 32 | return { 33 | k.replace("module.", ""): v 34 | for k, v in param.items() 35 | if "module." in k 36 | } 37 | else: 38 | return param 39 | if rank <= 0: 40 | self.flownet.load_state_dict(torch.load('{}/flownet.pkl'.format(path))) 41 | self.metricnet.load_state_dict(convert(torch.load('{}/metric.pkl'.format(path)))) 42 | self.fusionnet.load_state_dict(convert(torch.load('{}/fusionnet.pkl'.format(path)))) 43 | 44 | def reuse(self, img0, img1, scale): 45 | feat11, feat12, feat13 = self.fusionnet.feat_ext(img0) 46 | feat21, feat22, feat23 = self.fusionnet.feat_ext(img1) 47 | feat_ext0 = [feat11, feat12, feat13] 48 | feat_ext1 = [feat21, feat22, feat23] 49 | 50 | img0 = F.interpolate(img0, scale_factor = 0.5, mode="bilinear", align_corners=False) 51 | img1 = F.interpolate(img1, scale_factor = 0.5, mode="bilinear", align_corners=False) 52 | 53 | if scale != 1.0: 54 | imgf0 = F.interpolate(img0, scale_factor = scale, mode="bilinear", align_corners=False) 55 | imgf1 = F.interpolate(img1, scale_factor = scale, mode="bilinear", align_corners=False) 56 | else: 57 | imgf0 = img0 58 | imgf1 = img1 59 | flow01 = self.flownet(imgf0, imgf1) 60 | flow10 = self.flownet(imgf1, imgf0) 61 | if scale != 1.0: 62 | flow01 = F.interpolate(flow01, scale_factor = 1. / scale, mode="bilinear", align_corners=False) / scale 63 | flow10 = F.interpolate(flow10, scale_factor = 1. / scale, mode="bilinear", align_corners=False) / scale 64 | 65 | metric0, metric1 = self.metricnet(img0, img1, flow01, flow10) 66 | 67 | return flow01, flow10, metric0, metric1, feat_ext0, feat_ext1 68 | 69 | def inference(self, img0, img1, reuse_things, timestep): 70 | out = self.fusionnet(img0, img1, reuse_things, timestep) 71 | return out 72 | -------------------------------------------------------------------------------- /model/MetricNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.gmflow.geometry import forward_backward_consistency_check 6 | 7 | 8 | class MetricNet(nn.Module): 9 | def __init__(self): 10 | super(MetricNet, self).__init__() 11 | self.metric_net = nn.Sequential( 12 | nn.Conv2d(4, 64, 3, 1, 1), 13 | nn.PReLU(64), 14 | nn.Conv2d(64, 64, 3, 1, 1), 15 | nn.PReLU(64), 16 | nn.Conv2d(64, 1, 3, 1, 1) 17 | ) 18 | 19 | def forward(self, img0, img1, flow01, flow10): 20 | fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10) 21 | 22 | metric0 = self.metric_net(torch.cat((img0, fwd_occ.unsqueeze(1)), 1)) 23 | metric1 = self.metric_net(torch.cat((img1, bwd_occ.unsqueeze(1)), 1)) 24 | 25 | return metric0, metric1 26 | -------------------------------------------------------------------------------- /model/gmflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/98mxr/GMFupSS/5b5e7f0f5927e35aef4e807d4f1358df22c365b0/model/gmflow/__init__.py -------------------------------------------------------------------------------- /model/gmflow/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .trident_conv import MultiScaleTridentConv 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, 8 | ): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 12 | dilation=dilation, padding=dilation, stride=stride, bias=False) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | self.norm1 = norm_layer(planes) 18 | self.norm2 = norm_layer(planes) 19 | if not stride == 1 or in_planes != planes: 20 | self.norm3 = norm_layer(planes) 21 | 22 | if stride == 1 and in_planes == planes: 23 | self.downsample = None 24 | else: 25 | self.downsample = nn.Sequential( 26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 27 | 28 | def forward(self, x): 29 | y = x 30 | y = self.relu(self.norm1(self.conv1(y))) 31 | y = self.relu(self.norm2(self.conv2(y))) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | 36 | return self.relu(x + y) 37 | 38 | 39 | class CNNEncoder(nn.Module): 40 | def __init__(self, output_dim=128, 41 | norm_layer=nn.InstanceNorm2d, 42 | num_output_scales=1, 43 | **kwargs, 44 | ): 45 | super(CNNEncoder, self).__init__() 46 | self.num_branch = num_output_scales 47 | 48 | feature_dims = [64, 96, 128] 49 | 50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 51 | self.norm1 = norm_layer(feature_dims[0]) 52 | self.relu1 = nn.ReLU(inplace=True) 53 | 54 | self.in_planes = feature_dims[0] 55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 57 | 58 | # highest resolution 1/4 or 1/8 59 | stride = 2 if num_output_scales == 1 else 1 60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride, 61 | norm_layer=norm_layer, 62 | ) # 1/4 or 1/8 63 | 64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) 65 | 66 | if self.num_branch > 1: 67 | if self.num_branch == 4: 68 | strides = (1, 2, 4, 8) 69 | elif self.num_branch == 3: 70 | strides = (1, 2, 4) 71 | elif self.num_branch == 2: 72 | strides = (1, 2) 73 | else: 74 | raise ValueError 75 | 76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, 77 | kernel_size=3, 78 | strides=strides, 79 | paddings=1, 80 | num_branch=self.num_branch, 81 | ) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 87 | if m.weight is not None: 88 | nn.init.constant_(m.weight, 1) 89 | if m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): 93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) 94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) 95 | 96 | layers = (layer1, layer2) 97 | 98 | self.in_planes = dim 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.norm1(x) 104 | x = self.relu1(x) 105 | 106 | x = self.layer1(x) # 1/2 107 | x = self.layer2(x) # 1/4 108 | x = self.layer3(x) # 1/8 or 1/4 109 | 110 | x = self.conv2(x) 111 | 112 | if self.num_branch > 1: 113 | out = self.trident_conv([x] * self.num_branch) # high to low res 114 | else: 115 | out = [x] 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /model/gmflow/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def coords_grid(b, h, w, homogeneous=False, device=None): 6 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 7 | 8 | stacks = [x, y] 9 | 10 | if homogeneous: 11 | ones = torch.ones_like(x) # [H, W] 12 | stacks.append(ones) 13 | 14 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 15 | 16 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 17 | 18 | if device is not None: 19 | grid = grid.to(device) 20 | 21 | return grid 22 | 23 | 24 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 25 | assert device is not None 26 | 27 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), 28 | torch.linspace(h_min, h_max, len_h, device=device)], 29 | ) 30 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] 31 | 32 | return grid 33 | 34 | 35 | def normalize_coords(coords, h, w): 36 | # coords: [B, H, W, 2] 37 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) 38 | return (coords - c) / c # [-1, 1] 39 | 40 | 41 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): 42 | # img: [B, C, H, W] 43 | # sample_coords: [B, 2, H, W] in image scale 44 | if sample_coords.size(1) != 2: # [B, H, W, 2] 45 | sample_coords = sample_coords.permute(0, 3, 1, 2) 46 | 47 | b, _, h, w = sample_coords.shape 48 | 49 | # Normalize to [-1, 1] 50 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 51 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 52 | 53 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 54 | 55 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 56 | 57 | if return_mask: 58 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] 59 | 60 | return img, mask 61 | 62 | return img 63 | 64 | 65 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): 66 | b, c, h, w = feature.size() 67 | assert flow.size(1) == 2 68 | 69 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 70 | 71 | return bilinear_sample(feature, grid, padding_mode=padding_mode, 72 | return_mask=mask) 73 | 74 | 75 | def forward_backward_consistency_check(fwd_flow, bwd_flow, 76 | alpha=0.01, 77 | beta=0.5 78 | ): 79 | # fwd_flow, bwd_flow: [B, 2, H, W] 80 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 81 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 82 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 83 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 84 | 85 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 86 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 87 | 88 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 89 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 90 | 91 | threshold = alpha * flow_mag + beta 92 | 93 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W] 94 | bwd_occ = (diff_bwd > threshold).float() 95 | 96 | return fwd_occ, bwd_occ 97 | -------------------------------------------------------------------------------- /model/gmflow/gmflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .backbone import CNNEncoder 6 | from .transformer import FeatureTransformer, FeatureFlowAttention 7 | from .matching import global_correlation_softmax, local_correlation_softmax 8 | from .geometry import flow_warp 9 | from .utils import normalize_img, feature_add_position 10 | 11 | 12 | class GMFlow(nn.Module): 13 | def __init__(self, 14 | num_scales=2, 15 | upsample_factor=4, 16 | feature_channels=128, 17 | attention_type='swin', 18 | num_transformer_layers=6, 19 | ffn_dim_expansion=4, 20 | num_head=1, 21 | **kwargs, 22 | ): 23 | super(GMFlow, self).__init__() 24 | 25 | self.num_scales = num_scales 26 | self.feature_channels = feature_channels 27 | self.upsample_factor = upsample_factor 28 | self.attention_type = attention_type 29 | self.num_transformer_layers = num_transformer_layers 30 | 31 | # CNN backbone 32 | self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) 33 | 34 | # Transformer 35 | self.transformer = FeatureTransformer(num_layers=num_transformer_layers, 36 | d_model=feature_channels, 37 | nhead=num_head, 38 | attention_type=attention_type, 39 | ffn_dim_expansion=ffn_dim_expansion, 40 | ) 41 | 42 | # flow propagation with self-attn 43 | self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels) 44 | 45 | # convex upsampling: concat feature0 and flow as input 46 | self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) 49 | 50 | def extract_feature(self, img0, img1): 51 | concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] 52 | features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low 53 | 54 | # reverse: resolution from low to high 55 | features = features[::-1] 56 | 57 | feature0, feature1 = [], [] 58 | 59 | for i in range(len(features)): 60 | feature = features[i] 61 | chunks = torch.chunk(feature, 2, 0) # tuple 62 | feature0.append(chunks[0]) 63 | feature1.append(chunks[1]) 64 | 65 | return feature0, feature1 66 | 67 | def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, 68 | ): 69 | if bilinear: 70 | up_flow = F.interpolate(flow, scale_factor=upsample_factor, 71 | mode='bilinear', align_corners=True) * upsample_factor 72 | 73 | else: 74 | # convex upsampling 75 | concat = torch.cat((flow, feature), dim=1) 76 | 77 | mask = self.upsampler(concat) 78 | b, flow_channel, h, w = flow.shape 79 | mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W] 80 | mask = torch.softmax(mask, dim=2) 81 | 82 | up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) 83 | up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] 84 | 85 | up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] 86 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] 87 | up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h, 88 | self.upsample_factor * w) # [B, 2, K*H, K*W] 89 | 90 | return up_flow 91 | 92 | def forward(self, img0, img1, 93 | attn_splits_list=[2, 8], 94 | corr_radius_list=[-1, 4], 95 | prop_radius_list=[-1, 1], 96 | pred_bidir_flow=False, 97 | **kwargs, 98 | ): 99 | 100 | img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] 101 | 102 | # resolution low to high 103 | feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features 104 | 105 | flow = None 106 | 107 | assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales 108 | 109 | for scale_idx in range(self.num_scales): 110 | feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] 111 | 112 | if pred_bidir_flow and scale_idx > 0: 113 | # predicting bidirectional flow with refinement 114 | feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) 115 | 116 | upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) 117 | 118 | if scale_idx > 0: 119 | flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 120 | 121 | if flow is not None: 122 | flow = flow.detach() 123 | feature1 = flow_warp(feature1, flow) # [B, C, H, W] 124 | 125 | attn_splits = attn_splits_list[scale_idx] 126 | corr_radius = corr_radius_list[scale_idx] 127 | prop_radius = prop_radius_list[scale_idx] 128 | 129 | # add position to features 130 | feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) 131 | 132 | # Transformer 133 | feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits) 134 | 135 | # correlation and softmax 136 | if corr_radius == -1: # global matching 137 | flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] 138 | else: # local matching 139 | flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] 140 | 141 | # flow or residual flow 142 | flow = flow + flow_pred if flow is not None else flow_pred 143 | 144 | # upsample to the original resolution for supervison 145 | if self.training: # only need to upsample intermediate flow predictions at training time 146 | flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor) 147 | 148 | # flow propagation with self-attn 149 | if pred_bidir_flow and scale_idx == 0: 150 | feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation 151 | flow = self.feature_flow_attn(feature0, flow.detach(), 152 | local_window_attn=prop_radius > 0, 153 | local_window_radius=prop_radius) 154 | 155 | # bilinear upsampling at training time except the last one 156 | if self.training and scale_idx < self.num_scales - 1: 157 | flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor) 158 | 159 | if scale_idx == self.num_scales - 1: 160 | flow_up = self.upsample_flow(flow, feature0) 161 | 162 | return flow_up 163 | -------------------------------------------------------------------------------- /model/gmflow/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .geometry import coords_grid, generate_window_grid, normalize_coords 5 | 6 | 7 | def global_correlation_softmax(feature0, feature1, 8 | pred_bidir_flow=False, 9 | ): 10 | # global correlation 11 | b, c, h, w = feature0.shape 12 | feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] 13 | feature1 = feature1.view(b, c, -1) # [B, C, H*W] 14 | 15 | correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] 16 | 17 | # flow from softmax 18 | init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] 19 | grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 20 | 21 | correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] 22 | 23 | if pred_bidir_flow: 24 | correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] 25 | init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] 26 | grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] 27 | b = b * 2 28 | 29 | prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] 30 | 31 | correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 32 | 33 | # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow 34 | flow = correspondence - init_grid 35 | 36 | return flow, prob 37 | 38 | 39 | def local_correlation_softmax(feature0, feature1, local_radius, 40 | padding_mode='zeros', 41 | ): 42 | b, c, h, w = feature0.size() 43 | coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] 44 | coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 45 | 46 | local_h = 2 * local_radius + 1 47 | local_w = 2 * local_radius + 1 48 | 49 | window_grid = generate_window_grid(-local_radius, local_radius, 50 | -local_radius, local_radius, 51 | local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] 52 | window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] 53 | sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] 54 | 55 | sample_coords_softmax = sample_coords 56 | 57 | # exclude coords that are out of image space 58 | valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] 59 | valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] 60 | 61 | valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax 62 | 63 | # normalize coordinates to [-1, 1] 64 | sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] 65 | window_feature = F.grid_sample(feature1, sample_coords_norm, 66 | padding_mode=padding_mode, align_corners=True 67 | ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] 68 | feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] 69 | 70 | corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] 71 | 72 | # mask invalid locations 73 | corr[~valid] = -1e9 74 | 75 | prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] 76 | 77 | correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( 78 | b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 79 | 80 | flow = correspondence - coords_init 81 | match_prob = prob 82 | 83 | return flow, match_prob 84 | -------------------------------------------------------------------------------- /model/gmflow/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | -------------------------------------------------------------------------------- /model/gmflow/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .utils import split_feature, merge_splits 6 | 7 | 8 | def single_head_full_attention(q, k, v): 9 | # q, k, v: [B, L, C] 10 | assert q.dim() == k.dim() == v.dim() == 3 11 | 12 | scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] 13 | attn = torch.softmax(scores, dim=2) # [B, L, L] 14 | out = torch.matmul(attn, v) # [B, L, C] 15 | 16 | return out 17 | 18 | 19 | def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, 20 | shift_size_h, shift_size_w, device=torch.device('cuda')): 21 | # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 22 | # calculate attention mask for SW-MSA 23 | h, w = input_resolution 24 | img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 25 | h_slices = (slice(0, -window_size_h), 26 | slice(-window_size_h, -shift_size_h), 27 | slice(-shift_size_h, None)) 28 | w_slices = (slice(0, -window_size_w), 29 | slice(-window_size_w, -shift_size_w), 30 | slice(-shift_size_w, None)) 31 | cnt = 0 32 | for h in h_slices: 33 | for w in w_slices: 34 | img_mask[:, h, w, :] = cnt 35 | cnt += 1 36 | 37 | mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) 38 | 39 | mask_windows = mask_windows.view(-1, window_size_h * window_size_w) 40 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 41 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 42 | 43 | return attn_mask 44 | 45 | 46 | def single_head_split_window_attention(q, k, v, 47 | num_splits=1, 48 | with_shift=False, 49 | h=None, 50 | w=None, 51 | attn_mask=None, 52 | ): 53 | # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py 54 | # q, k, v: [B, L, C] 55 | assert q.dim() == k.dim() == v.dim() == 3 56 | 57 | assert h is not None and w is not None 58 | assert q.size(1) == h * w 59 | 60 | b, _, c = q.size() 61 | 62 | b_new = b * num_splits * num_splits 63 | 64 | window_size_h = h // num_splits 65 | window_size_w = w // num_splits 66 | 67 | q = q.view(b, h, w, c) # [B, H, W, C] 68 | k = k.view(b, h, w, c) 69 | v = v.view(b, h, w, c) 70 | 71 | scale_factor = c ** 0.5 72 | 73 | if with_shift: 74 | assert attn_mask is not None # compute once 75 | shift_size_h = window_size_h // 2 76 | shift_size_w = window_size_w // 2 77 | 78 | q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 79 | k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 80 | v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) 81 | 82 | q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] 83 | k = split_feature(k, num_splits=num_splits, channel_last=True) 84 | v = split_feature(v, num_splits=num_splits, channel_last=True) 85 | 86 | scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) 87 | ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] 88 | 89 | if with_shift: 90 | scores += attn_mask.repeat(b, 1, 1) 91 | 92 | attn = torch.softmax(scores, dim=-1) 93 | 94 | out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] 95 | 96 | out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), 97 | num_splits=num_splits, channel_last=True) # [B, H, W, C] 98 | 99 | # shift back 100 | if with_shift: 101 | out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) 102 | 103 | out = out.view(b, -1, c) 104 | 105 | return out 106 | 107 | 108 | class TransformerLayer(nn.Module): 109 | def __init__(self, 110 | d_model=256, 111 | nhead=1, 112 | attention_type='swin', 113 | no_ffn=False, 114 | ffn_dim_expansion=4, 115 | with_shift=False, 116 | **kwargs, 117 | ): 118 | super(TransformerLayer, self).__init__() 119 | 120 | self.dim = d_model 121 | self.nhead = nhead 122 | self.attention_type = attention_type 123 | self.no_ffn = no_ffn 124 | 125 | self.with_shift = with_shift 126 | 127 | # multi-head attention 128 | self.q_proj = nn.Linear(d_model, d_model, bias=False) 129 | self.k_proj = nn.Linear(d_model, d_model, bias=False) 130 | self.v_proj = nn.Linear(d_model, d_model, bias=False) 131 | 132 | self.merge = nn.Linear(d_model, d_model, bias=False) 133 | 134 | self.norm1 = nn.LayerNorm(d_model) 135 | 136 | # no ffn after self-attn, with ffn after cross-attn 137 | if not self.no_ffn: 138 | in_channels = d_model * 2 139 | self.mlp = nn.Sequential( 140 | nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), 141 | nn.GELU(), 142 | nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), 143 | ) 144 | 145 | self.norm2 = nn.LayerNorm(d_model) 146 | 147 | def forward(self, source, target, 148 | height=None, 149 | width=None, 150 | shifted_window_attn_mask=None, 151 | attn_num_splits=None, 152 | **kwargs, 153 | ): 154 | # source, target: [B, L, C] 155 | query, key, value = source, target, target 156 | 157 | # single-head attention 158 | query = self.q_proj(query) # [B, L, C] 159 | key = self.k_proj(key) # [B, L, C] 160 | value = self.v_proj(value) # [B, L, C] 161 | 162 | if self.attention_type == 'swin' and attn_num_splits > 1: 163 | if self.nhead > 1: 164 | # we observe that multihead attention slows down the speed and increases the memory consumption 165 | # without bringing obvious performance gains and thus the implementation is removed 166 | raise NotImplementedError 167 | else: 168 | message = single_head_split_window_attention(query, key, value, 169 | num_splits=attn_num_splits, 170 | with_shift=self.with_shift, 171 | h=height, 172 | w=width, 173 | attn_mask=shifted_window_attn_mask, 174 | ) 175 | else: 176 | message = single_head_full_attention(query, key, value) # [B, L, C] 177 | 178 | message = self.merge(message) # [B, L, C] 179 | message = self.norm1(message) 180 | 181 | if not self.no_ffn: 182 | message = self.mlp(torch.cat([source, message], dim=-1)) 183 | message = self.norm2(message) 184 | 185 | return source + message 186 | 187 | 188 | class TransformerBlock(nn.Module): 189 | """self attention + cross attention + FFN""" 190 | 191 | def __init__(self, 192 | d_model=256, 193 | nhead=1, 194 | attention_type='swin', 195 | ffn_dim_expansion=4, 196 | with_shift=False, 197 | **kwargs, 198 | ): 199 | super(TransformerBlock, self).__init__() 200 | 201 | self.self_attn = TransformerLayer(d_model=d_model, 202 | nhead=nhead, 203 | attention_type=attention_type, 204 | no_ffn=True, 205 | ffn_dim_expansion=ffn_dim_expansion, 206 | with_shift=with_shift, 207 | ) 208 | 209 | self.cross_attn_ffn = TransformerLayer(d_model=d_model, 210 | nhead=nhead, 211 | attention_type=attention_type, 212 | ffn_dim_expansion=ffn_dim_expansion, 213 | with_shift=with_shift, 214 | ) 215 | 216 | def forward(self, source, target, 217 | height=None, 218 | width=None, 219 | shifted_window_attn_mask=None, 220 | attn_num_splits=None, 221 | **kwargs, 222 | ): 223 | # source, target: [B, L, C] 224 | 225 | # self attention 226 | source = self.self_attn(source, source, 227 | height=height, 228 | width=width, 229 | shifted_window_attn_mask=shifted_window_attn_mask, 230 | attn_num_splits=attn_num_splits, 231 | ) 232 | 233 | # cross attention and ffn 234 | source = self.cross_attn_ffn(source, target, 235 | height=height, 236 | width=width, 237 | shifted_window_attn_mask=shifted_window_attn_mask, 238 | attn_num_splits=attn_num_splits, 239 | ) 240 | 241 | return source 242 | 243 | 244 | class FeatureTransformer(nn.Module): 245 | def __init__(self, 246 | num_layers=6, 247 | d_model=128, 248 | nhead=1, 249 | attention_type='swin', 250 | ffn_dim_expansion=4, 251 | **kwargs, 252 | ): 253 | super(FeatureTransformer, self).__init__() 254 | 255 | self.attention_type = attention_type 256 | 257 | self.d_model = d_model 258 | self.nhead = nhead 259 | 260 | self.layers = nn.ModuleList([ 261 | TransformerBlock(d_model=d_model, 262 | nhead=nhead, 263 | attention_type=attention_type, 264 | ffn_dim_expansion=ffn_dim_expansion, 265 | with_shift=True if attention_type == 'swin' and i % 2 == 1 else False, 266 | ) 267 | for i in range(num_layers)]) 268 | 269 | for p in self.parameters(): 270 | if p.dim() > 1: 271 | nn.init.xavier_uniform_(p) 272 | 273 | def forward(self, feature0, feature1, 274 | attn_num_splits=None, 275 | **kwargs, 276 | ): 277 | 278 | b, c, h, w = feature0.shape 279 | assert self.d_model == c 280 | 281 | feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 282 | feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] 283 | 284 | if self.attention_type == 'swin' and attn_num_splits > 1: 285 | # global and refine use different number of splits 286 | window_size_h = h // attn_num_splits 287 | window_size_w = w // attn_num_splits 288 | 289 | # compute attn mask once 290 | shifted_window_attn_mask = generate_shift_window_attn_mask( 291 | input_resolution=(h, w), 292 | window_size_h=window_size_h, 293 | window_size_w=window_size_w, 294 | shift_size_h=window_size_h // 2, 295 | shift_size_w=window_size_w // 2, 296 | device=feature0.device, 297 | ) # [K*K, H/K*W/K, H/K*W/K] 298 | else: 299 | shifted_window_attn_mask = None 300 | 301 | # concat feature0 and feature1 in batch dimension to compute in parallel 302 | concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] 303 | concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] 304 | 305 | for layer in self.layers: 306 | concat0 = layer(concat0, concat1, 307 | height=h, 308 | width=w, 309 | shifted_window_attn_mask=shifted_window_attn_mask, 310 | attn_num_splits=attn_num_splits, 311 | ) 312 | 313 | # update feature1 314 | concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) 315 | 316 | feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] 317 | 318 | # reshape back 319 | feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] 320 | feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] 321 | 322 | return feature0, feature1 323 | 324 | 325 | class FeatureFlowAttention(nn.Module): 326 | """ 327 | flow propagation with self-attention on feature 328 | query: feature0, key: feature0, value: flow 329 | """ 330 | 331 | def __init__(self, in_channels, 332 | **kwargs, 333 | ): 334 | super(FeatureFlowAttention, self).__init__() 335 | 336 | self.q_proj = nn.Linear(in_channels, in_channels) 337 | self.k_proj = nn.Linear(in_channels, in_channels) 338 | 339 | for p in self.parameters(): 340 | if p.dim() > 1: 341 | nn.init.xavier_uniform_(p) 342 | 343 | def forward(self, feature0, flow, 344 | local_window_attn=False, 345 | local_window_radius=1, 346 | **kwargs, 347 | ): 348 | # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] 349 | if local_window_attn: 350 | return self.forward_local_window_attn(feature0, flow, 351 | local_window_radius=local_window_radius) 352 | 353 | b, c, h, w = feature0.size() 354 | 355 | query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] 356 | 357 | # a note: the ``correct'' implementation should be: 358 | # ``query = self.q_proj(query), key = self.k_proj(query)'' 359 | # this problem is observed while cleaning up the code 360 | # however, this doesn't affect the performance since the projection is a linear operation, 361 | # thus the two projection matrices for key can be merged 362 | # so I just leave it as is in order to not re-train all models :) 363 | query = self.q_proj(query) # [B, H*W, C] 364 | key = self.k_proj(query) # [B, H*W, C] 365 | 366 | value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] 367 | 368 | scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] 369 | prob = torch.softmax(scores, dim=-1) 370 | 371 | out = torch.matmul(prob, value) # [B, H*W, 2] 372 | out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] 373 | 374 | return out 375 | 376 | def forward_local_window_attn(self, feature0, flow, 377 | local_window_radius=1, 378 | ): 379 | assert flow.size(1) == 2 380 | assert local_window_radius > 0 381 | 382 | b, c, h, w = feature0.size() 383 | 384 | feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) 385 | ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] 386 | 387 | kernel_size = 2 * local_window_radius + 1 388 | 389 | feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) 390 | 391 | feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, 392 | padding=local_window_radius) # [B, C*(2R+1)^2), H*W] 393 | 394 | feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( 395 | 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] 396 | 397 | flow_window = F.unfold(flow, kernel_size=kernel_size, 398 | padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] 399 | 400 | flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute( 401 | 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2] 402 | 403 | scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] 404 | 405 | prob = torch.softmax(scores, dim=-1) 406 | 407 | out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] 408 | 409 | return out 410 | -------------------------------------------------------------------------------- /model/gmflow/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter( 53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 54 | ) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 57 | else: 58 | self.bias = None 59 | 60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0) 63 | 64 | def forward(self, inputs): 65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 66 | assert len(inputs) == num_branch 67 | 68 | if self.training or self.test_branch_idx == -1: 69 | outputs = [ 70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 71 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 72 | ] 73 | else: 74 | outputs = [ 75 | F.conv2d( 76 | inputs[0], 77 | self.weight, 78 | self.bias, 79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 81 | self.dilation, 82 | self.groups, 83 | ) 84 | ] 85 | 86 | if self.norm is not None: 87 | outputs = [self.norm(x) for x in outputs] 88 | if self.activation is not None: 89 | outputs = [self.activation(x) for x in outputs] 90 | return outputs 91 | -------------------------------------------------------------------------------- /model/gmflow/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .position import PositionEmbeddingSine 3 | 4 | 5 | def split_feature(feature, 6 | num_splits=2, 7 | channel_last=False, 8 | ): 9 | if channel_last: # [B, H, W, C] 10 | b, h, w, c = feature.size() 11 | assert h % num_splits == 0 and w % num_splits == 0 12 | 13 | b_new = b * num_splits * num_splits 14 | h_new = h // num_splits 15 | w_new = w // num_splits 16 | 17 | feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c 18 | ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] 19 | else: # [B, C, H, W] 20 | b, c, h, w = feature.size() 21 | assert h % num_splits == 0 and w % num_splits == 0 22 | 23 | b_new = b * num_splits * num_splits 24 | h_new = h // num_splits 25 | w_new = w // num_splits 26 | 27 | feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits 28 | ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] 29 | 30 | return feature 31 | 32 | 33 | def merge_splits(splits, 34 | num_splits=2, 35 | channel_last=False, 36 | ): 37 | if channel_last: # [B*K*K, H/K, W/K, C] 38 | b, h, w, c = splits.size() 39 | new_b = b // num_splits // num_splits 40 | 41 | splits = splits.view(new_b, num_splits, num_splits, h, w, c) 42 | merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( 43 | new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] 44 | else: # [B*K*K, C, H/K, W/K] 45 | b, c, h, w = splits.size() 46 | new_b = b // num_splits // num_splits 47 | 48 | splits = splits.view(new_b, num_splits, num_splits, c, h, w) 49 | merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( 50 | new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] 51 | 52 | return merge 53 | 54 | 55 | def normalize_img(img0, img1): 56 | # loaded images are in [0, 255] 57 | # normalize by ImageNet mean and std 58 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) 59 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) 60 | img0 = (img0 - mean) / std 61 | img1 = (img1 - mean) / std 62 | 63 | return img0, img1 64 | 65 | 66 | def feature_add_position(feature0, feature1, attn_splits, feature_channels): 67 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) 68 | 69 | if attn_splits > 1: # add position in splited window 70 | feature0_splits = split_feature(feature0, num_splits=attn_splits) 71 | feature1_splits = split_feature(feature1, num_splits=attn_splits) 72 | 73 | position = pos_enc(feature0_splits) 74 | 75 | feature0_splits = feature0_splits + position 76 | feature1_splits = feature1_splits + position 77 | 78 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits) 79 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits) 80 | else: 81 | position = pos_enc(feature0) 82 | 83 | feature0 = feature0 + position 84 | feature1 = feature1 + position 85 | 86 | return feature0, feature1 87 | -------------------------------------------------------------------------------- /model/softsplat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import cupy 5 | import os 6 | import re 7 | import torch 8 | import typing 9 | 10 | 11 | ########################################################## 12 | 13 | 14 | objCudacache = {} 15 | 16 | 17 | def cuda_int32(intIn:int): 18 | return cupy.int32(intIn) 19 | # end 20 | 21 | 22 | def cuda_float32(fltIn:float): 23 | return cupy.float32(fltIn) 24 | # end 25 | 26 | 27 | def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): 28 | if 'device' not in objCudacache: 29 | objCudacache['device'] = torch.cuda.get_device_name() 30 | # end 31 | 32 | strKey = strFunction 33 | 34 | for strVariable in objVariables: 35 | objValue = objVariables[strVariable] 36 | 37 | strKey += strVariable 38 | 39 | if objValue is None: 40 | continue 41 | 42 | elif type(objValue) == int: 43 | strKey += str(objValue) 44 | 45 | elif type(objValue) == float: 46 | strKey += str(objValue) 47 | 48 | elif type(objValue) == bool: 49 | strKey += str(objValue) 50 | 51 | elif type(objValue) == str: 52 | strKey += objValue 53 | 54 | elif type(objValue) == torch.Tensor: 55 | strKey += str(objValue.dtype) 56 | strKey += str(objValue.shape) 57 | strKey += str(objValue.stride()) 58 | 59 | elif True: 60 | print(strVariable, type(objValue)) 61 | assert(False) 62 | 63 | # end 64 | # end 65 | 66 | strKey += objCudacache['device'] 67 | 68 | if strKey not in objCudacache: 69 | for strVariable in objVariables: 70 | objValue = objVariables[strVariable] 71 | 72 | if objValue is None: 73 | continue 74 | 75 | elif type(objValue) == int: 76 | strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) 77 | 78 | elif type(objValue) == float: 79 | strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) 80 | 81 | elif type(objValue) == bool: 82 | strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) 83 | 84 | elif type(objValue) == str: 85 | strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) 86 | 87 | elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: 88 | strKernel = strKernel.replace('{{type}}', 'unsigned char') 89 | 90 | elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: 91 | strKernel = strKernel.replace('{{type}}', 'half') 92 | 93 | elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: 94 | strKernel = strKernel.replace('{{type}}', 'float') 95 | 96 | elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: 97 | strKernel = strKernel.replace('{{type}}', 'double') 98 | 99 | elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: 100 | strKernel = strKernel.replace('{{type}}', 'int') 101 | 102 | elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: 103 | strKernel = strKernel.replace('{{type}}', 'long') 104 | 105 | elif type(objValue) == torch.Tensor: 106 | print(strVariable, objValue.dtype) 107 | assert(False) 108 | 109 | elif True: 110 | print(strVariable, type(objValue)) 111 | assert(False) 112 | 113 | # end 114 | # end 115 | 116 | while True: 117 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 118 | 119 | if objMatch is None: 120 | break 121 | # end 122 | 123 | intArg = int(objMatch.group(2)) 124 | 125 | strTensor = objMatch.group(4) 126 | intSizes = objVariables[strTensor].size() 127 | 128 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) 129 | # end 130 | 131 | while True: 132 | objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) 133 | 134 | if objMatch is None: 135 | break 136 | # end 137 | 138 | intStart = objMatch.span()[1] 139 | intStop = objMatch.span()[1] 140 | intParentheses = 1 141 | 142 | while True: 143 | intParentheses += 1 if strKernel[intStop] == '(' else 0 144 | intParentheses -= 1 if strKernel[intStop] == ')' else 0 145 | 146 | if intParentheses == 0: 147 | break 148 | # end 149 | 150 | intStop += 1 151 | # end 152 | 153 | intArgs = int(objMatch.group(2)) 154 | strArgs = strKernel[intStart:intStop].split(',') 155 | 156 | assert(intArgs == len(strArgs) - 1) 157 | 158 | strTensor = strArgs[0] 159 | intStrides = objVariables[strTensor].stride() 160 | 161 | strIndex = [] 162 | 163 | for intArg in range(intArgs): 164 | strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') 165 | # end 166 | 167 | strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') 168 | # end 169 | 170 | while True: 171 | objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) 172 | 173 | if objMatch is None: 174 | break 175 | # end 176 | 177 | intStart = objMatch.span()[1] 178 | intStop = objMatch.span()[1] 179 | intParentheses = 1 180 | 181 | while True: 182 | intParentheses += 1 if strKernel[intStop] == '(' else 0 183 | intParentheses -= 1 if strKernel[intStop] == ')' else 0 184 | 185 | if intParentheses == 0: 186 | break 187 | # end 188 | 189 | intStop += 1 190 | # end 191 | 192 | intArgs = int(objMatch.group(2)) 193 | strArgs = strKernel[intStart:intStop].split(',') 194 | 195 | assert(intArgs == len(strArgs) - 1) 196 | 197 | strTensor = strArgs[0] 198 | intStrides = objVariables[strTensor].stride() 199 | 200 | strIndex = [] 201 | 202 | for intArg in range(intArgs): 203 | strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') 204 | # end 205 | 206 | strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') 207 | # end 208 | 209 | objCudacache[strKey] = { 210 | 'strFunction': strFunction, 211 | 'strKernel': strKernel 212 | } 213 | # end 214 | 215 | return strKey 216 | # end 217 | 218 | 219 | @cupy.memoize(for_each_device=True) 220 | def cuda_launch(strKey:str): 221 | if 'CUDA_HOME' not in os.environ: 222 | os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() 223 | # end 224 | 225 | return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) 226 | # end 227 | 228 | 229 | ########################################################## 230 | 231 | 232 | def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str): 233 | assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) 234 | 235 | if strMode == 'sum': assert(tenMetric is None) 236 | if strMode == 'avg': assert(tenMetric is None) 237 | if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) 238 | if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) 239 | 240 | if strMode == 'avg': 241 | tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) 242 | 243 | elif strMode.split('-')[0] == 'linear': 244 | tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) 245 | 246 | elif strMode.split('-')[0] == 'soft': 247 | tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) 248 | 249 | # end 250 | 251 | tenOut = softsplat_func.apply(tenIn, tenFlow) 252 | 253 | if strMode.split('-')[0] in ['avg', 'linear', 'soft']: 254 | tenNormalize = tenOut[:, -1:, :, :] 255 | 256 | if len(strMode.split('-')) == 1: 257 | tenNormalize = tenNormalize + 0.0000001 258 | 259 | elif strMode.split('-')[1] == 'addeps': 260 | tenNormalize = tenNormalize + 0.0000001 261 | 262 | elif strMode.split('-')[1] == 'zeroeps': 263 | tenNormalize[tenNormalize == 0.0] = 1.0 264 | 265 | elif strMode.split('-')[1] == 'clipeps': 266 | tenNormalize = tenNormalize.clip(0.0000001, None) 267 | 268 | # end 269 | 270 | tenOut = tenOut[:, :-1, :, :] / tenNormalize 271 | # end 272 | 273 | return tenOut 274 | # end 275 | 276 | 277 | class softsplat_func(torch.autograd.Function): 278 | @staticmethod 279 | @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 280 | def forward(self, tenIn, tenFlow): 281 | tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) 282 | 283 | if tenIn.is_cuda == True: 284 | cuda_launch(cuda_kernel('softsplat_out', ''' 285 | extern "C" __global__ void __launch_bounds__(512) softsplat_out( 286 | const int n, 287 | const {{type}}* __restrict__ tenIn, 288 | const {{type}}* __restrict__ tenFlow, 289 | {{type}}* __restrict__ tenOut 290 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 291 | const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); 292 | const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); 293 | const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); 294 | const int intX = ( intIndex ) % SIZE_3(tenOut); 295 | 296 | assert(SIZE_1(tenFlow) == 2); 297 | 298 | {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); 299 | {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); 300 | 301 | if (isfinite(fltX) == false) { return; } 302 | if (isfinite(fltY) == false) { return; } 303 | 304 | {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); 305 | 306 | int intNorthwestX = (int) (floor(fltX)); 307 | int intNorthwestY = (int) (floor(fltY)); 308 | int intNortheastX = intNorthwestX + 1; 309 | int intNortheastY = intNorthwestY; 310 | int intSouthwestX = intNorthwestX; 311 | int intSouthwestY = intNorthwestY + 1; 312 | int intSoutheastX = intNorthwestX + 1; 313 | int intSoutheastY = intNorthwestY + 1; 314 | 315 | {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); 316 | {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); 317 | {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); 318 | {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); 319 | 320 | if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { 321 | atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); 322 | } 323 | 324 | if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { 325 | atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); 326 | } 327 | 328 | if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { 329 | atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); 330 | } 331 | 332 | if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { 333 | atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); 334 | } 335 | } } 336 | ''', { 337 | 'tenIn': tenIn, 338 | 'tenFlow': tenFlow, 339 | 'tenOut': tenOut 340 | }))( 341 | grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), 342 | block=tuple([512, 1, 1]), 343 | args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], 344 | stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) 345 | ) 346 | 347 | elif tenIn.is_cuda != True: 348 | assert(False) 349 | 350 | # end 351 | 352 | self.save_for_backward(tenIn, tenFlow) 353 | 354 | return tenOut 355 | # end 356 | 357 | @staticmethod 358 | @torch.cuda.amp.custom_bwd 359 | def backward(self, tenOutgrad): 360 | tenIn, tenFlow = self.saved_tensors 361 | 362 | tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) 363 | 364 | tenIngrad = tenIn.new_empty([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None 365 | tenFlowgrad = tenFlow.new_empty([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None 366 | 367 | if tenIngrad is not None: 368 | cuda_launch(cuda_kernel('softsplat_ingrad', ''' 369 | extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( 370 | const int n, 371 | const {{type}}* __restrict__ tenIn, 372 | const {{type}}* __restrict__ tenFlow, 373 | const {{type}}* __restrict__ tenOutgrad, 374 | {{type}}* __restrict__ tenIngrad, 375 | {{type}}* __restrict__ tenFlowgrad 376 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 377 | const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); 378 | const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); 379 | const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); 380 | const int intX = ( intIndex ) % SIZE_3(tenIngrad); 381 | 382 | assert(SIZE_1(tenFlow) == 2); 383 | 384 | {{type}} fltIngrad = 0.0f; 385 | 386 | {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); 387 | {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); 388 | 389 | if (isfinite(fltX) == false) { return; } 390 | if (isfinite(fltY) == false) { return; } 391 | 392 | int intNorthwestX = (int) (floor(fltX)); 393 | int intNorthwestY = (int) (floor(fltY)); 394 | int intNortheastX = intNorthwestX + 1; 395 | int intNortheastY = intNorthwestY; 396 | int intSouthwestX = intNorthwestX; 397 | int intSouthwestY = intNorthwestY + 1; 398 | int intSoutheastX = intNorthwestX + 1; 399 | int intSoutheastY = intNorthwestY + 1; 400 | 401 | {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); 402 | {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); 403 | {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); 404 | {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); 405 | 406 | if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { 407 | fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; 408 | } 409 | 410 | if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { 411 | fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; 412 | } 413 | 414 | if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { 415 | fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; 416 | } 417 | 418 | if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { 419 | fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; 420 | } 421 | 422 | tenIngrad[intIndex] = fltIngrad; 423 | } } 424 | ''', { 425 | 'tenIn': tenIn, 426 | 'tenFlow': tenFlow, 427 | 'tenOutgrad': tenOutgrad, 428 | 'tenIngrad': tenIngrad, 429 | 'tenFlowgrad': tenFlowgrad 430 | }))( 431 | grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), 432 | block=tuple([512, 1, 1]), 433 | args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], 434 | stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) 435 | ) 436 | # end 437 | 438 | if tenFlowgrad is not None: 439 | cuda_launch(cuda_kernel('softsplat_flowgrad', ''' 440 | extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( 441 | const int n, 442 | const {{type}}* __restrict__ tenIn, 443 | const {{type}}* __restrict__ tenFlow, 444 | const {{type}}* __restrict__ tenOutgrad, 445 | {{type}}* __restrict__ tenIngrad, 446 | {{type}}* __restrict__ tenFlowgrad 447 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 448 | const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); 449 | const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); 450 | const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); 451 | const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); 452 | 453 | assert(SIZE_1(tenFlow) == 2); 454 | 455 | {{type}} fltFlowgrad = 0.0f; 456 | 457 | {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); 458 | {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); 459 | 460 | if (isfinite(fltX) == false) { return; } 461 | if (isfinite(fltY) == false) { return; } 462 | 463 | int intNorthwestX = (int) (floor(fltX)); 464 | int intNorthwestY = (int) (floor(fltY)); 465 | int intNortheastX = intNorthwestX + 1; 466 | int intNortheastY = intNorthwestY; 467 | int intSouthwestX = intNorthwestX; 468 | int intSouthwestY = intNorthwestY + 1; 469 | int intSoutheastX = intNorthwestX + 1; 470 | int intSoutheastY = intNorthwestY + 1; 471 | 472 | {{type}} fltNorthwest = 0.0f; 473 | {{type}} fltNortheast = 0.0f; 474 | {{type}} fltSouthwest = 0.0f; 475 | {{type}} fltSoutheast = 0.0f; 476 | 477 | if (intC == 0) { 478 | fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); 479 | fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); 480 | fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); 481 | fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); 482 | 483 | } else if (intC == 1) { 484 | fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); 485 | fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); 486 | fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); 487 | fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); 488 | 489 | } 490 | 491 | for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { 492 | {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); 493 | 494 | if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { 495 | fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; 496 | } 497 | 498 | if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { 499 | fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; 500 | } 501 | 502 | if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { 503 | fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; 504 | } 505 | 506 | if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { 507 | fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; 508 | } 509 | } 510 | 511 | tenFlowgrad[intIndex] = fltFlowgrad; 512 | } } 513 | ''', { 514 | 'tenIn': tenIn, 515 | 'tenFlow': tenFlow, 516 | 'tenOutgrad': tenOutgrad, 517 | 'tenIngrad': tenIngrad, 518 | 'tenFlowgrad': tenFlowgrad 519 | }))( 520 | grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), 521 | block=tuple([512, 1, 1]), 522 | args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], 523 | stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) 524 | ) 525 | # end 526 | 527 | return tenIngrad, tenFlowgrad 528 | # end 529 | # end 530 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16 2 | tqdm>=4.35.0 3 | sk-video>=1.1.10 4 | torch>=1.3.0 5 | opencv-python>=4.1.2 6 | moviepy>=1.0.3 7 | torchvision==0.7.0 8 | -------------------------------------------------------------------------------- /train_log/flownet.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/98mxr/GMFupSS/5b5e7f0f5927e35aef4e807d4f1358df22c365b0/train_log/flownet.pkl -------------------------------------------------------------------------------- /train_log/fusionnet.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/98mxr/GMFupSS/5b5e7f0f5927e35aef4e807d4f1358df22c365b0/train_log/fusionnet.pkl -------------------------------------------------------------------------------- /train_log/metric.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/98mxr/GMFupSS/5b5e7f0f5927e35aef4e807d4f1358df22c365b0/train_log/metric.pkl --------------------------------------------------------------------------------