├── vsrife ├── models │ ├── flownet_v4.0.pkl │ ├── flownet_v4.1.pkl │ ├── flownet_v4.10.pkl │ ├── flownet_v4.11.pkl │ ├── flownet_v4.12.pkl │ ├── flownet_v4.13.pkl │ ├── flownet_v4.14.pkl │ ├── flownet_v4.15.pkl │ ├── flownet_v4.17.pkl │ ├── flownet_v4.18.pkl │ ├── flownet_v4.19.pkl │ ├── flownet_v4.2.pkl │ ├── flownet_v4.20.pkl │ ├── flownet_v4.21.pkl │ ├── flownet_v4.22.pkl │ ├── flownet_v4.23.pkl │ ├── flownet_v4.24.pkl │ ├── flownet_v4.25.pkl │ ├── flownet_v4.26.pkl │ ├── flownet_v4.3.pkl │ ├── flownet_v4.4.pkl │ ├── flownet_v4.5.pkl │ ├── flownet_v4.6.pkl │ ├── flownet_v4.7.pkl │ ├── flownet_v4.8.pkl │ ├── flownet_v4.9.pkl │ ├── flownet_v4.12.lite.pkl │ ├── flownet_v4.13.lite.pkl │ ├── flownet_v4.14.lite.pkl │ ├── flownet_v4.15.lite.pkl │ ├── flownet_v4.16.lite.pkl │ ├── flownet_v4.17.lite.pkl │ ├── flownet_v4.22.lite.pkl │ ├── flownet_v4.25.heavy.pkl │ ├── flownet_v4.25.lite.pkl │ └── flownet_v4.26.heavy.pkl ├── warplayer.py ├── __main__.py ├── IFNet_HDv3_v4_2.py ├── IFNet_HDv3_v4_3.py ├── IFNet_HDv3_v4_4.py ├── IFNet_HDv3_v4_5.py ├── IFNet_HDv3_v4_6.py ├── IFNet_HDv3_v4_0.py ├── IFNet_HDv3_v4_1.py ├── IFNet_HDv3_v4_7.py ├── IFNet_HDv3_v4_8.py ├── IFNet_HDv3_v4_9.py ├── IFNet_HDv3_v4_21.py ├── IFNet_HDv3_v4_22.py ├── IFNet_HDv3_v4_23.py ├── IFNet_HDv3_v4_22_lite.py ├── IFNet_HDv3_v4_25.py ├── IFNet_HDv3_v4_26.py ├── IFNet_HDv3_v4_25_lite.py ├── IFNet_HDv3_v4_26_heavy.py ├── IFNet_HDv3_v4_25_heavy.py ├── IFNet_HDv3_v4_10.py ├── IFNet_HDv3_v4_11.py ├── IFNet_HDv3_v4_12.py ├── IFNet_HDv3_v4_12_lite.py ├── IFNet_HDv3_v4_13_lite.py ├── IFNet_HDv3_v4_13.py ├── IFNet_HDv3_v4_14.py ├── IFNet_HDv3_v4_15.py ├── IFNet_HDv3_v4_15_lite.py ├── IFNet_HDv3_v4_16_lite.py ├── IFNet_HDv3_v4_17.py ├── IFNet_HDv3_v4_17_lite.py ├── IFNet_HDv3_v4_18.py ├── IFNet_HDv3_v4_19.py ├── IFNet_HDv3_v4_20.py ├── IFNet_HDv3_v4_24.py └── IFNet_HDv3_v4_14_lite.py ├── .gitattributes ├── pyproject.toml ├── LICENSE ├── README.md └── .gitignore /vsrife/models/flownet_v4.0.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.1.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.10.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.11.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.12.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.13.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.14.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.15.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.17.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.18.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.19.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.2.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.20.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.21.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.22.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.23.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.24.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.25.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.26.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.3.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.4.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.5.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.6.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.7.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.8.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.9.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.12.lite.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.13.lite.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.14.lite.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.15.lite.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.16.lite.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.17.lite.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.22.lite.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.25.heavy.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.25.lite.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vsrife/models/flownet_v4.26.heavy.pkl: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /vsrife/warplayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def warp(tenInput, tenFlow, tenFlow_div, backwarp_tenGrid): 6 | dtype = tenInput.dtype 7 | tenInput = tenInput.to(torch.float) 8 | tenFlow = tenFlow.to(torch.float) 9 | 10 | tenFlow = torch.cat([tenFlow[:, 0:1] / tenFlow_div[0], tenFlow[:, 1:2] / tenFlow_div[1]], 1) 11 | g = (backwarp_tenGrid + tenFlow).permute(0, 2, 3, 1) 12 | return F.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True).to(dtype) 13 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "vsrife" 7 | version = "5.6.0" 8 | description = "RIFE function for VapourSynth" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | license-files = { paths = ["LICENSE"] } 12 | authors = [ 13 | { name = "HolyWu", email = "holywu@gmail.com" }, 14 | ] 15 | keywords = [ 16 | "PyTorch", 17 | "RIFE", 18 | "TensorRT", 19 | "VapourSynth", 20 | ] 21 | classifiers = [ 22 | "Environment :: GPU :: NVIDIA CUDA", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | "Programming Language :: Python :: 3", 26 | "Topic :: Multimedia :: Video", 27 | ] 28 | dependencies = [ 29 | "numpy", 30 | "requests", 31 | "torch>=2.6.0", 32 | "tqdm", 33 | "VapourSynth>=66", 34 | ] 35 | 36 | [project.urls] 37 | Homepage = "https://github.com/HolyWu/vs-rife" 38 | Issues = "https://github.com/HolyWu/vs-rife/issues" 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 HolyWu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RIFE 2 | Real-Time Intermediate Flow Estimation for Video Frame Interpolation, based on https://github.com/hzwer/Practical-RIFE. 3 | 4 | 5 | ## Dependencies 6 | - [PyTorch](https://pytorch.org/get-started/) 2.6.0 or later 7 | - [VapourSynth](http://www.vapoursynth.com/) R66 or later 8 | - [vs-miscfilters-obsolete](https://github.com/vapoursynth/vs-miscfilters-obsolete) (only needed for scene change detection) 9 | 10 | `trt` requires additional packages: 11 | - [TensorRT](https://developer.nvidia.com/tensorrt) 10.7.0.post1 or later 12 | - [Torch-TensorRT](https://pytorch.org/TensorRT/) 2.6.0 or later 13 | 14 | To install the latest stable version of PyTorch and Torch-TensorRT, run: 15 | ``` 16 | pip install -U packaging setuptools wheel 17 | pip install -U torch torchvision torch_tensorrt --index-url https://download.pytorch.org/whl/cu126 --extra-index-url https://pypi.nvidia.com 18 | ``` 19 | 20 | 21 | ## Installation 22 | ``` 23 | pip install -U vsrife 24 | ``` 25 | 26 | If you want to download all models at once, run `python -m vsrife`. If you prefer to only download the model you 27 | specified at first run, set `auto_download=True` in `rife()`. 28 | 29 | 30 | ## Usage 31 | ```python 32 | from vsrife import rife 33 | 34 | ret = rife(clip) 35 | ``` 36 | 37 | See `__init__.py` for the description of the parameters. 38 | -------------------------------------------------------------------------------- /vsrife/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | from tqdm import tqdm 5 | 6 | 7 | def download_model(url: str) -> None: 8 | filename = url.split("/")[-1] 9 | r = requests.get(url, stream=True) 10 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "models", filename), "wb") as f: 11 | with tqdm( 12 | unit="B", 13 | unit_scale=True, 14 | unit_divisor=1024, 15 | miniters=1, 16 | desc=filename, 17 | total=int(r.headers.get("content-length", 0)), 18 | ) as pbar: 19 | for chunk in r.iter_content(chunk_size=4096): 20 | f.write(chunk) 21 | pbar.update(len(chunk)) 22 | 23 | 24 | if __name__ == "__main__": 25 | url = "https://github.com/HolyWu/vs-rife/releases/download/model/" 26 | models = [ 27 | "flownet_v4.0", 28 | "flownet_v4.1", 29 | "flownet_v4.2", 30 | "flownet_v4.3", 31 | "flownet_v4.4", 32 | "flownet_v4.5", 33 | "flownet_v4.6", 34 | "flownet_v4.7", 35 | "flownet_v4.8", 36 | "flownet_v4.9", 37 | "flownet_v4.10", 38 | "flownet_v4.11", 39 | "flownet_v4.12", 40 | "flownet_v4.12.lite", 41 | "flownet_v4.13", 42 | "flownet_v4.13.lite", 43 | "flownet_v4.14", 44 | "flownet_v4.14.lite", 45 | "flownet_v4.15", 46 | "flownet_v4.15.lite", 47 | "flownet_v4.16.lite", 48 | "flownet_v4.17", 49 | "flownet_v4.17.lite", 50 | "flownet_v4.18", 51 | "flownet_v4.19", 52 | "flownet_v4.20", 53 | "flownet_v4.21", 54 | "flownet_v4.22", 55 | "flownet_v4.22.lite", 56 | "flownet_v4.23", 57 | "flownet_v4.24", 58 | "flownet_v4.25", 59 | "flownet_v4.25.lite", 60 | "flownet_v4.25.heavy", 61 | "flownet_v4.26", 62 | "flownet_v4.26.heavy", 63 | ] 64 | for model in models: 65 | download_model(url + model + ".pkl") 66 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class IFBlock(nn.Module): 18 | def __init__(self, in_planes, c=64): 19 | super(IFBlock, self).__init__() 20 | self.conv0 = nn.Sequential( 21 | conv(in_planes, c // 2, 3, 2, 1), 22 | conv(c // 2, c, 3, 2, 1), 23 | ) 24 | self.convblock = nn.Sequential( 25 | conv(c, c), 26 | conv(c, c), 27 | conv(c, c), 28 | conv(c, c), 29 | conv(c, c), 30 | conv(c, c), 31 | conv(c, c), 32 | conv(c, c), 33 | ) 34 | self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) 35 | 36 | def forward(self, x, flow=None, scale=1): 37 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 38 | if flow is not None: 39 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 40 | x = torch.cat((x, flow), 1) 41 | feat = self.conv0(x) 42 | feat = self.convblock(feat) 43 | tmp = self.lastconv(feat) 44 | tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear") 45 | flow = tmp[:, :4] * scale * 2 46 | mask = tmp[:, 4:5] 47 | return flow, mask 48 | 49 | 50 | class IFNet(nn.Module): 51 | def __init__(self, scale=1, ensemble=False): 52 | super(IFNet, self).__init__() 53 | self.block0 = IFBlock(7, c=192) 54 | self.block1 = IFBlock(8 + 4, c=128) 55 | self.block2 = IFBlock(8 + 4, c=96) 56 | self.block3 = IFBlock(8 + 4, c=64) 57 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 58 | self.ensemble = ensemble 59 | 60 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): 61 | img0 = img0.clamp(0.0, 1.0) 62 | img1 = img1.clamp(0.0, 1.0) 63 | flow_list = [] 64 | merged = [] 65 | mask_list = [] 66 | warped_img0 = img0 67 | warped_img1 = img1 68 | flow = None 69 | mask = None 70 | block = [self.block0, self.block1, self.block2, self.block3] 71 | for i in range(4): 72 | if flow is None: 73 | flow, mask = block[i](torch.cat((img0, img1, timestep), 1), None, scale=self.scale_list[i]) 74 | if self.ensemble: 75 | f1, m1 = block[i](torch.cat((img1, img0, 1 - timestep), 1), None, scale=self.scale_list[i]) 76 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 77 | mask = (mask + (-m1)) / 2 78 | else: 79 | f0, m0 = block[i]( 80 | torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=self.scale_list[i] 81 | ) 82 | if self.ensemble: 83 | f1, m1 = block[i]( 84 | torch.cat((warped_img1, warped_img0, 1 - timestep, -mask), 1), 85 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 86 | scale=self.scale_list[i], 87 | ) 88 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 89 | m0 = (m0 + (-m1)) / 2 90 | flow = flow + f0 91 | mask = mask + m0 92 | mask_list.append(mask) 93 | flow_list.append(flow) 94 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 95 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 96 | merged.append((warped_img0, warped_img1)) 97 | mask_list[3] = torch.sigmoid(mask_list[3]) 98 | return merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) 99 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class IFBlock(nn.Module): 18 | def __init__(self, in_planes, c=64): 19 | super(IFBlock, self).__init__() 20 | self.conv0 = nn.Sequential( 21 | conv(in_planes, c // 2, 3, 2, 1), 22 | conv(c // 2, c, 3, 2, 1), 23 | ) 24 | self.convblock = nn.Sequential( 25 | conv(c, c), 26 | conv(c, c), 27 | conv(c, c), 28 | conv(c, c), 29 | conv(c, c), 30 | conv(c, c), 31 | conv(c, c), 32 | conv(c, c), 33 | ) 34 | self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) 35 | 36 | def forward(self, x, flow=None, scale=1): 37 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 38 | if flow is not None: 39 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 40 | x = torch.cat((x, flow), 1) 41 | feat = self.conv0(x) 42 | feat = self.convblock(feat) 43 | tmp = self.lastconv(feat) 44 | tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear") 45 | flow = tmp[:, :4] * scale * 2 46 | mask = tmp[:, 4:5] 47 | return flow, mask 48 | 49 | 50 | class IFNet(nn.Module): 51 | def __init__(self, scale=1, ensemble=False): 52 | super(IFNet, self).__init__() 53 | self.block0 = IFBlock(7, c=192) 54 | self.block1 = IFBlock(8 + 4, c=128) 55 | self.block2 = IFBlock(8 + 4, c=96) 56 | self.block3 = IFBlock(8 + 4, c=64) 57 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 58 | self.ensemble = ensemble 59 | 60 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): 61 | img0 = img0.clamp(0.0, 1.0) 62 | img1 = img1.clamp(0.0, 1.0) 63 | flow_list = [] 64 | merged = [] 65 | mask_list = [] 66 | warped_img0 = img0 67 | warped_img1 = img1 68 | flow = None 69 | mask = None 70 | block = [self.block0, self.block1, self.block2, self.block3] 71 | for i in range(4): 72 | if flow is None: 73 | flow, mask = block[i](torch.cat((img0, img1, timestep), 1), None, scale=self.scale_list[i]) 74 | if self.ensemble: 75 | f1, m1 = block[i](torch.cat((img1, img0, 1 - timestep), 1), None, scale=self.scale_list[i]) 76 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 77 | mask = (mask + (-m1)) / 2 78 | else: 79 | f0, m0 = block[i]( 80 | torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=self.scale_list[i] 81 | ) 82 | if self.ensemble: 83 | f1, m1 = block[i]( 84 | torch.cat((warped_img1, warped_img0, 1 - timestep, -mask), 1), 85 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 86 | scale=self.scale_list[i], 87 | ) 88 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 89 | m0 = (m0 + (-m1)) / 2 90 | flow = flow + f0 91 | mask = mask + m0 92 | mask_list.append(mask) 93 | flow_list.append(flow) 94 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 95 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 96 | merged.append((warped_img0, warped_img1)) 97 | mask_list[3] = torch.sigmoid(mask_list[3]) 98 | return merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) 99 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class IFBlock(nn.Module): 18 | def __init__(self, in_planes, c=64): 19 | super(IFBlock, self).__init__() 20 | self.conv0 = nn.Sequential( 21 | conv(in_planes, c // 2, 3, 2, 1), 22 | conv(c // 2, c, 3, 2, 1), 23 | ) 24 | self.convblock = nn.Sequential( 25 | conv(c, c), 26 | conv(c, c), 27 | conv(c, c), 28 | conv(c, c), 29 | conv(c, c), 30 | conv(c, c), 31 | conv(c, c), 32 | conv(c, c), 33 | ) 34 | self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) 35 | 36 | def forward(self, x, flow=None, scale=1): 37 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 38 | if flow is not None: 39 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 40 | x = torch.cat((x, flow), 1) 41 | feat = self.conv0(x) 42 | feat = self.convblock(feat) 43 | tmp = self.lastconv(feat) 44 | tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear") 45 | flow = tmp[:, :4] * scale * 2 46 | mask = tmp[:, 4:5] 47 | return flow, mask 48 | 49 | 50 | class IFNet(nn.Module): 51 | def __init__(self, scale=1, ensemble=False): 52 | super(IFNet, self).__init__() 53 | self.block0 = IFBlock(7, c=192) 54 | self.block1 = IFBlock(8 + 4, c=128) 55 | self.block2 = IFBlock(8 + 4, c=96) 56 | self.block3 = IFBlock(8 + 4, c=64) 57 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 58 | self.ensemble = ensemble 59 | 60 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): 61 | img0 = img0.clamp(0.0, 1.0) 62 | img1 = img1.clamp(0.0, 1.0) 63 | flow_list = [] 64 | merged = [] 65 | mask_list = [] 66 | warped_img0 = img0 67 | warped_img1 = img1 68 | flow = None 69 | mask = None 70 | block = [self.block0, self.block1, self.block2, self.block3] 71 | for i in range(4): 72 | if flow is None: 73 | flow, mask = block[i](torch.cat((img0, img1, timestep), 1), None, scale=self.scale_list[i]) 74 | if self.ensemble: 75 | f1, m1 = block[i](torch.cat((img1, img0, 1 - timestep), 1), None, scale=self.scale_list[i]) 76 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 77 | mask = (mask + (-m1)) / 2 78 | else: 79 | f0, m0 = block[i]( 80 | torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=self.scale_list[i] 81 | ) 82 | if self.ensemble: 83 | f1, m1 = block[i]( 84 | torch.cat((warped_img1, warped_img0, 1 - timestep, -mask), 1), 85 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 86 | scale=self.scale_list[i], 87 | ) 88 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 89 | m0 = (m0 + (-m1)) / 2 90 | flow = flow + f0 91 | mask = mask + m0 92 | mask_list.append(mask) 93 | flow_list.append(flow) 94 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 95 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 96 | merged.append((warped_img0, warped_img1)) 97 | mask_list[3] = torch.sigmoid(mask_list[3]) 98 | return merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 5, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7, c=192) 65 | self.block1 = IFBlock(8 + 4, c=128) 66 | self.block2 = IFBlock(8 + 4, c=96) 67 | self.block3 = IFBlock(8 + 4, c=64) 68 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 69 | self.ensemble = ensemble 70 | 71 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): 72 | img0 = img0.clamp(0.0, 1.0) 73 | img1 = img1.clamp(0.0, 1.0) 74 | flow_list = [] 75 | merged = [] 76 | mask_list = [] 77 | warped_img0 = img0 78 | warped_img1 = img1 79 | flow = None 80 | mask = None 81 | block = [self.block0, self.block1, self.block2, self.block3] 82 | for i in range(4): 83 | if flow is None: 84 | flow, mask = block[i](torch.cat((img0, img1, timestep), 1), None, scale=self.scale_list[i]) 85 | if self.ensemble: 86 | f1, m1 = block[i](torch.cat((img1, img0, 1 - timestep), 1), None, scale=self.scale_list[i]) 87 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 88 | mask = (mask + (-m1)) / 2 89 | else: 90 | f0, m0 = block[i]( 91 | torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=self.scale_list[i] 92 | ) 93 | if self.ensemble: 94 | f1, m1 = block[i]( 95 | torch.cat((warped_img1, warped_img0, 1 - timestep, -mask), 1), 96 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 97 | scale=self.scale_list[i], 98 | ) 99 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 100 | m0 = (m0 + (-m1)) / 2 101 | flow = flow + f0 102 | mask = mask + m0 103 | mask_list.append(mask) 104 | flow_list.append(flow) 105 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 106 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 107 | merged.append((warped_img0, warped_img1)) 108 | mask_list[3] = torch.sigmoid(mask_list[3]) 109 | return merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) 110 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_6.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7, c=192) 65 | self.block1 = IFBlock(8 + 4, c=128) 66 | self.block2 = IFBlock(8 + 4, c=96) 67 | self.block3 = IFBlock(8 + 4, c=64) 68 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 69 | self.ensemble = ensemble 70 | 71 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): 72 | img0 = img0.clamp(0.0, 1.0) 73 | img1 = img1.clamp(0.0, 1.0) 74 | flow_list = [] 75 | merged = [] 76 | mask_list = [] 77 | warped_img0 = img0 78 | warped_img1 = img1 79 | flow = None 80 | mask = None 81 | block = [self.block0, self.block1, self.block2, self.block3] 82 | for i in range(4): 83 | if flow is None: 84 | flow, mask = block[i](torch.cat((img0, img1, timestep), 1), None, scale=self.scale_list[i]) 85 | if self.ensemble: 86 | f1, m1 = block[i](torch.cat((img1, img0, 1 - timestep), 1), None, scale=self.scale_list[i]) 87 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 88 | mask = (mask + (-m1)) / 2 89 | else: 90 | f0, m0 = block[i]( 91 | torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=self.scale_list[i] 92 | ) 93 | if self.ensemble: 94 | f1, m1 = block[i]( 95 | torch.cat((warped_img1, warped_img0, 1 - timestep, -mask), 1), 96 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 97 | scale=self.scale_list[i], 98 | ) 99 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 100 | m0 = (m0 + (-m1)) / 2 101 | flow = flow + f0 102 | mask = mask + m0 103 | mask_list.append(mask) 104 | flow_list.append(flow) 105 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 106 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 107 | merged.append((warped_img0, warped_img1)) 108 | mask_list[3] = torch.sigmoid(mask_list[3]) 109 | return merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) 110 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.PReLU(out_planes), 14 | ) 15 | 16 | 17 | class IFBlock(nn.Module): 18 | def __init__(self, in_planes, c=64): 19 | super(IFBlock, self).__init__() 20 | self.conv0 = nn.Sequential( 21 | conv(in_planes, c // 2, 3, 2, 1), 22 | conv(c // 2, c, 3, 2, 1), 23 | ) 24 | self.convblock = nn.Sequential( 25 | conv(c, c), 26 | conv(c, c), 27 | conv(c, c), 28 | conv(c, c), 29 | conv(c, c), 30 | conv(c, c), 31 | conv(c, c), 32 | conv(c, c), 33 | ) 34 | self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) 35 | 36 | def forward(self, x, flow=None, scale=1): 37 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 38 | if flow is not None: 39 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 40 | x = torch.cat((x, flow), 1) 41 | feat = self.conv0(x) 42 | feat = self.convblock(feat) + feat 43 | tmp = self.lastconv(feat) 44 | tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear") 45 | flow = tmp[:, :4] * scale * 2 46 | mask = tmp[:, 4:5] 47 | return flow, mask 48 | 49 | 50 | class IFNet(nn.Module): 51 | def __init__(self, scale=1, ensemble=False): 52 | super(IFNet, self).__init__() 53 | self.block0 = IFBlock(7, c=192) 54 | self.block1 = IFBlock(8 + 4, c=128) 55 | self.block2 = IFBlock(8 + 4, c=96) 56 | self.block3 = IFBlock(8 + 4, c=64) 57 | self.scale = scale 58 | self.ensemble = ensemble 59 | 60 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): 61 | img0 = img0.clamp(0.0, 1.0) 62 | img1 = img1.clamp(0.0, 1.0) 63 | scale_list = [8 / self.scale, 4 / self.scale, 2 / self.scale, 1 / self.scale] 64 | flow_list = [] 65 | merged = [] 66 | mask_list = [] 67 | warped_img0 = img0 68 | warped_img1 = img1 69 | flow = None 70 | mask = None 71 | block = [self.block0, self.block1, self.block2, self.block3] 72 | for i in range(4): 73 | if flow is None: 74 | flow, mask = block[i](torch.cat((img0, img1, timestep), 1), None, scale=scale_list[i]) 75 | if self.ensemble: 76 | f1, m1 = block[i](torch.cat((img1, img0, 1 - timestep), 1), None, scale=scale_list[i]) 77 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 78 | mask = (mask + (-m1)) / 2 79 | else: 80 | f0, m0 = block[i](torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=scale_list[i]) 81 | if i == 1 and f0[:, :2].abs().max() > 32 and f0[:, 2:4].abs().max() > 32: 82 | for k in range(4): 83 | scale_list[k] *= 2 84 | flow, mask = block[0](torch.cat((img0, img1, timestep), 1), None, scale=scale_list[0]) 85 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 86 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 87 | f0, m0 = block[i]( 88 | torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=scale_list[i] 89 | ) 90 | if self.ensemble: 91 | f1, m1 = block[i]( 92 | torch.cat((warped_img1, warped_img0, 1 - timestep, -mask), 1), 93 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 94 | scale=scale_list[i], 95 | ) 96 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 97 | m0 = (m0 + (-m1)) / 2 98 | flow = flow + f0 99 | mask = mask + m0 100 | mask_list.append(mask) 101 | flow_list.append(flow) 102 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 103 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 104 | merged.append((warped_img0, warped_img1)) 105 | mask_list[3] = torch.sigmoid(mask_list[3]) 106 | return merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) 107 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.PReLU(out_planes), 14 | ) 15 | 16 | 17 | class IFBlock(nn.Module): 18 | def __init__(self, in_planes, c=64): 19 | super(IFBlock, self).__init__() 20 | self.conv0 = nn.Sequential( 21 | conv(in_planes, c // 2, 3, 2, 1), 22 | conv(c // 2, c, 3, 2, 1), 23 | ) 24 | self.convblock = nn.Sequential( 25 | conv(c, c), 26 | conv(c, c), 27 | conv(c, c), 28 | conv(c, c), 29 | conv(c, c), 30 | conv(c, c), 31 | conv(c, c), 32 | conv(c, c), 33 | ) 34 | self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) 35 | 36 | def forward(self, x, flow=None, scale=1): 37 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 38 | if flow is not None: 39 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 40 | x = torch.cat((x, flow), 1) 41 | feat = self.conv0(x) 42 | feat = self.convblock(feat) + feat 43 | tmp = self.lastconv(feat) 44 | tmp = F.interpolate(tmp, scale_factor=scale * 2, mode="bilinear") 45 | flow = tmp[:, :4] * scale * 2 46 | mask = tmp[:, 4:5] 47 | return flow, mask 48 | 49 | 50 | class IFNet(nn.Module): 51 | def __init__(self, scale=1, ensemble=False): 52 | super(IFNet, self).__init__() 53 | self.block0 = IFBlock(7, c=192) 54 | self.block1 = IFBlock(8 + 4, c=128) 55 | self.block2 = IFBlock(8 + 4, c=96) 56 | self.block3 = IFBlock(8 + 4, c=64) 57 | self.scale = scale 58 | self.ensemble = ensemble 59 | 60 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid): 61 | img0 = img0.clamp(0.0, 1.0) 62 | img1 = img1.clamp(0.0, 1.0) 63 | scale_list = [8 / self.scale, 4 / self.scale, 2 / self.scale, 1 / self.scale] 64 | flow_list = [] 65 | merged = [] 66 | mask_list = [] 67 | warped_img0 = img0 68 | warped_img1 = img1 69 | flow = None 70 | mask = None 71 | block = [self.block0, self.block1, self.block2, self.block3] 72 | for i in range(4): 73 | if flow is None: 74 | flow, mask = block[i](torch.cat((img0, img1, timestep), 1), None, scale=scale_list[i]) 75 | if self.ensemble: 76 | f1, m1 = block[i](torch.cat((img1, img0, 1 - timestep), 1), None, scale=scale_list[i]) 77 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 78 | mask = (mask + (-m1)) / 2 79 | else: 80 | f0, m0 = block[i](torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=scale_list[i]) 81 | if i == 1 and f0[:, :2].abs().max() > 32 and f0[:, 2:4].abs().max() > 32: 82 | for k in range(4): 83 | scale_list[k] *= 2 84 | flow, mask = block[0](torch.cat((img0, img1, timestep), 1), None, scale=scale_list[0]) 85 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 86 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 87 | f0, m0 = block[i]( 88 | torch.cat((warped_img0, warped_img1, timestep, mask), 1), flow, scale=scale_list[i] 89 | ) 90 | if self.ensemble: 91 | f1, m1 = block[i]( 92 | torch.cat((warped_img1, warped_img0, 1 - timestep, -mask), 1), 93 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 94 | scale=scale_list[i], 95 | ) 96 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 97 | m0 = (m0 + (-m1)) / 2 98 | flow = flow + f0 99 | mask = mask + m0 100 | mask_list.append(mask) 101 | flow_list.append(flow) 102 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 103 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 104 | merged.append((warped_img0, warped_img1)) 105 | mask_list[3] = torch.sigmoid(mask_list[3]) 106 | return merged[3][0] * mask_list[3] + merged[3][1] * (1 - mask_list[3]) 107 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7 + 8, c=192) 65 | self.block1 = IFBlock(8 + 4 + 8, c=128) 66 | self.block2 = IFBlock(8 + 4 + 8, c=96) 67 | self.block3 = IFBlock(8 + 4 + 8, c=64) 68 | self.encode = nn.Sequential(nn.Conv2d(3, 16, 3, 2, 1), nn.ConvTranspose2d(16, 4, 4, 2, 1)) 69 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 70 | self.ensemble = ensemble 71 | 72 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 73 | img0 = img0.clamp(0.0, 1.0) 74 | img1 = img1.clamp(0.0, 1.0) 75 | flow_list = [] 76 | merged = [] 77 | mask_list = [] 78 | warped_img0 = img0 79 | warped_img1 = img1 80 | flow = None 81 | mask = None 82 | block = [self.block0, self.block1, self.block2, self.block3] 83 | for i in range(4): 84 | if flow is None: 85 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 86 | if self.ensemble: 87 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 88 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 89 | mask = (mask + (-m_)) / 2 90 | else: 91 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 92 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 93 | fd, m0 = block[i]( 94 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 95 | ) 96 | if self.ensemble: 97 | f_, m_ = block[i]( 98 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 99 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 100 | scale=self.scale_list[i], 101 | ) 102 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 103 | mask = (m0 + (-m_)) / 2 104 | else: 105 | mask = m0 106 | flow = flow + fd 107 | mask_list.append(mask) 108 | flow_list.append(flow) 109 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 110 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 111 | merged.append((warped_img0, warped_img1)) 112 | mask = torch.sigmoid(mask) 113 | return warped_img0 * mask + warped_img1 * (1 - mask) 114 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7 + 8, c=192) 65 | self.block1 = IFBlock(8 + 4 + 8, c=128) 66 | self.block2 = IFBlock(8 + 4 + 8, c=96) 67 | self.block3 = IFBlock(8 + 4 + 8, c=64) 68 | self.encode = nn.Sequential(nn.Conv2d(3, 16, 3, 2, 1), nn.ConvTranspose2d(16, 4, 4, 2, 1)) 69 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 70 | self.ensemble = ensemble 71 | 72 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 73 | img0 = img0.clamp(0.0, 1.0) 74 | img1 = img1.clamp(0.0, 1.0) 75 | flow_list = [] 76 | merged = [] 77 | mask_list = [] 78 | warped_img0 = img0 79 | warped_img1 = img1 80 | flow = None 81 | mask = None 82 | block = [self.block0, self.block1, self.block2, self.block3] 83 | for i in range(4): 84 | if flow is None: 85 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 86 | if self.ensemble: 87 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 88 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 89 | mask = (mask + (-m_)) / 2 90 | else: 91 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 92 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 93 | fd, m0 = block[i]( 94 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 95 | ) 96 | if self.ensemble: 97 | f_, m_ = block[i]( 98 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 99 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 100 | scale=self.scale_list[i], 101 | ) 102 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 103 | mask = (m0 + (-m_)) / 2 104 | else: 105 | mask = m0 106 | flow = flow + fd 107 | mask_list.append(mask) 108 | flow_list.append(flow) 109 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 110 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 111 | merged.append((warped_img0, warped_img1)) 112 | mask = torch.sigmoid(mask) 113 | return warped_img0 * mask + warped_img1 * (1 - mask) 114 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_9.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7 + 8, c=192) 65 | self.block1 = IFBlock(8 + 4 + 8, c=128) 66 | self.block2 = IFBlock(8 + 4 + 8, c=96) 67 | self.block3 = IFBlock(8 + 4 + 8, c=64) 68 | self.encode = nn.Sequential(nn.Conv2d(3, 16, 3, 2, 1), nn.ConvTranspose2d(16, 4, 4, 2, 1)) 69 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 70 | self.ensemble = ensemble 71 | 72 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 73 | img0 = img0.clamp(0.0, 1.0) 74 | img1 = img1.clamp(0.0, 1.0) 75 | flow_list = [] 76 | merged = [] 77 | mask_list = [] 78 | warped_img0 = img0 79 | warped_img1 = img1 80 | flow = None 81 | mask = None 82 | block = [self.block0, self.block1, self.block2, self.block3] 83 | for i in range(4): 84 | if flow is None: 85 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 86 | if self.ensemble: 87 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 88 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 89 | mask = (mask + (-m_)) / 2 90 | else: 91 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 92 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 93 | fd, m0 = block[i]( 94 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 95 | ) 96 | if self.ensemble: 97 | f_, m_ = block[i]( 98 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 99 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 100 | scale=self.scale_list[i], 101 | ) 102 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 103 | mask = (m0 + (-m_)) / 2 104 | else: 105 | mask = m0 106 | flow = flow + fd 107 | mask_list.append(mask) 108 | flow_list.append(flow) 109 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 110 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 111 | merged.append((warped_img0, warped_img1)) 112 | mask = torch.sigmoid(mask) 113 | return warped_img0 * mask + warped_img1 * (1 - mask) 114 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_21.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 16, c=256) 89 | self.block1 = IFBlock(8 + 4 + 16 + 8, c=192) 90 | self.block2 = IFBlock(8 + 4 + 16 + 8, c=96) 91 | self.block3 = IFBlock(8 + 4 + 16 + 8, c=48) 92 | self.encode = Head() 93 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 94 | if ensemble: 95 | raise ValueError("rife: ensemble is not supported in v4.21") 96 | 97 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 98 | img0 = img0.clamp(0.0, 1.0) 99 | img1 = img1.clamp(0.0, 1.0) 100 | flow_list = [] 101 | merged = [] 102 | mask_list = [] 103 | warped_img0 = img0 104 | warped_img1 = img1 105 | flow = None 106 | mask = None 107 | block = [self.block0, self.block1, self.block2, self.block3] 108 | for i in range(4): 109 | if flow is None: 110 | flow, mask, feat = block[i]( 111 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 112 | ) 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0, feat = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 118 | flow, 119 | scale=self.scale_list[i], 120 | ) 121 | mask = m0 122 | flow = flow + fd 123 | mask_list.append(mask) 124 | flow_list.append(flow) 125 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 126 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 127 | merged.append((warped_img0, warped_img1)) 128 | mask = torch.sigmoid(mask) 129 | return warped_img0 * mask + warped_img1 * (1 - mask) 130 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_22.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 16, c=256) 89 | self.block1 = IFBlock(8 + 4 + 16 + 8, c=192) 90 | self.block2 = IFBlock(8 + 4 + 16 + 8, c=96) 91 | self.block3 = IFBlock(8 + 4 + 16 + 8, c=48) 92 | self.encode = Head() 93 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 94 | if ensemble: 95 | raise ValueError("rife: ensemble is not supported in v4.22") 96 | 97 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 98 | img0 = img0.clamp(0.0, 1.0) 99 | img1 = img1.clamp(0.0, 1.0) 100 | flow_list = [] 101 | merged = [] 102 | mask_list = [] 103 | warped_img0 = img0 104 | warped_img1 = img1 105 | flow = None 106 | mask = None 107 | block = [self.block0, self.block1, self.block2, self.block3] 108 | for i in range(4): 109 | if flow is None: 110 | flow, mask, feat = block[i]( 111 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 112 | ) 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0, feat = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 118 | flow, 119 | scale=self.scale_list[i], 120 | ) 121 | mask = m0 122 | flow = flow + fd 123 | mask_list.append(mask) 124 | flow_list.append(flow) 125 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 126 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 127 | merged.append((warped_img0, warped_img1)) 128 | mask = torch.sigmoid(mask) 129 | return warped_img0 * mask + warped_img1 * (1 - mask) 130 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_23.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 16, c=256) 89 | self.block1 = IFBlock(8 + 4 + 16 + 8, c=192) 90 | self.block2 = IFBlock(8 + 4 + 16 + 8, c=96) 91 | self.block3 = IFBlock(8 + 4 + 16 + 8, c=48) 92 | self.encode = Head() 93 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 94 | if ensemble: 95 | raise ValueError("rife: ensemble is not supported in v4.23") 96 | 97 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 98 | img0 = img0.clamp(0.0, 1.0) 99 | img1 = img1.clamp(0.0, 1.0) 100 | flow_list = [] 101 | merged = [] 102 | mask_list = [] 103 | warped_img0 = img0 104 | warped_img1 = img1 105 | flow = None 106 | mask = None 107 | block = [self.block0, self.block1, self.block2, self.block3] 108 | for i in range(4): 109 | if flow is None: 110 | flow, mask, feat = block[i]( 111 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 112 | ) 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0, feat = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 118 | flow, 119 | scale=self.scale_list[i], 120 | ) 121 | mask = m0 122 | flow = flow + fd 123 | mask_list.append(mask) 124 | flow_list.append(flow) 125 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 126 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 127 | merged.append((warped_img0, warped_img1)) 128 | mask = torch.sigmoid(mask) 129 | return warped_img0 * mask + warped_img1 * (1 - mask) 130 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_22_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 8, c=192) 89 | self.block1 = IFBlock(8 + 4 + 8 + 8, c=128) 90 | self.block2 = IFBlock(8 + 4 + 8 + 8, c=64) 91 | self.block3 = IFBlock(8 + 4 + 8 + 8, c=32) 92 | self.encode = Head() 93 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 94 | if ensemble: 95 | raise ValueError("rife: ensemble is not supported in v4.22.lite") 96 | 97 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 98 | img0 = img0.clamp(0.0, 1.0) 99 | img1 = img1.clamp(0.0, 1.0) 100 | flow_list = [] 101 | merged = [] 102 | mask_list = [] 103 | warped_img0 = img0 104 | warped_img1 = img1 105 | flow = None 106 | mask = None 107 | block = [self.block0, self.block1, self.block2, self.block3] 108 | for i in range(4): 109 | if flow is None: 110 | flow, mask, feat = block[i]( 111 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 112 | ) 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0, feat = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 118 | flow, 119 | scale=self.scale_list[i], 120 | ) 121 | mask = m0 122 | flow = flow + fd 123 | mask_list.append(mask) 124 | flow_list.append(flow) 125 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 126 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 127 | merged.append((warped_img0, warped_img1)) 128 | mask = torch.sigmoid(mask) 129 | return warped_img0 * mask + warped_img1 * (1 - mask) 130 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_25.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 8, c=192) 89 | self.block1 = IFBlock(8 + 4 + 8 + 8, c=128) 90 | self.block2 = IFBlock(8 + 4 + 8 + 8, c=96) 91 | self.block3 = IFBlock(8 + 4 + 8 + 8, c=64) 92 | self.block4 = IFBlock(8 + 4 + 8 + 8, c=32) 93 | self.encode = Head() 94 | self.scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] 95 | if ensemble: 96 | raise ValueError("rife: ensemble is not supported in v4.25") 97 | 98 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 99 | img0 = img0.clamp(0.0, 1.0) 100 | img1 = img1.clamp(0.0, 1.0) 101 | flow_list = [] 102 | merged = [] 103 | mask_list = [] 104 | warped_img0 = img0 105 | warped_img1 = img1 106 | flow = None 107 | mask = None 108 | block = [self.block0, self.block1, self.block2, self.block3, self.block4] 109 | for i in range(5): 110 | if flow is None: 111 | flow, mask, feat = block[i]( 112 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 113 | ) 114 | else: 115 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 116 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 117 | fd, m0, feat = block[i]( 118 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 119 | flow, 120 | scale=self.scale_list[i], 121 | ) 122 | mask = m0 123 | flow = flow + fd 124 | mask_list.append(mask) 125 | flow_list.append(flow) 126 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 127 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 128 | merged.append((warped_img0, warped_img1)) 129 | mask = torch.sigmoid(mask) 130 | return warped_img0 * mask + warped_img1 * (1 - mask) 131 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_26.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 8, c=192) 89 | self.block1 = IFBlock(8 + 4 + 8 + 8, c=128) 90 | self.block2 = IFBlock(8 + 4 + 8 + 8, c=96) 91 | self.block3 = IFBlock(8 + 4 + 8 + 8, c=64) 92 | self.block4 = IFBlock(8 + 4 + 8 + 8, c=32) 93 | self.encode = Head() 94 | self.scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] 95 | if ensemble: 96 | raise ValueError("rife: ensemble is not supported in v4.26") 97 | 98 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 99 | img0 = img0.clamp(0.0, 1.0) 100 | img1 = img1.clamp(0.0, 1.0) 101 | flow_list = [] 102 | merged = [] 103 | mask_list = [] 104 | warped_img0 = img0 105 | warped_img1 = img1 106 | flow = None 107 | mask = None 108 | block = [self.block0, self.block1, self.block2, self.block3, self.block4] 109 | for i in range(5): 110 | if flow is None: 111 | flow, mask, feat = block[i]( 112 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 113 | ) 114 | else: 115 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 116 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 117 | fd, m0, feat = block[i]( 118 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 119 | flow, 120 | scale=self.scale_list[i], 121 | ) 122 | mask = m0 123 | flow = flow + fd 124 | mask_list.append(mask) 125 | flow_list.append(flow) 126 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 127 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 128 | merged.append((warped_img0, warped_img1)) 129 | mask = torch.sigmoid(mask) 130 | return warped_img0 * mask + warped_img1 * (1 - mask) 131 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_25_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 8, c=192) 89 | self.block1 = IFBlock(8 + 4 + 8 + 8, c=128) 90 | self.block2 = IFBlock(8 + 4 + 8 + 8, c=96) 91 | self.block3 = IFBlock(8 + 4 + 8 + 8, c=64) 92 | self.block4 = IFBlock(8 + 4 + 8 + 8, c=24) 93 | self.encode = Head() 94 | self.scale_list = [32 / scale, 16 / scale, 8 / scale, 4 / scale, 1 / scale] 95 | if ensemble: 96 | raise ValueError("rife: ensemble is not supported in v4.25.lite") 97 | 98 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 99 | img0 = img0.clamp(0.0, 1.0) 100 | img1 = img1.clamp(0.0, 1.0) 101 | flow_list = [] 102 | merged = [] 103 | mask_list = [] 104 | warped_img0 = img0 105 | warped_img1 = img1 106 | flow = None 107 | mask = None 108 | block = [self.block0, self.block1, self.block2, self.block3, self.block4] 109 | for i in range(5): 110 | if flow is None: 111 | flow, mask, feat = block[i]( 112 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 113 | ) 114 | else: 115 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 116 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 117 | fd, m0, feat = block[i]( 118 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 119 | flow, 120 | scale=self.scale_list[i], 121 | ) 122 | mask = m0 123 | flow = flow + fd 124 | mask_list.append(mask) 125 | flow_list.append(flow) 126 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 127 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 128 | merged.append((warped_img0, warped_img1)) 129 | mask = torch.sigmoid(mask) 130 | return warped_img0 * mask + warped_img1 * (1 - mask) 131 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_26_heavy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 16, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 32, c=192) 89 | self.block1 = IFBlock(8 + 4 + 8 + 32, c=128) 90 | self.block2 = IFBlock(8 + 4 + 8 + 32, c=96) 91 | self.block3 = IFBlock(8 + 4 + 8 + 32, c=64) 92 | self.block4 = IFBlock(8 + 4 + 8 + 32, c=32) 93 | self.encode = Head() 94 | self.scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] 95 | if ensemble: 96 | raise ValueError("rife: ensemble is not supported in v4.26.heavy") 97 | 98 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 99 | img0 = img0.clamp(0.0, 1.0) 100 | img1 = img1.clamp(0.0, 1.0) 101 | flow_list = [] 102 | merged = [] 103 | mask_list = [] 104 | warped_img0 = img0 105 | warped_img1 = img1 106 | flow = None 107 | mask = None 108 | block = [self.block0, self.block1, self.block2, self.block3, self.block4] 109 | for i in range(5): 110 | if flow is None: 111 | flow, mask, feat = block[i]( 112 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 113 | ) 114 | else: 115 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 116 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 117 | fd, m0, feat = block[i]( 118 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 119 | flow, 120 | scale=self.scale_list[i], 121 | ) 122 | mask = m0 123 | flow = flow + fd 124 | mask_list.append(mask) 125 | flow_list.append(flow) 126 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 127 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 128 | merged.append((warped_img0, warped_img1)) 129 | mask = torch.sigmoid(mask) 130 | return warped_img0 * mask + warped_img1 * (1 - mask) 131 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_25_heavy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | feat = tmp[:, 5:] 82 | return flow, mask, feat 83 | 84 | 85 | class IFNet(nn.Module): 86 | def __init__(self, scale=1, ensemble=False): 87 | super(IFNet, self).__init__() 88 | self.block0 = IFBlock(7 + 8, c=192 * 2) 89 | self.block1 = IFBlock(8 + 4 + 8 + 8, c=128 * 2) 90 | self.block2 = IFBlock(8 + 4 + 8 + 8, c=96 * 2) 91 | self.block3 = IFBlock(8 + 4 + 8 + 8, c=64 * 2) 92 | self.block4 = IFBlock(8 + 4 + 8 + 8, c=32 * 2) 93 | self.encode = Head() 94 | self.scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] 95 | if ensemble: 96 | raise ValueError("rife: ensemble is not supported in v4.25.heavy") 97 | 98 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 99 | img0 = img0.clamp(0.0, 1.0) 100 | img1 = img1.clamp(0.0, 1.0) 101 | flow_list = [] 102 | merged = [] 103 | mask_list = [] 104 | warped_img0 = img0 105 | warped_img1 = img1 106 | flow = None 107 | mask = None 108 | block = [self.block0, self.block1, self.block2, self.block3, self.block4] 109 | for i in range(5): 110 | if flow is None: 111 | flow, mask, feat = block[i]( 112 | torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i] 113 | ) 114 | else: 115 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 116 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 117 | fd, m0, feat = block[i]( 118 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), 119 | flow, 120 | scale=self.scale_list[i], 121 | ) 122 | mask = m0 123 | flow = flow + fd 124 | mask_list.append(mask) 125 | flow_list.append(flow) 126 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 127 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 128 | merged.append((warped_img0, warped_img1)) 129 | mask = torch.sigmoid(mask) 130 | return warped_img0 * mask + warped_img1 * (1 - mask) 131 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7 + 16, c=192) 65 | self.block1 = IFBlock(8 + 4 + 16, c=128) 66 | self.block2 = IFBlock(8 + 4 + 16, c=96) 67 | self.block3 = IFBlock(8 + 4 + 16, c=64) 68 | self.encode = nn.Sequential( 69 | nn.Conv2d(3, 32, 3, 2, 1), 70 | nn.LeakyReLU(0.2, True), 71 | nn.Conv2d(32, 32, 3, 1, 1), 72 | nn.LeakyReLU(0.2, True), 73 | nn.Conv2d(32, 32, 3, 1, 1), 74 | nn.LeakyReLU(0.2, True), 75 | nn.ConvTranspose2d(32, 8, 4, 2, 1), 76 | ) 77 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 78 | self.ensemble = ensemble 79 | 80 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 81 | img0 = img0.clamp(0.0, 1.0) 82 | img1 = img1.clamp(0.0, 1.0) 83 | flow_list = [] 84 | merged = [] 85 | mask_list = [] 86 | warped_img0 = img0 87 | warped_img1 = img1 88 | flow = None 89 | mask = None 90 | block = [self.block0, self.block1, self.block2, self.block3] 91 | for i in range(4): 92 | if flow is None: 93 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 94 | if self.ensemble: 95 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 96 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 97 | mask = (mask + (-m_)) / 2 98 | else: 99 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 100 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 101 | fd, m0 = block[i]( 102 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 103 | ) 104 | if self.ensemble: 105 | f_, m_ = block[i]( 106 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 107 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 108 | scale=self.scale_list[i], 109 | ) 110 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 111 | mask = (m0 + (-m_)) / 2 112 | else: 113 | mask = m0 114 | flow = flow + fd 115 | mask_list.append(mask) 116 | flow_list.append(flow) 117 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 118 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 119 | merged.append((warped_img0, warped_img1)) 120 | mask = torch.sigmoid(mask) 121 | return warped_img0 * mask + warped_img1 * (1 - mask) 122 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_11.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7 + 16, c=192) 65 | self.block1 = IFBlock(8 + 4 + 16, c=128) 66 | self.block2 = IFBlock(8 + 4 + 16, c=96) 67 | self.block3 = IFBlock(8 + 4 + 16, c=64) 68 | self.encode = nn.Sequential( 69 | nn.Conv2d(3, 32, 3, 2, 1), 70 | nn.LeakyReLU(0.2, True), 71 | nn.Conv2d(32, 32, 3, 1, 1), 72 | nn.LeakyReLU(0.2, True), 73 | nn.Conv2d(32, 32, 3, 1, 1), 74 | nn.LeakyReLU(0.2, True), 75 | nn.ConvTranspose2d(32, 8, 4, 2, 1), 76 | ) 77 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 78 | self.ensemble = ensemble 79 | 80 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 81 | img0 = img0.clamp(0.0, 1.0) 82 | img1 = img1.clamp(0.0, 1.0) 83 | flow_list = [] 84 | merged = [] 85 | mask_list = [] 86 | warped_img0 = img0 87 | warped_img1 = img1 88 | flow = None 89 | mask = None 90 | block = [self.block0, self.block1, self.block2, self.block3] 91 | for i in range(4): 92 | if flow is None: 93 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 94 | if self.ensemble: 95 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 96 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 97 | mask = (mask + (-m_)) / 2 98 | else: 99 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 100 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 101 | fd, m0 = block[i]( 102 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 103 | ) 104 | if self.ensemble: 105 | f_, m_ = block[i]( 106 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 107 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 108 | scale=self.scale_list[i], 109 | ) 110 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 111 | mask = (m0 + (-m_)) / 2 112 | else: 113 | mask = m0 114 | flow = flow + fd 115 | mask_list.append(mask) 116 | flow_list.append(flow) 117 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 118 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 119 | merged.append((warped_img0, warped_img1)) 120 | mask = torch.sigmoid(mask) 121 | return warped_img0 * mask + warped_img1 * (1 - mask) 122 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_12.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7 + 16, c=192) 65 | self.block1 = IFBlock(8 + 4 + 16, c=128) 66 | self.block2 = IFBlock(8 + 4 + 16, c=96) 67 | self.block3 = IFBlock(8 + 4 + 16, c=64) 68 | self.encode = nn.Sequential( 69 | nn.Conv2d(3, 32, 3, 2, 1), 70 | nn.LeakyReLU(0.2, True), 71 | nn.Conv2d(32, 32, 3, 1, 1), 72 | nn.LeakyReLU(0.2, True), 73 | nn.Conv2d(32, 32, 3, 1, 1), 74 | nn.LeakyReLU(0.2, True), 75 | nn.ConvTranspose2d(32, 8, 4, 2, 1), 76 | ) 77 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 78 | self.ensemble = ensemble 79 | 80 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 81 | img0 = img0.clamp(0.0, 1.0) 82 | img1 = img1.clamp(0.0, 1.0) 83 | flow_list = [] 84 | merged = [] 85 | mask_list = [] 86 | warped_img0 = img0 87 | warped_img1 = img1 88 | flow = None 89 | mask = None 90 | block = [self.block0, self.block1, self.block2, self.block3] 91 | for i in range(4): 92 | if flow is None: 93 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 94 | if self.ensemble: 95 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 96 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 97 | mask = (mask + (-m_)) / 2 98 | else: 99 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 100 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 101 | fd, m0 = block[i]( 102 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 103 | ) 104 | if self.ensemble: 105 | f_, m_ = block[i]( 106 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 107 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 108 | scale=self.scale_list[i], 109 | ) 110 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 111 | mask = (m0 + (-m_)) / 2 112 | else: 113 | mask = m0 114 | flow = flow + fd 115 | mask_list.append(mask) 116 | flow_list.append(flow) 117 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 118 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 119 | merged.append((warped_img0, warped_img1)) 120 | mask = torch.sigmoid(mask) 121 | return warped_img0 * mask + warped_img1 * (1 - mask) 122 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_12_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7 + 8, c=128) 65 | self.block1 = IFBlock(8 + 4 + 8, c=96) 66 | self.block2 = IFBlock(8 + 4 + 8, c=64) 67 | self.block3 = IFBlock(8 + 4 + 8, c=48) 68 | self.encode = nn.Sequential( 69 | nn.Conv2d(3, 32, 3, 2, 1), 70 | nn.LeakyReLU(0.2, True), 71 | nn.Conv2d(32, 32, 3, 1, 1), 72 | nn.LeakyReLU(0.2, True), 73 | nn.Conv2d(32, 32, 3, 1, 1), 74 | nn.LeakyReLU(0.2, True), 75 | nn.ConvTranspose2d(32, 4, 4, 2, 1), 76 | ) 77 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 78 | self.ensemble = ensemble 79 | 80 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 81 | img0 = img0.clamp(0.0, 1.0) 82 | img1 = img1.clamp(0.0, 1.0) 83 | flow_list = [] 84 | merged = [] 85 | mask_list = [] 86 | warped_img0 = img0 87 | warped_img1 = img1 88 | flow = None 89 | mask = None 90 | block = [self.block0, self.block1, self.block2, self.block3] 91 | for i in range(4): 92 | if flow is None: 93 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 94 | if self.ensemble: 95 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 96 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 97 | mask = (mask + (-m_)) / 2 98 | else: 99 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 100 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 101 | fd, m0 = block[i]( 102 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 103 | ) 104 | if self.ensemble: 105 | f_, m_ = block[i]( 106 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 107 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 108 | scale=self.scale_list[i], 109 | ) 110 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 111 | mask = (m0 + (-m_)) / 2 112 | else: 113 | mask = m0 114 | flow = flow + fd 115 | mask_list.append(mask) 116 | flow_list.append(flow) 117 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 118 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 119 | merged.append((warped_img0, warped_img1)) 120 | mask = torch.sigmoid(mask) 121 | return warped_img0 * mask + warped_img1 * (1 - mask) 122 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_13_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class ResConv(nn.Module): 18 | def __init__(self, c, dilation=1): 19 | super(ResConv, self).__init__() 20 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 21 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 22 | self.relu = nn.LeakyReLU(0.2, True) 23 | 24 | def forward(self, x): 25 | return self.relu(self.conv(x) * self.beta + x) 26 | 27 | 28 | class IFBlock(nn.Module): 29 | def __init__(self, in_planes, c=64): 30 | super(IFBlock, self).__init__() 31 | self.conv0 = nn.Sequential( 32 | conv(in_planes, c // 2, 3, 2, 1), 33 | conv(c // 2, c, 3, 2, 1), 34 | ) 35 | self.convblock = nn.Sequential( 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ResConv(c), 41 | ResConv(c), 42 | ResConv(c), 43 | ResConv(c), 44 | ) 45 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 46 | 47 | def forward(self, x, flow=None, scale=1): 48 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 49 | if flow is not None: 50 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 51 | x = torch.cat((x, flow), 1) 52 | feat = self.conv0(x) 53 | feat = self.convblock(feat) 54 | tmp = self.lastconv(feat) 55 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 56 | flow = tmp[:, :4] * scale 57 | mask = tmp[:, 4:5] 58 | return flow, mask 59 | 60 | 61 | class IFNet(nn.Module): 62 | def __init__(self, scale=1, ensemble=False): 63 | super(IFNet, self).__init__() 64 | self.block0 = IFBlock(7 + 8, c=128) 65 | self.block1 = IFBlock(8 + 4 + 8, c=96) 66 | self.block2 = IFBlock(8 + 4 + 8, c=64) 67 | self.block3 = IFBlock(8 + 4 + 8, c=48) 68 | self.encode = nn.Sequential( 69 | nn.Conv2d(3, 32, 3, 2, 1), 70 | nn.LeakyReLU(0.2, True), 71 | nn.Conv2d(32, 32, 3, 1, 1), 72 | nn.LeakyReLU(0.2, True), 73 | nn.Conv2d(32, 32, 3, 1, 1), 74 | nn.LeakyReLU(0.2, True), 75 | nn.ConvTranspose2d(32, 4, 4, 2, 1), 76 | ) 77 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 78 | self.ensemble = ensemble 79 | 80 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 81 | img0 = img0.clamp(0.0, 1.0) 82 | img1 = img1.clamp(0.0, 1.0) 83 | flow_list = [] 84 | merged = [] 85 | mask_list = [] 86 | warped_img0 = img0 87 | warped_img1 = img1 88 | flow = None 89 | mask = None 90 | block = [self.block0, self.block1, self.block2, self.block3] 91 | for i in range(4): 92 | if flow is None: 93 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 94 | if self.ensemble: 95 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 96 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 97 | mask = (mask + (-m_)) / 2 98 | else: 99 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 100 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 101 | fd, m0 = block[i]( 102 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 103 | ) 104 | if self.ensemble: 105 | f_, m_ = block[i]( 106 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 107 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 108 | scale=self.scale_list[i], 109 | ) 110 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 111 | mask = (m0 + (-m_)) / 2 112 | else: 113 | mask = m0 114 | flow = flow + fd 115 | mask_list.append(mask) 116 | flow_list.append(flow) 117 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 118 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 119 | merged.append((warped_img0, warped_img1)) 120 | mask = torch.sigmoid(mask) 121 | return warped_img0 * mask + warped_img1 * (1 - mask) 122 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_13.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=192) 88 | self.block1 = IFBlock(8 + 4 + 16, c=128) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=64) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_14.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=192) 88 | self.block1 = IFBlock(8 + 4 + 16, c=128) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=64) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_15.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=192) 88 | self.block1 = IFBlock(8 + 4 + 16, c=128) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=64) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_15_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 8, c=128) 88 | self.block1 = IFBlock(8 + 4 + 8, c=96) 89 | self.block2 = IFBlock(8 + 4 + 8, c=64) 90 | self.block3 = IFBlock(8 + 4 + 8, c=48) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_16_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 8, c=128) 88 | self.block1 = IFBlock(8 + 4 + 8, c=96) 89 | self.block2 = IFBlock(8 + 4 + 8, c=64) 90 | self.block3 = IFBlock(8 + 4 + 8, c=48) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_17.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=192) 88 | self.block1 = IFBlock(8 + 4 + 16, c=128) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=64) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_17_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 8, c=128) 88 | self.block1 = IFBlock(8 + 4 + 8, c=96) 89 | self.block2 = IFBlock(8 + 4 + 8, c=64) 90 | self.block3 = IFBlock(8 + 4 + 8, c=48) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=192) 88 | self.block1 = IFBlock(8 + 4 + 16, c=128) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=64) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_19.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=192) 88 | self.block1 = IFBlock(8 + 4 + 16, c=128) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=64) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_20.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=384) 88 | self.block1 = IFBlock(8 + 4 + 16, c=192) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=48) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_24.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=192) 88 | self.block1 = IFBlock(8 + 4 + 16, c=128) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=64) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | -------------------------------------------------------------------------------- /vsrife/IFNet_HDv3_v4_14_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .warplayer import warp 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d( 11 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True 12 | ), 13 | nn.LeakyReLU(0.2, True), 14 | ) 15 | 16 | 17 | class Head(nn.Module): 18 | def __init__(self): 19 | super(Head, self).__init__() 20 | self.cnn0 = nn.Conv2d(3, 32, 3, 2, 1) 21 | self.cnn1 = nn.Conv2d(32, 32, 3, 1, 1) 22 | self.cnn2 = nn.Conv2d(32, 32, 3, 1, 1) 23 | self.cnn3 = nn.ConvTranspose2d(32, 8, 4, 2, 1) 24 | self.relu = nn.LeakyReLU(0.2, True) 25 | 26 | def forward(self, x, feat=False): 27 | x = x.clamp(0.0, 1.0) 28 | x0 = self.cnn0(x) 29 | x = self.relu(x0) 30 | x1 = self.cnn1(x) 31 | x = self.relu(x1) 32 | x2 = self.cnn2(x) 33 | x = self.relu(x2) 34 | x3 = self.cnn3(x) 35 | if feat: 36 | return [x0, x1, x2, x3] 37 | return x3 38 | 39 | 40 | class ResConv(nn.Module): 41 | def __init__(self, c, dilation=1): 42 | super(ResConv, self).__init__() 43 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=2) 44 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 45 | self.relu = nn.LeakyReLU(0.2, True) 46 | 47 | def forward(self, x): 48 | return self.relu(self.conv(x) * self.beta + x) 49 | 50 | 51 | class IFBlock(nn.Module): 52 | def __init__(self, in_planes, c=64): 53 | super(IFBlock, self).__init__() 54 | self.conv0 = nn.Sequential( 55 | conv(in_planes, c // 2, 3, 2, 1), 56 | conv(c // 2, c, 3, 2, 1), 57 | ) 58 | self.convblock = nn.Sequential( 59 | ResConv(c), 60 | ResConv(c), 61 | ResConv(c), 62 | ResConv(c), 63 | ResConv(c), 64 | ResConv(c), 65 | ResConv(c), 66 | ResConv(c), 67 | ) 68 | self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)) 69 | 70 | def forward(self, x, flow=None, scale=1): 71 | x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") 72 | if flow is not None: 73 | flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear") / scale 74 | x = torch.cat((x, flow), 1) 75 | feat = self.conv0(x) 76 | feat = self.convblock(feat) 77 | tmp = self.lastconv(feat) 78 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") 79 | flow = tmp[:, :4] * scale 80 | mask = tmp[:, 4:5] 81 | return flow, mask 82 | 83 | 84 | class IFNet(nn.Module): 85 | def __init__(self, scale=1, ensemble=False): 86 | super(IFNet, self).__init__() 87 | self.block0 = IFBlock(7 + 16, c=192) 88 | self.block1 = IFBlock(8 + 4 + 16, c=128) 89 | self.block2 = IFBlock(8 + 4 + 16, c=96) 90 | self.block3 = IFBlock(8 + 4 + 16, c=64) 91 | self.encode = Head() 92 | self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale] 93 | self.ensemble = ensemble 94 | 95 | def forward(self, img0, img1, timestep, tenFlow_div, backwarp_tenGrid, f0, f1): 96 | img0 = img0.clamp(0.0, 1.0) 97 | img1 = img1.clamp(0.0, 1.0) 98 | flow_list = [] 99 | merged = [] 100 | mask_list = [] 101 | warped_img0 = img0 102 | warped_img1 = img1 103 | flow = None 104 | mask = None 105 | block = [self.block0, self.block1, self.block2, self.block3] 106 | for i in range(4): 107 | if flow is None: 108 | flow, mask = block[i](torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) 109 | if self.ensemble: 110 | f_, m_ = block[i](torch.cat((img1, img0, f1, f0, 1 - timestep), 1), None, scale=self.scale_list[i]) 111 | flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 112 | mask = (mask + (-m_)) / 2 113 | else: 114 | wf0 = warp(f0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 115 | wf1 = warp(f1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 116 | fd, m0 = block[i]( 117 | torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask), 1), flow, scale=self.scale_list[i] 118 | ) 119 | if self.ensemble: 120 | f_, m_ = block[i]( 121 | torch.cat((warped_img1, warped_img0, wf1, wf0, 1 - timestep, -mask), 1), 122 | torch.cat((flow[:, 2:4], flow[:, :2]), 1), 123 | scale=self.scale_list[i], 124 | ) 125 | fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2 126 | mask = (m0 + (-m_)) / 2 127 | else: 128 | mask = m0 129 | flow = flow + fd 130 | mask_list.append(mask) 131 | flow_list.append(flow) 132 | warped_img0 = warp(img0, flow[:, :2], tenFlow_div, backwarp_tenGrid) 133 | warped_img1 = warp(img1, flow[:, 2:4], tenFlow_div, backwarp_tenGrid) 134 | merged.append((warped_img0, warped_img1)) 135 | mask = torch.sigmoid(mask) 136 | return warped_img0 * mask + warped_img1 * (1 - mask) 137 | --------------------------------------------------------------------------------