├── LICENSE ├── MultiPassDedup.ipynb ├── README.md ├── assert └── result.gif ├── infer.py ├── models ├── IFNet_HDv3.py ├── gimm │ ├── configs │ │ ├── gimm │ │ │ └── gimm.yaml │ │ └── gimmvfi │ │ │ ├── gimmvfi_f_arb.yaml │ │ │ └── gimmvfi_r_arb.yaml │ └── src │ │ ├── models │ │ ├── __init__.py │ │ ├── ema.py │ │ └── generalizable_INR │ │ │ ├── __init__.py │ │ │ ├── configs.py │ │ │ ├── flowformer │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── alt_cuda_corr │ │ │ │ ├── correlation.cpp │ │ │ │ ├── correlation_kernel.cu │ │ │ │ └── setup.py │ │ │ ├── assets │ │ │ │ └── teaser.png │ │ │ ├── chairs_split.txt │ │ │ ├── configs │ │ │ │ ├── default.py │ │ │ │ ├── kitti.py │ │ │ │ ├── sintel.py │ │ │ │ ├── small_things_eval.py │ │ │ │ ├── submission.py │ │ │ │ ├── things.py │ │ │ │ ├── things_eval.py │ │ │ │ └── things_flowformer_sharp.py │ │ │ ├── core │ │ │ │ ├── FlowFormer │ │ │ │ │ ├── LatentCostFormer │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── attention.py │ │ │ │ │ │ ├── cnn.py │ │ │ │ │ │ ├── convnext.py │ │ │ │ │ │ ├── decoder.py │ │ │ │ │ │ ├── encoder.py │ │ │ │ │ │ ├── gma.py │ │ │ │ │ │ ├── gru.py │ │ │ │ │ │ ├── mlpmixer.py │ │ │ │ │ │ ├── transformer.py │ │ │ │ │ │ └── twins.py │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── common.py │ │ │ │ │ └── encoders.py │ │ │ │ ├── __init__.py │ │ │ │ ├── corr.py │ │ │ │ ├── datasets.py │ │ │ │ ├── extractor.py │ │ │ │ ├── loss.py │ │ │ │ ├── optimizer │ │ │ │ │ └── __init__.py │ │ │ │ ├── position_encoding.py │ │ │ │ ├── raft.py │ │ │ │ ├── update.py │ │ │ │ └── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── augmentor.py │ │ │ │ │ ├── datasets.py │ │ │ │ │ ├── flow_transforms.py │ │ │ │ │ ├── flow_viz.py │ │ │ │ │ ├── frame_utils.py │ │ │ │ │ ├── logger.py │ │ │ │ │ ├── misc.py │ │ │ │ │ └── utils.py │ │ │ ├── evaluate_FlowFormer.py │ │ │ ├── evaluate_FlowFormer_tile.py │ │ │ ├── run_train.sh │ │ │ ├── train_FlowFormer.py │ │ │ └── visualize_flow.py │ │ │ ├── gimm.py │ │ │ ├── gimmvfi_f.py │ │ │ ├── gimmvfi_r.py │ │ │ ├── gmflow │ │ │ ├── __init__.py │ │ │ ├── backbone.py │ │ │ ├── geometry.py │ │ │ ├── gmflow.py │ │ │ ├── matching.py │ │ │ ├── position.py │ │ │ ├── transformer.py │ │ │ ├── trident_conv.py │ │ │ └── utils.py │ │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── coord_sampler.py │ │ │ ├── fi_components.py │ │ │ ├── fi_utils.py │ │ │ ├── hyponet.py │ │ │ ├── layers.py │ │ │ ├── module_config.py │ │ │ ├── softsplat.py │ │ │ └── utils.py │ │ │ └── raft │ │ │ ├── __init__.py │ │ │ ├── corr.py │ │ │ ├── extractor.py │ │ │ ├── other_raft.py │ │ │ ├── raft.py │ │ │ ├── update.py │ │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── augmentor.py │ │ │ ├── flow_viz.py │ │ │ ├── frame_utils.py │ │ │ └── utils.py │ │ └── utils │ │ ├── __init__.py │ │ ├── accumulator.py │ │ ├── config.py │ │ ├── dist.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ ├── loss.py │ │ ├── lpips │ │ ├── __init__.py │ │ ├── alex.pth │ │ ├── lpips.py │ │ └── pretrained_networks.py │ │ ├── profiler.py │ │ ├── setup.py │ │ ├── utils.py │ │ └── writer.py ├── gmflow │ ├── __init__.py │ ├── backbone.py │ ├── geometry.py │ ├── gmflow.py │ ├── matching.py │ ├── position.py │ ├── transformer.py │ ├── trident_conv.py │ └── utils.py ├── model_pg104 │ ├── FeatureNet.py │ ├── FusionNet.py │ ├── GMFSS.py │ ├── IFNet_HDv3.py │ ├── MetricNet.py │ ├── softsplat.py │ └── warplayer.py ├── pytorch_msssim │ └── __init__.py ├── softsplat │ ├── softsplat.py │ └── softsplat_torch.py ├── utils │ └── tools.py └── vfi.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 hyw-dev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📖MultiPassDedup 2 | 3 | ### Efficient Deduplicate for Anime Video Frame Interpolation 4 | > When performing frame interpolation on anime footage, conventional deduplication methods often rely on identification, which has many drawbacks, such as losing background textures and failing to correctly handle multiple characters drawn with different cadences in the same scene. 5 | > 6 | > Through observation and summarization of patterns in anime videos, we found that repeatedly updating the original frames provides an easier and more effective solution to these issues. 7 | Therefore, we developed this project to implement this approach. Combined with the powerful GMFSS interpolation algorithm, we can achieve excellent results in most anime scenarios. 8 | 9 | ![result](assert/result.gif) 10 | 11 | google colab logo 12 | 13 | ## 👀Demos Videos(BiliBili) 14 | ### [Jujutsu Kaisen S2 NCOP](https://www.bilibili.com/video/BV16W421N7s5/?share_source=copy_web&vd_source=8a8926eb0f1d5f0f1cab7529c8f51282) 15 | ### [Houseki no Kuni NCOP](https://www.bilibili.com/video/BV1py4y1A7qj/?share_source=copy_web&vd_source=8a8926eb0f1d5f0f1cab7529c8f51282) 16 | 17 | ## 🔧Installation 18 | ```bash 19 | git clone https://github.com/routineLife1/MultiPassDedup.git 20 | cd DRBA 21 | pip3 install -r requirements.txt 22 | ``` 23 | download weights from [Google Drive](https://drive.google.com/file/d/1gXyqRiLgZ0sQEuDl4vbbxIgbUvg3k50x/view?usp=sharing) and unzip it, put them to ./weights/ 24 | 25 | 26 | The cupy package is included in the requirements, but its installation is optional. It is used to accelerate computation. If you encounter difficulties while installing this package, you can skip it. 27 | 28 | 29 | ## ⚡Usage 30 | - normalize the source video to 24000/1001 fps by following command using ffmpeg **(If the INPUT video framerate is around 23.976, skip this step.)** 31 | ```bash 32 | ffmpeg -i INPUT -crf 16 -r 24000/1001 -preset slow -c:v libx265 -x265-params profile=main10 -c:a copy OUTPUT 33 | ``` 34 | - open the video and check out it's max consistent deduplication counts, (3 -> on Three, 2 -> on Two, 0 -> AUTO) **(If the INPUT video framerate is around 23.976, skip this step.)** 35 | - run the follwing command to finish interpolation 36 | (N_PASS = max_consistent_deduplication_counts) **(Under the most circumstances, -np 0 can automatically determine an appropriate n_pass value)** 37 | ```bash 38 | python infer.py -i [VIDEO] -o [VIDEO_OUTPUT] -np [N_PASS] -t [TIMES] -m [MODEL_TYPE] -s -st 0.3 -scale [SCALE] 39 | # or use the following command to export video at any frame rate 40 | python infer.py -i [VIDEO] -o [VIDEO_OUTPUT] -np [N_PASS] -fps [OUTPUT_FPS] -m [MODEL_TYPE] -s -st 0.3 -scale [SCALE] 41 | ``` 42 | 43 | **example(smooth a 23.976fps video with on three and interpolate it to 60fps):** 44 | 45 | ```bash 46 | ffmpeg -i E:/Myvideo/01_src.mkv -crf 16 -r 24000/1001 -preset slow -c:v libx265 -x265-params profile=main10 -c:a copy E:/Myvideo/01.mkv 47 | 48 | python infer.py -i E:/MyVideo/01.mkv -o E:/MyVideo/out.mkv -np 3 -fps 60 -m gmfss -s -st 0.3 -scale 1.0 49 | ``` 50 | 51 | **Full Usage** 52 | ```bash 53 | Usage: python infer.py -i in_video -o out_video [options]... 54 | 55 | -h show this help 56 | -i input input video path (absolute path of output video) 57 | -o output output video path (absolute path of output video) 58 | -fps dst_fps target frame rate (default=60) 59 | -s enable_scdet enable scene change detection (default Enable) 60 | -st scdet_threshold ssim scene detection threshold (default=0.3) 61 | -hw hwaccel enable hardware acceleration encode (default Enable) (require nvidia graph card) 62 | -s scale flow scale factor (default=1.0), generally use 1.0 with 1080P and 0.5 with 4K resolution 63 | -m model_type model type (default=gmfss) 64 | -np n_pass max consistent deduplication counts (default=3) 65 | ``` 66 | 67 | - input accept absolute video file path. Example: E:/input.mp4 68 | - output accept absolute video file path. Example: E:/output.mp4 69 | - dst_fps = target interpolated video frame rate. Example: 60 70 | - enable_scdet = enable scene change detection. 71 | - scdet_threshold = scene change detection threshold. The larger the value, the more sensitive the detection. 72 | - hwaccel = enable hardware acceleration during encoding output video. 73 | - scale = flow scale factor. Decrease this value to reduce the computational difficulty of the model at higher resolutions. Generally, use 1.0 for 1080P and 0.5 for 4K resolution. 74 | - model_type = model type. Currently, gmfss, rife and gimm is supported. 75 | - n_pass = max consistent deduplication counts. 76 | 77 | ## 🤗 Acknowledgement 78 | This project is supported by [SVFI](https://doc.svfi.group/) Development Team. 79 | 80 | Thanks for [Q8sh2ing](https://github.com/Q8sh2ing) implement the Online Colab Demo. 81 | 82 | ## Reference 83 | [GMFSS](https://github.com/98mxr/GMFSS_Fortuna) [Practical-RIFE](https://github.com/hzwer/Practical-RIFE) [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI) 84 | -------------------------------------------------------------------------------- /assert/result.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/assert/result.gif -------------------------------------------------------------------------------- /models/IFNet_HDv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.model_pg104.warplayer import warp 5 | 6 | # from train_log.refine import * 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 12 | return nn.Sequential( 13 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 14 | padding=padding, dilation=dilation, bias=True), 15 | nn.LeakyReLU(0.2, True) 16 | ) 17 | 18 | 19 | def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 20 | return nn.Sequential( 21 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 22 | padding=padding, dilation=dilation, bias=False), 23 | nn.BatchNorm2d(out_planes), 24 | nn.LeakyReLU(0.2, True) 25 | ) 26 | 27 | 28 | class ResConv(nn.Module): 29 | def __init__(self, c, dilation=1): 30 | super(ResConv, self).__init__() 31 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1 \ 32 | ) 33 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 34 | self.relu = nn.LeakyReLU(0.2, True) 35 | 36 | def forward(self, x): 37 | return self.relu(self.conv(x) * self.beta + x) 38 | 39 | 40 | class IFBlock(nn.Module): 41 | def __init__(self, in_planes, c=64): 42 | super(IFBlock, self).__init__() 43 | self.conv0 = nn.Sequential( 44 | conv(in_planes, c // 2, 3, 2, 1), 45 | conv(c // 2, c, 3, 2, 1), 46 | ) 47 | self.convblock = nn.Sequential( 48 | ResConv(c), 49 | ResConv(c), 50 | ResConv(c), 51 | ResConv(c), 52 | ResConv(c), 53 | ResConv(c), 54 | ResConv(c), 55 | ResConv(c), 56 | ) 57 | self.lastconv = nn.Sequential( 58 | nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), 59 | nn.PixelShuffle(2) 60 | ) 61 | 62 | def forward(self, x, flow=None, scale=1): 63 | x = F.interpolate(x, scale_factor=1. / scale, mode="bilinear", align_corners=False) 64 | if flow is not None: 65 | flow = F.interpolate(flow, scale_factor=1. / scale, mode="bilinear", align_corners=False) * 1. / scale 66 | x = torch.cat((x, flow), 1) 67 | feat = self.conv0(x) 68 | feat = self.convblock(feat) 69 | tmp = self.lastconv(feat) 70 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) 71 | flow = tmp[:, :4] * scale 72 | mask = tmp[:, 4:5] 73 | return flow, mask 74 | 75 | 76 | class IFNet(nn.Module): 77 | def __init__(self): 78 | super(IFNet, self).__init__() 79 | self.block0 = IFBlock(7 + 8, c=192) 80 | self.block1 = IFBlock(8 + 4 + 8, c=128) 81 | self.block2 = IFBlock(8 + 4 + 8, c=96) 82 | self.block3 = IFBlock(8 + 4 + 8, c=64) 83 | self.encode = nn.Sequential( 84 | nn.Conv2d(3, 16, 3, 2, 1), 85 | nn.ConvTranspose2d(16, 4, 4, 2, 1) 86 | ) 87 | # self.contextnet = Contextnet() 88 | # self.unet = Unet() 89 | 90 | def forward(self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): 91 | if ensemble: 92 | print('ensemble is removed') 93 | if training == False: 94 | channel = x.shape[1] // 2 95 | img0 = x[:, :channel] 96 | img1 = x[:, channel:] 97 | if not torch.is_tensor(timestep): 98 | timestep = (x[:, :1].clone() * 0 + 1) * timestep 99 | else: 100 | timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) 101 | f0 = self.encode(img0[:, :3]) 102 | f1 = self.encode(img1[:, :3]) 103 | flow_list = [] 104 | merged = [] 105 | mask_list = [] 106 | warped_img0 = img0 107 | warped_img1 = img1 108 | flow = None 109 | mask = None 110 | loss_cons = 0 111 | block = [self.block0, self.block1, self.block2, self.block3] 112 | for i in range(4): 113 | if flow is None: 114 | flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, 115 | scale=scale_list[i]) 116 | else: 117 | fd, mask = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], warp(f0, flow[:, :2]), 118 | warp(f1, flow[:, 2:4]), timestep, mask), 1), flow, scale=scale_list[i]) 119 | flow = flow + fd 120 | mask_list.append(mask) 121 | flow_list.append(flow) 122 | warped_img0 = warp(img0, flow[:, :2]) 123 | warped_img1 = warp(img1, flow[:, 2:4]) 124 | merged.append((warped_img0, warped_img1)) 125 | mask = torch.sigmoid(mask) 126 | merged[3] = (warped_img0 * mask + warped_img1 * (1 - mask)) 127 | if not fastmode: 128 | print('contextnet is removed') 129 | ''' 130 | c0 = self.contextnet(img0, flow[:, :2]) 131 | c1 = self.contextnet(img1, flow[:, 2:4]) 132 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 133 | res = tmp[:, :3] * 2 - 1 134 | merged[3] = torch.clamp(merged[3] + res, 0, 1) 135 | ''' 136 | return merged[3] 137 | -------------------------------------------------------------------------------- /models/gimm/configs/gimm/gimm.yaml: -------------------------------------------------------------------------------- 1 | trainer: stage_inr 2 | dataset: 3 | type: fast_vimeo_flow 4 | path: ./data/vimeo90k/vimeo_triplet 5 | add_objects: false 6 | expansion: false 7 | random_t: false 8 | aug: true 9 | t_scale: 10 10 | pair: false 11 | 12 | arch: # needs to add encoder, modulation type 13 | type: gimm 14 | ema: null 15 | 16 | modulated_layer_idxs: [1] 17 | 18 | coord_range: [-1., 1.] 19 | 20 | hyponet: 21 | type: mlp 22 | n_layer: 5 # including the output layer 23 | hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1] 24 | use_bias: true 25 | input_dim: 3 26 | output_dim: 2 27 | output_bias: 0.5 28 | activation: 29 | type: siren 30 | siren_w0: 1.0 31 | initialization: 32 | weight_init_type: siren 33 | bias_init_type: siren 34 | 35 | loss: 36 | type: mse #now unnecessary 37 | 38 | optimizer: 39 | type: adam 40 | init_lr: 0.0001 41 | weight_decay: 0.0 42 | betas: [0.9, 0.999] #[0.9, 0.95] 43 | ft: false 44 | warmup: 45 | epoch: 0 46 | multiplier: 1 47 | buffer_epoch: 0 48 | min_lr: 0.0001 49 | mode: fix 50 | start_from_zero: True 51 | max_gn: null 52 | 53 | experiment: 54 | amp: True 55 | batch_size: 32 56 | total_batch_size: 64 57 | epochs: 400 58 | save_ckpt_freq: 20 59 | test_freq: 10 60 | test_imlog_freq: 10 61 | 62 | -------------------------------------------------------------------------------- /models/gimm/configs/gimmvfi/gimmvfi_f_arb.yaml: -------------------------------------------------------------------------------- 1 | trainer: stage_inr 2 | dataset: 3 | type: vimeo_arb 4 | path: ./data/vimeo90k/vimeo_septuplet 5 | aug: true 6 | 7 | arch: 8 | type: gimmvfi_f 9 | ema: true 10 | modulated_layer_idxs: [1] 11 | 12 | coord_range: [-1., 1.] 13 | 14 | hyponet: 15 | type: mlp 16 | n_layer: 5 # including the output layer 17 | hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1] 18 | use_bias: true 19 | input_dim: 3 20 | output_dim: 2 21 | output_bias: 0.5 22 | activation: 23 | type: siren 24 | siren_w0: 1.0 25 | initialization: 26 | weight_init_type: siren 27 | bias_init_type: siren 28 | 29 | loss: 30 | subsample: 31 | type: random 32 | ratio: 0.1 33 | 34 | optimizer: 35 | type: adamw 36 | init_lr: 0.00008 37 | weight_decay: 0.00004 38 | betas: [0.9, 0.999] 39 | ft: true 40 | warmup: 41 | epoch: 1 42 | multiplier: 1 43 | buffer_epoch: 0 44 | min_lr: 0.000008 45 | mode: fix 46 | start_from_zero: True 47 | max_gn: null 48 | 49 | experiment: 50 | amp: True 51 | batch_size: 4 52 | total_batch_size: 32 53 | epochs: 60 54 | save_ckpt_freq: 10 55 | test_freq: 10 56 | test_imlog_freq: 10 57 | 58 | -------------------------------------------------------------------------------- /models/gimm/configs/gimmvfi/gimmvfi_r_arb.yaml: -------------------------------------------------------------------------------- 1 | trainer: stage_inr 2 | dataset: 3 | type: vimeo_arb 4 | path: ./data/vimeo90k/vimeo_septuplet 5 | aug: true 6 | 7 | arch: 8 | type: gimmvfi_r 9 | ema: true 10 | modulated_layer_idxs: [1] 11 | 12 | coord_range: [-1., 1.] 13 | 14 | hyponet: 15 | type: mlp 16 | n_layer: 5 # including the output layer 17 | hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1] 18 | use_bias: true 19 | input_dim: 3 20 | output_dim: 2 21 | output_bias: 0.5 22 | activation: 23 | type: siren 24 | siren_w0: 1.0 25 | initialization: 26 | weight_init_type: siren 27 | bias_init_type: siren 28 | 29 | loss: 30 | subsample: 31 | type: random 32 | ratio: 0.1 33 | 34 | optimizer: 35 | type: adamw 36 | init_lr: 0.00008 37 | weight_decay: 0.00004 38 | betas: [0.9, 0.999] 39 | ft: true 40 | warmup: 41 | epoch: 1 42 | multiplier: 1 43 | buffer_epoch: 0 44 | min_lr: 0.000008 45 | mode: fix 46 | start_from_zero: True 47 | max_gn: null 48 | 49 | experiment: 50 | amp: True 51 | batch_size: 4 52 | total_batch_size: 32 53 | epochs: 60 54 | save_ckpt_freq: 10 55 | test_freq: 10 56 | test_imlog_freq: 10 57 | 58 | -------------------------------------------------------------------------------- /models/gimm/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | from .ema import ExponentialMovingAverage 12 | from .generalizable_INR import gimmvfi_f, gimmvfi_r, gimm 13 | 14 | 15 | def create_model(config, ema=False): 16 | model_type = config.type.lower() 17 | if model_type == "gimm": 18 | model = gimm(config) 19 | model_ema = gimm(config) if ema else None 20 | elif model_type == "gimmvfi_f": 21 | model = gimmvfi_f(config) 22 | model_ema = gimmvfi_f(config) if ema else None 23 | elif model_type == "gimmvfi_r": 24 | model = gimmvfi_r(config) 25 | model_ema = gimmvfi_r(config) if ema else None 26 | else: 27 | raise ValueError(f"{model_type} is invalid..") 28 | 29 | if ema: 30 | mu = config.ema 31 | if config.ema_value is not None: 32 | mu = config.ema_value 33 | model_ema = ExponentialMovingAverage(model_ema, mu) 34 | model_ema.eval() 35 | model_ema.update(model, step=-1) 36 | 37 | return model, model_ema 38 | -------------------------------------------------------------------------------- /models/gimm/src/models/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | import logging 12 | import torch 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class ExponentialMovingAverage(torch.nn.Module): 18 | def __init__(self, init_module, mu): 19 | super(ExponentialMovingAverage, self).__init__() 20 | 21 | self.module = init_module 22 | self.mu = mu 23 | 24 | def forward(self, x, *args, **kwargs): 25 | return self.module(x, *args, **kwargs) 26 | 27 | def update(self, module, step=None): 28 | if step is None or not isinstance(self.mu, bool): 29 | # print(['use ema value:', self.mu]) 30 | mu = self.mu 31 | else: 32 | # see : https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/ExponentialMovingAverage?hl=PL 33 | mu = min(self.mu, (1.0 + step) / (10.0 + step)) 34 | 35 | state_dict = {} 36 | with torch.no_grad(): 37 | for (name, m1), (name2, m2) in zip( 38 | self.module.state_dict().items(), module.state_dict().items() 39 | ): 40 | if name != name2: 41 | logger.warning( 42 | "[ExpoentialMovingAverage] not matched keys %s, %s", name, name2 43 | ) 44 | 45 | if step is not None and step < 0: 46 | state_dict[name] = m2.clone().detach() 47 | else: 48 | state_dict[name] = ((mu * m1) + ((1.0 - mu) * m2)).clone().detach() 49 | 50 | self.module.load_state_dict(state_dict) 51 | 52 | def compute_psnr(self, *args, **kwargs): 53 | return self.module.compute_psnr(*args, **kwargs) 54 | 55 | def get_recon_imgs(self, *args, **kwargs): 56 | return self.module.get_recon_imgs(*args, **kwargs) 57 | 58 | def sample_coord_input(self, *args, **kwargs): 59 | return self.module.sample_coord_input(*args, **kwargs) 60 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | from .gimm import GIMM 12 | 13 | from .gimmvfi_f import GIMMVFI_F 14 | from .gimmvfi_r import GIMMVFI_R 15 | 16 | 17 | def gimm(config): 18 | return GIMM(config) 19 | 20 | 21 | def gimmvfi_f(config): 22 | return GIMMVFI_F(config) 23 | 24 | 25 | def gimmvfi_r(config): 26 | return GIMMVFI_R(config) 27 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | from typing import List, Optional 12 | from dataclasses import dataclass 13 | 14 | from omegaconf import OmegaConf, MISSING 15 | from .modules.module_config import HypoNetConfig 16 | 17 | 18 | @dataclass 19 | class GIMMConfig: 20 | type: str = "gimm" 21 | ema: Optional[bool] = None 22 | ema_value: Optional[float] = None 23 | fwarp_type: str = "linear" 24 | hyponet: HypoNetConfig = HypoNetConfig() 25 | coord_range: List[float] = MISSING 26 | modulated_layer_idxs: Optional[List[int]] = None 27 | 28 | @classmethod 29 | def create(cls, config): 30 | # We need to specify the type of the default DataEncoderConfig. 31 | # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value) 32 | # hence merging with the config with other type would cause config error. 33 | defaults = OmegaConf.structured(cls(ema=False)) 34 | config = OmegaConf.merge(defaults, config) 35 | return config 36 | 37 | 38 | @dataclass 39 | class GIMMVFIConfig: 40 | type: str = "gimmvfi" 41 | ema: Optional[bool] = None 42 | ema_value: Optional[float] = None 43 | fwarp_type: str = "linear" 44 | rec_weight: float = 0.1 45 | hyponet: HypoNetConfig = HypoNetConfig() 46 | raft_iter: int = 20 47 | coord_range: List[float] = MISSING 48 | modulated_layer_idxs: Optional[List[int]] = None 49 | 50 | @classmethod 51 | def create(cls, config): 52 | # We need to specify the type of the default DataEncoderConfig. 53 | # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value) 54 | # hence merging with the config with other type would cause config error. 55 | defaults = OmegaConf.structured(cls(ema=False)) 56 | config = OmegaConf.merge(defaults, config) 57 | return config 58 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .configs.submission import get_cfg 3 | from .core.FlowFormer import build_flowformer 4 | 5 | 6 | def initialize_Flowformer(): 7 | cfg = get_cfg() 8 | model = build_flowformer(cfg) 9 | 10 | ckpt = torch.load(cfg.model, map_location="cpu") 11 | 12 | def convert(param): 13 | return {k.replace("module.", ""): v for k, v in param.items() if "module" in k} 14 | 15 | ckpt = convert(ckpt) 16 | model.load_state_dict(ckpt) 17 | 18 | return model 19 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name="correlation", 7 | ext_modules=[ 8 | CUDAExtension( 9 | "alt_cuda_corr", 10 | sources=["correlation.cpp", "correlation_kernel.cu"], 11 | extra_compile_args={"cxx": [], "nvcc": ["-O3"]}, 12 | ), 13 | ], 14 | cmdclass={"build_ext": BuildExtension}, 15 | ) 16 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/flowformer/assets/teaser.png -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/configs/default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | _CN.name = "default" 6 | _CN.suffix = "arxiv2" 7 | _CN.gamma = 0.8 8 | _CN.max_flow = 400 9 | _CN.batch_size = 8 10 | _CN.sum_freq = 100 11 | _CN.val_freq = 5000 12 | _CN.image_size = [368, 496] 13 | _CN.add_noise = True 14 | _CN.critical_params = [] 15 | 16 | _CN.transformer = "latentcostformer" 17 | _CN.restore_ckpt = None 18 | 19 | ########################################### 20 | # latentcostformer 21 | _CN.latentcostformer = CN() 22 | _CN.latentcostformer.pe = "linear" 23 | _CN.latentcostformer.dropout = 0.0 24 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 25 | _CN.latentcostformer.query_latent_dim = 64 26 | _CN.latentcostformer.cost_latent_input_dim = 64 27 | _CN.latentcostformer.cost_latent_token_num = 8 28 | _CN.latentcostformer.cost_latent_dim = 128 29 | _CN.latentcostformer.predictor_dim = 128 30 | _CN.latentcostformer.motion_feature_dim = 209 # use concat, so double query_latent_dim 31 | _CN.latentcostformer.arc_type = "transformer" 32 | _CN.latentcostformer.cost_heads_num = 1 33 | # encoder 34 | _CN.latentcostformer.pretrain = True 35 | _CN.latentcostformer.context_concat = False 36 | _CN.latentcostformer.encoder_depth = 3 37 | _CN.latentcostformer.feat_cross_attn = False 38 | _CN.latentcostformer.patch_size = 8 39 | _CN.latentcostformer.patch_embed = "single" 40 | _CN.latentcostformer.gma = True 41 | _CN.latentcostformer.rm_res = True 42 | _CN.latentcostformer.vert_c_dim = 64 43 | _CN.latentcostformer.cost_encoder_res = True 44 | _CN.latentcostformer.cnet = "twins" 45 | _CN.latentcostformer.fnet = "twins" 46 | _CN.latentcostformer.only_global = False 47 | _CN.latentcostformer.add_flow_token = True 48 | _CN.latentcostformer.use_mlp = False 49 | _CN.latentcostformer.vertical_conv = False 50 | # decoder 51 | _CN.latentcostformer.decoder_depth = 12 52 | _CN.latentcostformer.critical_params = [ 53 | "cost_heads_num", 54 | "vert_c_dim", 55 | "cnet", 56 | "pretrain", 57 | "add_flow_token", 58 | "encoder_depth", 59 | "gma", 60 | "cost_encoder_res", 61 | ] 62 | ########################################## 63 | 64 | ### TRAINER 65 | _CN.trainer = CN() 66 | _CN.trainer.scheduler = "OneCycleLR" 67 | 68 | _CN.trainer.optimizer = "adamw" 69 | _CN.trainer.canonical_lr = 25e-5 70 | _CN.trainer.adamw_decay = 1e-4 71 | _CN.trainer.clip = 1.0 72 | _CN.trainer.num_steps = 120000 73 | _CN.trainer.epsilon = 1e-8 74 | _CN.trainer.anneal_strategy = "linear" 75 | 76 | 77 | def get_cfg(): 78 | return _CN.clone() 79 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/configs/kitti.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | _CN.name = "kitti" 6 | _CN.suffix = "kitti" 7 | _CN.gamma = 0.85 8 | _CN.max_flow = 400 9 | _CN.batch_size = 6 10 | _CN.sum_freq = 100 11 | _CN.val_freq = 499999999 12 | _CN.image_size = [432, 960] 13 | _CN.add_noise = True 14 | _CN.critical_params = [] 15 | 16 | _CN.transformer = "latentcostformer" 17 | 18 | _CN.model = None 19 | 20 | _CN.restore_ckpt = "checkpoints/sintel.pth" 21 | 22 | # latentcostformer 23 | _CN.latentcostformer = CN() 24 | _CN.latentcostformer.pe = "linear" 25 | _CN.latentcostformer.dropout = 0.0 26 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 27 | _CN.latentcostformer.query_latent_dim = 64 28 | _CN.latentcostformer.cost_latent_input_dim = 64 29 | _CN.latentcostformer.cost_latent_token_num = 8 30 | _CN.latentcostformer.cost_latent_dim = 128 31 | _CN.latentcostformer.predictor_dim = 128 32 | _CN.latentcostformer.motion_feature_dim = 209 # use concat, so double query_latent_dim 33 | _CN.latentcostformer.arc_type = "transformer" 34 | _CN.latentcostformer.cost_heads_num = 1 35 | # encoder 36 | _CN.latentcostformer.pretrain = True 37 | _CN.latentcostformer.context_concat = False 38 | _CN.latentcostformer.encoder_depth = 3 39 | _CN.latentcostformer.feat_cross_attn = False 40 | _CN.latentcostformer.vertical_encoder_attn = "twins" 41 | _CN.latentcostformer.patch_size = 8 42 | _CN.latentcostformer.patch_embed = "single" 43 | _CN.latentcostformer.gma = "GMA" 44 | _CN.latentcostformer.rm_res = True 45 | _CN.latentcostformer.vert_c_dim = 64 46 | _CN.latentcostformer.cost_encoder_res = True 47 | _CN.latentcostformer.pwc_aug = False 48 | _CN.latentcostformer.cnet = "twins" 49 | _CN.latentcostformer.fnet = "twins" 50 | _CN.latentcostformer.no_sc = False 51 | _CN.latentcostformer.use_rpe = False 52 | _CN.latentcostformer.only_global = False 53 | _CN.latentcostformer.add_flow_token = True 54 | _CN.latentcostformer.use_mlp = False 55 | _CN.latentcostformer.vertical_conv = False 56 | # decoder 57 | _CN.latentcostformer.decoder_depth = 12 58 | _CN.latentcostformer.critical_params = [ 59 | "cost_heads_num", 60 | "vert_c_dim", 61 | "cnet", 62 | "pretrain", 63 | "add_flow_token", 64 | "encoder_depth", 65 | "gma", 66 | "cost_encoder_res", 67 | ] 68 | 69 | 70 | ### TRAINER 71 | _CN.trainer = CN() 72 | _CN.trainer.scheduler = "OneCycleLR" 73 | _CN.trainer.optimizer = "adamw" 74 | _CN.trainer.canonical_lr = 12.5e-5 75 | _CN.trainer.adamw_decay = 1e-5 76 | _CN.trainer.clip = 1.0 77 | _CN.trainer.num_steps = 50000 78 | _CN.trainer.epsilon = 1e-8 79 | _CN.trainer.anneal_strategy = "linear" 80 | 81 | 82 | def get_cfg(): 83 | return _CN.clone() 84 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/configs/sintel.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | _CN.name = "default" 6 | _CN.suffix = "sintel" 7 | _CN.gamma = 0.85 8 | _CN.max_flow = 400 9 | _CN.batch_size = 6 10 | _CN.sum_freq = 100 11 | _CN.val_freq = 5000000 12 | _CN.image_size = [432, 960] 13 | _CN.add_noise = True 14 | _CN.critical_params = [] 15 | 16 | _CN.transformer = "latentcostformer" 17 | _CN.restore_ckpt = "checkpoints/things.pth" 18 | 19 | # latentcostformer 20 | _CN.latentcostformer = CN() 21 | _CN.latentcostformer.pe = "linear" 22 | _CN.latentcostformer.dropout = 0.0 23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 24 | _CN.latentcostformer.query_latent_dim = 64 25 | _CN.latentcostformer.cost_latent_input_dim = 64 26 | _CN.latentcostformer.cost_latent_token_num = 8 27 | _CN.latentcostformer.cost_latent_dim = 128 28 | _CN.latentcostformer.arc_type = "transformer" 29 | _CN.latentcostformer.cost_heads_num = 1 30 | # encoder 31 | _CN.latentcostformer.pretrain = True 32 | _CN.latentcostformer.context_concat = False 33 | _CN.latentcostformer.encoder_depth = 3 34 | _CN.latentcostformer.feat_cross_attn = False 35 | _CN.latentcostformer.patch_size = 8 36 | _CN.latentcostformer.patch_embed = "single" 37 | _CN.latentcostformer.no_pe = False 38 | _CN.latentcostformer.gma = "GMA" 39 | _CN.latentcostformer.kernel_size = 9 40 | _CN.latentcostformer.rm_res = True 41 | _CN.latentcostformer.vert_c_dim = 64 42 | _CN.latentcostformer.cost_encoder_res = True 43 | _CN.latentcostformer.cnet = "twins" 44 | _CN.latentcostformer.fnet = "twins" 45 | _CN.latentcostformer.no_sc = False 46 | _CN.latentcostformer.only_global = False 47 | _CN.latentcostformer.add_flow_token = True 48 | _CN.latentcostformer.use_mlp = False 49 | _CN.latentcostformer.vertical_conv = False 50 | 51 | # decoder 52 | _CN.latentcostformer.decoder_depth = 12 53 | _CN.latentcostformer.critical_params = [ 54 | "cost_heads_num", 55 | "vert_c_dim", 56 | "cnet", 57 | "pretrain", 58 | "add_flow_token", 59 | "encoder_depth", 60 | "gma", 61 | "cost_encoder_res", 62 | ] 63 | 64 | ### TRAINER 65 | _CN.trainer = CN() 66 | _CN.trainer.scheduler = "OneCycleLR" 67 | _CN.trainer.optimizer = "adamw" 68 | _CN.trainer.canonical_lr = 12.5e-5 69 | _CN.trainer.adamw_decay = 1e-5 70 | _CN.trainer.clip = 1.0 71 | _CN.trainer.num_steps = 120000 72 | _CN.trainer.epsilon = 1e-8 73 | _CN.trainer.anneal_strategy = "linear" 74 | 75 | 76 | def get_cfg(): 77 | return _CN.clone() 78 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/configs/small_things_eval.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | _CN.name = "" 6 | _CN.suffix = "" 7 | _CN.gamma = 0.8 8 | _CN.max_flow = 400 9 | _CN.batch_size = 6 10 | _CN.sum_freq = 100 11 | _CN.val_freq = 5000000 12 | _CN.image_size = [432, 960] 13 | _CN.add_noise = False 14 | _CN.critical_params = [] 15 | 16 | _CN.transformer = "latentcostformer" 17 | _CN.model = "checkpoints/flowformer-small/things.pth" 18 | 19 | # latentcostformer 20 | _CN.latentcostformer = CN() 21 | _CN.latentcostformer.pe = "linear" 22 | _CN.latentcostformer.dropout = 0.0 23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 24 | _CN.latentcostformer.query_latent_dim = 64 25 | _CN.latentcostformer.cost_latent_input_dim = 64 26 | _CN.latentcostformer.cost_latent_token_num = 4 27 | _CN.latentcostformer.cost_latent_dim = 32 28 | _CN.latentcostformer.arc_type = "transformer" 29 | _CN.latentcostformer.cost_heads_num = 1 30 | # encoder 31 | _CN.latentcostformer.pretrain = True 32 | _CN.latentcostformer.context_concat = False 33 | _CN.latentcostformer.encoder_depth = 1 34 | _CN.latentcostformer.feat_cross_attn = False 35 | _CN.latentcostformer.patch_size = 8 36 | _CN.latentcostformer.patch_embed = "single" 37 | _CN.latentcostformer.no_pe = False 38 | _CN.latentcostformer.gma = "GMA" 39 | _CN.latentcostformer.kernel_size = 9 40 | _CN.latentcostformer.rm_res = True 41 | _CN.latentcostformer.vert_c_dim = 0 42 | _CN.latentcostformer.cost_encoder_res = True 43 | _CN.latentcostformer.cnet = "basicencoder" 44 | _CN.latentcostformer.fnet = "basicencoder" 45 | _CN.latentcostformer.no_sc = False 46 | _CN.latentcostformer.only_global = False 47 | _CN.latentcostformer.add_flow_token = True 48 | _CN.latentcostformer.use_mlp = False 49 | _CN.latentcostformer.vertical_conv = False 50 | 51 | # decoder 52 | _CN.latentcostformer.decoder_depth = 32 53 | _CN.latentcostformer.critical_params = [ 54 | "cost_heads_num", 55 | "vert_c_dim", 56 | "cnet", 57 | "pretrain", 58 | "add_flow_token", 59 | "encoder_depth", 60 | "gma", 61 | "cost_encoder_res", 62 | ] 63 | 64 | ### TRAINER 65 | _CN.trainer = CN() 66 | _CN.trainer.scheduler = "OneCycleLR" 67 | _CN.trainer.optimizer = "adamw" 68 | _CN.trainer.canonical_lr = 12.5e-5 69 | _CN.trainer.adamw_decay = 1e-4 70 | _CN.trainer.clip = 1.0 71 | _CN.trainer.num_steps = 120000 72 | _CN.trainer.epsilon = 1e-8 73 | _CN.trainer.anneal_strategy = "linear" 74 | 75 | 76 | def get_cfg(): 77 | return _CN.clone() 78 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/configs/submission.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | _CN.name = "" 6 | _CN.suffix = "" 7 | _CN.gamma = 0.8 8 | _CN.max_flow = 400 9 | _CN.batch_size = 6 10 | _CN.sum_freq = 100 11 | _CN.val_freq = 5000000 12 | _CN.image_size = [432, 960] 13 | _CN.add_noise = False 14 | _CN.critical_params = [] 15 | 16 | _CN.transformer = "latentcostformer" 17 | _CN.model = r"E:\Work\VFI\Algorithm\GIMM-VFI\pretrained_ckpt\flowformer_sintel.pth" 18 | 19 | # latentcostformer 20 | _CN.latentcostformer = CN() 21 | _CN.latentcostformer.pe = "linear" 22 | _CN.latentcostformer.dropout = 0.0 23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 24 | _CN.latentcostformer.query_latent_dim = 64 25 | _CN.latentcostformer.cost_latent_input_dim = 64 26 | _CN.latentcostformer.cost_latent_token_num = 8 27 | _CN.latentcostformer.cost_latent_dim = 128 28 | _CN.latentcostformer.arc_type = "transformer" 29 | _CN.latentcostformer.cost_heads_num = 1 30 | # encoder 31 | _CN.latentcostformer.pretrain = True 32 | _CN.latentcostformer.context_concat = False 33 | _CN.latentcostformer.encoder_depth = 3 34 | _CN.latentcostformer.feat_cross_attn = False 35 | _CN.latentcostformer.patch_size = 8 36 | _CN.latentcostformer.patch_embed = "single" 37 | _CN.latentcostformer.no_pe = False 38 | _CN.latentcostformer.gma = "GMA" 39 | _CN.latentcostformer.kernel_size = 9 40 | _CN.latentcostformer.rm_res = True 41 | _CN.latentcostformer.vert_c_dim = 64 42 | _CN.latentcostformer.cost_encoder_res = True 43 | _CN.latentcostformer.cnet = "twins" 44 | _CN.latentcostformer.fnet = "twins" 45 | _CN.latentcostformer.no_sc = False 46 | _CN.latentcostformer.only_global = False 47 | _CN.latentcostformer.add_flow_token = True 48 | _CN.latentcostformer.use_mlp = False 49 | _CN.latentcostformer.vertical_conv = False 50 | 51 | # decoder 52 | _CN.latentcostformer.decoder_depth = 32 53 | _CN.latentcostformer.critical_params = [ 54 | "cost_heads_num", 55 | "vert_c_dim", 56 | "cnet", 57 | "pretrain", 58 | "add_flow_token", 59 | "encoder_depth", 60 | "gma", 61 | "cost_encoder_res", 62 | ] 63 | 64 | ### TRAINER 65 | _CN.trainer = CN() 66 | _CN.trainer.scheduler = "OneCycleLR" 67 | _CN.trainer.optimizer = "adamw" 68 | _CN.trainer.canonical_lr = 12.5e-5 69 | _CN.trainer.adamw_decay = 1e-4 70 | _CN.trainer.clip = 1.0 71 | _CN.trainer.num_steps = 120000 72 | _CN.trainer.epsilon = 1e-8 73 | _CN.trainer.anneal_strategy = "linear" 74 | 75 | 76 | def get_cfg(): 77 | return _CN.clone() 78 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/configs/things.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | _CN.name = "" 6 | _CN.suffix = "" 7 | _CN.gamma = 0.8 8 | _CN.max_flow = 400 9 | _CN.batch_size = 6 10 | _CN.sum_freq = 100 11 | _CN.val_freq = 5000000 12 | _CN.image_size = [432, 960] 13 | _CN.add_noise = True 14 | _CN.critical_params = [] 15 | 16 | _CN.transformer = "latentcostformer" 17 | _CN.restore_ckpt = "checkpoints/chairs.pth" 18 | 19 | ####################################### 20 | _CN.latentcostformer = CN() 21 | _CN.latentcostformer.pe = "linear" 22 | _CN.latentcostformer.dropout = 0.0 23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 24 | _CN.latentcostformer.query_latent_dim = 64 25 | _CN.latentcostformer.cost_latent_input_dim = 64 26 | _CN.latentcostformer.cost_latent_token_num = 8 27 | _CN.latentcostformer.cost_latent_dim = 128 28 | _CN.latentcostformer.cost_heads_num = 1 29 | # encoder 30 | _CN.latentcostformer.pretrain = True 31 | _CN.latentcostformer.context_concat = False 32 | _CN.latentcostformer.encoder_depth = 3 33 | _CN.latentcostformer.feat_cross_attn = False 34 | _CN.latentcostformer.nat_rep = "abs" 35 | _CN.latentcostformer.patch_size = 8 36 | _CN.latentcostformer.patch_embed = "single" 37 | _CN.latentcostformer.no_pe = False 38 | _CN.latentcostformer.gma = "GMA" 39 | _CN.latentcostformer.kernel_size = 9 40 | _CN.latentcostformer.rm_res = True 41 | _CN.latentcostformer.vert_c_dim = 64 42 | _CN.latentcostformer.cost_encoder_res = True 43 | _CN.latentcostformer.cnet = "twins" 44 | _CN.latentcostformer.fnet = "twins" 45 | _CN.latentcostformer.only_global = False 46 | _CN.latentcostformer.add_flow_token = True 47 | _CN.latentcostformer.use_mlp = False 48 | _CN.latentcostformer.vertical_conv = False 49 | 50 | # decoder 51 | _CN.latentcostformer.decoder_depth = 12 52 | _CN.latentcostformer.critical_params = [ 53 | "cost_heads_num", 54 | "vert_c_dim", 55 | "cnet", 56 | "pretrain", 57 | "add_flow_token", 58 | "encoder_depth", 59 | "gma", 60 | "cost_encoder_res", 61 | ] 62 | 63 | ### TRAINER 64 | _CN.trainer = CN() 65 | _CN.trainer.scheduler = "OneCycleLR" 66 | _CN.trainer.optimizer = "adamw" 67 | _CN.trainer.canonical_lr = 12.5e-5 68 | _CN.trainer.adamw_decay = 1e-4 69 | _CN.trainer.clip = 1.0 70 | _CN.trainer.num_steps = 120000 71 | _CN.trainer.epsilon = 1e-8 72 | _CN.trainer.anneal_strategy = "linear" 73 | 74 | 75 | def get_cfg(): 76 | return _CN.clone() 77 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/configs/things_eval.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | _CN.name = "" 6 | _CN.suffix = "" 7 | _CN.gamma = 0.8 8 | _CN.max_flow = 400 9 | _CN.batch_size = 6 10 | _CN.sum_freq = 100 11 | _CN.val_freq = 5000000 12 | _CN.image_size = [432, 960] 13 | _CN.add_noise = False 14 | _CN.critical_params = [] 15 | 16 | _CN.transformer = "latentcostformer" 17 | _CN.model = "checkpoints/things.pth" 18 | 19 | # latentcostformer 20 | _CN.latentcostformer = CN() 21 | _CN.latentcostformer.pe = "linear" 22 | _CN.latentcostformer.dropout = 0.0 23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 24 | _CN.latentcostformer.query_latent_dim = 64 25 | _CN.latentcostformer.cost_latent_input_dim = 64 26 | _CN.latentcostformer.cost_latent_token_num = 8 27 | _CN.latentcostformer.cost_latent_dim = 128 28 | _CN.latentcostformer.arc_type = "transformer" 29 | _CN.latentcostformer.cost_heads_num = 1 30 | # encoder 31 | _CN.latentcostformer.pretrain = True 32 | _CN.latentcostformer.context_concat = False 33 | _CN.latentcostformer.encoder_depth = 3 34 | _CN.latentcostformer.feat_cross_attn = False 35 | _CN.latentcostformer.patch_size = 8 36 | _CN.latentcostformer.patch_embed = "single" 37 | _CN.latentcostformer.no_pe = False 38 | _CN.latentcostformer.gma = "GMA" 39 | _CN.latentcostformer.kernel_size = 9 40 | _CN.latentcostformer.rm_res = True 41 | _CN.latentcostformer.vert_c_dim = 64 42 | _CN.latentcostformer.cost_encoder_res = True 43 | _CN.latentcostformer.cnet = "twins" 44 | _CN.latentcostformer.fnet = "twins" 45 | _CN.latentcostformer.no_sc = False 46 | _CN.latentcostformer.only_global = False 47 | _CN.latentcostformer.add_flow_token = True 48 | _CN.latentcostformer.use_mlp = False 49 | _CN.latentcostformer.vertical_conv = False 50 | 51 | # decoder 52 | _CN.latentcostformer.decoder_depth = 32 53 | _CN.latentcostformer.critical_params = [ 54 | "cost_heads_num", 55 | "vert_c_dim", 56 | "cnet", 57 | "pretrain", 58 | "add_flow_token", 59 | "encoder_depth", 60 | "gma", 61 | "cost_encoder_res", 62 | ] 63 | 64 | ### TRAINER 65 | _CN.trainer = CN() 66 | _CN.trainer.scheduler = "OneCycleLR" 67 | _CN.trainer.optimizer = "adamw" 68 | _CN.trainer.canonical_lr = 12.5e-5 69 | _CN.trainer.adamw_decay = 1e-4 70 | _CN.trainer.clip = 1.0 71 | _CN.trainer.num_steps = 120000 72 | _CN.trainer.epsilon = 1e-8 73 | _CN.trainer.anneal_strategy = "linear" 74 | 75 | 76 | def get_cfg(): 77 | return _CN.clone() 78 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/configs/things_flowformer_sharp.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | _CN.name = "" 6 | _CN.suffix = "" 7 | _CN.gamma = 0.8 8 | _CN.max_flow = 400 9 | _CN.batch_size = 6 10 | _CN.sum_freq = 100 11 | _CN.val_freq = 5000000 12 | _CN.image_size = [400, 720] 13 | _CN.add_noise = True 14 | _CN.critical_params = [] 15 | 16 | _CN.transformer = "latentcostformer" 17 | _CN.restore_ckpt = "checkpoints/chairs.pth" 18 | 19 | ####################################### 20 | _CN.latentcostformer = CN() 21 | _CN.latentcostformer.pe = "linear" 22 | _CN.latentcostformer.dropout = 0.0 23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 24 | _CN.latentcostformer.query_latent_dim = 64 25 | _CN.latentcostformer.cost_latent_input_dim = 64 26 | _CN.latentcostformer.cost_latent_token_num = 8 27 | _CN.latentcostformer.cost_latent_dim = 128 28 | _CN.latentcostformer.cost_heads_num = 1 29 | # encoder 30 | _CN.latentcostformer.pretrain = True 31 | _CN.latentcostformer.context_concat = False 32 | _CN.latentcostformer.encoder_depth = 3 33 | _CN.latentcostformer.feat_cross_attn = False 34 | _CN.latentcostformer.nat_rep = "abs" 35 | _CN.latentcostformer.patch_size = 8 36 | _CN.latentcostformer.patch_embed = "single" 37 | _CN.latentcostformer.no_pe = False 38 | _CN.latentcostformer.gma = "GMA" 39 | _CN.latentcostformer.kernel_size = 9 40 | _CN.latentcostformer.rm_res = True 41 | _CN.latentcostformer.vert_c_dim = 64 42 | _CN.latentcostformer.cost_encoder_res = True 43 | _CN.latentcostformer.cnet = "twins" 44 | _CN.latentcostformer.fnet = "twins" 45 | _CN.latentcostformer.only_global = False 46 | _CN.latentcostformer.add_flow_token = True 47 | _CN.latentcostformer.use_mlp = False 48 | _CN.latentcostformer.vertical_conv = False 49 | 50 | # decoder 51 | _CN.latentcostformer.decoder_depth = 12 52 | _CN.latentcostformer.critical_params = [ 53 | "cost_heads_num", 54 | "vert_c_dim", 55 | "cnet", 56 | "pretrain", 57 | "add_flow_token", 58 | "encoder_depth", 59 | "gma", 60 | "cost_encoder_res", 61 | ] 62 | 63 | ### TRAINER 64 | _CN.trainer = CN() 65 | _CN.trainer.scheduler = "OneCycleLR" 66 | _CN.trainer.optimizer = "adamw" 67 | _CN.trainer.canonical_lr = 12.5e-5 68 | _CN.trainer.adamw_decay = 1e-4 69 | _CN.trainer.clip = 1.0 70 | _CN.trainer.num_steps = 120000 71 | _CN.trainer.epsilon = 1e-8 72 | _CN.trainer.anneal_strategy = "linear" 73 | 74 | 75 | def get_cfg(): 76 | return _CN.clone() 77 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | class ConvNextLayer(nn.Module): 9 | def __init__(self, dim, depth=4): 10 | super().__init__() 11 | self.net = nn.Sequential(*[ConvNextBlock(dim=dim) for j in range(depth)]) 12 | 13 | def forward(self, x): 14 | return self.net(x) 15 | 16 | def compute_params(self): 17 | num = 0 18 | for param in self.parameters(): 19 | num += np.prod(param.size()) 20 | 21 | return num 22 | 23 | 24 | class ConvNextBlock(nn.Module): 25 | r"""ConvNeXt Block. There are two equivalent implementations: 26 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 27 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 28 | We use (2) as we find it slightly faster in PyTorch 29 | 30 | Args: 31 | dim (int): Number of input channels. 32 | drop_path (float): Stochastic depth rate. Default: 0.0 33 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 34 | """ 35 | 36 | def __init__(self, dim, layer_scale_init_value=1e-6): 37 | super().__init__() 38 | self.dwconv = nn.Conv2d( 39 | dim, dim, kernel_size=7, padding=3, groups=dim 40 | ) # depthwise conv 41 | self.norm = LayerNorm(dim, eps=1e-6) 42 | self.pwconv1 = nn.Linear( 43 | dim, 4 * dim 44 | ) # pointwise/1x1 convs, implemented with linear layers 45 | self.act = nn.GELU() 46 | self.pwconv2 = nn.Linear(4 * dim, dim) 47 | self.gamma = ( 48 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 49 | if layer_scale_init_value > 0 50 | else None 51 | ) 52 | # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 53 | # print(f"conv next layer") 54 | 55 | def forward(self, x): 56 | input = x 57 | x = self.dwconv(x) 58 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 59 | x = self.norm(x) 60 | x = self.pwconv1(x) 61 | x = self.act(x) 62 | x = self.pwconv2(x) 63 | if self.gamma is not None: 64 | x = self.gamma * x 65 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 66 | 67 | x = input + x 68 | return x 69 | 70 | 71 | class LayerNorm(nn.Module): 72 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. 73 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 74 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 75 | with shape (batch_size, channels, height, width). 76 | """ 77 | 78 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 79 | super().__init__() 80 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 81 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 82 | self.eps = eps 83 | self.data_format = data_format 84 | if self.data_format not in ["channels_last", "channels_first"]: 85 | raise NotImplementedError 86 | self.normalized_shape = (normalized_shape,) 87 | 88 | def forward(self, x): 89 | if self.data_format == "channels_last": 90 | return F.layer_norm( 91 | x, self.normalized_shape, self.weight, self.bias, self.eps 92 | ) 93 | elif self.data_format == "channels_first": 94 | u = x.mean(1, keepdim=True) 95 | s = (x - u).pow(2).mean(1, keepdim=True) 96 | x = (x - u) / torch.sqrt(s + self.eps) 97 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 98 | return x 99 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | 6 | class RelPosEmb(nn.Module): 7 | def __init__(self, max_pos_size, dim_head): 8 | super().__init__() 9 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) 10 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) 11 | 12 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange( 13 | max_pos_size 14 | ).view(-1, 1) 15 | rel_ind = deltas + max_pos_size - 1 16 | self.register_buffer("rel_ind", rel_ind) 17 | 18 | def forward(self, q): 19 | batch, heads, h, w, c = q.shape 20 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) 21 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) 22 | 23 | height_emb = rearrange(height_emb, "(x u) d -> x u () d", x=h) 24 | width_emb = rearrange(width_emb, "(y v) d -> y () v d", y=w) 25 | 26 | height_score = einsum("b h x y d, x u v d -> b h x y u v", q, height_emb) 27 | width_score = einsum("b h x y d, y u v d -> b h x y u v", q, width_emb) 28 | 29 | return height_score + width_score 30 | 31 | 32 | class Attention(nn.Module): 33 | def __init__( 34 | self, 35 | *, 36 | args, 37 | dim, 38 | max_pos_size=100, 39 | heads=4, 40 | dim_head=128, 41 | ): 42 | super().__init__() 43 | self.args = args 44 | self.heads = heads 45 | self.scale = dim_head**-0.5 46 | inner_dim = heads * dim_head 47 | 48 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) 49 | 50 | self.pos_emb = RelPosEmb(max_pos_size, dim_head) 51 | for param in self.pos_emb.parameters(): 52 | param.requires_grad = False 53 | 54 | def forward(self, fmap): 55 | heads, b, c, h, w = self.heads, *fmap.shape 56 | 57 | q, k = self.to_qk(fmap).chunk(2, dim=1) 58 | 59 | q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k)) 60 | q = self.scale * q 61 | 62 | # if self.args.position_only: 63 | # sim = self.pos_emb(q) 64 | 65 | # elif self.args.position_and_content: 66 | # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 67 | # sim_pos = self.pos_emb(q) 68 | # sim = sim_content + sim_pos 69 | 70 | # else: 71 | sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k) 72 | 73 | sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)") 74 | attn = sim.softmax(dim=-1) 75 | 76 | return attn 77 | 78 | 79 | class Aggregate(nn.Module): 80 | def __init__( 81 | self, 82 | args, 83 | dim, 84 | heads=4, 85 | dim_head=128, 86 | ): 87 | super().__init__() 88 | self.args = args 89 | self.heads = heads 90 | self.scale = dim_head**-0.5 91 | inner_dim = heads * dim_head 92 | 93 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 94 | 95 | self.gamma = nn.Parameter(torch.zeros(1)) 96 | 97 | if dim != inner_dim: 98 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) 99 | else: 100 | self.project = None 101 | 102 | def forward(self, attn, fmap): 103 | heads, b, c, h, w = self.heads, *fmap.shape 104 | 105 | v = self.to_v(fmap) 106 | v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads) 107 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 108 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 109 | 110 | if self.project is not None: 111 | out = self.project(out) 112 | 113 | out = fmap + self.gamma * out 114 | 115 | return out 116 | 117 | 118 | if __name__ == "__main__": 119 | att = Attention(dim=128, heads=1) 120 | fmap = torch.randn(2, 128, 40, 90) 121 | out = att(fmap) 122 | 123 | print(out.shape) 124 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from einops.layers.torch import Rearrange, Reduce 3 | from functools import partial 4 | import numpy as np 5 | 6 | 7 | class PreNormResidual(nn.Module): 8 | def __init__(self, dim, fn): 9 | super().__init__() 10 | self.fn = fn 11 | self.norm = nn.LayerNorm(dim) 12 | 13 | def forward(self, x): 14 | return self.fn(self.norm(x)) + x 15 | 16 | 17 | def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear): 18 | return nn.Sequential( 19 | dense(dim, dim * expansion_factor), 20 | nn.GELU(), 21 | nn.Dropout(dropout), 22 | dense(dim * expansion_factor, dim), 23 | nn.Dropout(dropout), 24 | ) 25 | 26 | 27 | class MLPMixerLayer(nn.Module): 28 | def __init__(self, dim, cfg, drop_path=0.0, dropout=0.0): 29 | super(MLPMixerLayer, self).__init__() 30 | 31 | # print(f"use mlp mixer layer") 32 | K = cfg.cost_latent_token_num 33 | expansion_factor = cfg.mlp_expansion_factor 34 | chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear 35 | 36 | self.mlpmixer = nn.Sequential( 37 | PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)), 38 | PreNormResidual( 39 | dim, FeedForward(dim, expansion_factor, dropout, chan_last) 40 | ), 41 | ) 42 | 43 | def compute_params(self): 44 | num = 0 45 | for param in self.mlpmixer.parameters(): 46 | num += np.prod(param.size()) 47 | 48 | return num 49 | 50 | def forward(self, x): 51 | """ 52 | x: [BH1W1, K, D] 53 | """ 54 | 55 | return self.mlpmixer(x) 56 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py: -------------------------------------------------------------------------------- 1 | import loguru 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import einsum 6 | 7 | from einops.layers.torch import Rearrange 8 | from einops import rearrange 9 | 10 | from ...utils.utils import coords_grid, bilinear_sampler, upflow8 11 | from ..common import ( 12 | FeedForward, 13 | pyramid_retrieve_tokens, 14 | sampler, 15 | sampler_gaussian_fix, 16 | retrieve_tokens, 17 | MultiHeadAttention, 18 | MLP, 19 | ) 20 | from ..encoders import twins_svt_large_context, twins_svt_large 21 | from ...position_encoding import PositionEncodingSine, LinearPositionEncoding 22 | from .twins import PosConv 23 | from .encoder import MemoryEncoder 24 | from .decoder import MemoryDecoder 25 | from .cnn import BasicEncoder 26 | 27 | 28 | class FlowFormer(nn.Module): 29 | def __init__(self, cfg): 30 | super(FlowFormer, self).__init__() 31 | self.cfg = cfg 32 | 33 | self.memory_encoder = MemoryEncoder(cfg) 34 | self.memory_decoder = MemoryDecoder(cfg) 35 | if cfg.cnet == "twins": 36 | self.context_encoder = twins_svt_large(pretrained=self.cfg.pretrain) 37 | elif cfg.cnet == "basicencoder": 38 | self.context_encoder = BasicEncoder(output_dim=256, norm_fn="instance") 39 | 40 | def build_coord(self, img): 41 | N, C, H, W = img.shape 42 | coords = coords_grid(N, H // 8, W // 8) 43 | return coords 44 | 45 | 46 | def forward( 47 | self, image1, image2, output=None, flow_init=None, return_feat=False, iters=None 48 | ): 49 | # Following https://github.com/princeton-vl/RAFT/ 50 | image1 = 2 * (image1 / 255.0) - 1.0 51 | image2 = 2 * (image2 / 255.0) - 1.0 52 | 53 | data = {} 54 | 55 | if self.cfg.context_concat: 56 | context = self.context_encoder(torch.cat([image1, image2], dim=1)) 57 | else: 58 | if return_feat: 59 | context, cfeat = self.context_encoder(image1, return_feat=return_feat) 60 | else: 61 | context = self.context_encoder(image1) 62 | if return_feat: 63 | cost_memory, ffeat = self.memory_encoder( 64 | image1, image2, data, context, return_feat=return_feat 65 | ) 66 | else: 67 | cost_memory = self.memory_encoder(image1, image2, data, context) 68 | 69 | flow_predictions = self.memory_decoder( 70 | cost_memory, context, data, flow_init=flow_init, iters=iters 71 | ) 72 | 73 | if return_feat: 74 | return flow_predictions, cfeat, ffeat 75 | return flow_predictions 76 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def build_flowformer(cfg): 5 | name = cfg.transformer 6 | if name == "latentcostformer": 7 | from .LatentCostFormer.transformer import FlowFormer 8 | else: 9 | raise ValueError(f"FlowFormer = {name} is not a valid architecture!") 10 | 11 | return FlowFormer(cfg[name]) 12 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import numpy as np 5 | 6 | 7 | class twins_svt_large(nn.Module): 8 | def __init__(self, pretrained=True): 9 | super().__init__() 10 | self.svt = timm.create_model("twins_svt_large", pretrained=pretrained) 11 | 12 | del self.svt.head 13 | del self.svt.patch_embeds[2] 14 | del self.svt.patch_embeds[2] 15 | del self.svt.blocks[2] 16 | del self.svt.blocks[2] 17 | del self.svt.pos_block[2] 18 | del self.svt.pos_block[2] 19 | self.svt.norm.weight.requires_grad = False 20 | self.svt.norm.bias.requires_grad = False 21 | 22 | def forward(self, x, data=None, layer=2, return_feat=False): 23 | B = x.shape[0] 24 | if return_feat: 25 | feat = [] 26 | for i, (embed, drop, blocks, pos_blk) in enumerate( 27 | zip( 28 | self.svt.patch_embeds, 29 | self.svt.pos_drops, 30 | self.svt.blocks, 31 | self.svt.pos_block, 32 | ) 33 | ): 34 | x, size = embed(x) 35 | x = drop(x) 36 | for j, blk in enumerate(blocks): 37 | x = blk(x, size) 38 | if j == 0: 39 | x = pos_blk(x, size) 40 | if i < len(self.svt.depths) - 1: 41 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 42 | if return_feat: 43 | feat.append(x) 44 | if i == layer - 1: 45 | break 46 | if return_feat: 47 | return x, feat 48 | return x 49 | 50 | def compute_params(self, layer=2): 51 | num = 0 52 | for i, (embed, drop, blocks, pos_blk) in enumerate( 53 | zip( 54 | self.svt.patch_embeds, 55 | self.svt.pos_drops, 56 | self.svt.blocks, 57 | self.svt.pos_block, 58 | ) 59 | ): 60 | for param in embed.parameters(): 61 | num += np.prod(param.size()) 62 | 63 | for param in drop.parameters(): 64 | num += np.prod(param.size()) 65 | 66 | for param in blocks.parameters(): 67 | num += np.prod(param.size()) 68 | 69 | for param in pos_blk.parameters(): 70 | num += np.prod(param.size()) 71 | 72 | if i == layer - 1: 73 | break 74 | 75 | for param in self.svt.head.parameters(): 76 | num += np.prod(param.size()) 77 | 78 | return num 79 | 80 | 81 | class twins_svt_large_context(nn.Module): 82 | def __init__(self, pretrained=True): 83 | super().__init__() 84 | self.svt = timm.create_model("twins_svt_large_context", pretrained=pretrained) 85 | 86 | def forward(self, x, data=None, layer=2): 87 | B = x.shape[0] 88 | for i, (embed, drop, blocks, pos_blk) in enumerate( 89 | zip( 90 | self.svt.patch_embeds, 91 | self.svt.pos_drops, 92 | self.svt.blocks, 93 | self.svt.pos_block, 94 | ) 95 | ): 96 | x, size = embed(x) 97 | x = drop(x) 98 | for j, blk in enumerate(blocks): 99 | x = blk(x, size) 100 | if j == 0: 101 | x = pos_blk(x, size) 102 | if i < len(self.svt.depths) - 1: 103 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 104 | 105 | if i == layer - 1: 106 | break 107 | 108 | return x 109 | 110 | 111 | if __name__ == "__main__": 112 | m = twins_svt_large() 113 | input = torch.randn(2, 3, 400, 800) 114 | out = m.extract_feature(input) 115 | print(out.shape) 116 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/flowformer/core/__init__.py -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels - 1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2 * r + 1) 38 | dy = torch.linspace(-r, r, 2 * r + 1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | corr = bilinear_sampler(corr, coords_lvl) 45 | corr = corr.view(batch, h1, w1, -1) 46 | out_pyramid.append(corr) 47 | 48 | out = torch.cat(out_pyramid, dim=-1) 49 | return out.permute(0, 3, 1, 2).contiguous().float() 50 | 51 | @staticmethod 52 | def corr(fmap1, fmap2): 53 | batch, dim, ht, wd = fmap1.shape 54 | fmap1 = fmap1.view(batch, dim, ht * wd) 55 | fmap2 = fmap2.view(batch, dim, ht * wd) 56 | 57 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 58 | corr = corr.view(batch, ht, wd, 1, ht, wd) 59 | return corr / torch.sqrt(torch.tensor(dim).float()) 60 | 61 | 62 | class AlternateCorrBlock: 63 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 64 | self.num_levels = num_levels 65 | self.radius = radius 66 | 67 | self.pyramid = [(fmap1, fmap2)] 68 | for i in range(self.num_levels): 69 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 70 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 71 | self.pyramid.append((fmap1, fmap2)) 72 | 73 | def __call__(self, coords): 74 | coords = coords.permute(0, 2, 3, 1) 75 | B, H, W, _ = coords.shape 76 | dim = self.pyramid[0][0].shape[1] 77 | 78 | corr_list = [] 79 | for i in range(self.num_levels): 80 | r = self.radius 81 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 82 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 83 | 84 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 85 | (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 86 | corr_list.append(corr.squeeze(1)) 87 | 88 | corr = torch.stack(corr_list, dim=1) 89 | corr = corr.reshape(B, -1, H, W) 90 | return corr / torch.sqrt(torch.tensor(dim).float()) 91 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | MAX_FLOW = 400 4 | 5 | 6 | def sequence_loss(flow_preds, flow_gt, valid, cfg): 7 | """Loss function defined over sequence of flow predictions""" 8 | 9 | gamma = cfg.gamma 10 | max_flow = cfg.max_flow 11 | n_predictions = len(flow_preds) 12 | flow_loss = 0.0 13 | flow_gt_thresholds = [5, 10, 20] 14 | 15 | # exlude invalid pixels and extremely large diplacements 16 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 17 | valid = (valid >= 0.5) & (mag < max_flow) 18 | 19 | for i in range(n_predictions): 20 | i_weight = gamma ** (n_predictions - i - 1) 21 | i_loss = (flow_preds[i] - flow_gt).abs() 22 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 23 | 24 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() 25 | epe = epe.view(-1)[valid.view(-1)] 26 | 27 | metrics = { 28 | "epe": epe.mean().item(), 29 | "1px": (epe < 1).float().mean().item(), 30 | "3px": (epe < 3).float().mean().item(), 31 | "5px": (epe < 5).float().mean().item(), 32 | } 33 | 34 | flow_gt_length = torch.sum(flow_gt**2, dim=1).sqrt() 35 | flow_gt_length = flow_gt_length.view(-1)[valid.view(-1)] 36 | for t in flow_gt_thresholds: 37 | e = epe[flow_gt_length < t] 38 | metrics.update({f"{t}-th-5px": (e < 5).float().mean().item()}) 39 | 40 | return flow_loss, metrics 41 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import ( 3 | MultiStepLR, 4 | CosineAnnealingLR, 5 | ExponentialLR, 6 | OneCycleLR, 7 | ) 8 | 9 | 10 | def fetch_optimizer(model, cfg): 11 | """Create the optimizer and learning rate scheduler""" 12 | # optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 13 | 14 | # scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, 15 | # pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 16 | optimizer = build_optimizer(model, cfg) 17 | scheduler = build_scheduler(cfg, optimizer) 18 | 19 | return optimizer, scheduler 20 | 21 | 22 | def build_optimizer(model, config): 23 | name = config.optimizer 24 | lr = config.canonical_lr 25 | 26 | if name == "adam": 27 | return torch.optim.Adam( 28 | model.parameters(), 29 | lr=lr, 30 | weight_decay=config.adam_decay, 31 | eps=config.epsilon, 32 | ) 33 | elif name == "adamw": 34 | if hasattr(config, "twins_lr_factor"): 35 | factor = config.twins_lr_factor 36 | print("[Decrease lr of pre-trained model by factor {}]".format(factor)) 37 | param_dicts = [ 38 | { 39 | "params": [ 40 | p 41 | for n, p in model.named_parameters() 42 | if "feat_encoder" not in n 43 | and "context_encoder" not in n 44 | and p.requires_grad 45 | ] 46 | }, 47 | { 48 | "params": [ 49 | p 50 | for n, p in model.named_parameters() 51 | if ("feat_encoder" in n or "context_encoder" in n) 52 | and p.requires_grad 53 | ], 54 | "lr": lr * factor, 55 | }, 56 | ] 57 | full = [n for n, _ in model.named_parameters()] 58 | return torch.optim.AdamW( 59 | param_dicts, lr=lr, weight_decay=config.adamw_decay, eps=config.epsilon 60 | ) 61 | else: 62 | return torch.optim.AdamW( 63 | model.parameters(), 64 | lr=lr, 65 | weight_decay=config.adamw_decay, 66 | eps=config.epsilon, 67 | ) 68 | else: 69 | raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") 70 | 71 | 72 | def build_scheduler(config, optimizer): 73 | """ 74 | Returns: 75 | scheduler (dict):{ 76 | 'scheduler': lr_scheduler, 77 | 'interval': 'step', # or 'epoch' 78 | } 79 | """ 80 | # scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} 81 | name = config.scheduler 82 | lr = config.canonical_lr 83 | 84 | if name == "OneCycleLR": 85 | # scheduler = OneCycleLR(optimizer, ) 86 | if hasattr(config, "twins_lr_factor"): 87 | factor = config.twins_lr_factor 88 | scheduler = OneCycleLR( 89 | optimizer, 90 | [lr, lr * factor], 91 | config.num_steps + 100, 92 | pct_start=0.05, 93 | cycle_momentum=False, 94 | anneal_strategy=config.anneal_strategy, 95 | ) 96 | else: 97 | scheduler = OneCycleLR( 98 | optimizer, 99 | lr, 100 | config.num_steps + 100, 101 | pct_start=0.05, 102 | cycle_momentum=False, 103 | anneal_strategy=config.anneal_strategy, 104 | ) 105 | # elif name == 'MultiStepLR': 106 | # scheduler.update( 107 | # {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) 108 | # elif name == 'CosineAnnealing': 109 | # scheduler = CosineAnnealingLR(optimizer, config.num_steps+100) 110 | # scheduler.update( 111 | # {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) 112 | # elif name == 'ExponentialLR': 113 | # scheduler.update( 114 | # {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) 115 | else: 116 | raise NotImplementedError() 117 | 118 | return scheduler 119 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/position_encoding.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import math 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class PositionEncodingSine(nn.Module): 8 | """ 9 | This is a sinusoidal position encoding that generalized to 2-dimensional images 10 | """ 11 | 12 | def __init__(self, d_model, max_shape=(256, 256)): 13 | """ 14 | Args: 15 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels 16 | """ 17 | super().__init__() 18 | 19 | pe = torch.zeros((d_model, *max_shape)) 20 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) 21 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) 22 | div_term = torch.exp( 23 | torch.arange(0, d_model // 2, 2).float() 24 | * (-math.log(10000.0) / d_model // 2) 25 | ) 26 | div_term = div_term[:, None, None] # [C//4, 1, 1] 27 | pe[0::4, :, :] = torch.sin(x_position * div_term) 28 | pe[1::4, :, :] = torch.cos(x_position * div_term) 29 | pe[2::4, :, :] = torch.sin(y_position * div_term) 30 | pe[3::4, :, :] = torch.cos(y_position * div_term) 31 | 32 | self.register_buffer("pe", pe.unsqueeze(0)) # [1, C, H, W] 33 | 34 | def forward(self, x): 35 | """ 36 | Args: 37 | x: [N, C, H, W] 38 | """ 39 | return x + self.pe[:, :, : x.size(2), : x.size(3)] 40 | 41 | 42 | class LinearPositionEncoding(nn.Module): 43 | """ 44 | This is a sinusoidal position encoding that generalized to 2-dimensional images 45 | """ 46 | 47 | def __init__(self, d_model, max_shape=(256, 256)): 48 | """ 49 | Args: 50 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels 51 | """ 52 | super().__init__() 53 | 54 | pe = torch.zeros((d_model, *max_shape)) 55 | y_position = ( 56 | torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1 57 | ) / max_shape[0] 58 | x_position = ( 59 | torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1 60 | ) / max_shape[1] 61 | div_term = torch.arange(0, d_model // 2, 2).float() 62 | div_term = div_term[:, None, None] # [C//4, 1, 1] 63 | pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi) 64 | pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi) 65 | pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi) 66 | pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi) 67 | 68 | self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W] 69 | 70 | def forward(self, x): 71 | """ 72 | Args: 73 | x: [N, C, H, W] 74 | """ 75 | # assert x.shape[2] == 80 and x.shape[3] == 80 76 | 77 | return x + self.pe[:, :, : x.size(2), : x.size(3)] 78 | 79 | 80 | class LearnedPositionEncoding(nn.Module): 81 | """ 82 | This is a sinusoidal position encoding that generalized to 2-dimensional images 83 | """ 84 | 85 | def __init__(self, d_model, max_shape=(80, 80)): 86 | """ 87 | Args: 88 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels 89 | """ 90 | super().__init__() 91 | 92 | self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model)) 93 | 94 | def forward(self, x): 95 | """ 96 | Args: 97 | x: [N, C, H, W] 98 | """ 99 | # assert x.shape[2] == 80 and x.shape[3] == 80 100 | 101 | return x + self.pe 102 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, AlternateCorrBlock 9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | pass 24 | 25 | 26 | class RAFT(nn.Module): 27 | def __init__(self, args): 28 | super(RAFT, self).__init__() 29 | self.args = args 30 | 31 | if args.small: 32 | self.hidden_dim = hdim = 96 33 | self.context_dim = cdim = 64 34 | args.corr_levels = 4 35 | args.corr_radius = 3 36 | 37 | else: 38 | self.hidden_dim = hdim = 128 39 | self.context_dim = cdim = 128 40 | args.corr_levels = 4 41 | args.corr_radius = 4 42 | 43 | if "dropout" not in self.args: 44 | self.args.dropout = 0 45 | 46 | if "alternate_corr" not in self.args: 47 | self.args.alternate_corr = False 48 | 49 | # feature network, context network, and update block 50 | if args.small: 51 | self.fnet = SmallEncoder( 52 | output_dim=128, norm_fn="instance", dropout=args.dropout 53 | ) 54 | self.cnet = SmallEncoder( 55 | output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout 56 | ) 57 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 58 | 59 | else: 60 | self.fnet = BasicEncoder( 61 | output_dim=256, norm_fn="instance", dropout=args.dropout 62 | ) 63 | self.cnet = BasicEncoder( 64 | output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout 65 | ) 66 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 67 | 68 | def freeze_bn(self): 69 | for m in self.modules(): 70 | if isinstance(m, nn.BatchNorm2d): 71 | m.eval() 72 | 73 | def initialize_flow(self, img): 74 | """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 75 | N, C, H, W = img.shape 76 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 77 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 78 | 79 | # optical flow computed as difference: flow = coords1 - coords0 80 | return coords0, coords1 81 | 82 | def upsample_flow(self, flow, mask): 83 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" 84 | N, _, H, W = flow.shape 85 | mask = mask.view(N, 1, 9, 8, 8, H, W) 86 | mask = torch.softmax(mask, dim=2) 87 | 88 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 89 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 90 | 91 | up_flow = torch.sum(mask * up_flow, dim=2) 92 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 93 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 94 | 95 | def forward( 96 | self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False 97 | ): 98 | """Estimate optical flow between pair of frames""" 99 | 100 | image1 = 2 * (image1 / 255.0) - 1.0 101 | image2 = 2 * (image2 / 255.0) - 1.0 102 | 103 | image1 = image1.contiguous() 104 | image2 = image2.contiguous() 105 | 106 | hdim = self.hidden_dim 107 | cdim = self.context_dim 108 | 109 | # run the feature network 110 | with autocast(enabled=self.args.mixed_precision): 111 | fmap1, fmap2 = self.fnet([image1, image2]) 112 | 113 | fmap1 = fmap1.float() 114 | fmap2 = fmap2.float() 115 | if self.args.alternate_corr: 116 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 117 | else: 118 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 119 | 120 | # run the context network 121 | with autocast(enabled=self.args.mixed_precision): 122 | cnet = self.cnet(image1) 123 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 124 | net = torch.tanh(net) 125 | inp = torch.relu(inp) 126 | 127 | coords0, coords1 = self.initialize_flow(image1) 128 | 129 | if flow_init is not None: 130 | coords1 = coords1 + flow_init 131 | 132 | flow_predictions = [] 133 | for itr in range(iters): 134 | coords1 = coords1.detach() 135 | corr = corr_fn(coords1) # index correlation volume 136 | 137 | flow = coords1 - coords0 138 | with autocast(enabled=self.args.mixed_precision): 139 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 140 | 141 | # F(t+1) = F(t) + \Delta(t) 142 | coords1 = coords1 + delta_flow 143 | 144 | # upsample predictions 145 | if up_mask is None: 146 | flow_up = upflow8(coords1 - coords0) 147 | else: 148 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 149 | 150 | flow_predictions.append(flow_up) 151 | 152 | if test_mode: 153 | return coords1 - coords0, flow_up 154 | 155 | return flow_predictions 156 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/flowformer/core/utils/__init__.py -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | 21 | def make_colorwheel(): 22 | """ 23 | Generates a color wheel for optical flow visualization as presented in: 24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 26 | 27 | Code follows the original C++ source code of Daniel Scharstein. 28 | Code follows the the Matlab source code of Deqing Sun. 29 | 30 | Returns: 31 | np.ndarray: Color wheel 32 | """ 33 | 34 | RY = 15 35 | YG = 6 36 | GC = 4 37 | CB = 11 38 | BM = 13 39 | MR = 6 40 | 41 | ncols = RY + YG + GC + CB + BM + MR 42 | colorwheel = np.zeros((ncols, 3)) 43 | col = 0 44 | 45 | # RY 46 | colorwheel[0:RY, 0] = 255 47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 48 | col = col + RY 49 | # YG 50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 51 | colorwheel[col : col + YG, 1] = 255 52 | col = col + YG 53 | # GC 54 | colorwheel[col : col + GC, 1] = 255 55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 56 | col = col + GC 57 | # CB 58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 59 | colorwheel[col : col + CB, 2] = 255 60 | col = col + CB 61 | # BM 62 | colorwheel[col : col + BM, 2] = 255 63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 64 | col = col + BM 65 | # MR 66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 67 | colorwheel[col : col + MR, 0] = 255 68 | return colorwheel 69 | 70 | 71 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 72 | """ 73 | Applies the flow color wheel to (possibly clipped) flow components u and v. 74 | 75 | According to the C++ source code of Daniel Scharstein 76 | According to the Matlab source code of Deqing Sun 77 | 78 | Args: 79 | u (np.ndarray): Input horizontal flow of shape [H,W] 80 | v (np.ndarray): Input vertical flow of shape [H,W] 81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 82 | 83 | Returns: 84 | np.ndarray: Flow visualization image of shape [H,W,3] 85 | """ 86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 87 | colorwheel = make_colorwheel() # shape [55x3] 88 | ncols = colorwheel.shape[0] 89 | rad = np.sqrt(np.square(u) + np.square(v)) 90 | a = np.arctan2(-v, -u) / np.pi 91 | fk = (a + 1) / 2 * (ncols - 1) 92 | k0 = np.floor(fk).astype(np.int32) 93 | k1 = k0 + 1 94 | k1[k1 == ncols] = 0 95 | f = fk - k0 96 | for i in range(colorwheel.shape[1]): 97 | tmp = colorwheel[:, i] 98 | col0 = tmp[k0] / 255.0 99 | col1 = tmp[k1] / 255.0 100 | col = (1 - f) * col0 + f * col1 101 | idx = rad <= 1 102 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 103 | col[~idx] = col[~idx] * 0.75 # out of range 104 | # Note the 2-i => BGR instead of RGB 105 | ch_idx = 2 - i if convert_to_bgr else i 106 | flow_image[:, :, ch_idx] = np.floor(255 * col) 107 | return flow_image 108 | 109 | 110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, max_flow=None): 111 | """ 112 | Expects a two dimensional flow image of shape. 113 | 114 | Args: 115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 118 | 119 | Returns: 120 | np.ndarray: Flow visualization image of shape [H,W,3] 121 | """ 122 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 124 | if clip_flow is not None: 125 | flow_uv = np.clip(flow_uv, 0, clip_flow) 126 | u = flow_uv[:, :, 0] 127 | v = flow_uv[:, :, 1] 128 | if max_flow is None: 129 | rad = np.sqrt(np.square(u) + np.square(v)) 130 | rad_max = np.max(rad) 131 | else: 132 | rad_max = max_flow 133 | epsilon = 1e-5 134 | u = u / (rad_max + epsilon) 135 | v = v / (rad_max + epsilon) 136 | return flow_uv_to_colors(u, v, convert_to_bgr) 137 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | 14 | def readFlow(fn): 15 | """Read .flo file in Middlebury format""" 16 | # Code adapted from: 17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 18 | 19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 20 | # print 'fn = %s'%(fn) 21 | with open(fn, "rb") as f: 22 | magic = np.fromfile(f, np.float32, count=1) 23 | if 202021.25 != magic: 24 | print("Magic number incorrect. Invalid .flo file") 25 | return None 26 | else: 27 | w = np.fromfile(f, np.int32, count=1) 28 | h = np.fromfile(f, np.int32, count=1) 29 | # print 'Reading %d x %d flo file\n' % (w, h) 30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 31 | # Reshape data into 3D array (columns, rows, bands) 32 | # The reshape here is for visualization, the original code is (w,h,2) 33 | return np.resize(data, (int(h), int(w), 2)) 34 | 35 | 36 | def readPFM(file): 37 | file = open(file, "rb") 38 | 39 | color = None 40 | width = None 41 | height = None 42 | scale = None 43 | endian = None 44 | 45 | header = file.readline().rstrip() 46 | if header == b"PF": 47 | color = True 48 | elif header == b"Pf": 49 | color = False 50 | else: 51 | raise Exception("Not a PFM file.") 52 | 53 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) 54 | if dim_match: 55 | width, height = map(int, dim_match.groups()) 56 | else: 57 | raise Exception("Malformed PFM header.") 58 | 59 | scale = float(file.readline().rstrip()) 60 | if scale < 0: # little-endian 61 | endian = "<" 62 | scale = -scale 63 | else: 64 | endian = ">" # big-endian 65 | 66 | data = np.fromfile(file, endian + "f") 67 | shape = (height, width, 3) if color else (height, width) 68 | 69 | data = np.reshape(data, shape) 70 | data = np.flipud(data) 71 | return data 72 | 73 | 74 | def writeFlow(filename, uv, v=None): 75 | """Write optical flow to file. 76 | 77 | If v is None, uv is assumed to contain both u and v channels, 78 | stacked in depth. 79 | Original code by Deqing Sun, adapted from Daniel Scharstein. 80 | """ 81 | nBands = 2 82 | 83 | if v is None: 84 | assert uv.ndim == 3 85 | assert uv.shape[2] == 2 86 | u = uv[:, :, 0] 87 | v = uv[:, :, 1] 88 | else: 89 | u = uv 90 | 91 | assert u.shape == v.shape 92 | height, width = u.shape 93 | f = open(filename, "wb") 94 | # write the header 95 | f.write(TAG_CHAR) 96 | np.array(width).astype(np.int32).tofile(f) 97 | np.array(height).astype(np.int32).tofile(f) 98 | # arrange into matrix form 99 | tmp = np.zeros((height, width * nBands)) 100 | tmp[:, np.arange(width) * 2] = u 101 | tmp[:, np.arange(width) * 2 + 1] = v 102 | tmp.astype(np.float32).tofile(f) 103 | f.close() 104 | 105 | 106 | def readFlowKITTI(filename): 107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 108 | flow = flow[:, :, ::-1].astype(np.float32) 109 | flow, valid = flow[:, :, :2], flow[:, :, 2] 110 | flow = (flow - 2**15) / 64.0 111 | return flow, valid 112 | 113 | 114 | def readDispKITTI(filename): 115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 116 | valid = disp > 0.0 117 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 118 | return flow, valid 119 | 120 | 121 | def writeFlowKITTI(filename, uv): 122 | uv = 64.0 * uv + 2**15 123 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 125 | cv2.imwrite(filename, uv[..., ::-1]) 126 | 127 | 128 | def read_gen(file_name, pil=False): 129 | ext = splitext(file_name)[-1] 130 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": 131 | return Image.open(file_name) 132 | elif ext == ".bin" or ext == ".raw": 133 | return np.load(file_name) 134 | elif ext == ".flo": 135 | return readFlow(file_name).astype(np.float32) 136 | elif ext == ".pfm": 137 | flow = readPFM(file_name).astype(np.float32) 138 | if len(flow.shape) == 2: 139 | return flow 140 | else: 141 | return flow[:, :, :-1] 142 | return [] 143 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/utils/logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from loguru import logger as loguru_logger 3 | 4 | 5 | class Logger: 6 | def __init__(self, model, scheduler, cfg): 7 | self.model = model 8 | self.scheduler = scheduler 9 | self.total_steps = 0 10 | self.running_loss = {} 11 | self.writer = None 12 | self.cfg = cfg 13 | 14 | def _print_training_status(self): 15 | metrics_data = [ 16 | self.running_loss[k] / self.cfg.sum_freq 17 | for k in sorted(self.running_loss.keys()) 18 | ] 19 | training_str = "[{:6d}, {}] ".format( 20 | self.total_steps + 1, self.scheduler.get_last_lr() 21 | ) 22 | metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) 23 | 24 | # print the training status 25 | loguru_logger.info(training_str + metrics_str) 26 | 27 | if self.writer is None: 28 | if self.cfg.log_dir is None: 29 | self.writer = SummaryWriter() 30 | else: 31 | self.writer = SummaryWriter(self.cfg.log_dir) 32 | 33 | for k in self.running_loss: 34 | self.writer.add_scalar( 35 | k, self.running_loss[k] / self.cfg.sum_freq, self.total_steps 36 | ) 37 | self.running_loss[k] = 0.0 38 | 39 | def push(self, metrics): 40 | self.total_steps += 1 41 | 42 | for key in metrics: 43 | if key not in self.running_loss: 44 | self.running_loss[key] = 0.0 45 | 46 | self.running_loss[key] += metrics[key] 47 | 48 | if self.total_steps % self.cfg.sum_freq == self.cfg.sum_freq - 1: 49 | self._print_training_status() 50 | self.running_loss = {} 51 | 52 | def write_dict(self, results): 53 | if self.writer is None: 54 | self.writer = SummaryWriter() 55 | 56 | for key in results: 57 | self.writer.add_scalar(key, results[key], self.total_steps) 58 | 59 | def close(self): 60 | self.writer.close() 61 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import shutil 4 | 5 | 6 | def process_transformer_cfg(cfg): 7 | log_dir = "" 8 | if "critical_params" in cfg: 9 | critical_params = [cfg[key] for key in cfg.critical_params] 10 | for name, param in zip(cfg["critical_params"], critical_params): 11 | log_dir += "{:s}[{:s}]".format(name, str(param)) 12 | 13 | return log_dir 14 | 15 | 16 | def process_cfg(cfg): 17 | log_dir = "logs/" + cfg.name + "/" + cfg.transformer + "/" 18 | critical_params = [cfg.trainer[key] for key in cfg.critical_params] 19 | for name, param in zip(cfg["critical_params"], critical_params): 20 | log_dir += "{:s}[{:s}]".format(name, str(param)) 21 | 22 | log_dir += process_transformer_cfg(cfg[cfg.transformer]) 23 | 24 | now = time.localtime() 25 | now_time = "{:02d}_{:02d}_{:02d}_{:02d}".format( 26 | now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min 27 | ) 28 | log_dir += cfg.suffix + "(" + now_time + ")" 29 | cfg.log_dir = log_dir 30 | os.makedirs(log_dir) 31 | 32 | shutil.copytree("configs", f"{log_dir}/configs") 33 | shutil.copytree("core/FlowFormer", f"{log_dir}/FlowFormer") 34 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """Pads images such that dimensions are divisible by 8""" 9 | 10 | def __init__(self, dims, mode="sintel"): 11 | self.ht, self.wd = dims[-2:] 12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 14 | if mode == "sintel": 15 | self._pad = [ 16 | pad_wd // 2, 17 | pad_wd - pad_wd // 2, 18 | pad_ht // 2, 19 | pad_ht - pad_ht // 2, 20 | ] 21 | elif mode == "kitti400": 22 | self._pad = [0, 0, 0, 400 - self.ht] 23 | else: 24 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 25 | 26 | def pad(self, *inputs): 27 | return [F.pad(x, self._pad, mode="replicate") for x in inputs] 28 | 29 | def unpad(self, x): 30 | ht, wd = x.shape[-2:] 31 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 32 | return x[..., c[0] : c[1], c[2] : c[3]] 33 | 34 | 35 | def forward_interpolate(flow): 36 | flow = flow.detach().cpu().numpy() 37 | dx, dy = flow[0], flow[1] 38 | 39 | ht, wd = dx.shape 40 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 41 | 42 | x1 = x0 + dx 43 | y1 = y0 + dy 44 | 45 | x1 = x1.reshape(-1) 46 | y1 = y1.reshape(-1) 47 | dx = dx.reshape(-1) 48 | dy = dy.reshape(-1) 49 | 50 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 51 | x1 = x1[valid] 52 | y1 = y1[valid] 53 | dx = dx[valid] 54 | dy = dy[valid] 55 | 56 | flow_x = interpolate.griddata( 57 | (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 58 | ) 59 | 60 | flow_y = interpolate.griddata( 61 | (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 62 | ) 63 | 64 | flow = np.stack([flow_x, flow_y], axis=0) 65 | return torch.from_numpy(flow).float() 66 | 67 | 68 | def bilinear_sampler(img, coords, mode="bilinear", mask=False): 69 | """Wrapper for grid_sample, uses pixel coordinates""" 70 | H, W = img.shape[-2:] 71 | xgrid, ygrid = coords.split([1, 1], dim=-1) 72 | xgrid = 2 * xgrid / (W - 1) - 1 73 | ygrid = 2 * ygrid / (H - 1) - 1 74 | 75 | grid = torch.cat([xgrid, ygrid], dim=-1) 76 | img = F.grid_sample(img, grid, align_corners=True) 77 | 78 | if mask: 79 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 80 | return img, mask.float() 81 | 82 | return img 83 | 84 | 85 | def indexing(img, coords, mask=False): 86 | """Wrapper for grid_sample, uses pixel coordinates""" 87 | """ 88 | TODO: directly indexing features instead of sampling 89 | """ 90 | H, W = img.shape[-2:] 91 | xgrid, ygrid = coords.split([1, 1], dim=-1) 92 | xgrid = 2 * xgrid / (W - 1) - 1 93 | ygrid = 2 * ygrid / (H - 1) - 1 94 | 95 | grid = torch.cat([xgrid, ygrid], dim=-1) 96 | img = F.grid_sample(img, grid, align_corners=True, mode="nearest") 97 | 98 | if mask: 99 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 100 | return img, mask.float() 101 | 102 | return img 103 | 104 | 105 | def coords_grid(batch, ht, wd): 106 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 107 | coords = torch.stack(coords[::-1], dim=0).float() 108 | return coords[None].repeat(batch, 1, 1, 1) 109 | 110 | 111 | def upflow8(flow, mode="bilinear"): 112 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 113 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 114 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/flowformer/run_train.sh: -------------------------------------------------------------------------------- 1 | mkdir -p checkpoints 2 | python -u train_FlowFormer.py --name chairs --stage chairs --validation chairs 3 | python -u train_FlowFormer.py --name things --stage things --validation sintel 4 | python -u train_FlowFormer.py --name sintel --stage sintel --validation sintel 5 | python -u train_FlowFormer.py --name kitti --stage kitti --validation kitti -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/gmflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .gmflow import GMFlow 2 | import argparse 3 | import torch 4 | 5 | 6 | def initialize_GMFlow(model_path="pretrained_ckpt/gmflow_sintel_with_refinement.pkl", device="cuda"): 7 | """Initializes the RAFT model.""" 8 | 9 | model = GMFlow() 10 | ckpt = torch.load(model_path, map_location="cpu") 11 | 12 | # def convert(param): 13 | # return {k.replace("module.", ""): v for k, v in param.items() if "module" in k} 14 | # 15 | # ckpt = convert(ckpt) 16 | model.load_state_dict(ckpt, strict=True) 17 | print("load gmflow from " + model_path) 18 | 19 | return model 20 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/gmflow/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .trident_conv import MultiScaleTridentConv 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, 8 | ): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 12 | dilation=dilation, padding=dilation, stride=stride, bias=False) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | self.norm1 = norm_layer(planes) 18 | self.norm2 = norm_layer(planes) 19 | if not stride == 1 or in_planes != planes: 20 | self.norm3 = norm_layer(planes) 21 | 22 | if stride == 1 and in_planes == planes: 23 | self.downsample = None 24 | else: 25 | self.downsample = nn.Sequential( 26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 27 | 28 | def forward(self, x): 29 | y = x 30 | y = self.relu(self.norm1(self.conv1(y))) 31 | y = self.relu(self.norm2(self.conv2(y))) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | 36 | return self.relu(x + y) 37 | 38 | 39 | class CNNEncoder(nn.Module): 40 | def __init__(self, output_dim=128, 41 | norm_layer=nn.InstanceNorm2d, 42 | num_output_scales=1, 43 | **kwargs, 44 | ): 45 | super(CNNEncoder, self).__init__() 46 | self.num_branch = num_output_scales 47 | 48 | feature_dims = [64, 96, 128] 49 | 50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 51 | self.norm1 = norm_layer(feature_dims[0]) 52 | self.relu1 = nn.ReLU(inplace=True) 53 | 54 | self.in_planes = feature_dims[0] 55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 57 | 58 | # highest resolution 1/4 or 1/8 59 | stride = 2 if num_output_scales == 1 else 1 60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride, 61 | norm_layer=norm_layer, 62 | ) # 1/4 or 1/8 63 | 64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) 65 | 66 | if self.num_branch > 1: 67 | if self.num_branch == 4: 68 | strides = (1, 2, 4, 8) 69 | elif self.num_branch == 3: 70 | strides = (1, 2, 4) 71 | elif self.num_branch == 2: 72 | strides = (1, 2) 73 | else: 74 | raise ValueError 75 | 76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, 77 | kernel_size=3, 78 | strides=strides, 79 | paddings=1, 80 | num_branch=self.num_branch, 81 | ) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 87 | if m.weight is not None: 88 | nn.init.constant_(m.weight, 1) 89 | if m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): 93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) 94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) 95 | 96 | layers = (layer1, layer2) 97 | 98 | self.in_planes = dim 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.norm1(x) 104 | x = self.relu1(x) 105 | 106 | x = self.layer1(x) # 1/2 107 | x = self.layer2(x) # 1/4 108 | x = self.layer3(x) # 1/8 or 1/4 109 | 110 | x = self.conv2(x) 111 | 112 | if self.num_branch > 1: 113 | out = self.trident_conv([x] * self.num_branch) # high to low res 114 | else: 115 | out = [x] 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/gmflow/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def coords_grid(b, h, w, homogeneous=False, device=None): 6 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 7 | 8 | stacks = [x, y] 9 | 10 | if homogeneous: 11 | ones = torch.ones_like(x) # [H, W] 12 | stacks.append(ones) 13 | 14 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] 15 | 16 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 17 | 18 | if device is not None: 19 | grid = grid.to(device) 20 | 21 | return grid 22 | 23 | 24 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): 25 | assert device is not None 26 | 27 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), 28 | torch.linspace(h_min, h_max, len_h, device=device)], 29 | ) 30 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] 31 | 32 | return grid 33 | 34 | 35 | def normalize_coords(coords, h, w): 36 | # coords: [B, H, W, 2] 37 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) 38 | return (coords - c) / c # [-1, 1] 39 | 40 | 41 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): 42 | # img: [B, C, H, W] 43 | # sample_coords: [B, 2, H, W] in image scale 44 | if sample_coords.size(1) != 2: # [B, H, W, 2] 45 | sample_coords = sample_coords.permute(0, 3, 1, 2) 46 | 47 | b, _, h, w = sample_coords.shape 48 | 49 | # Normalize to [-1, 1] 50 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 51 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 52 | 53 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 54 | 55 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 56 | 57 | if return_mask: 58 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] 59 | 60 | return img, mask 61 | 62 | return img 63 | 64 | 65 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): 66 | b, c, h, w = feature.size() 67 | assert flow.size(1) == 2 68 | 69 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 70 | 71 | return bilinear_sample(feature, grid, padding_mode=padding_mode, 72 | return_mask=mask) 73 | 74 | 75 | def forward_backward_consistency_check(fwd_flow, bwd_flow, 76 | alpha=0.01, 77 | beta=0.5 78 | ): 79 | # fwd_flow, bwd_flow: [B, 2, H, W] 80 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 81 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 82 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 83 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 84 | 85 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 86 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 87 | 88 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 89 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 90 | 91 | threshold = alpha * flow_mag + beta 92 | 93 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W] 94 | bwd_occ = (diff_bwd > threshold).float() 95 | 96 | return fwd_occ, bwd_occ 97 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/gmflow/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .geometry import coords_grid, generate_window_grid, normalize_coords 5 | 6 | 7 | def global_correlation_softmax(feature0, feature1, 8 | pred_bidir_flow=False, 9 | ): 10 | # global correlation 11 | b, c, h, w = feature0.shape 12 | feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] 13 | feature1 = feature1.view(b, c, -1) # [B, C, H*W] 14 | 15 | correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] 16 | 17 | # flow from softmax 18 | init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] 19 | grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 20 | 21 | correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] 22 | 23 | if pred_bidir_flow: 24 | correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] 25 | init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] 26 | grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] 27 | b = b * 2 28 | 29 | prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] 30 | 31 | correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 32 | 33 | # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow 34 | flow = correspondence - init_grid 35 | 36 | return flow, prob 37 | 38 | 39 | def local_correlation_softmax(feature0, feature1, local_radius, 40 | padding_mode='zeros', 41 | ): 42 | b, c, h, w = feature0.size() 43 | coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] 44 | coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 45 | 46 | local_h = 2 * local_radius + 1 47 | local_w = 2 * local_radius + 1 48 | 49 | window_grid = generate_window_grid(-local_radius, local_radius, 50 | -local_radius, local_radius, 51 | local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] 52 | window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] 53 | sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] 54 | 55 | sample_coords_softmax = sample_coords 56 | 57 | # exclude coords that are out of image space 58 | valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] 59 | valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] 60 | 61 | valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax 62 | 63 | # normalize coordinates to [-1, 1] 64 | sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] 65 | window_feature = F.grid_sample(feature1, sample_coords_norm, 66 | padding_mode=padding_mode, align_corners=True 67 | ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] 68 | feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] 69 | 70 | corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] 71 | 72 | # mask invalid locations 73 | corr[~valid] = -1e9 74 | 75 | prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] 76 | 77 | correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( 78 | b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 79 | 80 | flow = correspondence - coords_init 81 | match_prob = prob 82 | 83 | return flow, match_prob 84 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/gmflow/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/gmflow/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter( 53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 54 | ) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 57 | else: 58 | self.bias = None 59 | 60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0) 63 | 64 | def forward(self, inputs): 65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 66 | assert len(inputs) == num_branch 67 | 68 | if self.training or self.test_branch_idx == -1: 69 | outputs = [ 70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 71 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 72 | ] 73 | else: 74 | outputs = [ 75 | F.conv2d( 76 | inputs[0], 77 | self.weight, 78 | self.bias, 79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 81 | self.dilation, 82 | self.groups, 83 | ) 84 | ] 85 | 86 | if self.norm is not None: 87 | outputs = [self.norm(x) for x in outputs] 88 | if self.activation is not None: 89 | outputs = [self.activation(x) for x in outputs] 90 | return outputs 91 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/gmflow/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .position import PositionEmbeddingSine 3 | 4 | 5 | def split_feature(feature, 6 | num_splits=2, 7 | channel_last=False, 8 | ): 9 | if channel_last: # [B, H, W, C] 10 | b, h, w, c = feature.size() 11 | assert h % num_splits == 0 and w % num_splits == 0 12 | 13 | b_new = b * num_splits * num_splits 14 | h_new = h // num_splits 15 | w_new = w // num_splits 16 | 17 | feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c 18 | ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] 19 | else: # [B, C, H, W] 20 | b, c, h, w = feature.size() 21 | assert h % num_splits == 0 and w % num_splits == 0 22 | 23 | b_new = b * num_splits * num_splits 24 | h_new = h // num_splits 25 | w_new = w // num_splits 26 | 27 | feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits 28 | ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] 29 | 30 | return feature 31 | 32 | 33 | def merge_splits(splits, 34 | num_splits=2, 35 | channel_last=False, 36 | ): 37 | if channel_last: # [B*K*K, H/K, W/K, C] 38 | b, h, w, c = splits.size() 39 | new_b = b // num_splits // num_splits 40 | 41 | splits = splits.view(new_b, num_splits, num_splits, h, w, c) 42 | merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( 43 | new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] 44 | else: # [B*K*K, C, H/K, W/K] 45 | b, c, h, w = splits.size() 46 | new_b = b // num_splits // num_splits 47 | 48 | splits = splits.view(new_b, num_splits, num_splits, c, h, w) 49 | merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( 50 | new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] 51 | 52 | return merge 53 | 54 | 55 | def normalize_img(img0, img1): 56 | # loaded images are in [0, 255] 57 | # normalize by ImageNet mean and std 58 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) 59 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) 60 | img0 = (img0 / 255. - mean) / std 61 | img1 = (img1 / 255. - mean) / std 62 | 63 | return img0, img1 64 | 65 | 66 | def feature_add_position(feature0, feature1, attn_splits, feature_channels): 67 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) 68 | 69 | if attn_splits > 1: # add position in splited window 70 | feature0_splits = split_feature(feature0, num_splits=attn_splits) 71 | feature1_splits = split_feature(feature1, num_splits=attn_splits) 72 | 73 | position = pos_enc(feature0_splits) 74 | 75 | feature0_splits = feature0_splits + position 76 | feature1_splits = feature1_splits + position 77 | 78 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits) 79 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits) 80 | else: 81 | position = pos_enc(feature0) 82 | 83 | feature0 = feature0 + position 84 | feature1 = feature1 + position 85 | 86 | return feature0, feature1 87 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/modules/__init__.py -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/modules/coord_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class CoordSampler3D(nn.Module): 16 | def __init__(self, coord_range, t_coord_only=False): 17 | super().__init__() 18 | self.coord_range = coord_range 19 | self.t_coord_only = t_coord_only 20 | 21 | def shape2coordinate( 22 | self, 23 | batch_size, 24 | spatial_shape, 25 | t_ids, 26 | coord_range=(-1.0, 1.0), 27 | upsample_ratio=1, 28 | device=None, 29 | ): 30 | coords = [] 31 | assert isinstance(t_ids, list) 32 | _coords = torch.tensor(t_ids, device=device) / 1.0 33 | coords.append(_coords.to(torch.float32)) 34 | for num_s in spatial_shape: 35 | num_s = int(num_s * upsample_ratio) 36 | _coords = (0.5 + torch.arange(num_s, device=device)) / num_s 37 | _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords 38 | coords.append(_coords) 39 | coords = torch.meshgrid(*coords, indexing="ij") 40 | coords = torch.stack(coords, dim=-1) 41 | ones_like_shape = (1,) * coords.ndim 42 | coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape) 43 | return coords # (B,T,H,W,3) 44 | 45 | def batchshape2coordinate( 46 | self, 47 | batch_size, 48 | spatial_shape, 49 | t_ids, 50 | coord_range=(-1.0, 1.0), 51 | upsample_ratio=1, 52 | device=None, 53 | ): 54 | coords = [] 55 | _coords = torch.tensor(1, device=device) 56 | coords.append(_coords.to(torch.float32)) 57 | for num_s in spatial_shape: 58 | num_s = int(num_s * upsample_ratio) 59 | _coords = (0.5 + torch.arange(num_s, device=device)) / num_s 60 | _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords 61 | coords.append(_coords) 62 | coords = torch.meshgrid(*coords, indexing="ij") 63 | coords = torch.stack(coords, dim=-1) 64 | ones_like_shape = (1,) * coords.ndim 65 | # Now coords b,1,h,w,3, coords[...,0]=1. 66 | coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape) 67 | # assign per-sample timestep within the batch 68 | coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1) 69 | return coords 70 | 71 | def forward( 72 | self, 73 | batch_size, 74 | s_shape, 75 | t_ids, 76 | coord_range=None, 77 | upsample_ratio=1.0, 78 | device=None, 79 | ): 80 | coord_range = self.coord_range if coord_range is None else coord_range 81 | if isinstance(t_ids, list): 82 | coords = self.shape2coordinate( 83 | batch_size, s_shape, t_ids, coord_range, upsample_ratio, device 84 | ) 85 | elif isinstance(t_ids, torch.Tensor): 86 | coords = self.batchshape2coordinate( 87 | batch_size, s_shape, t_ids, coord_range, upsample_ratio, device 88 | ) 89 | if self.t_coord_only: 90 | coords = coords[..., :1] 91 | return coords 92 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/modules/fi_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # raft: https://github.com/princeton-vl/RAFT 9 | # ema-vfi: https://github.com/MCG-NJU/EMA-VFI 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | backwarp_tenGrid = {} 17 | 18 | 19 | def warp(tenInput, tenFlow): 20 | k = (str(tenFlow.device), str(tenFlow.size())) 21 | if k not in backwarp_tenGrid: 22 | tenHorizontal = ( 23 | torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) 24 | .view(1, 1, 1, tenFlow.shape[3]) 25 | .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 26 | ) 27 | tenVertical = ( 28 | torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) 29 | .view(1, 1, tenFlow.shape[2], 1) 30 | .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 31 | ) 32 | backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) 33 | 34 | tenFlow = torch.cat( 35 | [ 36 | tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 37 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), 38 | ], 39 | 1, 40 | ) 41 | 42 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 43 | return torch.nn.functional.grid_sample( 44 | input=tenInput, 45 | grid=g, 46 | mode="bilinear", 47 | padding_mode="border", 48 | align_corners=True, 49 | ) 50 | 51 | 52 | def normalize_flow(flows): 53 | flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape( 54 | -1, 1, 1, 1, 1 55 | ) 56 | flows = flows / flow_scaler # [-1,1] 57 | # # Adapt to [0,1] 58 | flows = (flows + 1.0) / 2.0 59 | return flows, flow_scaler 60 | 61 | 62 | def unnormalize_flow(flows, flow_scaler): 63 | return (flows * 2.0 - 1.0) * flow_scaler 64 | 65 | 66 | def resize(x, scale_factor): 67 | return F.interpolate( 68 | x, scale_factor=scale_factor, mode="bilinear", align_corners=False 69 | ) 70 | 71 | 72 | def coords_grid(batch, ht, wd): 73 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 74 | coords = torch.stack(coords[::-1], dim=0).float() 75 | return coords[None].repeat(batch, 1, 1, 1) 76 | 77 | 78 | def build_coord(img): 79 | N, C, H, W = img.shape 80 | coords = coords_grid(N, H // 8, W // 8) 81 | return coords 82 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/modules/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | 8 | from torch import nn 9 | import torch 10 | 11 | 12 | # define siren layer & Siren model 13 | class Sine(nn.Module): 14 | """Sine activation with scaling. 15 | 16 | Args: 17 | w0 (float): Omega_0 parameter from SIREN paper. 18 | """ 19 | 20 | def __init__(self, w0=1.0): 21 | super().__init__() 22 | self.w0 = w0 23 | 24 | def forward(self, x): 25 | return torch.sin(self.w0 * x) 26 | 27 | 28 | # Damping activation from http://arxiv.org/abs/2306.15242 29 | class Damping(nn.Module): 30 | """Sine activation with sublinear factor 31 | 32 | Args: 33 | w0 (float): Omega_0 parameter from SIREN paper. 34 | """ 35 | 36 | def __init__(self, w0=1.0): 37 | super().__init__() 38 | self.w0 = w0 39 | 40 | def forward(self, x): 41 | x = torch.clamp(x, min=1e-30) 42 | return torch.sin(self.w0 * x) * torch.sqrt(x.abs()) 43 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/modules/module_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | from typing import List, Optional 12 | from dataclasses import dataclass 13 | from omegaconf import MISSING 14 | 15 | 16 | @dataclass 17 | class HypoNetActivationConfig: 18 | type: str = "relu" 19 | siren_w0: Optional[float] = 30.0 20 | 21 | 22 | @dataclass 23 | class HypoNetInitConfig: 24 | weight_init_type: Optional[str] = "kaiming_uniform" 25 | bias_init_type: Optional[str] = "zero" 26 | 27 | 28 | @dataclass 29 | class HypoNetConfig: 30 | type: str = "mlp" 31 | n_layer: int = 5 32 | hidden_dim: List[int] = MISSING 33 | use_bias: bool = True 34 | input_dim: int = 2 35 | output_dim: int = 3 36 | output_bias: float = 0.5 37 | activation: HypoNetActivationConfig = HypoNetActivationConfig() 38 | initialization: HypoNetInitConfig = HypoNetInitConfig() 39 | 40 | normalize_weight: bool = True 41 | linear_interpo: bool = False 42 | 43 | 44 | @dataclass 45 | class CoordSamplerConfig: 46 | data_type: str = "image" 47 | t_coord_only: bool = False 48 | coord_range: List[float] = MISSING 49 | time_range: List[float] = MISSING 50 | train_strategy: Optional[str] = MISSING 51 | val_strategy: Optional[str] = MISSING 52 | patch_size: Optional[int] = MISSING 53 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/modules/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | 15 | from .layers import Sine, Damping 16 | 17 | 18 | def convert_int_to_list(size, len_list=2): 19 | if isinstance(size, int): 20 | return [size] * len_list 21 | else: 22 | assert len(size) == len_list 23 | return size 24 | 25 | 26 | def initialize_params(params, init_type, **kwargs): 27 | fan_in, fan_out = params.shape[0], params.shape[1] 28 | if init_type is None or init_type == "normal": 29 | nn.init.normal_(params) 30 | elif init_type == "kaiming_uniform": 31 | nn.init.kaiming_uniform_(params, a=math.sqrt(5)) 32 | elif init_type == "uniform_fan_in": 33 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 34 | nn.init.uniform_(params, -bound, bound) 35 | elif init_type == "zero": 36 | nn.init.zeros_(params) 37 | elif "siren" == init_type: 38 | assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys() 39 | w0 = kwargs["siren_w0"] 40 | if kwargs["is_first"]: 41 | w_std = 1 / fan_in 42 | else: 43 | w_std = math.sqrt(6.0 / fan_in) / w0 44 | nn.init.uniform_(params, -w_std, w_std) 45 | else: 46 | raise NotImplementedError 47 | 48 | 49 | def create_params_with_init( 50 | shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs 51 | ): 52 | if not include_bias: 53 | params = torch.empty([shape[0], shape[1]]) 54 | initialize_params(params, init_type, **kwargs) 55 | return params 56 | else: 57 | params = torch.empty([shape[0] - 1, shape[1]]) 58 | bias = torch.empty([1, shape[1]]) 59 | 60 | initialize_params(params, init_type, **kwargs) 61 | initialize_params(bias, bias_init_type, **kwargs) 62 | return torch.cat([params, bias], dim=0) 63 | 64 | 65 | def create_activation(config): 66 | if config.type == "relu": 67 | activation = nn.ReLU() 68 | elif config.type == "siren": 69 | activation = Sine(config.siren_w0) 70 | elif config.type == "silu": 71 | activation = nn.SiLU() 72 | elif config.type == "damp": 73 | activation = Damping(config.siren_w0) 74 | else: 75 | raise NotImplementedError 76 | return activation 77 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/raft/__init__.py: -------------------------------------------------------------------------------- 1 | from .raft import RAFT 2 | import argparse 3 | import torch 4 | from .extractor import BasicEncoder 5 | 6 | 7 | def initialize_RAFT(model_path="weights/raft-things.pth", device="cuda"): 8 | """Initializes the RAFT model.""" 9 | args = argparse.ArgumentParser() 10 | args.raft_model = model_path 11 | args.small = False 12 | args.mixed_precision = False 13 | args.alternate_corr = False 14 | model = RAFT(args) 15 | ckpt = torch.load(args.raft_model, map_location="cpu") 16 | 17 | def convert(param): 18 | return {k.replace("module.", ""): v for k, v in param.items() if "module" in k} 19 | 20 | ckpt = convert(ckpt) 21 | model.load_state_dict(ckpt, strict=True) 22 | print("load raft from " + model_path) 23 | 24 | return model 25 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/raft/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/raft/utils/__init__.py -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/raft/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | 21 | def make_colorwheel(): 22 | """ 23 | Generates a color wheel for optical flow visualization as presented in: 24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 26 | 27 | Code follows the original C++ source code of Daniel Scharstein. 28 | Code follows the the Matlab source code of Deqing Sun. 29 | 30 | Returns: 31 | np.ndarray: Color wheel 32 | """ 33 | 34 | RY = 15 35 | YG = 6 36 | GC = 4 37 | CB = 11 38 | BM = 13 39 | MR = 6 40 | 41 | ncols = RY + YG + GC + CB + BM + MR 42 | colorwheel = np.zeros((ncols, 3)) 43 | col = 0 44 | 45 | # RY 46 | colorwheel[0:RY, 0] = 255 47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 48 | col = col + RY 49 | # YG 50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 51 | colorwheel[col : col + YG, 1] = 255 52 | col = col + YG 53 | # GC 54 | colorwheel[col : col + GC, 1] = 255 55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 56 | col = col + GC 57 | # CB 58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 59 | colorwheel[col : col + CB, 2] = 255 60 | col = col + CB 61 | # BM 62 | colorwheel[col : col + BM, 2] = 255 63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 64 | col = col + BM 65 | # MR 66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 67 | colorwheel[col : col + MR, 0] = 255 68 | return colorwheel 69 | 70 | 71 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 72 | """ 73 | Applies the flow color wheel to (possibly clipped) flow components u and v. 74 | 75 | According to the C++ source code of Daniel Scharstein 76 | According to the Matlab source code of Deqing Sun 77 | 78 | Args: 79 | u (np.ndarray): Input horizontal flow of shape [H,W] 80 | v (np.ndarray): Input vertical flow of shape [H,W] 81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 82 | 83 | Returns: 84 | np.ndarray: Flow visualization image of shape [H,W,3] 85 | """ 86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 87 | colorwheel = make_colorwheel() # shape [55x3] 88 | ncols = colorwheel.shape[0] 89 | rad = np.sqrt(np.square(u) + np.square(v)) 90 | a = np.arctan2(-v, -u) / np.pi 91 | fk = (a + 1) / 2 * (ncols - 1) 92 | k0 = np.floor(fk).astype(np.int32) 93 | k1 = k0 + 1 94 | k1[k1 == ncols] = 0 95 | f = fk - k0 96 | for i in range(colorwheel.shape[1]): 97 | tmp = colorwheel[:, i] 98 | col0 = tmp[k0] / 255.0 99 | col1 = tmp[k1] / 255.0 100 | col = (1 - f) * col0 + f * col1 101 | idx = rad <= 1 102 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 103 | col[~idx] = col[~idx] * 0.75 # out of range 104 | # Note the 2-i => BGR instead of RGB 105 | ch_idx = 2 - i if convert_to_bgr else i 106 | flow_image[:, :, ch_idx] = np.floor(255 * col) 107 | return flow_image 108 | 109 | 110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 111 | """ 112 | Expects a two dimensional flow image of shape. 113 | 114 | Args: 115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 118 | 119 | Returns: 120 | np.ndarray: Flow visualization image of shape [H,W,3] 121 | """ 122 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 124 | if clip_flow is not None: 125 | flow_uv = np.clip(flow_uv, 0, clip_flow) 126 | u = flow_uv[:, :, 0] 127 | v = flow_uv[:, :, 1] 128 | rad = np.sqrt(np.square(u) + np.square(v)) 129 | rad_max = np.max(rad) 130 | epsilon = 1e-5 131 | u = u / (rad_max + epsilon) 132 | v = v / (rad_max + epsilon) 133 | return flow_uv_to_colors(u, v, convert_to_bgr) 134 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/raft/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | 14 | def readFlow(fn): 15 | """Read .flo file in Middlebury format""" 16 | # Code adapted from: 17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 18 | 19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 20 | # print 'fn = %s'%(fn) 21 | with open(fn, "rb") as f: 22 | magic = np.fromfile(f, np.float32, count=1) 23 | if 202021.25 != magic: 24 | print("Magic number incorrect. Invalid .flo file") 25 | return None 26 | else: 27 | w = np.fromfile(f, np.int32, count=1) 28 | h = np.fromfile(f, np.int32, count=1) 29 | # print 'Reading %d x %d flo file\n' % (w, h) 30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 31 | # Reshape data into 3D array (columns, rows, bands) 32 | # The reshape here is for visualization, the original code is (w,h,2) 33 | return np.resize(data, (int(h), int(w), 2)) 34 | 35 | 36 | def readPFM(file): 37 | file = open(file, "rb") 38 | 39 | color = None 40 | width = None 41 | height = None 42 | scale = None 43 | endian = None 44 | 45 | header = file.readline().rstrip() 46 | if header == b"PF": 47 | color = True 48 | elif header == b"Pf": 49 | color = False 50 | else: 51 | raise Exception("Not a PFM file.") 52 | 53 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) 54 | if dim_match: 55 | width, height = map(int, dim_match.groups()) 56 | else: 57 | raise Exception("Malformed PFM header.") 58 | 59 | scale = float(file.readline().rstrip()) 60 | if scale < 0: # little-endian 61 | endian = "<" 62 | scale = -scale 63 | else: 64 | endian = ">" # big-endian 65 | 66 | data = np.fromfile(file, endian + "f") 67 | shape = (height, width, 3) if color else (height, width) 68 | 69 | data = np.reshape(data, shape) 70 | data = np.flipud(data) 71 | return data 72 | 73 | 74 | def writeFlow(filename, uv, v=None): 75 | """Write optical flow to file. 76 | 77 | If v is None, uv is assumed to contain both u and v channels, 78 | stacked in depth. 79 | Original code by Deqing Sun, adapted from Daniel Scharstein. 80 | """ 81 | nBands = 2 82 | 83 | if v is None: 84 | assert uv.ndim == 3 85 | assert uv.shape[2] == 2 86 | u = uv[:, :, 0] 87 | v = uv[:, :, 1] 88 | else: 89 | u = uv 90 | 91 | assert u.shape == v.shape 92 | height, width = u.shape 93 | f = open(filename, "wb") 94 | # write the header 95 | f.write(TAG_CHAR) 96 | np.array(width).astype(np.int32).tofile(f) 97 | np.array(height).astype(np.int32).tofile(f) 98 | # arrange into matrix form 99 | tmp = np.zeros((height, width * nBands)) 100 | tmp[:, np.arange(width) * 2] = u 101 | tmp[:, np.arange(width) * 2 + 1] = v 102 | tmp.astype(np.float32).tofile(f) 103 | f.close() 104 | 105 | 106 | def readFlowKITTI(filename): 107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 108 | flow = flow[:, :, ::-1].astype(np.float32) 109 | flow, valid = flow[:, :, :2], flow[:, :, 2] 110 | flow = (flow - 2**15) / 64.0 111 | return flow, valid 112 | 113 | 114 | def readDispKITTI(filename): 115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 116 | valid = disp > 0.0 117 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 118 | return flow, valid 119 | 120 | 121 | def writeFlowKITTI(filename, uv): 122 | uv = 64.0 * uv + 2**15 123 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 125 | cv2.imwrite(filename, uv[..., ::-1]) 126 | 127 | 128 | def read_gen(file_name, pil=False): 129 | ext = splitext(file_name)[-1] 130 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": 131 | return Image.open(file_name) 132 | elif ext == ".bin" or ext == ".raw": 133 | return np.load(file_name) 134 | elif ext == ".flo": 135 | return readFlow(file_name).astype(np.float32) 136 | elif ext == ".pfm": 137 | flow = readPFM(file_name).astype(np.float32) 138 | if len(flow.shape) == 2: 139 | return flow 140 | else: 141 | return flow[:, :, :-1] 142 | return [] 143 | -------------------------------------------------------------------------------- /models/gimm/src/models/generalizable_INR/raft/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """Pads images such that dimensions are divisible by 8""" 9 | 10 | def __init__(self, dims, mode="sintel"): 11 | self.ht, self.wd = dims[-2:] 12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 14 | if mode == "sintel": 15 | self._pad = [ 16 | pad_wd // 2, 17 | pad_wd - pad_wd // 2, 18 | pad_ht // 2, 19 | pad_ht - pad_ht // 2, 20 | ] 21 | else: 22 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 23 | 24 | def pad(self, *inputs): 25 | return [F.pad(x, self._pad, mode="replicate") for x in inputs] 26 | 27 | def unpad(self, x): 28 | ht, wd = x.shape[-2:] 29 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 30 | return x[..., c[0] : c[1], c[2] : c[3]] 31 | 32 | 33 | def forward_interpolate(flow): 34 | flow = flow.detach().cpu().numpy() 35 | dx, dy = flow[0], flow[1] 36 | 37 | ht, wd = dx.shape 38 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 39 | 40 | x1 = x0 + dx 41 | y1 = y0 + dy 42 | 43 | x1 = x1.reshape(-1) 44 | y1 = y1.reshape(-1) 45 | dx = dx.reshape(-1) 46 | dy = dy.reshape(-1) 47 | 48 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 49 | x1 = x1[valid] 50 | y1 = y1[valid] 51 | dx = dx[valid] 52 | dy = dy[valid] 53 | 54 | flow_x = interpolate.griddata( 55 | (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 56 | ) 57 | 58 | flow_y = interpolate.griddata( 59 | (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 60 | ) 61 | 62 | flow = np.stack([flow_x, flow_y], axis=0) 63 | return torch.from_numpy(flow).float() 64 | 65 | 66 | def bilinear_sampler(img, coords, mode="bilinear", mask=False): 67 | """Wrapper for grid_sample, uses pixel coordinates""" 68 | H, W = img.shape[-2:] 69 | xgrid, ygrid = coords.split([1, 1], dim=-1) 70 | xgrid = 2 * xgrid / (W - 1) - 1 71 | ygrid = 2 * ygrid / (H - 1) - 1 72 | 73 | grid = torch.cat([xgrid, ygrid], dim=-1) 74 | img = F.grid_sample(img, grid, align_corners=True) 75 | 76 | if mask: 77 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 78 | return img, mask.float() 79 | 80 | return img 81 | 82 | 83 | def coords_grid(batch, ht, wd, device): 84 | coords = torch.meshgrid( 85 | torch.arange(ht, device=device), torch.arange(wd, device=device) 86 | ) 87 | coords = torch.stack(coords[::-1], dim=0).float() 88 | return coords[None].repeat(batch, 1, 1, 1) 89 | 90 | 91 | def upflow8(flow, mode="bilinear"): 92 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 93 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 94 | -------------------------------------------------------------------------------- /models/gimm/src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/utils/__init__.py -------------------------------------------------------------------------------- /models/gimm/src/utils/accumulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | import models.gimm.src.utils.dist as dist_utils 13 | 14 | 15 | class AccmStageINR: 16 | def __init__( 17 | self, 18 | scalar_metric_names, 19 | vector_metric_names=(), 20 | vector_metric_lengths=(), 21 | device="cpu", 22 | ): 23 | self.device = device 24 | 25 | assert len(vector_metric_lengths) == len(vector_metric_names) 26 | 27 | self.scalar_metric_names = scalar_metric_names 28 | self.vector_metric_names = vector_metric_names 29 | self.metrics_sum = {} 30 | 31 | for n in self.scalar_metric_names: 32 | self.metrics_sum[n] = torch.zeros(1, device=self.device) 33 | 34 | for n, length in zip(self.vector_metric_names, vector_metric_lengths): 35 | self.metrics_sum[n] = torch.zeros(length, device=self.device) 36 | 37 | self.counter = 0 38 | self.Summary.scalar_metric_names = self.scalar_metric_names 39 | self.Summary.vector_metric_names = self.vector_metric_names 40 | 41 | @torch.no_grad() 42 | def update(self, metrics_to_add, count=None, sync=False, distenv=None): 43 | # we assume that value is simultaneously None (or not None) for every process 44 | metrics_to_add = { 45 | name: value for (name, value) in metrics_to_add.items() if value is not None 46 | } 47 | 48 | if sync: 49 | for name, value in metrics_to_add.items(): 50 | gathered_value = dist_utils.all_gather_cat(distenv, value.unsqueeze(0)) 51 | gathered_value = gathered_value.sum(dim=0).detach() 52 | metrics_to_add[name] = gathered_value 53 | 54 | for name, value in metrics_to_add.items(): 55 | if name not in self.metrics_sum: 56 | raise KeyError(f"unexpected metric name: {name}") 57 | self.metrics_sum[name] += value 58 | 59 | self.counter += count if not sync else count * distenv.world_size 60 | 61 | @torch.no_grad() 62 | def get_summary(self, n_samples=None): 63 | n_samples = n_samples if n_samples else self.counter 64 | return self.Summary({k: v / n_samples for k, v in self.metrics_sum.items()}) 65 | 66 | class Summary: 67 | scalar_metric_names = () 68 | vector_metric_names = () 69 | 70 | def __init__(self, metrics): 71 | for key, value in metrics.items(): 72 | self[key] = value 73 | 74 | def print_line(self): 75 | reprs = [] 76 | for k in self.scalar_metric_names: 77 | v = self[k] 78 | repr = f"{k}: {v.item():.4f}" 79 | reprs.append(repr) 80 | 81 | for k in self.vector_metric_names: 82 | v = self[k] 83 | array_repr = ",".join([f"{v_i.item():.4f}" for v_i in v]) 84 | repr = f"{k}: [{array_repr}]" 85 | reprs.append(repr) 86 | 87 | return ", ".join(reprs) 88 | 89 | def tb_like(self): 90 | tb_summary = {} 91 | for k in self.scalar_metric_names: 92 | v = self[k] 93 | tb_summary[f"loss/{k}"] = v 94 | 95 | for k in self.vector_metric_names: 96 | v = self[k] 97 | for i, v_i in enumerate(v): 98 | tb_summary[f"loss/{k}_{i}"] = v_i 99 | 100 | return tb_summary 101 | 102 | def __getitem__(self, item): 103 | return getattr(self, item) 104 | 105 | def __setitem__(self, key, value): 106 | setattr(self, key, value) 107 | -------------------------------------------------------------------------------- /models/gimm/src/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | from omegaconf import OmegaConf 12 | from easydict import EasyDict as edict 13 | import yaml 14 | 15 | from models.gimm.src.models.generalizable_INR.configs import GIMMConfig, GIMMVFIConfig 16 | import os.path as osp 17 | 18 | 19 | def easydict_to_dict(obj): 20 | if not isinstance(obj, edict): 21 | return obj 22 | else: 23 | return {k: easydict_to_dict(v) for k, v in obj.items()} 24 | 25 | 26 | def load_config(config_path): 27 | with open(config_path) as f: 28 | config = yaml.load(f, Loader=yaml.FullLoader) 29 | config = easydict_to_dict(config) 30 | config = OmegaConf.create(config) 31 | return config 32 | 33 | 34 | def augment_arch_defaults(arch_config): 35 | if arch_config.type == "gimm": 36 | arch_defaults = GIMMConfig.create(arch_config) 37 | elif arch_config.type == "gimmvfi": 38 | arch_defaults = GIMMVFIConfig.create(arch_config) 39 | elif arch_config.type == "gimmvfi_f" or arch_config.type == "gimmvfi_r": 40 | arch_defaults = GIMMVFIConfig.create(arch_config) 41 | else: 42 | raise ValueError(f"{arch_config.type} is not implemented for default arguments") 43 | 44 | return OmegaConf.merge(arch_defaults, arch_config) 45 | 46 | 47 | def augment_optimizer_defaults(optim_config): 48 | defaults = OmegaConf.create( 49 | { 50 | "type": "adamW", 51 | "max_gn": None, 52 | "warmup": { 53 | "mode": "linear", 54 | "start_from_zero": (True if optim_config.warmup.epoch > 0 else False), 55 | }, 56 | } 57 | ) 58 | return OmegaConf.merge(defaults, optim_config) 59 | 60 | 61 | def augment_defaults(config): 62 | defaults = OmegaConf.create( 63 | { 64 | "arch": augment_arch_defaults(config.arch), 65 | "dataset": { 66 | "transforms": {"type": None}, 67 | }, 68 | "optimizer": augment_optimizer_defaults(config.optimizer), 69 | "experiment": { 70 | "test_freq": 10, 71 | "amp": False, 72 | }, 73 | } 74 | ) 75 | 76 | if "inr" in config.arch.type or "gimm" in config.arch.type: 77 | subsample_defaults = OmegaConf.create({"type": None, "ratio": 1.0}) 78 | loss_defaults = OmegaConf.create( 79 | { 80 | "loss": { 81 | "type": "mse", 82 | "subsample": subsample_defaults, 83 | "coord_noise": None, 84 | } 85 | } 86 | ) 87 | defaults = OmegaConf.merge(defaults, loss_defaults) 88 | config = OmegaConf.merge(defaults, config) 89 | return config 90 | 91 | 92 | def augment_dist_defaults(config, distenv): 93 | config = config.copy() 94 | local_batch_size = config.experiment.batch_size 95 | world_batch_size = distenv.world_size * local_batch_size 96 | total_batch_size = config.experiment.get("total_batch_size", world_batch_size) 97 | 98 | if total_batch_size % world_batch_size != 0: 99 | raise ValueError("total batch size must be divisible by world batch size") 100 | else: 101 | grad_accm_steps = total_batch_size // world_batch_size 102 | 103 | config.optimizer.grad_accm_steps = grad_accm_steps 104 | config.experiment.total_batch_size = total_batch_size 105 | return config 106 | 107 | 108 | def config_setup(args, distenv, config_path, extra_args=()): 109 | if not osp.isfile(config_path): 110 | config_path = args.model_config 111 | if args.eval: 112 | config = load_config(config_path) 113 | config = augment_defaults(config) 114 | if hasattr(args, "test_batch_size"): 115 | config.experiment.batch_size = args.test_batch_size 116 | if not hasattr(config, "seed"): 117 | config.seed = args.seed 118 | 119 | elif args.resume: 120 | config = load_config(config_path) 121 | if distenv.world_size != config.runtime.distenv.world_size: 122 | raise ValueError("world_size not identical to the resuming config") 123 | config.runtime = {"args": vars(args), "distenv": distenv} 124 | 125 | else: # training 126 | config_path = args.model_config 127 | config = load_config(config_path) 128 | 129 | extra_config = OmegaConf.from_dotlist(extra_args) 130 | config = OmegaConf.merge(config, extra_config) 131 | 132 | config = augment_defaults(config) 133 | config = augment_dist_defaults(config, distenv) 134 | 135 | config.seed = args.seed 136 | config.runtime = { 137 | "args": vars(args), 138 | "extra_config": extra_config, 139 | "distenv": distenv, 140 | } 141 | 142 | return config 143 | -------------------------------------------------------------------------------- /models/gimm/src/utils/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | from dataclasses import dataclass 12 | import datetime 13 | import os 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torch.nn.parallel import DistributedDataParallel 18 | 19 | 20 | @dataclass 21 | class DistEnv: 22 | world_size: int 23 | world_rank: int 24 | local_rank: int 25 | num_gpus: int 26 | master: bool 27 | device_name: str 28 | 29 | 30 | def initialize(args, logger=None): 31 | args.rank = int(os.environ.get("RANK", 0)) 32 | args.world_size = int(os.environ.get("WORLD_SIZE", 1)) 33 | args.local_rank = int(os.environ.get("LOCAL_RANK", 0)) 34 | 35 | if args.world_size > 1: 36 | os.environ["RANK"] = str(args.rank) 37 | os.environ["WORLD_SIZE"] = str(args.world_size) 38 | os.environ["LOCAL_RANK"] = str(args.local_rank) 39 | 40 | print(f"[dist] Distributed: wait dist process group:{args.local_rank}") 41 | dist.init_process_group( 42 | backend=args.dist_backend, 43 | init_method="env://", 44 | world_size=args.world_size, 45 | timeout=datetime.timedelta(0, args.timeout), 46 | ) 47 | assert args.world_size == dist.get_world_size() 48 | print( 49 | f"""[dist] Distributed: success device:{args.local_rank}, """, 50 | f"""{dist.get_rank()}/{dist.get_world_size()}""", 51 | ) 52 | distenv = DistEnv( 53 | world_size=dist.get_world_size(), 54 | world_rank=dist.get_rank(), 55 | local_rank=args.local_rank, 56 | num_gpus=1, 57 | master=(dist.get_rank() == 0), 58 | device_name=torch.cuda.get_device_name(), 59 | ) 60 | else: 61 | print("[dist] Single processed") 62 | distenv = DistEnv( 63 | 1, 0, 0, torch.cuda.device_count(), True, torch.cuda.get_device_name() 64 | ) 65 | 66 | print(f"[dist] {distenv}") 67 | 68 | if logger is not None: 69 | logger.info(distenv) 70 | 71 | return distenv 72 | 73 | 74 | def dataparallel_and_sync( 75 | distenv, model, find_unused_parameters=False, static_graph=False 76 | ): 77 | if dist.is_initialized(): 78 | model = DistributedDataParallel( 79 | model, 80 | device_ids=[distenv.local_rank], 81 | output_device=distenv.local_rank, 82 | find_unused_parameters=find_unused_parameters, 83 | # Available only with PyTorch 1.11 or above. 84 | # When set to ``True``, DDP knows the trained graph is static. 85 | # Especially, this enables activation checkpointing multiple times 86 | # which was not supported in the previous versions. 87 | # See the docstring of DistributedDataParallel for more details. 88 | static_graph=static_graph, 89 | ) 90 | for _, param in model.state_dict().items(): 91 | dist.broadcast(param, 0) 92 | 93 | dist.barrier() 94 | else: 95 | model = torch.nn.DataParallel(model) 96 | torch.cuda.synchronize() 97 | 98 | return model 99 | 100 | 101 | def param_sync(param): 102 | dist.broadcast(param, 0) 103 | dist.barrier() 104 | torch.cuda.synchronize() 105 | 106 | 107 | @torch.no_grad() 108 | def all_gather_cat(distenv, tensor, dim=0): 109 | if distenv.world_size == 1: 110 | return tensor 111 | 112 | g_tensor = [torch.ones_like(tensor) for _ in range(distenv.world_size)] 113 | dist.all_gather(g_tensor, tensor) 114 | g_tensor = torch.cat(g_tensor, dim=dim) 115 | 116 | return g_tensor 117 | -------------------------------------------------------------------------------- /models/gimm/src/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | 21 | def make_colorwheel(): 22 | """ 23 | Generates a color wheel for optical flow visualization as presented in: 24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 26 | 27 | Code follows the original C++ source code of Daniel Scharstein. 28 | Code follows the the Matlab source code of Deqing Sun. 29 | 30 | Returns: 31 | np.ndarray: Color wheel 32 | """ 33 | 34 | RY = 15 35 | YG = 6 36 | GC = 4 37 | CB = 11 38 | BM = 13 39 | MR = 6 40 | 41 | ncols = RY + YG + GC + CB + BM + MR 42 | colorwheel = np.zeros((ncols, 3)) 43 | col = 0 44 | 45 | # RY 46 | colorwheel[0:RY, 0] = 255 47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 48 | col = col + RY 49 | # YG 50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 51 | colorwheel[col : col + YG, 1] = 255 52 | col = col + YG 53 | # GC 54 | colorwheel[col : col + GC, 1] = 255 55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 56 | col = col + GC 57 | # CB 58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 59 | colorwheel[col : col + CB, 2] = 255 60 | col = col + CB 61 | # BM 62 | colorwheel[col : col + BM, 2] = 255 63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 64 | col = col + BM 65 | # MR 66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 67 | colorwheel[col : col + MR, 0] = 255 68 | return colorwheel 69 | 70 | 71 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 72 | """ 73 | Applies the flow color wheel to (possibly clipped) flow components u and v. 74 | 75 | According to the C++ source code of Daniel Scharstein 76 | According to the Matlab source code of Deqing Sun 77 | 78 | Args: 79 | u (np.ndarray): Input horizontal flow of shape [H,W] 80 | v (np.ndarray): Input vertical flow of shape [H,W] 81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 82 | 83 | Returns: 84 | np.ndarray: Flow visualization image of shape [H,W,3] 85 | """ 86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 87 | colorwheel = make_colorwheel() # shape [55x3] 88 | ncols = colorwheel.shape[0] 89 | rad = np.sqrt(np.square(u) + np.square(v)) 90 | a = np.arctan2(-v, -u) / np.pi 91 | fk = (a + 1) / 2 * (ncols - 1) 92 | k0 = np.floor(fk).astype(np.int32) 93 | k1 = k0 + 1 94 | k1[k1 == ncols] = 0 95 | f = fk - k0 96 | for i in range(colorwheel.shape[1]): 97 | tmp = colorwheel[:, i] 98 | col0 = tmp[k0] / 255.0 99 | col1 = tmp[k1] / 255.0 100 | col = (1 - f) * col0 + f * col1 101 | idx = rad <= 1 102 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 103 | col[~idx] = col[~idx] * 0.75 # out of range 104 | # Note the 2-i => BGR instead of RGB 105 | ch_idx = 2 - i if convert_to_bgr else i 106 | flow_image[:, :, ch_idx] = np.floor(255 * col) 107 | return flow_image 108 | 109 | 110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, max_flow=None): 111 | """ 112 | Expects a two dimensional flow image of shape. 113 | 114 | Args: 115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 118 | 119 | Returns: 120 | np.ndarray: Flow visualization image of shape [H,W,3] 121 | """ 122 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 124 | if clip_flow is not None: 125 | flow_uv = np.clip(flow_uv, 0, clip_flow) 126 | u = flow_uv[:, :, 0] 127 | v = flow_uv[:, :, 1] 128 | if max_flow is None: 129 | rad = np.sqrt(np.square(u) + np.square(v)) 130 | rad_max = np.max(rad) 131 | else: 132 | rad_max = max_flow 133 | epsilon = 1e-5 134 | u = u / (rad_max + epsilon) 135 | v = v / (rad_max + epsilon) 136 | return flow_uv_to_colors(u, v, convert_to_bgr) 137 | -------------------------------------------------------------------------------- /models/gimm/src/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # FlowFormer: https://github.com/drinkingcoder/FlowFormer-Official 9 | # -------------------------------------------------------- 10 | 11 | import numpy as np 12 | from PIL import Image 13 | from os.path import * 14 | import re 15 | 16 | import cv2 17 | 18 | cv2.setNumThreads(0) 19 | cv2.ocl.setUseOpenCL(False) 20 | 21 | TAG_CHAR = np.array([202021.25], np.float32) 22 | 23 | 24 | def readFlow(fn): 25 | """Read .flo file in Middlebury format""" 26 | # Code adapted from: 27 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 28 | 29 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 30 | # print 'fn = %s'%(fn) 31 | with open(fn, "rb") as f: 32 | magic = np.fromfile(f, np.float32, count=1) 33 | if 202021.25 != magic: 34 | print("Magic number incorrect. Invalid .flo file") 35 | return None 36 | else: 37 | w = np.fromfile(f, np.int32, count=1) 38 | h = np.fromfile(f, np.int32, count=1) 39 | # print 'Reading %d x %d flo file\n' % (w, h) 40 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 41 | # Reshape data into 3D array (columns, rows, bands) 42 | # The reshape here is for visualization, the original code is (w,h,2) 43 | return np.resize(data, (int(h), int(w), 2)) 44 | 45 | 46 | def readPFM(file): 47 | file = open(file, "rb") 48 | 49 | color = None 50 | width = None 51 | height = None 52 | scale = None 53 | endian = None 54 | 55 | header = file.readline().rstrip() 56 | if header == b"PF": 57 | color = True 58 | elif header == b"Pf": 59 | color = False 60 | else: 61 | raise Exception("Not a PFM file.") 62 | 63 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) 64 | if dim_match: 65 | width, height = map(int, dim_match.groups()) 66 | else: 67 | raise Exception("Malformed PFM header.") 68 | 69 | scale = float(file.readline().rstrip()) 70 | if scale < 0: # little-endian 71 | endian = "<" 72 | scale = -scale 73 | else: 74 | endian = ">" # big-endian 75 | 76 | data = np.fromfile(file, endian + "f") 77 | shape = (height, width, 3) if color else (height, width) 78 | 79 | data = np.reshape(data, shape) 80 | data = np.flipud(data) 81 | return data 82 | 83 | 84 | def writeFlow(filename, uv, v=None): 85 | """Write optical flow to file. 86 | 87 | If v is None, uv is assumed to contain both u and v channels, 88 | stacked in depth. 89 | Original code by Deqing Sun, adapted from Daniel Scharstein. 90 | """ 91 | nBands = 2 92 | 93 | if v is None: 94 | assert uv.ndim == 3 95 | assert uv.shape[2] == 2 96 | u = uv[:, :, 0] 97 | v = uv[:, :, 1] 98 | else: 99 | u = uv 100 | 101 | assert u.shape == v.shape 102 | height, width = u.shape 103 | f = open(filename, "wb") 104 | # write the header 105 | f.write(TAG_CHAR) 106 | np.array(width).astype(np.int32).tofile(f) 107 | np.array(height).astype(np.int32).tofile(f) 108 | # arrange into matrix form 109 | tmp = np.zeros((height, width * nBands)) 110 | tmp[:, np.arange(width) * 2] = u 111 | tmp[:, np.arange(width) * 2 + 1] = v 112 | tmp.astype(np.float32).tofile(f) 113 | f.close() 114 | 115 | 116 | def readFlowKITTI(filename): 117 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 118 | flow = flow[:, :, ::-1].astype(np.float32) 119 | flow, valid = flow[:, :, :2], flow[:, :, 2] 120 | flow = (flow - 2**15) / 64.0 121 | return flow, valid 122 | 123 | 124 | def readDispKITTI(filename): 125 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 126 | valid = disp > 0.0 127 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 128 | return flow, valid 129 | 130 | 131 | def writeFlowKITTI(filename, uv): 132 | uv = 64.0 * uv + 2**15 133 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 134 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 135 | cv2.imwrite(filename, uv[..., ::-1]) 136 | 137 | 138 | def read_gen(file_name, pil=False): 139 | ext = splitext(file_name)[-1] 140 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": 141 | return Image.open(file_name) 142 | elif ext == ".bin" or ext == ".raw": 143 | return np.load(file_name) 144 | elif ext == ".flo": 145 | return readFlow(file_name).astype(np.float32) 146 | elif ext == ".pfm": 147 | flow = readPFM(file_name).astype(np.float32) 148 | if len(flow.shape) == 2: 149 | return flow 150 | else: 151 | return flow[:, :, :-1] 152 | return [] 153 | -------------------------------------------------------------------------------- /models/gimm/src/utils/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # EMA-VFI: https://github.com/MCG-NJU/EMA-VFI 9 | # IFRNet: https://github.com/ltkong218/IFRNet 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import numpy as np 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | ## Laploss 21 | def gauss_kernel(channels=3): 22 | kernel = torch.tensor( 23 | [ 24 | [1.0, 4.0, 6.0, 4.0, 1], 25 | [4.0, 16.0, 24.0, 16.0, 4.0], 26 | [6.0, 24.0, 36.0, 24.0, 6.0], 27 | [4.0, 16.0, 24.0, 16.0, 4.0], 28 | [1.0, 4.0, 6.0, 4.0, 1.0], 29 | ] 30 | ) 31 | kernel /= 256.0 32 | kernel = kernel.repeat(channels, 1, 1, 1) 33 | kernel = kernel.to(device) 34 | return kernel 35 | 36 | 37 | def downsample(x): 38 | return x[:, :, ::2, ::2] 39 | 40 | 41 | def upsample(x): 42 | cc = torch.cat( 43 | [x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], 44 | dim=3, 45 | ) 46 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) 47 | cc = cc.permute(0, 1, 3, 2) 48 | cc = torch.cat( 49 | [ 50 | cc, 51 | torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device), 52 | ], 53 | dim=3, 54 | ) 55 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) 56 | x_up = cc.permute(0, 1, 3, 2) 57 | return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1])) 58 | 59 | 60 | def conv_gauss(img, kernel): 61 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect") 62 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) 63 | return out 64 | 65 | 66 | def laplacian_pyramid(img, kernel, max_levels=3): 67 | current = img 68 | pyr = [] 69 | for level in range(max_levels): 70 | filtered = conv_gauss(current, kernel) 71 | down = downsample(filtered) 72 | up = upsample(down) 73 | diff = current - up 74 | pyr.append(diff) 75 | current = down 76 | return pyr 77 | 78 | 79 | class LapLoss(torch.nn.Module): 80 | def __init__(self, max_levels=5, channels=3): 81 | super(LapLoss, self).__init__() 82 | self.max_levels = max_levels 83 | self.gauss_kernel = gauss_kernel(channels=channels) 84 | 85 | def forward(self, input, target): 86 | pyr_input = laplacian_pyramid( 87 | img=input, kernel=self.gauss_kernel, max_levels=self.max_levels 88 | ) 89 | pyr_target = laplacian_pyramid( 90 | img=target, kernel=self.gauss_kernel, max_levels=self.max_levels 91 | ) 92 | return sum( 93 | torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target) 94 | ) 95 | 96 | 97 | class Ternary(nn.Module): 98 | def __init__(self, patch_size=7): 99 | super(Ternary, self).__init__() 100 | self.patch_size = patch_size 101 | out_channels = patch_size * patch_size 102 | self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) 103 | self.w = np.transpose(self.w, (3, 2, 0, 1)) 104 | self.w = torch.tensor(self.w).float().to(device) 105 | 106 | def transform(self, tensor): 107 | tensor_ = tensor.mean(dim=1, keepdim=True) 108 | patches = F.conv2d(tensor_, self.w, padding=self.patch_size // 2, bias=None) 109 | loc_diff = patches - tensor_ 110 | loc_diff_norm = loc_diff / torch.sqrt(0.81 + loc_diff**2) 111 | return loc_diff_norm 112 | 113 | def valid_mask(self, tensor): 114 | padding = self.patch_size // 2 115 | b, c, h, w = tensor.size() 116 | inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor) 117 | mask = F.pad(inner, [padding] * 4) 118 | return mask 119 | 120 | def forward(self, x, y): # pred,gt 121 | loc_diff_x = self.transform(x) 122 | loc_diff_y = self.transform(y) 123 | diff = loc_diff_x - loc_diff_y.detach() 124 | dist = (diff**2 / (0.1 + diff**2)).mean(dim=1, keepdim=True) 125 | mask = self.valid_mask(x) 126 | loss = (dist * mask).mean() 127 | return loss 128 | 129 | 130 | class Charbonnier_L1(nn.Module): 131 | def __init__(self): 132 | super(Charbonnier_L1, self).__init__() 133 | 134 | def forward(self, pred, gt, mask=None): 135 | diff = pred - gt 136 | if mask is None: 137 | loss = ((diff**2 + 1e-6) ** 0.5).mean() 138 | else: 139 | loss = (((diff**2 + 1e-6) ** 0.5) * mask).mean() / (mask.mean() + 1e-9) 140 | return loss 141 | 142 | 143 | class Charbonnier_Ada(nn.Module): 144 | def __init__(self): 145 | super(Charbonnier_Ada, self).__init__() 146 | 147 | def forward(self, diff, weight): 148 | alpha = weight / 2 149 | epsilon = 10 ** (-(10 * weight - 1) / 3) 150 | loss = ((diff**2 + epsilon**2) ** alpha).mean() 151 | return loss 152 | -------------------------------------------------------------------------------- /models/gimm/src/utils/lpips/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/utils/lpips/alex.pth -------------------------------------------------------------------------------- /models/gimm/src/utils/profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | class Profiler: 12 | opts_model_size = {"trainable-only", "transformer-block-only"} 13 | 14 | def __init__(self, logger): 15 | self._logger = logger 16 | 17 | def get_model_size(self, model, opt=None): 18 | if opt is None: 19 | self._logger.info( 20 | "[OPTION: ALL] #parameters: %.4fM", 21 | sum(p.numel() for p in model.parameters()) / 1e6, 22 | ) 23 | else: 24 | assert ( 25 | opt in self.opts_model_size 26 | ), f"{opt} is not in {self.opts_model_size}" 27 | 28 | if opt == "trainable-only": 29 | self._logger.info( 30 | "[OPTION: %s] #parameters: %.4fM", 31 | opt, 32 | sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 33 | ) 34 | else: 35 | if hasattr(model, "blocks"): 36 | self._logger.info( 37 | "[OPTION: %s] #parameters: %.4fM", 38 | opt, 39 | sum(p.numel() for p in model.blocks.parameters()) / 1e6, 40 | ) 41 | -------------------------------------------------------------------------------- /models/gimm/src/utils/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | from datetime import datetime 12 | import logging 13 | import inspect 14 | import os 15 | import shutil 16 | from pathlib import Path 17 | 18 | from omegaconf import OmegaConf 19 | 20 | from .writer import Writer 21 | from .config import config_setup 22 | from .dist import initialize as dist_init 23 | 24 | 25 | def logger_setup(log_path, eval=False): 26 | log_fname = os.path.join(log_path, "val.log" if eval else "train.log") 27 | 28 | for hdlr in logging.root.handlers: 29 | logging.root.removeHandler(hdlr) 30 | 31 | SMOKE_TEST = bool(os.environ.get("SMOKE_TEST", 0)) 32 | 33 | logging.basicConfig( 34 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 35 | datefmt="%m/%d/%Y %H:%M:%S", 36 | level=(logging.DEBUG if SMOKE_TEST else logging.INFO), 37 | handlers=[logging.FileHandler(log_fname), logging.StreamHandler()], 38 | ) 39 | main_filename, *_ = inspect.getframeinfo(inspect.currentframe().f_back.f_back) 40 | 41 | logger = logging.getLogger(Path(main_filename).name) 42 | writer = Writer(log_path) 43 | 44 | return logger, writer 45 | 46 | 47 | def setup(args, extra_args=()): 48 | """ 49 | meaning of args.result_path: 50 | - if args.eval, directory where the model is 51 | - if args.resume, no meaning 52 | - otherwise, path to store the logs 53 | 54 | Returns: 55 | config, logger, writer 56 | """ 57 | 58 | distenv = dist_init(args) 59 | 60 | args.result_path = Path(args.result_path).absolute().as_posix() 61 | args.model_config = Path(args.model_config).absolute().resolve().as_posix() 62 | 63 | now = datetime.now().strftime("%d%m%Y_%H%M%S") 64 | 65 | if args.eval: 66 | config_path = Path(args.result_path).joinpath("config.yaml") 67 | log_path = Path(args.result_path).joinpath("val", now) 68 | 69 | elif args.resume: 70 | load_path = Path(args.load_path) 71 | if not load_path.is_file(): 72 | raise ValueError("load_path must be a valid filename") 73 | 74 | config_path = load_path.parent.joinpath("config.yaml").absolute() 75 | log_path = load_path.parent.parent.joinpath(now) 76 | 77 | else: 78 | config_path = Path(args.model_config).absolute() 79 | task_name = config_path.stem 80 | if args.postfix: 81 | task_name += f"__{args.postfix}" 82 | log_path = Path(args.result_path).joinpath(task_name, now) 83 | 84 | config = config_setup(args, distenv, config_path, extra_args=extra_args) 85 | config.result_path = log_path.absolute().resolve().as_posix() 86 | 87 | if distenv.master: 88 | if not log_path.exists(): 89 | os.makedirs(log_path) 90 | logger, writer = logger_setup(log_path) 91 | logger.info(distenv) 92 | logger.info(f"log_path: {log_path}") 93 | logger.info("\n" + OmegaConf.to_yaml(config)) 94 | OmegaConf.save(config, log_path.joinpath("config.yaml")) 95 | 96 | src_dir = Path(os.getcwd()).joinpath("src") 97 | shutil.copytree(src_dir, log_path.joinpath("src")) 98 | logger.info(f"source copied to {log_path}/src") 99 | else: 100 | logger, writer, log_path = None, None, None 101 | 102 | return config, logger, writer 103 | 104 | 105 | def single_setup(args, extra_args=(), train=True): 106 | assert args.eval 107 | args.model_config = Path(args.model_config).absolute().resolve().as_posix() 108 | config_path = args.model_config 109 | config = config_setup(args, None, config_path, extra_args=extra_args) 110 | return config 111 | -------------------------------------------------------------------------------- /models/gimm/src/utils/writer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc 9 | # -------------------------------------------------------- 10 | 11 | 12 | import os 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | class Writer: 17 | def __init__(self, result_path): 18 | self.result_path = result_path 19 | 20 | self.writer_trn = SummaryWriter(os.path.join(result_path, "train")) 21 | self.writer_val = SummaryWriter(os.path.join(result_path, "valid")) 22 | self.writer_val_ema = SummaryWriter(os.path.join(result_path, "valid_ema")) 23 | 24 | def _get_writer(self, mode): 25 | if mode == "train": 26 | writer = self.writer_trn 27 | elif mode == "valid": 28 | writer = self.writer_val 29 | elif mode == "valid_ema": 30 | writer = self.writer_val_ema 31 | else: 32 | raise ValueError(f"{mode} is not valid..") 33 | 34 | return writer 35 | 36 | def add_scalar(self, tag, scalar, mode, epoch=0): 37 | writer = self._get_writer(mode) 38 | writer.add_scalar(tag, scalar, epoch) 39 | 40 | def add_image(self, tag, image, mode, epoch=0): 41 | writer = self._get_writer(mode) 42 | writer.add_image(tag, image, epoch) 43 | 44 | def add_text(self, tag, text, mode, epoch=0): 45 | writer = self._get_writer(mode) 46 | writer.add_text(tag, text, epoch) 47 | 48 | def add_audio(self, tag, audio, mode, sampling_rate=16000, epoch=0): 49 | writer = self._get_writer(mode) 50 | writer.add_audio(tag, audio, epoch, sampling_rate) 51 | 52 | def close(self): 53 | self.writer_trn.close() 54 | self.writer_val.close() 55 | self.writer_val_ema.close() 56 | -------------------------------------------------------------------------------- /models/gmflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gmflow/__init__.py -------------------------------------------------------------------------------- /models/gmflow/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.gmflow.trident_conv import MultiScaleTridentConv 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, 8 | ): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 12 | dilation=dilation, padding=dilation, stride=stride, bias=False) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | self.norm1 = norm_layer(planes) 18 | self.norm2 = norm_layer(planes) 19 | if not stride == 1 or in_planes != planes: 20 | self.norm3 = norm_layer(planes) 21 | 22 | if stride == 1 and in_planes == planes: 23 | self.downsample = None 24 | else: 25 | self.downsample = nn.Sequential( 26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 27 | 28 | def forward(self, x): 29 | y = x 30 | y = self.relu(self.norm1(self.conv1(y))) 31 | y = self.relu(self.norm2(self.conv2(y))) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | 36 | return self.relu(x + y) 37 | 38 | 39 | class CNNEncoder(nn.Module): 40 | def __init__(self, output_dim=128, 41 | norm_layer=nn.InstanceNorm2d, 42 | num_output_scales=1, 43 | **kwargs, 44 | ): 45 | super(CNNEncoder, self).__init__() 46 | self.num_branch = num_output_scales 47 | 48 | feature_dims = [64, 96, 128] 49 | 50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 51 | self.norm1 = norm_layer(feature_dims[0]) 52 | self.relu1 = nn.ReLU(inplace=True) 53 | 54 | self.in_planes = feature_dims[0] 55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 57 | 58 | # highest resolution 1/4 or 1/8 59 | stride = 2 if num_output_scales == 1 else 1 60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride, 61 | norm_layer=norm_layer, 62 | ) # 1/4 or 1/8 63 | 64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) 65 | 66 | if self.num_branch > 1: 67 | if self.num_branch == 4: 68 | strides = (1, 2, 4, 8) 69 | elif self.num_branch == 3: 70 | strides = (1, 2, 4) 71 | elif self.num_branch == 2: 72 | strides = (1, 2) 73 | else: 74 | raise ValueError 75 | 76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, 77 | kernel_size=3, 78 | strides=strides, 79 | paddings=1, 80 | num_branch=self.num_branch, 81 | ) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 87 | if m.weight is not None: 88 | nn.init.constant_(m.weight, 1) 89 | if m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): 93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) 94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) 95 | 96 | layers = (layer1, layer2) 97 | 98 | self.in_planes = dim 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.norm1(x) 104 | x = self.relu1(x) 105 | 106 | x = self.layer1(x) # 1/2 107 | x = self.layer2(x) # 1/4 108 | x = self.layer3(x) # 1/8 or 1/4 109 | 110 | x = self.conv2(x) 111 | 112 | if self.num_branch > 1: 113 | out = self.trident_conv([x] * self.num_branch) # high to low res 114 | else: 115 | out = [x] 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /models/gmflow/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | coords_grid_cache = {} 4 | 5 | def coords_grid(b, h, w, homogeneous=False, device=None, dtype: torch.dtype=torch.float32): 6 | k = (str(device), str((b, h, w))) 7 | if k in coords_grid_cache: 8 | return coords_grid_cache[k] 9 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] 10 | 11 | stacks = [x, y] 12 | 13 | if homogeneous: 14 | ones = torch.ones_like(x) # [H, W] 15 | stacks.append(ones) 16 | 17 | grid = torch.stack(stacks, dim=0) # [2, H, W] or [3, H, W] 18 | 19 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] 20 | 21 | if device is not None: 22 | grid = grid.to(device, dtype=dtype) 23 | coords_grid_cache[k] = grid 24 | return grid 25 | 26 | window_grid_cache = {} 27 | 28 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None, dtype=torch.float32): 29 | assert device is not None 30 | k = (str(device), str((h_min, h_max, w_min, w_max, len_h, len_w))) 31 | if k in window_grid_cache: 32 | return window_grid_cache[k] 33 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), 34 | torch.linspace(h_min, h_max, len_h, device=device)], 35 | ) 36 | grid = torch.stack((x, y), -1).transpose(0, 1).to(device, dtype=dtype) # [H, W, 2] 37 | window_grid_cache[k] = grid 38 | return grid 39 | 40 | normalize_coords_cache = {} 41 | 42 | def normalize_coords(coords, h, w): 43 | # coords: [B, H, W, 2] 44 | k = (str(coords.device), str((h, w))) 45 | if k in normalize_coords_cache: 46 | c = normalize_coords_cache[k] 47 | else: 48 | c = torch.tensor([(w - 1) / 2., (h - 1) / 2.], dtype=coords.dtype, device=coords.device) 49 | normalize_coords_cache[k] = c 50 | return (coords - c) / c # [-1, 1] 51 | 52 | 53 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): 54 | # img: [B, C, H, W] 55 | # sample_coords: [B, 2, H, W] in image scale 56 | if sample_coords.size(1) != 2: # [B, H, W, 2] 57 | sample_coords = sample_coords.permute(0, 3, 1, 2) 58 | 59 | b, _, h, w = sample_coords.shape 60 | 61 | # Normalize to [-1, 1] 62 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 63 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 64 | 65 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] 66 | 67 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 68 | 69 | if return_mask: 70 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] 71 | 72 | return img, mask 73 | 74 | return img 75 | 76 | 77 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'): 78 | b, c, h, w = feature.size() 79 | assert flow.size(1) == 2 80 | 81 | grid = coords_grid(b, h, w, device=flow.device, dtype=flow.dtype) + flow # [B, 2, H, W] 82 | 83 | return bilinear_sample(feature, grid, padding_mode=padding_mode, 84 | return_mask=mask) 85 | 86 | 87 | def forward_backward_consistency_check(fwd_flow, bwd_flow, 88 | alpha=0.01, 89 | beta=0.5 90 | ): 91 | # fwd_flow, bwd_flow: [B, 2, H, W] 92 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) 93 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 94 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 95 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] 96 | 97 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] 98 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] 99 | 100 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] 101 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) 102 | 103 | threshold = alpha * flow_mag + beta 104 | 105 | fwd_occ = (diff_fwd > threshold).to(fwd_flow) # [B, H, W] 106 | bwd_occ = (diff_bwd > threshold).to(bwd_flow) 107 | 108 | return fwd_occ, bwd_occ 109 | -------------------------------------------------------------------------------- /models/gmflow/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from models.gmflow.geometry import coords_grid, generate_window_grid, normalize_coords 5 | 6 | 7 | def global_correlation_softmax(feature0, feature1, 8 | pred_bidir_flow=False, 9 | ): 10 | # global correlation 11 | b, c, h, w = feature0.shape 12 | feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] 13 | feature1 = feature1.view(b, c, -1) # [B, C, H*W] 14 | 15 | correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] 16 | 17 | # flow from softmax 18 | init_grid = coords_grid(b, h, w, device=correlation.device, dtype=feature0.dtype) # [B, 2, H, W] 19 | grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 20 | 21 | correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] 22 | 23 | if pred_bidir_flow: 24 | correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] 25 | init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] 26 | grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] 27 | b = b * 2 28 | 29 | prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] 30 | 31 | correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 32 | 33 | # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow 34 | flow = correspondence - init_grid 35 | 36 | return flow, prob 37 | 38 | 39 | def local_correlation_softmax(feature0, feature1, local_radius, 40 | padding_mode='zeros', 41 | ): 42 | b, c, h, w = feature0.size() 43 | coords_init = coords_grid(b, h, w, device=feature0.device, dtype=feature0.dtype) # [B, 2, H, W] 44 | coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] 45 | 46 | local_h = 2 * local_radius + 1 47 | local_w = 2 * local_radius + 1 48 | 49 | window_grid = generate_window_grid(-local_radius, local_radius, 50 | -local_radius, local_radius, 51 | local_h, local_w, device=feature0.device, dtype=feature0.dtype) # [2R+1, 2R+1, 2] 52 | window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] 53 | sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] 54 | 55 | sample_coords_softmax = sample_coords 56 | 57 | # exclude coords that are out of image space 58 | valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] 59 | valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] 60 | 61 | valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax 62 | 63 | # normalize coordinates to [-1, 1] 64 | sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] 65 | window_feature = F.grid_sample(feature1.contiguous(), sample_coords_norm.contiguous(), 66 | padding_mode=padding_mode, align_corners=True 67 | ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] 68 | feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] 69 | 70 | corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] 71 | 72 | # mask invalid locations 73 | corr[~valid] = -1e4 74 | 75 | prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] 76 | 77 | correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( 78 | b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] 79 | 80 | flow = correspondence - coords_init 81 | match_prob = prob 82 | 83 | return flow, match_prob 84 | -------------------------------------------------------------------------------- /models/gmflow/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from models.utils.tools import get_ones_tensor_size 10 | 11 | tensor_cache = dict() 12 | 13 | class PositionEmbeddingSine(nn.Module): 14 | """ 15 | This is a more standard version of the position embedding, very similar to the one 16 | used by the Attention is all you need paper, generalized to work on images. 17 | """ 18 | 19 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 20 | super().__init__() 21 | self.num_pos_feats = num_pos_feats 22 | self.temperature = temperature 23 | self.normalize = normalize 24 | if scale is not None and normalize is False: 25 | raise ValueError("normalize should be True if scale is passed") 26 | if scale is None: 27 | scale = 2 * math.pi 28 | self.scale = scale 29 | 30 | def forward(self, x): 31 | # x = tensor_list.tensors # [B, C, H, W] 32 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 33 | b, c, h, w = x.size() 34 | mask = get_ones_tensor_size((b, h, w), device=x.device, dtype=x.dtype) # [B, H, W] 35 | y_embed = mask.cumsum(1) 36 | x_embed = mask.cumsum(2) 37 | if self.normalize: 38 | eps = 1e-6 39 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 40 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 41 | 42 | if 'dim_t' not in tensor_cache: 43 | dim_t = torch.arange(self.num_pos_feats, device=x.device, dtype=x.dtype) 44 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 45 | tensor_cache['dim_t'] = dim_t 46 | else: 47 | dim_t = tensor_cache['dim_t'] 48 | 49 | pos_x = x_embed[:, :, :, None] / dim_t 50 | pos_y = y_embed[:, :, :, None] / dim_t 51 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 52 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 53 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 54 | return pos 55 | -------------------------------------------------------------------------------- /models/gmflow/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter( 53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 54 | ) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 57 | else: 58 | self.bias = None 59 | 60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0) 63 | 64 | def forward(self, inputs): 65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 66 | assert len(inputs) == num_branch 67 | 68 | if self.training or self.test_branch_idx == -1: 69 | outputs = [ 70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 71 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 72 | ] 73 | else: 74 | outputs = [ 75 | F.conv2d( 76 | inputs[0], 77 | self.weight, 78 | self.bias, 79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 81 | self.dilation, 82 | self.groups, 83 | ) 84 | ] 85 | 86 | if self.norm is not None: 87 | outputs = [self.norm(x) for x in outputs] 88 | if self.activation is not None: 89 | outputs = [self.activation(x) for x in outputs] 90 | return outputs 91 | -------------------------------------------------------------------------------- /models/gmflow/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.gmflow.position import PositionEmbeddingSine 3 | 4 | 5 | def split_feature(feature, 6 | num_splits=2, 7 | channel_last=False, 8 | ): 9 | if channel_last: # [B, H, W, C] 10 | b, h, w, c = feature.size() 11 | assert h % num_splits == 0 and w % num_splits == 0 12 | 13 | b_new = b * num_splits * num_splits 14 | h_new = h // num_splits 15 | w_new = w // num_splits 16 | 17 | feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c 18 | ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] 19 | else: # [B, C, H, W] 20 | b, c, h, w = feature.size() 21 | assert h % num_splits == 0 and w % num_splits == 0 22 | 23 | b_new = b * num_splits * num_splits 24 | h_new = h // num_splits 25 | w_new = w // num_splits 26 | 27 | feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits 28 | ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] 29 | 30 | return feature 31 | 32 | 33 | def merge_splits(splits, 34 | num_splits=2, 35 | channel_last=False, 36 | ): 37 | if channel_last: # [B*K*K, H/K, W/K, C] 38 | b, h, w, c = splits.size() 39 | new_b = b // num_splits // num_splits 40 | 41 | splits = splits.view(new_b, num_splits, num_splits, h, w, c) 42 | merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( 43 | new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] 44 | else: # [B*K*K, C, H/K, W/K] 45 | b, c, h, w = splits.size() 46 | new_b = b // num_splits // num_splits 47 | 48 | splits = splits.view(new_b, num_splits, num_splits, c, h, w) 49 | merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( 50 | new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] 51 | 52 | return merge 53 | 54 | 55 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 56 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 57 | is_mean_std_loaded = False 58 | 59 | 60 | def normalize_img(img0, img1): 61 | # loaded images are in [0, 255] 62 | # normalize by ImageNet mean and std 63 | global mean, std, is_mean_std_loaded 64 | if not is_mean_std_loaded: 65 | mean = mean.to(img0) 66 | std = std.to(img0) 67 | is_mean_std_loaded = True 68 | img0 = (img0 - mean) / std 69 | img1 = (img1 - mean) / std 70 | 71 | return img0, img1 72 | 73 | 74 | def feature_add_position(feature0, feature1, attn_splits, feature_channels): 75 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) 76 | 77 | if attn_splits > 1: # add position in splited window 78 | feature0_splits = split_feature(feature0, num_splits=attn_splits) 79 | feature1_splits = split_feature(feature1, num_splits=attn_splits) 80 | 81 | position = pos_enc(feature0_splits) 82 | 83 | feature0_splits = feature0_splits + position 84 | feature1_splits = feature1_splits + position 85 | 86 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits) 87 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits) 88 | else: 89 | position = pos_enc(feature0) 90 | 91 | feature0 = feature0 + position 92 | feature1 = feature1 + position 93 | 94 | return feature0, feature1 95 | -------------------------------------------------------------------------------- /models/model_pg104/FeatureNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FeatureNet(nn.Module): 7 | """The quadratic model""" 8 | def __init__(self): 9 | super(FeatureNet, self).__init__() 10 | self.block1 = nn.Sequential( 11 | nn.PReLU(), 12 | nn.Conv2d(3, 64, 3, 2, 1), 13 | nn.PReLU(), 14 | nn.Conv2d(64, 64, 3, 1, 1), 15 | ) 16 | self.block2 = nn.Sequential( 17 | nn.PReLU(), 18 | nn.Conv2d(64, 128, 3, 2, 1), 19 | nn.PReLU(), 20 | nn.Conv2d(128, 128, 3, 1, 1), 21 | ) 22 | self.block3 = nn.Sequential( 23 | nn.PReLU(), 24 | nn.Conv2d(128, 192, 3, 2, 1), 25 | nn.PReLU(), 26 | nn.Conv2d(192, 192, 3, 1, 1), 27 | ) 28 | 29 | def forward(self, x): 30 | x1 = self.block1(x) 31 | x2 = self.block2(x1) 32 | x3 = self.block3(x2) 33 | 34 | return x1, x2, x3 -------------------------------------------------------------------------------- /models/model_pg104/IFNet_HDv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.model_pg104.warplayer import warp 6 | 7 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 8 | return nn.Sequential( 9 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 10 | padding=padding, dilation=dilation, bias=True), 11 | nn.LeakyReLU(0.2, True) 12 | ) 13 | 14 | class ResConv(nn.Module): 15 | def __init__(self, c, dilation=1): 16 | super(ResConv, self).__init__() 17 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) 18 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) 19 | self.relu = nn.LeakyReLU(0.2, True) 20 | 21 | def forward(self, x): 22 | return self.relu(self.conv(x) * self.beta + x) 23 | 24 | class IFBlock(nn.Module): 25 | def __init__(self, in_planes, c=64): 26 | super(IFBlock, self).__init__() 27 | self.conv0 = nn.Sequential( 28 | conv(in_planes, c//2, 3, 2, 1), 29 | conv(c//2, c, 3, 2, 1), 30 | ) 31 | self.convblock = nn.Sequential( 32 | ResConv(c), 33 | ResConv(c), 34 | ResConv(c), 35 | ResConv(c), 36 | ResConv(c), 37 | ResConv(c), 38 | ResConv(c), 39 | ResConv(c), 40 | ) 41 | self.lastconv = nn.Sequential( 42 | nn.ConvTranspose2d(c, 4*6, 4, 2, 1), 43 | nn.PixelShuffle(2) 44 | ) 45 | 46 | def forward(self, x, flow=None, scale=1): 47 | x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False) 48 | if flow is not None: 49 | flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 50 | x = torch.cat((x, flow), 1) 51 | feat = self.conv0(x) 52 | feat = self.convblock(feat) 53 | tmp = self.lastconv(feat) 54 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) 55 | flow = tmp[:, :4] * scale 56 | mask = tmp[:, 4:5] 57 | return flow, mask 58 | 59 | class IFNet(nn.Module): 60 | def __init__(self): 61 | super(IFNet, self).__init__() 62 | self.block0 = IFBlock(7, c=192) 63 | self.block1 = IFBlock(8+4, c=128) 64 | self.block2 = IFBlock(8+4, c=96) 65 | self.block3 = IFBlock(8+4, c=64) 66 | 67 | def forward( self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False): 68 | if training == False: 69 | channel = x.shape[1] // 2 70 | img0 = x[:, :channel] 71 | img1 = x[:, channel:] 72 | if not torch.is_tensor(timestep): 73 | timestep = (x[:, :1].clone() * 0 + 1) * timestep 74 | else: 75 | timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) 76 | flow = None 77 | block = [self.block0, self.block1, self.block2, self.block3] 78 | for i in range(4): 79 | if flow is None: 80 | flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], timestep), 1), None, scale=scale_list[i]) 81 | if ensemble: 82 | f1, m1 = block[i](torch.cat((img1[:, :3], img0[:, :3], 1-timestep), 1), None, scale=scale_list[i]) 83 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 84 | mask = (mask + (-m1)) / 2 85 | else: 86 | f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], timestep, mask), 1), flow, scale=scale_list[i]) 87 | if ensemble: 88 | f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) 89 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 90 | m0 = (m0 + (-m1)) / 2 91 | flow = flow + f0 92 | mask = mask + m0 93 | warped_img0 = warp(img0, flow[:, :2]) 94 | warped_img1 = warp(img1, flow[:, 2:4]) 95 | mask = torch.sigmoid(mask) 96 | merged = warped_img0 * mask + warped_img1 * (1 - mask) 97 | return merged 98 | -------------------------------------------------------------------------------- /models/model_pg104/MetricNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.gmflow.geometry import forward_backward_consistency_check 6 | 7 | 8 | backwarp_tenGrid = {} 9 | 10 | def backwarp(tenIn, tenflow): 11 | if str(tenflow.shape) not in backwarp_tenGrid: 12 | tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenflow.shape[3], dtype=tenflow.dtype, device=tenflow.device).view(1, 1, 1, -1).repeat(1, 1, tenflow.shape[2], 1) 13 | tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenflow.shape[2], dtype=tenflow.dtype, device=tenflow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenflow.shape[3]) 14 | 15 | backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() 16 | # end 17 | 18 | tenflow = torch.cat([tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1) 19 | 20 | return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) 21 | 22 | 23 | class MetricNet(nn.Module): 24 | def __init__(self): 25 | super(MetricNet, self).__init__() 26 | self.metric_in = nn.Conv2d(14, 64, 3, 1, 1) 27 | self.metric_net1 = nn.Sequential( 28 | nn.PReLU(), 29 | nn.Conv2d(64, 64, 3, 1, 1) 30 | ) 31 | self.metric_net2 = nn.Sequential( 32 | nn.PReLU(), 33 | nn.Conv2d(64, 64, 3, 1, 1) 34 | ) 35 | self.metric_net3 = nn.Sequential( 36 | nn.PReLU(), 37 | nn.Conv2d(64, 64, 3, 1, 1) 38 | ) 39 | self.metric_out = nn.Sequential( 40 | nn.PReLU(), 41 | nn.Conv2d(64, 2, 3, 1, 1), 42 | nn.Tanh() 43 | ) 44 | 45 | def forward(self, img0, img1, flow01, flow10): 46 | metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction='none').mean([1], True) 47 | metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction='none').mean([1], True) 48 | 49 | fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10) 50 | 51 | flow01 = torch.cat([flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0), flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0)], 1) 52 | flow10 = torch.cat([flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0), flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0)], 1) 53 | 54 | img = torch.cat((img0, img1), 1) 55 | metric = torch.cat((-metric0, -metric1), 1) 56 | flow = torch.cat((flow01, flow10), 1) 57 | occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1) 58 | 59 | feat = self.metric_in(torch.cat((img, metric, flow, occ), 1)) 60 | feat = self.metric_net1(feat) + feat 61 | feat = self.metric_net2(feat) + feat 62 | feat = self.metric_net3(feat) + feat 63 | metric = self.metric_out(feat) * 10 64 | 65 | return metric[:, :1], metric[:, 1:2] 66 | -------------------------------------------------------------------------------- /models/model_pg104/warplayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | backwarp_tenGrid = {} 6 | 7 | 8 | def warp(tenInput, tenFlow): 9 | k = (str(tenFlow.device), str(tenFlow.size())) 10 | if k not in backwarp_tenGrid: 11 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 12 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 13 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 14 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 15 | backwarp_tenGrid[k] = torch.cat( 16 | [tenHorizontal, tenVertical], 1).to(device) 17 | 18 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 19 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 20 | 21 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 22 | return torch.nn.functional.grid_sample(input=tenInput, grid=g.to(tenInput.dtype), mode='bilinear', padding_mode='border', align_corners=True) -------------------------------------------------------------------------------- /models/vfi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.IFNet_HDv3 import IFNet 3 | from models.gimm.src.utils.setup import single_setup 4 | from models.gimm.src.models import create_model 5 | from models.model_pg104.GMFSS import Model as GMFSS 6 | import argparse 7 | from models.utils.tools import * 8 | 9 | 10 | class VFI: 11 | def __init__(self, model_type='rife', weights='weights', scale=1.0, 12 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): 13 | if model_type == 'rife': 14 | model = IFNet() 15 | model.load_state_dict(convert(torch.load(f'{weights}/rife48.pkl'))) 16 | elif model_type == 'gmfss': 17 | model = GMFSS() 18 | model.load_model(f'{weights}/train_log_pg104', -1) 19 | else: 20 | args = argparse.Namespace( 21 | model_config=r"models/gimm/configs/gimmvfi/gimmvfi_r_arb.yaml", 22 | load_path=f"{weights}/gimmvfi_r_arb_lpips.pt", 23 | ds_factor=scale, 24 | eval=True, 25 | seed=0 26 | ) 27 | config = single_setup(args) 28 | model, _ = create_model(config.arch) 29 | 30 | # Checkpoint loading 31 | if "ours" in args.load_path: 32 | ckpt = torch.load(args.load_path, map_location="cpu") 33 | 34 | def convert_gimm(param): 35 | return { 36 | k.replace("module.feature_bone", "frame_encoder"): v 37 | for k, v in param.items() 38 | if "feature_bone" in k 39 | } 40 | 41 | ckpt = convert_gimm(ckpt) 42 | model.load_state_dict(ckpt, strict=False) 43 | else: 44 | ckpt = torch.load(args.load_path, map_location="cpu") 45 | model.load_state_dict(ckpt["state_dict"], strict=False) 46 | 47 | model.eval() 48 | if model_type == 'gmfss': 49 | model.device() 50 | else: 51 | model.to(device) 52 | 53 | self.model = model 54 | self.model_type = model_type 55 | base_pads = { 56 | 'gimm': 64, 57 | 'rife': 64, 58 | 'gmfss': 128, 59 | } 60 | self.pad_size = base_pads[model_type] / scale 61 | self.device = device 62 | self.saved_result = {} 63 | self.scale = scale 64 | 65 | @torch.inference_mode() 66 | def gen_ts_frame(self, x, y, ts): 67 | _outputs = list() 68 | head = [x] if 0 in ts else [] 69 | tail = [y] if 1 in ts else [] 70 | if 0 in ts: 71 | ts.remove(0) 72 | if 1 in ts: 73 | ts.remove(1) 74 | with torch.autocast(str(self.device)): 75 | _reuse_things = self.model.reuse(x, y, self.scale) if self.model_type == 'gmfss' else None 76 | if self.model_type in ['rife', 'gmfss']: 77 | for t in ts: 78 | if self.model_type == 'rife': 79 | scale_list = [8 / self.scale, 4 / self.scale, 2 / self.scale, 1 / self.scale] 80 | _out = self.model(torch.cat((x, y), dim=1), t, scale_list) 81 | elif self.model_type == 'gmfss': 82 | _out = self.model.inference(x, y, _reuse_things, t) 83 | _outputs.append(_out) 84 | elif self.model_type == 'gimm': 85 | xs = torch.cat((x.unsqueeze(2), y.unsqueeze(2)), dim=2).to( 86 | self.device, non_blocking=True 87 | ) 88 | self.model.zero_grad() 89 | coord_inputs = [ 90 | ( 91 | self.model.sample_coord_input( 92 | xs.shape[0], 93 | xs.shape[-2:], 94 | [t], 95 | device=xs.device, 96 | upsample_ratio=self.scale, 97 | ), 98 | None, 99 | ) 100 | for t in ts 101 | ] 102 | timesteps = [ 103 | t * torch.ones(xs.shape[0]).to(xs.device).to(torch.float) 104 | for t in ts 105 | ] 106 | all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.scale) 107 | 108 | _outputs = all_outputs["imgt_pred"] 109 | 110 | _outputs = head + _outputs + tail 111 | 112 | return _outputs 113 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ffmpeg-python>=0.2.0 2 | numpy>=1.16, <=1.23.5 3 | torch>=2.5.1 4 | torchvision>=0.20.1 5 | tqdm>=4.35.0 6 | opencv-python>=4.1.2 7 | cupy-cuda11x 8 | decorator==5.1.1 9 | easydict==1.9 10 | einops==0.6.1 11 | imageio==2.31.2 12 | importlib-metadata==6.7.0 13 | omegaconf==2.3.0 14 | Pillow==9.5.0 15 | protobuf==3.20.0 16 | safetensors==0.4.2 17 | scikit-image==0.19.3 18 | scikit-learn==0.24.0 19 | scipy==1.7.3 20 | tensorboard==2.3.0 21 | tensorboard-plugin-wit==1.8.1 22 | timm==0.4.12 23 | typing_extensions==4.7.1 24 | yacs==0.1.6 --------------------------------------------------------------------------------