├── LICENSE
├── MultiPassDedup.ipynb
├── README.md
├── assert
└── result.gif
├── infer.py
├── models
├── IFNet_HDv3.py
├── gimm
│ ├── configs
│ │ ├── gimm
│ │ │ └── gimm.yaml
│ │ └── gimmvfi
│ │ │ ├── gimmvfi_f_arb.yaml
│ │ │ └── gimmvfi_r_arb.yaml
│ └── src
│ │ ├── models
│ │ ├── __init__.py
│ │ ├── ema.py
│ │ └── generalizable_INR
│ │ │ ├── __init__.py
│ │ │ ├── configs.py
│ │ │ ├── flowformer
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── alt_cuda_corr
│ │ │ │ ├── correlation.cpp
│ │ │ │ ├── correlation_kernel.cu
│ │ │ │ └── setup.py
│ │ │ ├── assets
│ │ │ │ └── teaser.png
│ │ │ ├── chairs_split.txt
│ │ │ ├── configs
│ │ │ │ ├── default.py
│ │ │ │ ├── kitti.py
│ │ │ │ ├── sintel.py
│ │ │ │ ├── small_things_eval.py
│ │ │ │ ├── submission.py
│ │ │ │ ├── things.py
│ │ │ │ ├── things_eval.py
│ │ │ │ └── things_flowformer_sharp.py
│ │ │ ├── core
│ │ │ │ ├── FlowFormer
│ │ │ │ │ ├── LatentCostFormer
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── attention.py
│ │ │ │ │ │ ├── cnn.py
│ │ │ │ │ │ ├── convnext.py
│ │ │ │ │ │ ├── decoder.py
│ │ │ │ │ │ ├── encoder.py
│ │ │ │ │ │ ├── gma.py
│ │ │ │ │ │ ├── gru.py
│ │ │ │ │ │ ├── mlpmixer.py
│ │ │ │ │ │ ├── transformer.py
│ │ │ │ │ │ └── twins.py
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── common.py
│ │ │ │ │ └── encoders.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── corr.py
│ │ │ │ ├── datasets.py
│ │ │ │ ├── extractor.py
│ │ │ │ ├── loss.py
│ │ │ │ ├── optimizer
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── position_encoding.py
│ │ │ │ ├── raft.py
│ │ │ │ ├── update.py
│ │ │ │ └── utils
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── augmentor.py
│ │ │ │ │ ├── datasets.py
│ │ │ │ │ ├── flow_transforms.py
│ │ │ │ │ ├── flow_viz.py
│ │ │ │ │ ├── frame_utils.py
│ │ │ │ │ ├── logger.py
│ │ │ │ │ ├── misc.py
│ │ │ │ │ └── utils.py
│ │ │ ├── evaluate_FlowFormer.py
│ │ │ ├── evaluate_FlowFormer_tile.py
│ │ │ ├── run_train.sh
│ │ │ ├── train_FlowFormer.py
│ │ │ └── visualize_flow.py
│ │ │ ├── gimm.py
│ │ │ ├── gimmvfi_f.py
│ │ │ ├── gimmvfi_r.py
│ │ │ ├── gmflow
│ │ │ ├── __init__.py
│ │ │ ├── backbone.py
│ │ │ ├── geometry.py
│ │ │ ├── gmflow.py
│ │ │ ├── matching.py
│ │ │ ├── position.py
│ │ │ ├── transformer.py
│ │ │ ├── trident_conv.py
│ │ │ └── utils.py
│ │ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── coord_sampler.py
│ │ │ ├── fi_components.py
│ │ │ ├── fi_utils.py
│ │ │ ├── hyponet.py
│ │ │ ├── layers.py
│ │ │ ├── module_config.py
│ │ │ ├── softsplat.py
│ │ │ └── utils.py
│ │ │ └── raft
│ │ │ ├── __init__.py
│ │ │ ├── corr.py
│ │ │ ├── extractor.py
│ │ │ ├── other_raft.py
│ │ │ ├── raft.py
│ │ │ ├── update.py
│ │ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── augmentor.py
│ │ │ ├── flow_viz.py
│ │ │ ├── frame_utils.py
│ │ │ └── utils.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── accumulator.py
│ │ ├── config.py
│ │ ├── dist.py
│ │ ├── flow_viz.py
│ │ ├── frame_utils.py
│ │ ├── loss.py
│ │ ├── lpips
│ │ ├── __init__.py
│ │ ├── alex.pth
│ │ ├── lpips.py
│ │ └── pretrained_networks.py
│ │ ├── profiler.py
│ │ ├── setup.py
│ │ ├── utils.py
│ │ └── writer.py
├── gmflow
│ ├── __init__.py
│ ├── backbone.py
│ ├── geometry.py
│ ├── gmflow.py
│ ├── matching.py
│ ├── position.py
│ ├── transformer.py
│ ├── trident_conv.py
│ └── utils.py
├── model_pg104
│ ├── FeatureNet.py
│ ├── FusionNet.py
│ ├── GMFSS.py
│ ├── IFNet_HDv3.py
│ ├── MetricNet.py
│ ├── softsplat.py
│ └── warplayer.py
├── pytorch_msssim
│ └── __init__.py
├── softsplat
│ ├── softsplat.py
│ └── softsplat_torch.py
├── utils
│ └── tools.py
└── vfi.py
└── requirements.txt
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 hyw-dev
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 📖MultiPassDedup
2 |
3 | ### Efficient Deduplicate for Anime Video Frame Interpolation
4 | > When performing frame interpolation on anime footage, conventional deduplication methods often rely on identification, which has many drawbacks, such as losing background textures and failing to correctly handle multiple characters drawn with different cadences in the same scene.
5 | >
6 | > Through observation and summarization of patterns in anime videos, we found that repeatedly updating the original frames provides an easier and more effective solution to these issues.
7 | Therefore, we developed this project to implement this approach. Combined with the powerful GMFSS interpolation algorithm, we can achieve excellent results in most anime scenarios.
8 |
9 | 
10 |
11 |
12 |
13 | ## 👀Demos Videos(BiliBili)
14 | ### [Jujutsu Kaisen S2 NCOP](https://www.bilibili.com/video/BV16W421N7s5/?share_source=copy_web&vd_source=8a8926eb0f1d5f0f1cab7529c8f51282)
15 | ### [Houseki no Kuni NCOP](https://www.bilibili.com/video/BV1py4y1A7qj/?share_source=copy_web&vd_source=8a8926eb0f1d5f0f1cab7529c8f51282)
16 |
17 | ## 🔧Installation
18 | ```bash
19 | git clone https://github.com/routineLife1/MultiPassDedup.git
20 | cd DRBA
21 | pip3 install -r requirements.txt
22 | ```
23 | download weights from [Google Drive](https://drive.google.com/file/d/1gXyqRiLgZ0sQEuDl4vbbxIgbUvg3k50x/view?usp=sharing) and unzip it, put them to ./weights/
24 |
25 |
26 | The cupy package is included in the requirements, but its installation is optional. It is used to accelerate computation. If you encounter difficulties while installing this package, you can skip it.
27 |
28 |
29 | ## ⚡Usage
30 | - normalize the source video to 24000/1001 fps by following command using ffmpeg **(If the INPUT video framerate is around 23.976, skip this step.)**
31 | ```bash
32 | ffmpeg -i INPUT -crf 16 -r 24000/1001 -preset slow -c:v libx265 -x265-params profile=main10 -c:a copy OUTPUT
33 | ```
34 | - open the video and check out it's max consistent deduplication counts, (3 -> on Three, 2 -> on Two, 0 -> AUTO) **(If the INPUT video framerate is around 23.976, skip this step.)**
35 | - run the follwing command to finish interpolation
36 | (N_PASS = max_consistent_deduplication_counts) **(Under the most circumstances, -np 0 can automatically determine an appropriate n_pass value)**
37 | ```bash
38 | python infer.py -i [VIDEO] -o [VIDEO_OUTPUT] -np [N_PASS] -t [TIMES] -m [MODEL_TYPE] -s -st 0.3 -scale [SCALE]
39 | # or use the following command to export video at any frame rate
40 | python infer.py -i [VIDEO] -o [VIDEO_OUTPUT] -np [N_PASS] -fps [OUTPUT_FPS] -m [MODEL_TYPE] -s -st 0.3 -scale [SCALE]
41 | ```
42 |
43 | **example(smooth a 23.976fps video with on three and interpolate it to 60fps):**
44 |
45 | ```bash
46 | ffmpeg -i E:/Myvideo/01_src.mkv -crf 16 -r 24000/1001 -preset slow -c:v libx265 -x265-params profile=main10 -c:a copy E:/Myvideo/01.mkv
47 |
48 | python infer.py -i E:/MyVideo/01.mkv -o E:/MyVideo/out.mkv -np 3 -fps 60 -m gmfss -s -st 0.3 -scale 1.0
49 | ```
50 |
51 | **Full Usage**
52 | ```bash
53 | Usage: python infer.py -i in_video -o out_video [options]...
54 |
55 | -h show this help
56 | -i input input video path (absolute path of output video)
57 | -o output output video path (absolute path of output video)
58 | -fps dst_fps target frame rate (default=60)
59 | -s enable_scdet enable scene change detection (default Enable)
60 | -st scdet_threshold ssim scene detection threshold (default=0.3)
61 | -hw hwaccel enable hardware acceleration encode (default Enable) (require nvidia graph card)
62 | -s scale flow scale factor (default=1.0), generally use 1.0 with 1080P and 0.5 with 4K resolution
63 | -m model_type model type (default=gmfss)
64 | -np n_pass max consistent deduplication counts (default=3)
65 | ```
66 |
67 | - input accept absolute video file path. Example: E:/input.mp4
68 | - output accept absolute video file path. Example: E:/output.mp4
69 | - dst_fps = target interpolated video frame rate. Example: 60
70 | - enable_scdet = enable scene change detection.
71 | - scdet_threshold = scene change detection threshold. The larger the value, the more sensitive the detection.
72 | - hwaccel = enable hardware acceleration during encoding output video.
73 | - scale = flow scale factor. Decrease this value to reduce the computational difficulty of the model at higher resolutions. Generally, use 1.0 for 1080P and 0.5 for 4K resolution.
74 | - model_type = model type. Currently, gmfss, rife and gimm is supported.
75 | - n_pass = max consistent deduplication counts.
76 |
77 | ## 🤗 Acknowledgement
78 | This project is supported by [SVFI](https://doc.svfi.group/) Development Team.
79 |
80 | Thanks for [Q8sh2ing](https://github.com/Q8sh2ing) implement the Online Colab Demo.
81 |
82 | ## Reference
83 | [GMFSS](https://github.com/98mxr/GMFSS_Fortuna) [Practical-RIFE](https://github.com/hzwer/Practical-RIFE) [GIMM-VFI](https://github.com/GSeanCDAT/GIMM-VFI)
84 |
--------------------------------------------------------------------------------
/assert/result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/assert/result.gif
--------------------------------------------------------------------------------
/models/IFNet_HDv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.model_pg104.warplayer import warp
5 |
6 | # from train_log.refine import *
7 |
8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9 |
10 |
11 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
12 | return nn.Sequential(
13 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
14 | padding=padding, dilation=dilation, bias=True),
15 | nn.LeakyReLU(0.2, True)
16 | )
17 |
18 |
19 | def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
20 | return nn.Sequential(
21 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
22 | padding=padding, dilation=dilation, bias=False),
23 | nn.BatchNorm2d(out_planes),
24 | nn.LeakyReLU(0.2, True)
25 | )
26 |
27 |
28 | class ResConv(nn.Module):
29 | def __init__(self, c, dilation=1):
30 | super(ResConv, self).__init__()
31 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1 \
32 | )
33 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
34 | self.relu = nn.LeakyReLU(0.2, True)
35 |
36 | def forward(self, x):
37 | return self.relu(self.conv(x) * self.beta + x)
38 |
39 |
40 | class IFBlock(nn.Module):
41 | def __init__(self, in_planes, c=64):
42 | super(IFBlock, self).__init__()
43 | self.conv0 = nn.Sequential(
44 | conv(in_planes, c // 2, 3, 2, 1),
45 | conv(c // 2, c, 3, 2, 1),
46 | )
47 | self.convblock = nn.Sequential(
48 | ResConv(c),
49 | ResConv(c),
50 | ResConv(c),
51 | ResConv(c),
52 | ResConv(c),
53 | ResConv(c),
54 | ResConv(c),
55 | ResConv(c),
56 | )
57 | self.lastconv = nn.Sequential(
58 | nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1),
59 | nn.PixelShuffle(2)
60 | )
61 |
62 | def forward(self, x, flow=None, scale=1):
63 | x = F.interpolate(x, scale_factor=1. / scale, mode="bilinear", align_corners=False)
64 | if flow is not None:
65 | flow = F.interpolate(flow, scale_factor=1. / scale, mode="bilinear", align_corners=False) * 1. / scale
66 | x = torch.cat((x, flow), 1)
67 | feat = self.conv0(x)
68 | feat = self.convblock(feat)
69 | tmp = self.lastconv(feat)
70 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
71 | flow = tmp[:, :4] * scale
72 | mask = tmp[:, 4:5]
73 | return flow, mask
74 |
75 |
76 | class IFNet(nn.Module):
77 | def __init__(self):
78 | super(IFNet, self).__init__()
79 | self.block0 = IFBlock(7 + 8, c=192)
80 | self.block1 = IFBlock(8 + 4 + 8, c=128)
81 | self.block2 = IFBlock(8 + 4 + 8, c=96)
82 | self.block3 = IFBlock(8 + 4 + 8, c=64)
83 | self.encode = nn.Sequential(
84 | nn.Conv2d(3, 16, 3, 2, 1),
85 | nn.ConvTranspose2d(16, 4, 4, 2, 1)
86 | )
87 | # self.contextnet = Contextnet()
88 | # self.unet = Unet()
89 |
90 | def forward(self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False):
91 | if ensemble:
92 | print('ensemble is removed')
93 | if training == False:
94 | channel = x.shape[1] // 2
95 | img0 = x[:, :channel]
96 | img1 = x[:, channel:]
97 | if not torch.is_tensor(timestep):
98 | timestep = (x[:, :1].clone() * 0 + 1) * timestep
99 | else:
100 | timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
101 | f0 = self.encode(img0[:, :3])
102 | f1 = self.encode(img1[:, :3])
103 | flow_list = []
104 | merged = []
105 | mask_list = []
106 | warped_img0 = img0
107 | warped_img1 = img1
108 | flow = None
109 | mask = None
110 | loss_cons = 0
111 | block = [self.block0, self.block1, self.block2, self.block3]
112 | for i in range(4):
113 | if flow is None:
114 | flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None,
115 | scale=scale_list[i])
116 | else:
117 | fd, mask = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], warp(f0, flow[:, :2]),
118 | warp(f1, flow[:, 2:4]), timestep, mask), 1), flow, scale=scale_list[i])
119 | flow = flow + fd
120 | mask_list.append(mask)
121 | flow_list.append(flow)
122 | warped_img0 = warp(img0, flow[:, :2])
123 | warped_img1 = warp(img1, flow[:, 2:4])
124 | merged.append((warped_img0, warped_img1))
125 | mask = torch.sigmoid(mask)
126 | merged[3] = (warped_img0 * mask + warped_img1 * (1 - mask))
127 | if not fastmode:
128 | print('contextnet is removed')
129 | '''
130 | c0 = self.contextnet(img0, flow[:, :2])
131 | c1 = self.contextnet(img1, flow[:, 2:4])
132 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
133 | res = tmp[:, :3] * 2 - 1
134 | merged[3] = torch.clamp(merged[3] + res, 0, 1)
135 | '''
136 | return merged[3]
137 |
--------------------------------------------------------------------------------
/models/gimm/configs/gimm/gimm.yaml:
--------------------------------------------------------------------------------
1 | trainer: stage_inr
2 | dataset:
3 | type: fast_vimeo_flow
4 | path: ./data/vimeo90k/vimeo_triplet
5 | add_objects: false
6 | expansion: false
7 | random_t: false
8 | aug: true
9 | t_scale: 10
10 | pair: false
11 |
12 | arch: # needs to add encoder, modulation type
13 | type: gimm
14 | ema: null
15 |
16 | modulated_layer_idxs: [1]
17 |
18 | coord_range: [-1., 1.]
19 |
20 | hyponet:
21 | type: mlp
22 | n_layer: 5 # including the output layer
23 | hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1]
24 | use_bias: true
25 | input_dim: 3
26 | output_dim: 2
27 | output_bias: 0.5
28 | activation:
29 | type: siren
30 | siren_w0: 1.0
31 | initialization:
32 | weight_init_type: siren
33 | bias_init_type: siren
34 |
35 | loss:
36 | type: mse #now unnecessary
37 |
38 | optimizer:
39 | type: adam
40 | init_lr: 0.0001
41 | weight_decay: 0.0
42 | betas: [0.9, 0.999] #[0.9, 0.95]
43 | ft: false
44 | warmup:
45 | epoch: 0
46 | multiplier: 1
47 | buffer_epoch: 0
48 | min_lr: 0.0001
49 | mode: fix
50 | start_from_zero: True
51 | max_gn: null
52 |
53 | experiment:
54 | amp: True
55 | batch_size: 32
56 | total_batch_size: 64
57 | epochs: 400
58 | save_ckpt_freq: 20
59 | test_freq: 10
60 | test_imlog_freq: 10
61 |
62 |
--------------------------------------------------------------------------------
/models/gimm/configs/gimmvfi/gimmvfi_f_arb.yaml:
--------------------------------------------------------------------------------
1 | trainer: stage_inr
2 | dataset:
3 | type: vimeo_arb
4 | path: ./data/vimeo90k/vimeo_septuplet
5 | aug: true
6 |
7 | arch:
8 | type: gimmvfi_f
9 | ema: true
10 | modulated_layer_idxs: [1]
11 |
12 | coord_range: [-1., 1.]
13 |
14 | hyponet:
15 | type: mlp
16 | n_layer: 5 # including the output layer
17 | hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1]
18 | use_bias: true
19 | input_dim: 3
20 | output_dim: 2
21 | output_bias: 0.5
22 | activation:
23 | type: siren
24 | siren_w0: 1.0
25 | initialization:
26 | weight_init_type: siren
27 | bias_init_type: siren
28 |
29 | loss:
30 | subsample:
31 | type: random
32 | ratio: 0.1
33 |
34 | optimizer:
35 | type: adamw
36 | init_lr: 0.00008
37 | weight_decay: 0.00004
38 | betas: [0.9, 0.999]
39 | ft: true
40 | warmup:
41 | epoch: 1
42 | multiplier: 1
43 | buffer_epoch: 0
44 | min_lr: 0.000008
45 | mode: fix
46 | start_from_zero: True
47 | max_gn: null
48 |
49 | experiment:
50 | amp: True
51 | batch_size: 4
52 | total_batch_size: 32
53 | epochs: 60
54 | save_ckpt_freq: 10
55 | test_freq: 10
56 | test_imlog_freq: 10
57 |
58 |
--------------------------------------------------------------------------------
/models/gimm/configs/gimmvfi/gimmvfi_r_arb.yaml:
--------------------------------------------------------------------------------
1 | trainer: stage_inr
2 | dataset:
3 | type: vimeo_arb
4 | path: ./data/vimeo90k/vimeo_septuplet
5 | aug: true
6 |
7 | arch:
8 | type: gimmvfi_r
9 | ema: true
10 | modulated_layer_idxs: [1]
11 |
12 | coord_range: [-1., 1.]
13 |
14 | hyponet:
15 | type: mlp
16 | n_layer: 5 # including the output layer
17 | hidden_dim: [128] # list, assert len(hidden_dim) in [1, n_layers-1]
18 | use_bias: true
19 | input_dim: 3
20 | output_dim: 2
21 | output_bias: 0.5
22 | activation:
23 | type: siren
24 | siren_w0: 1.0
25 | initialization:
26 | weight_init_type: siren
27 | bias_init_type: siren
28 |
29 | loss:
30 | subsample:
31 | type: random
32 | ratio: 0.1
33 |
34 | optimizer:
35 | type: adamw
36 | init_lr: 0.00008
37 | weight_decay: 0.00004
38 | betas: [0.9, 0.999]
39 | ft: true
40 | warmup:
41 | epoch: 1
42 | multiplier: 1
43 | buffer_epoch: 0
44 | min_lr: 0.000008
45 | mode: fix
46 | start_from_zero: True
47 | max_gn: null
48 |
49 | experiment:
50 | amp: True
51 | batch_size: 4
52 | total_batch_size: 32
53 | epochs: 60
54 | save_ckpt_freq: 10
55 | test_freq: 10
56 | test_imlog_freq: 10
57 |
58 |
--------------------------------------------------------------------------------
/models/gimm/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | from .ema import ExponentialMovingAverage
12 | from .generalizable_INR import gimmvfi_f, gimmvfi_r, gimm
13 |
14 |
15 | def create_model(config, ema=False):
16 | model_type = config.type.lower()
17 | if model_type == "gimm":
18 | model = gimm(config)
19 | model_ema = gimm(config) if ema else None
20 | elif model_type == "gimmvfi_f":
21 | model = gimmvfi_f(config)
22 | model_ema = gimmvfi_f(config) if ema else None
23 | elif model_type == "gimmvfi_r":
24 | model = gimmvfi_r(config)
25 | model_ema = gimmvfi_r(config) if ema else None
26 | else:
27 | raise ValueError(f"{model_type} is invalid..")
28 |
29 | if ema:
30 | mu = config.ema
31 | if config.ema_value is not None:
32 | mu = config.ema_value
33 | model_ema = ExponentialMovingAverage(model_ema, mu)
34 | model_ema.eval()
35 | model_ema.update(model, step=-1)
36 |
37 | return model, model_ema
38 |
--------------------------------------------------------------------------------
/models/gimm/src/models/ema.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | import logging
12 | import torch
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class ExponentialMovingAverage(torch.nn.Module):
18 | def __init__(self, init_module, mu):
19 | super(ExponentialMovingAverage, self).__init__()
20 |
21 | self.module = init_module
22 | self.mu = mu
23 |
24 | def forward(self, x, *args, **kwargs):
25 | return self.module(x, *args, **kwargs)
26 |
27 | def update(self, module, step=None):
28 | if step is None or not isinstance(self.mu, bool):
29 | # print(['use ema value:', self.mu])
30 | mu = self.mu
31 | else:
32 | # see : https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/ExponentialMovingAverage?hl=PL
33 | mu = min(self.mu, (1.0 + step) / (10.0 + step))
34 |
35 | state_dict = {}
36 | with torch.no_grad():
37 | for (name, m1), (name2, m2) in zip(
38 | self.module.state_dict().items(), module.state_dict().items()
39 | ):
40 | if name != name2:
41 | logger.warning(
42 | "[ExpoentialMovingAverage] not matched keys %s, %s", name, name2
43 | )
44 |
45 | if step is not None and step < 0:
46 | state_dict[name] = m2.clone().detach()
47 | else:
48 | state_dict[name] = ((mu * m1) + ((1.0 - mu) * m2)).clone().detach()
49 |
50 | self.module.load_state_dict(state_dict)
51 |
52 | def compute_psnr(self, *args, **kwargs):
53 | return self.module.compute_psnr(*args, **kwargs)
54 |
55 | def get_recon_imgs(self, *args, **kwargs):
56 | return self.module.get_recon_imgs(*args, **kwargs)
57 |
58 | def sample_coord_input(self, *args, **kwargs):
59 | return self.module.sample_coord_input(*args, **kwargs)
60 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | from .gimm import GIMM
12 |
13 | from .gimmvfi_f import GIMMVFI_F
14 | from .gimmvfi_r import GIMMVFI_R
15 |
16 |
17 | def gimm(config):
18 | return GIMM(config)
19 |
20 |
21 | def gimmvfi_f(config):
22 | return GIMMVFI_F(config)
23 |
24 |
25 | def gimmvfi_r(config):
26 | return GIMMVFI_R(config)
27 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/configs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | from typing import List, Optional
12 | from dataclasses import dataclass
13 |
14 | from omegaconf import OmegaConf, MISSING
15 | from .modules.module_config import HypoNetConfig
16 |
17 |
18 | @dataclass
19 | class GIMMConfig:
20 | type: str = "gimm"
21 | ema: Optional[bool] = None
22 | ema_value: Optional[float] = None
23 | fwarp_type: str = "linear"
24 | hyponet: HypoNetConfig = HypoNetConfig()
25 | coord_range: List[float] = MISSING
26 | modulated_layer_idxs: Optional[List[int]] = None
27 |
28 | @classmethod
29 | def create(cls, config):
30 | # We need to specify the type of the default DataEncoderConfig.
31 | # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
32 | # hence merging with the config with other type would cause config error.
33 | defaults = OmegaConf.structured(cls(ema=False))
34 | config = OmegaConf.merge(defaults, config)
35 | return config
36 |
37 |
38 | @dataclass
39 | class GIMMVFIConfig:
40 | type: str = "gimmvfi"
41 | ema: Optional[bool] = None
42 | ema_value: Optional[float] = None
43 | fwarp_type: str = "linear"
44 | rec_weight: float = 0.1
45 | hyponet: HypoNetConfig = HypoNetConfig()
46 | raft_iter: int = 20
47 | coord_range: List[float] = MISSING
48 | modulated_layer_idxs: Optional[List[int]] = None
49 |
50 | @classmethod
51 | def create(cls, config):
52 | # We need to specify the type of the default DataEncoderConfig.
53 | # Otherwise, data_encoder will be initialized & structured as "unfold" type (which is default value)
54 | # hence merging with the config with other type would cause config error.
55 | defaults = OmegaConf.structured(cls(ema=False))
56 | config = OmegaConf.merge(defaults, config)
57 | return config
58 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .configs.submission import get_cfg
3 | from .core.FlowFormer import build_flowformer
4 |
5 |
6 | def initialize_Flowformer():
7 | cfg = get_cfg()
8 | model = build_flowformer(cfg)
9 |
10 | ckpt = torch.load(cfg.model, map_location="cpu")
11 |
12 | def convert(param):
13 | return {k.replace("module.", ""): v for k, v in param.items() if "module" in k}
14 |
15 | ckpt = convert(ckpt)
16 | model.load_state_dict(ckpt)
17 |
18 | return model
19 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/alt_cuda_corr/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | // CUDA forward declarations
5 | std::vector corr_cuda_forward(
6 | torch::Tensor fmap1,
7 | torch::Tensor fmap2,
8 | torch::Tensor coords,
9 | int radius);
10 |
11 | std::vector corr_cuda_backward(
12 | torch::Tensor fmap1,
13 | torch::Tensor fmap2,
14 | torch::Tensor coords,
15 | torch::Tensor corr_grad,
16 | int radius);
17 |
18 | // C++ interface
19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22 |
23 | std::vector corr_forward(
24 | torch::Tensor fmap1,
25 | torch::Tensor fmap2,
26 | torch::Tensor coords,
27 | int radius) {
28 | CHECK_INPUT(fmap1);
29 | CHECK_INPUT(fmap2);
30 | CHECK_INPUT(coords);
31 |
32 | return corr_cuda_forward(fmap1, fmap2, coords, radius);
33 | }
34 |
35 |
36 | std::vector corr_backward(
37 | torch::Tensor fmap1,
38 | torch::Tensor fmap2,
39 | torch::Tensor coords,
40 | torch::Tensor corr_grad,
41 | int radius) {
42 | CHECK_INPUT(fmap1);
43 | CHECK_INPUT(fmap2);
44 | CHECK_INPUT(coords);
45 | CHECK_INPUT(corr_grad);
46 |
47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48 | }
49 |
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &corr_forward, "CORR forward");
53 | m.def("backward", &corr_backward, "CORR backward");
54 | }
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/alt_cuda_corr/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | setup(
6 | name="correlation",
7 | ext_modules=[
8 | CUDAExtension(
9 | "alt_cuda_corr",
10 | sources=["correlation.cpp", "correlation_kernel.cu"],
11 | extra_compile_args={"cxx": [], "nvcc": ["-O3"]},
12 | ),
13 | ],
14 | cmdclass={"build_ext": BuildExtension},
15 | )
16 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/flowformer/assets/teaser.png
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/configs/default.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | _CN.name = "default"
6 | _CN.suffix = "arxiv2"
7 | _CN.gamma = 0.8
8 | _CN.max_flow = 400
9 | _CN.batch_size = 8
10 | _CN.sum_freq = 100
11 | _CN.val_freq = 5000
12 | _CN.image_size = [368, 496]
13 | _CN.add_noise = True
14 | _CN.critical_params = []
15 |
16 | _CN.transformer = "latentcostformer"
17 | _CN.restore_ckpt = None
18 |
19 | ###########################################
20 | # latentcostformer
21 | _CN.latentcostformer = CN()
22 | _CN.latentcostformer.pe = "linear"
23 | _CN.latentcostformer.dropout = 0.0
24 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
25 | _CN.latentcostformer.query_latent_dim = 64
26 | _CN.latentcostformer.cost_latent_input_dim = 64
27 | _CN.latentcostformer.cost_latent_token_num = 8
28 | _CN.latentcostformer.cost_latent_dim = 128
29 | _CN.latentcostformer.predictor_dim = 128
30 | _CN.latentcostformer.motion_feature_dim = 209 # use concat, so double query_latent_dim
31 | _CN.latentcostformer.arc_type = "transformer"
32 | _CN.latentcostformer.cost_heads_num = 1
33 | # encoder
34 | _CN.latentcostformer.pretrain = True
35 | _CN.latentcostformer.context_concat = False
36 | _CN.latentcostformer.encoder_depth = 3
37 | _CN.latentcostformer.feat_cross_attn = False
38 | _CN.latentcostformer.patch_size = 8
39 | _CN.latentcostformer.patch_embed = "single"
40 | _CN.latentcostformer.gma = True
41 | _CN.latentcostformer.rm_res = True
42 | _CN.latentcostformer.vert_c_dim = 64
43 | _CN.latentcostformer.cost_encoder_res = True
44 | _CN.latentcostformer.cnet = "twins"
45 | _CN.latentcostformer.fnet = "twins"
46 | _CN.latentcostformer.only_global = False
47 | _CN.latentcostformer.add_flow_token = True
48 | _CN.latentcostformer.use_mlp = False
49 | _CN.latentcostformer.vertical_conv = False
50 | # decoder
51 | _CN.latentcostformer.decoder_depth = 12
52 | _CN.latentcostformer.critical_params = [
53 | "cost_heads_num",
54 | "vert_c_dim",
55 | "cnet",
56 | "pretrain",
57 | "add_flow_token",
58 | "encoder_depth",
59 | "gma",
60 | "cost_encoder_res",
61 | ]
62 | ##########################################
63 |
64 | ### TRAINER
65 | _CN.trainer = CN()
66 | _CN.trainer.scheduler = "OneCycleLR"
67 |
68 | _CN.trainer.optimizer = "adamw"
69 | _CN.trainer.canonical_lr = 25e-5
70 | _CN.trainer.adamw_decay = 1e-4
71 | _CN.trainer.clip = 1.0
72 | _CN.trainer.num_steps = 120000
73 | _CN.trainer.epsilon = 1e-8
74 | _CN.trainer.anneal_strategy = "linear"
75 |
76 |
77 | def get_cfg():
78 | return _CN.clone()
79 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/configs/kitti.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | _CN.name = "kitti"
6 | _CN.suffix = "kitti"
7 | _CN.gamma = 0.85
8 | _CN.max_flow = 400
9 | _CN.batch_size = 6
10 | _CN.sum_freq = 100
11 | _CN.val_freq = 499999999
12 | _CN.image_size = [432, 960]
13 | _CN.add_noise = True
14 | _CN.critical_params = []
15 |
16 | _CN.transformer = "latentcostformer"
17 |
18 | _CN.model = None
19 |
20 | _CN.restore_ckpt = "checkpoints/sintel.pth"
21 |
22 | # latentcostformer
23 | _CN.latentcostformer = CN()
24 | _CN.latentcostformer.pe = "linear"
25 | _CN.latentcostformer.dropout = 0.0
26 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
27 | _CN.latentcostformer.query_latent_dim = 64
28 | _CN.latentcostformer.cost_latent_input_dim = 64
29 | _CN.latentcostformer.cost_latent_token_num = 8
30 | _CN.latentcostformer.cost_latent_dim = 128
31 | _CN.latentcostformer.predictor_dim = 128
32 | _CN.latentcostformer.motion_feature_dim = 209 # use concat, so double query_latent_dim
33 | _CN.latentcostformer.arc_type = "transformer"
34 | _CN.latentcostformer.cost_heads_num = 1
35 | # encoder
36 | _CN.latentcostformer.pretrain = True
37 | _CN.latentcostformer.context_concat = False
38 | _CN.latentcostformer.encoder_depth = 3
39 | _CN.latentcostformer.feat_cross_attn = False
40 | _CN.latentcostformer.vertical_encoder_attn = "twins"
41 | _CN.latentcostformer.patch_size = 8
42 | _CN.latentcostformer.patch_embed = "single"
43 | _CN.latentcostformer.gma = "GMA"
44 | _CN.latentcostformer.rm_res = True
45 | _CN.latentcostformer.vert_c_dim = 64
46 | _CN.latentcostformer.cost_encoder_res = True
47 | _CN.latentcostformer.pwc_aug = False
48 | _CN.latentcostformer.cnet = "twins"
49 | _CN.latentcostformer.fnet = "twins"
50 | _CN.latentcostformer.no_sc = False
51 | _CN.latentcostformer.use_rpe = False
52 | _CN.latentcostformer.only_global = False
53 | _CN.latentcostformer.add_flow_token = True
54 | _CN.latentcostformer.use_mlp = False
55 | _CN.latentcostformer.vertical_conv = False
56 | # decoder
57 | _CN.latentcostformer.decoder_depth = 12
58 | _CN.latentcostformer.critical_params = [
59 | "cost_heads_num",
60 | "vert_c_dim",
61 | "cnet",
62 | "pretrain",
63 | "add_flow_token",
64 | "encoder_depth",
65 | "gma",
66 | "cost_encoder_res",
67 | ]
68 |
69 |
70 | ### TRAINER
71 | _CN.trainer = CN()
72 | _CN.trainer.scheduler = "OneCycleLR"
73 | _CN.trainer.optimizer = "adamw"
74 | _CN.trainer.canonical_lr = 12.5e-5
75 | _CN.trainer.adamw_decay = 1e-5
76 | _CN.trainer.clip = 1.0
77 | _CN.trainer.num_steps = 50000
78 | _CN.trainer.epsilon = 1e-8
79 | _CN.trainer.anneal_strategy = "linear"
80 |
81 |
82 | def get_cfg():
83 | return _CN.clone()
84 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/configs/sintel.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | _CN.name = "default"
6 | _CN.suffix = "sintel"
7 | _CN.gamma = 0.85
8 | _CN.max_flow = 400
9 | _CN.batch_size = 6
10 | _CN.sum_freq = 100
11 | _CN.val_freq = 5000000
12 | _CN.image_size = [432, 960]
13 | _CN.add_noise = True
14 | _CN.critical_params = []
15 |
16 | _CN.transformer = "latentcostformer"
17 | _CN.restore_ckpt = "checkpoints/things.pth"
18 |
19 | # latentcostformer
20 | _CN.latentcostformer = CN()
21 | _CN.latentcostformer.pe = "linear"
22 | _CN.latentcostformer.dropout = 0.0
23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
24 | _CN.latentcostformer.query_latent_dim = 64
25 | _CN.latentcostformer.cost_latent_input_dim = 64
26 | _CN.latentcostformer.cost_latent_token_num = 8
27 | _CN.latentcostformer.cost_latent_dim = 128
28 | _CN.latentcostformer.arc_type = "transformer"
29 | _CN.latentcostformer.cost_heads_num = 1
30 | # encoder
31 | _CN.latentcostformer.pretrain = True
32 | _CN.latentcostformer.context_concat = False
33 | _CN.latentcostformer.encoder_depth = 3
34 | _CN.latentcostformer.feat_cross_attn = False
35 | _CN.latentcostformer.patch_size = 8
36 | _CN.latentcostformer.patch_embed = "single"
37 | _CN.latentcostformer.no_pe = False
38 | _CN.latentcostformer.gma = "GMA"
39 | _CN.latentcostformer.kernel_size = 9
40 | _CN.latentcostformer.rm_res = True
41 | _CN.latentcostformer.vert_c_dim = 64
42 | _CN.latentcostformer.cost_encoder_res = True
43 | _CN.latentcostformer.cnet = "twins"
44 | _CN.latentcostformer.fnet = "twins"
45 | _CN.latentcostformer.no_sc = False
46 | _CN.latentcostformer.only_global = False
47 | _CN.latentcostformer.add_flow_token = True
48 | _CN.latentcostformer.use_mlp = False
49 | _CN.latentcostformer.vertical_conv = False
50 |
51 | # decoder
52 | _CN.latentcostformer.decoder_depth = 12
53 | _CN.latentcostformer.critical_params = [
54 | "cost_heads_num",
55 | "vert_c_dim",
56 | "cnet",
57 | "pretrain",
58 | "add_flow_token",
59 | "encoder_depth",
60 | "gma",
61 | "cost_encoder_res",
62 | ]
63 |
64 | ### TRAINER
65 | _CN.trainer = CN()
66 | _CN.trainer.scheduler = "OneCycleLR"
67 | _CN.trainer.optimizer = "adamw"
68 | _CN.trainer.canonical_lr = 12.5e-5
69 | _CN.trainer.adamw_decay = 1e-5
70 | _CN.trainer.clip = 1.0
71 | _CN.trainer.num_steps = 120000
72 | _CN.trainer.epsilon = 1e-8
73 | _CN.trainer.anneal_strategy = "linear"
74 |
75 |
76 | def get_cfg():
77 | return _CN.clone()
78 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/configs/small_things_eval.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | _CN.name = ""
6 | _CN.suffix = ""
7 | _CN.gamma = 0.8
8 | _CN.max_flow = 400
9 | _CN.batch_size = 6
10 | _CN.sum_freq = 100
11 | _CN.val_freq = 5000000
12 | _CN.image_size = [432, 960]
13 | _CN.add_noise = False
14 | _CN.critical_params = []
15 |
16 | _CN.transformer = "latentcostformer"
17 | _CN.model = "checkpoints/flowformer-small/things.pth"
18 |
19 | # latentcostformer
20 | _CN.latentcostformer = CN()
21 | _CN.latentcostformer.pe = "linear"
22 | _CN.latentcostformer.dropout = 0.0
23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
24 | _CN.latentcostformer.query_latent_dim = 64
25 | _CN.latentcostformer.cost_latent_input_dim = 64
26 | _CN.latentcostformer.cost_latent_token_num = 4
27 | _CN.latentcostformer.cost_latent_dim = 32
28 | _CN.latentcostformer.arc_type = "transformer"
29 | _CN.latentcostformer.cost_heads_num = 1
30 | # encoder
31 | _CN.latentcostformer.pretrain = True
32 | _CN.latentcostformer.context_concat = False
33 | _CN.latentcostformer.encoder_depth = 1
34 | _CN.latentcostformer.feat_cross_attn = False
35 | _CN.latentcostformer.patch_size = 8
36 | _CN.latentcostformer.patch_embed = "single"
37 | _CN.latentcostformer.no_pe = False
38 | _CN.latentcostformer.gma = "GMA"
39 | _CN.latentcostformer.kernel_size = 9
40 | _CN.latentcostformer.rm_res = True
41 | _CN.latentcostformer.vert_c_dim = 0
42 | _CN.latentcostformer.cost_encoder_res = True
43 | _CN.latentcostformer.cnet = "basicencoder"
44 | _CN.latentcostformer.fnet = "basicencoder"
45 | _CN.latentcostformer.no_sc = False
46 | _CN.latentcostformer.only_global = False
47 | _CN.latentcostformer.add_flow_token = True
48 | _CN.latentcostformer.use_mlp = False
49 | _CN.latentcostformer.vertical_conv = False
50 |
51 | # decoder
52 | _CN.latentcostformer.decoder_depth = 32
53 | _CN.latentcostformer.critical_params = [
54 | "cost_heads_num",
55 | "vert_c_dim",
56 | "cnet",
57 | "pretrain",
58 | "add_flow_token",
59 | "encoder_depth",
60 | "gma",
61 | "cost_encoder_res",
62 | ]
63 |
64 | ### TRAINER
65 | _CN.trainer = CN()
66 | _CN.trainer.scheduler = "OneCycleLR"
67 | _CN.trainer.optimizer = "adamw"
68 | _CN.trainer.canonical_lr = 12.5e-5
69 | _CN.trainer.adamw_decay = 1e-4
70 | _CN.trainer.clip = 1.0
71 | _CN.trainer.num_steps = 120000
72 | _CN.trainer.epsilon = 1e-8
73 | _CN.trainer.anneal_strategy = "linear"
74 |
75 |
76 | def get_cfg():
77 | return _CN.clone()
78 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/configs/submission.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | _CN.name = ""
6 | _CN.suffix = ""
7 | _CN.gamma = 0.8
8 | _CN.max_flow = 400
9 | _CN.batch_size = 6
10 | _CN.sum_freq = 100
11 | _CN.val_freq = 5000000
12 | _CN.image_size = [432, 960]
13 | _CN.add_noise = False
14 | _CN.critical_params = []
15 |
16 | _CN.transformer = "latentcostformer"
17 | _CN.model = r"E:\Work\VFI\Algorithm\GIMM-VFI\pretrained_ckpt\flowformer_sintel.pth"
18 |
19 | # latentcostformer
20 | _CN.latentcostformer = CN()
21 | _CN.latentcostformer.pe = "linear"
22 | _CN.latentcostformer.dropout = 0.0
23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
24 | _CN.latentcostformer.query_latent_dim = 64
25 | _CN.latentcostformer.cost_latent_input_dim = 64
26 | _CN.latentcostformer.cost_latent_token_num = 8
27 | _CN.latentcostformer.cost_latent_dim = 128
28 | _CN.latentcostformer.arc_type = "transformer"
29 | _CN.latentcostformer.cost_heads_num = 1
30 | # encoder
31 | _CN.latentcostformer.pretrain = True
32 | _CN.latentcostformer.context_concat = False
33 | _CN.latentcostformer.encoder_depth = 3
34 | _CN.latentcostformer.feat_cross_attn = False
35 | _CN.latentcostformer.patch_size = 8
36 | _CN.latentcostformer.patch_embed = "single"
37 | _CN.latentcostformer.no_pe = False
38 | _CN.latentcostformer.gma = "GMA"
39 | _CN.latentcostformer.kernel_size = 9
40 | _CN.latentcostformer.rm_res = True
41 | _CN.latentcostformer.vert_c_dim = 64
42 | _CN.latentcostformer.cost_encoder_res = True
43 | _CN.latentcostformer.cnet = "twins"
44 | _CN.latentcostformer.fnet = "twins"
45 | _CN.latentcostformer.no_sc = False
46 | _CN.latentcostformer.only_global = False
47 | _CN.latentcostformer.add_flow_token = True
48 | _CN.latentcostformer.use_mlp = False
49 | _CN.latentcostformer.vertical_conv = False
50 |
51 | # decoder
52 | _CN.latentcostformer.decoder_depth = 32
53 | _CN.latentcostformer.critical_params = [
54 | "cost_heads_num",
55 | "vert_c_dim",
56 | "cnet",
57 | "pretrain",
58 | "add_flow_token",
59 | "encoder_depth",
60 | "gma",
61 | "cost_encoder_res",
62 | ]
63 |
64 | ### TRAINER
65 | _CN.trainer = CN()
66 | _CN.trainer.scheduler = "OneCycleLR"
67 | _CN.trainer.optimizer = "adamw"
68 | _CN.trainer.canonical_lr = 12.5e-5
69 | _CN.trainer.adamw_decay = 1e-4
70 | _CN.trainer.clip = 1.0
71 | _CN.trainer.num_steps = 120000
72 | _CN.trainer.epsilon = 1e-8
73 | _CN.trainer.anneal_strategy = "linear"
74 |
75 |
76 | def get_cfg():
77 | return _CN.clone()
78 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/configs/things.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | _CN.name = ""
6 | _CN.suffix = ""
7 | _CN.gamma = 0.8
8 | _CN.max_flow = 400
9 | _CN.batch_size = 6
10 | _CN.sum_freq = 100
11 | _CN.val_freq = 5000000
12 | _CN.image_size = [432, 960]
13 | _CN.add_noise = True
14 | _CN.critical_params = []
15 |
16 | _CN.transformer = "latentcostformer"
17 | _CN.restore_ckpt = "checkpoints/chairs.pth"
18 |
19 | #######################################
20 | _CN.latentcostformer = CN()
21 | _CN.latentcostformer.pe = "linear"
22 | _CN.latentcostformer.dropout = 0.0
23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
24 | _CN.latentcostformer.query_latent_dim = 64
25 | _CN.latentcostformer.cost_latent_input_dim = 64
26 | _CN.latentcostformer.cost_latent_token_num = 8
27 | _CN.latentcostformer.cost_latent_dim = 128
28 | _CN.latentcostformer.cost_heads_num = 1
29 | # encoder
30 | _CN.latentcostformer.pretrain = True
31 | _CN.latentcostformer.context_concat = False
32 | _CN.latentcostformer.encoder_depth = 3
33 | _CN.latentcostformer.feat_cross_attn = False
34 | _CN.latentcostformer.nat_rep = "abs"
35 | _CN.latentcostformer.patch_size = 8
36 | _CN.latentcostformer.patch_embed = "single"
37 | _CN.latentcostformer.no_pe = False
38 | _CN.latentcostformer.gma = "GMA"
39 | _CN.latentcostformer.kernel_size = 9
40 | _CN.latentcostformer.rm_res = True
41 | _CN.latentcostformer.vert_c_dim = 64
42 | _CN.latentcostformer.cost_encoder_res = True
43 | _CN.latentcostformer.cnet = "twins"
44 | _CN.latentcostformer.fnet = "twins"
45 | _CN.latentcostformer.only_global = False
46 | _CN.latentcostformer.add_flow_token = True
47 | _CN.latentcostformer.use_mlp = False
48 | _CN.latentcostformer.vertical_conv = False
49 |
50 | # decoder
51 | _CN.latentcostformer.decoder_depth = 12
52 | _CN.latentcostformer.critical_params = [
53 | "cost_heads_num",
54 | "vert_c_dim",
55 | "cnet",
56 | "pretrain",
57 | "add_flow_token",
58 | "encoder_depth",
59 | "gma",
60 | "cost_encoder_res",
61 | ]
62 |
63 | ### TRAINER
64 | _CN.trainer = CN()
65 | _CN.trainer.scheduler = "OneCycleLR"
66 | _CN.trainer.optimizer = "adamw"
67 | _CN.trainer.canonical_lr = 12.5e-5
68 | _CN.trainer.adamw_decay = 1e-4
69 | _CN.trainer.clip = 1.0
70 | _CN.trainer.num_steps = 120000
71 | _CN.trainer.epsilon = 1e-8
72 | _CN.trainer.anneal_strategy = "linear"
73 |
74 |
75 | def get_cfg():
76 | return _CN.clone()
77 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/configs/things_eval.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | _CN.name = ""
6 | _CN.suffix = ""
7 | _CN.gamma = 0.8
8 | _CN.max_flow = 400
9 | _CN.batch_size = 6
10 | _CN.sum_freq = 100
11 | _CN.val_freq = 5000000
12 | _CN.image_size = [432, 960]
13 | _CN.add_noise = False
14 | _CN.critical_params = []
15 |
16 | _CN.transformer = "latentcostformer"
17 | _CN.model = "checkpoints/things.pth"
18 |
19 | # latentcostformer
20 | _CN.latentcostformer = CN()
21 | _CN.latentcostformer.pe = "linear"
22 | _CN.latentcostformer.dropout = 0.0
23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
24 | _CN.latentcostformer.query_latent_dim = 64
25 | _CN.latentcostformer.cost_latent_input_dim = 64
26 | _CN.latentcostformer.cost_latent_token_num = 8
27 | _CN.latentcostformer.cost_latent_dim = 128
28 | _CN.latentcostformer.arc_type = "transformer"
29 | _CN.latentcostformer.cost_heads_num = 1
30 | # encoder
31 | _CN.latentcostformer.pretrain = True
32 | _CN.latentcostformer.context_concat = False
33 | _CN.latentcostformer.encoder_depth = 3
34 | _CN.latentcostformer.feat_cross_attn = False
35 | _CN.latentcostformer.patch_size = 8
36 | _CN.latentcostformer.patch_embed = "single"
37 | _CN.latentcostformer.no_pe = False
38 | _CN.latentcostformer.gma = "GMA"
39 | _CN.latentcostformer.kernel_size = 9
40 | _CN.latentcostformer.rm_res = True
41 | _CN.latentcostformer.vert_c_dim = 64
42 | _CN.latentcostformer.cost_encoder_res = True
43 | _CN.latentcostformer.cnet = "twins"
44 | _CN.latentcostformer.fnet = "twins"
45 | _CN.latentcostformer.no_sc = False
46 | _CN.latentcostformer.only_global = False
47 | _CN.latentcostformer.add_flow_token = True
48 | _CN.latentcostformer.use_mlp = False
49 | _CN.latentcostformer.vertical_conv = False
50 |
51 | # decoder
52 | _CN.latentcostformer.decoder_depth = 32
53 | _CN.latentcostformer.critical_params = [
54 | "cost_heads_num",
55 | "vert_c_dim",
56 | "cnet",
57 | "pretrain",
58 | "add_flow_token",
59 | "encoder_depth",
60 | "gma",
61 | "cost_encoder_res",
62 | ]
63 |
64 | ### TRAINER
65 | _CN.trainer = CN()
66 | _CN.trainer.scheduler = "OneCycleLR"
67 | _CN.trainer.optimizer = "adamw"
68 | _CN.trainer.canonical_lr = 12.5e-5
69 | _CN.trainer.adamw_decay = 1e-4
70 | _CN.trainer.clip = 1.0
71 | _CN.trainer.num_steps = 120000
72 | _CN.trainer.epsilon = 1e-8
73 | _CN.trainer.anneal_strategy = "linear"
74 |
75 |
76 | def get_cfg():
77 | return _CN.clone()
78 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/configs/things_flowformer_sharp.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | _CN.name = ""
6 | _CN.suffix = ""
7 | _CN.gamma = 0.8
8 | _CN.max_flow = 400
9 | _CN.batch_size = 6
10 | _CN.sum_freq = 100
11 | _CN.val_freq = 5000000
12 | _CN.image_size = [400, 720]
13 | _CN.add_noise = True
14 | _CN.critical_params = []
15 |
16 | _CN.transformer = "latentcostformer"
17 | _CN.restore_ckpt = "checkpoints/chairs.pth"
18 |
19 | #######################################
20 | _CN.latentcostformer = CN()
21 | _CN.latentcostformer.pe = "linear"
22 | _CN.latentcostformer.dropout = 0.0
23 | _CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256
24 | _CN.latentcostformer.query_latent_dim = 64
25 | _CN.latentcostformer.cost_latent_input_dim = 64
26 | _CN.latentcostformer.cost_latent_token_num = 8
27 | _CN.latentcostformer.cost_latent_dim = 128
28 | _CN.latentcostformer.cost_heads_num = 1
29 | # encoder
30 | _CN.latentcostformer.pretrain = True
31 | _CN.latentcostformer.context_concat = False
32 | _CN.latentcostformer.encoder_depth = 3
33 | _CN.latentcostformer.feat_cross_attn = False
34 | _CN.latentcostformer.nat_rep = "abs"
35 | _CN.latentcostformer.patch_size = 8
36 | _CN.latentcostformer.patch_embed = "single"
37 | _CN.latentcostformer.no_pe = False
38 | _CN.latentcostformer.gma = "GMA"
39 | _CN.latentcostformer.kernel_size = 9
40 | _CN.latentcostformer.rm_res = True
41 | _CN.latentcostformer.vert_c_dim = 64
42 | _CN.latentcostformer.cost_encoder_res = True
43 | _CN.latentcostformer.cnet = "twins"
44 | _CN.latentcostformer.fnet = "twins"
45 | _CN.latentcostformer.only_global = False
46 | _CN.latentcostformer.add_flow_token = True
47 | _CN.latentcostformer.use_mlp = False
48 | _CN.latentcostformer.vertical_conv = False
49 |
50 | # decoder
51 | _CN.latentcostformer.decoder_depth = 12
52 | _CN.latentcostformer.critical_params = [
53 | "cost_heads_num",
54 | "vert_c_dim",
55 | "cnet",
56 | "pretrain",
57 | "add_flow_token",
58 | "encoder_depth",
59 | "gma",
60 | "cost_encoder_res",
61 | ]
62 |
63 | ### TRAINER
64 | _CN.trainer = CN()
65 | _CN.trainer.scheduler = "OneCycleLR"
66 | _CN.trainer.optimizer = "adamw"
67 | _CN.trainer.canonical_lr = 12.5e-5
68 | _CN.trainer.adamw_decay = 1e-4
69 | _CN.trainer.clip = 1.0
70 | _CN.trainer.num_steps = 120000
71 | _CN.trainer.epsilon = 1e-8
72 | _CN.trainer.anneal_strategy = "linear"
73 |
74 |
75 | def get_cfg():
76 | return _CN.clone()
77 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/__init__.py
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/convnext.py:
--------------------------------------------------------------------------------
1 | from turtle import forward
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 |
8 | class ConvNextLayer(nn.Module):
9 | def __init__(self, dim, depth=4):
10 | super().__init__()
11 | self.net = nn.Sequential(*[ConvNextBlock(dim=dim) for j in range(depth)])
12 |
13 | def forward(self, x):
14 | return self.net(x)
15 |
16 | def compute_params(self):
17 | num = 0
18 | for param in self.parameters():
19 | num += np.prod(param.size())
20 |
21 | return num
22 |
23 |
24 | class ConvNextBlock(nn.Module):
25 | r"""ConvNeXt Block. There are two equivalent implementations:
26 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
27 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
28 | We use (2) as we find it slightly faster in PyTorch
29 |
30 | Args:
31 | dim (int): Number of input channels.
32 | drop_path (float): Stochastic depth rate. Default: 0.0
33 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
34 | """
35 |
36 | def __init__(self, dim, layer_scale_init_value=1e-6):
37 | super().__init__()
38 | self.dwconv = nn.Conv2d(
39 | dim, dim, kernel_size=7, padding=3, groups=dim
40 | ) # depthwise conv
41 | self.norm = LayerNorm(dim, eps=1e-6)
42 | self.pwconv1 = nn.Linear(
43 | dim, 4 * dim
44 | ) # pointwise/1x1 convs, implemented with linear layers
45 | self.act = nn.GELU()
46 | self.pwconv2 = nn.Linear(4 * dim, dim)
47 | self.gamma = (
48 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
49 | if layer_scale_init_value > 0
50 | else None
51 | )
52 | # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
53 | # print(f"conv next layer")
54 |
55 | def forward(self, x):
56 | input = x
57 | x = self.dwconv(x)
58 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
59 | x = self.norm(x)
60 | x = self.pwconv1(x)
61 | x = self.act(x)
62 | x = self.pwconv2(x)
63 | if self.gamma is not None:
64 | x = self.gamma * x
65 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
66 |
67 | x = input + x
68 | return x
69 |
70 |
71 | class LayerNorm(nn.Module):
72 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
73 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
74 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
75 | with shape (batch_size, channels, height, width).
76 | """
77 |
78 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
79 | super().__init__()
80 | self.weight = nn.Parameter(torch.ones(normalized_shape))
81 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
82 | self.eps = eps
83 | self.data_format = data_format
84 | if self.data_format not in ["channels_last", "channels_first"]:
85 | raise NotImplementedError
86 | self.normalized_shape = (normalized_shape,)
87 |
88 | def forward(self, x):
89 | if self.data_format == "channels_last":
90 | return F.layer_norm(
91 | x, self.normalized_shape, self.weight, self.bias, self.eps
92 | )
93 | elif self.data_format == "channels_first":
94 | u = x.mean(1, keepdim=True)
95 | s = (x - u).pow(2).mean(1, keepdim=True)
96 | x = (x - u) / torch.sqrt(s + self.eps)
97 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
98 | return x
99 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/gma.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange
4 |
5 |
6 | class RelPosEmb(nn.Module):
7 | def __init__(self, max_pos_size, dim_head):
8 | super().__init__()
9 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head)
10 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head)
11 |
12 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(
13 | max_pos_size
14 | ).view(-1, 1)
15 | rel_ind = deltas + max_pos_size - 1
16 | self.register_buffer("rel_ind", rel_ind)
17 |
18 | def forward(self, q):
19 | batch, heads, h, w, c = q.shape
20 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1))
21 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1))
22 |
23 | height_emb = rearrange(height_emb, "(x u) d -> x u () d", x=h)
24 | width_emb = rearrange(width_emb, "(y v) d -> y () v d", y=w)
25 |
26 | height_score = einsum("b h x y d, x u v d -> b h x y u v", q, height_emb)
27 | width_score = einsum("b h x y d, y u v d -> b h x y u v", q, width_emb)
28 |
29 | return height_score + width_score
30 |
31 |
32 | class Attention(nn.Module):
33 | def __init__(
34 | self,
35 | *,
36 | args,
37 | dim,
38 | max_pos_size=100,
39 | heads=4,
40 | dim_head=128,
41 | ):
42 | super().__init__()
43 | self.args = args
44 | self.heads = heads
45 | self.scale = dim_head**-0.5
46 | inner_dim = heads * dim_head
47 |
48 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
49 |
50 | self.pos_emb = RelPosEmb(max_pos_size, dim_head)
51 | for param in self.pos_emb.parameters():
52 | param.requires_grad = False
53 |
54 | def forward(self, fmap):
55 | heads, b, c, h, w = self.heads, *fmap.shape
56 |
57 | q, k = self.to_qk(fmap).chunk(2, dim=1)
58 |
59 | q, k = map(lambda t: rearrange(t, "b (h d) x y -> b h x y d", h=heads), (q, k))
60 | q = self.scale * q
61 |
62 | # if self.args.position_only:
63 | # sim = self.pos_emb(q)
64 |
65 | # elif self.args.position_and_content:
66 | # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k)
67 | # sim_pos = self.pos_emb(q)
68 | # sim = sim_content + sim_pos
69 |
70 | # else:
71 | sim = einsum("b h x y d, b h u v d -> b h x y u v", q, k)
72 |
73 | sim = rearrange(sim, "b h x y u v -> b h (x y) (u v)")
74 | attn = sim.softmax(dim=-1)
75 |
76 | return attn
77 |
78 |
79 | class Aggregate(nn.Module):
80 | def __init__(
81 | self,
82 | args,
83 | dim,
84 | heads=4,
85 | dim_head=128,
86 | ):
87 | super().__init__()
88 | self.args = args
89 | self.heads = heads
90 | self.scale = dim_head**-0.5
91 | inner_dim = heads * dim_head
92 |
93 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False)
94 |
95 | self.gamma = nn.Parameter(torch.zeros(1))
96 |
97 | if dim != inner_dim:
98 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False)
99 | else:
100 | self.project = None
101 |
102 | def forward(self, attn, fmap):
103 | heads, b, c, h, w = self.heads, *fmap.shape
104 |
105 | v = self.to_v(fmap)
106 | v = rearrange(v, "b (h d) x y -> b h (x y) d", h=heads)
107 | out = einsum("b h i j, b h j d -> b h i d", attn, v)
108 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
109 |
110 | if self.project is not None:
111 | out = self.project(out)
112 |
113 | out = fmap + self.gamma * out
114 |
115 | return out
116 |
117 |
118 | if __name__ == "__main__":
119 | att = Attention(dim=128, heads=1)
120 | fmap = torch.randn(2, 128, 40, 90)
121 | out = att(fmap)
122 |
123 | print(out.shape)
124 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/mlpmixer.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from einops.layers.torch import Rearrange, Reduce
3 | from functools import partial
4 | import numpy as np
5 |
6 |
7 | class PreNormResidual(nn.Module):
8 | def __init__(self, dim, fn):
9 | super().__init__()
10 | self.fn = fn
11 | self.norm = nn.LayerNorm(dim)
12 |
13 | def forward(self, x):
14 | return self.fn(self.norm(x)) + x
15 |
16 |
17 | def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear):
18 | return nn.Sequential(
19 | dense(dim, dim * expansion_factor),
20 | nn.GELU(),
21 | nn.Dropout(dropout),
22 | dense(dim * expansion_factor, dim),
23 | nn.Dropout(dropout),
24 | )
25 |
26 |
27 | class MLPMixerLayer(nn.Module):
28 | def __init__(self, dim, cfg, drop_path=0.0, dropout=0.0):
29 | super(MLPMixerLayer, self).__init__()
30 |
31 | # print(f"use mlp mixer layer")
32 | K = cfg.cost_latent_token_num
33 | expansion_factor = cfg.mlp_expansion_factor
34 | chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
35 |
36 | self.mlpmixer = nn.Sequential(
37 | PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)),
38 | PreNormResidual(
39 | dim, FeedForward(dim, expansion_factor, dropout, chan_last)
40 | ),
41 | )
42 |
43 | def compute_params(self):
44 | num = 0
45 | for param in self.mlpmixer.parameters():
46 | num += np.prod(param.size())
47 |
48 | return num
49 |
50 | def forward(self, x):
51 | """
52 | x: [BH1W1, K, D]
53 | """
54 |
55 | return self.mlpmixer(x)
56 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/LatentCostFormer/transformer.py:
--------------------------------------------------------------------------------
1 | import loguru
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import einsum
6 |
7 | from einops.layers.torch import Rearrange
8 | from einops import rearrange
9 |
10 | from ...utils.utils import coords_grid, bilinear_sampler, upflow8
11 | from ..common import (
12 | FeedForward,
13 | pyramid_retrieve_tokens,
14 | sampler,
15 | sampler_gaussian_fix,
16 | retrieve_tokens,
17 | MultiHeadAttention,
18 | MLP,
19 | )
20 | from ..encoders import twins_svt_large_context, twins_svt_large
21 | from ...position_encoding import PositionEncodingSine, LinearPositionEncoding
22 | from .twins import PosConv
23 | from .encoder import MemoryEncoder
24 | from .decoder import MemoryDecoder
25 | from .cnn import BasicEncoder
26 |
27 |
28 | class FlowFormer(nn.Module):
29 | def __init__(self, cfg):
30 | super(FlowFormer, self).__init__()
31 | self.cfg = cfg
32 |
33 | self.memory_encoder = MemoryEncoder(cfg)
34 | self.memory_decoder = MemoryDecoder(cfg)
35 | if cfg.cnet == "twins":
36 | self.context_encoder = twins_svt_large(pretrained=self.cfg.pretrain)
37 | elif cfg.cnet == "basicencoder":
38 | self.context_encoder = BasicEncoder(output_dim=256, norm_fn="instance")
39 |
40 | def build_coord(self, img):
41 | N, C, H, W = img.shape
42 | coords = coords_grid(N, H // 8, W // 8)
43 | return coords
44 |
45 |
46 | def forward(
47 | self, image1, image2, output=None, flow_init=None, return_feat=False, iters=None
48 | ):
49 | # Following https://github.com/princeton-vl/RAFT/
50 | image1 = 2 * (image1 / 255.0) - 1.0
51 | image2 = 2 * (image2 / 255.0) - 1.0
52 |
53 | data = {}
54 |
55 | if self.cfg.context_concat:
56 | context = self.context_encoder(torch.cat([image1, image2], dim=1))
57 | else:
58 | if return_feat:
59 | context, cfeat = self.context_encoder(image1, return_feat=return_feat)
60 | else:
61 | context = self.context_encoder(image1)
62 | if return_feat:
63 | cost_memory, ffeat = self.memory_encoder(
64 | image1, image2, data, context, return_feat=return_feat
65 | )
66 | else:
67 | cost_memory = self.memory_encoder(image1, image2, data, context)
68 |
69 | flow_predictions = self.memory_decoder(
70 | cost_memory, context, data, flow_init=flow_init, iters=iters
71 | )
72 |
73 | if return_feat:
74 | return flow_predictions, cfeat, ffeat
75 | return flow_predictions
76 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def build_flowformer(cfg):
5 | name = cfg.transformer
6 | if name == "latentcostformer":
7 | from .LatentCostFormer.transformer import FlowFormer
8 | else:
9 | raise ValueError(f"FlowFormer = {name} is not a valid architecture!")
10 |
11 | return FlowFormer(cfg[name])
12 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/FlowFormer/encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | import numpy as np
5 |
6 |
7 | class twins_svt_large(nn.Module):
8 | def __init__(self, pretrained=True):
9 | super().__init__()
10 | self.svt = timm.create_model("twins_svt_large", pretrained=pretrained)
11 |
12 | del self.svt.head
13 | del self.svt.patch_embeds[2]
14 | del self.svt.patch_embeds[2]
15 | del self.svt.blocks[2]
16 | del self.svt.blocks[2]
17 | del self.svt.pos_block[2]
18 | del self.svt.pos_block[2]
19 | self.svt.norm.weight.requires_grad = False
20 | self.svt.norm.bias.requires_grad = False
21 |
22 | def forward(self, x, data=None, layer=2, return_feat=False):
23 | B = x.shape[0]
24 | if return_feat:
25 | feat = []
26 | for i, (embed, drop, blocks, pos_blk) in enumerate(
27 | zip(
28 | self.svt.patch_embeds,
29 | self.svt.pos_drops,
30 | self.svt.blocks,
31 | self.svt.pos_block,
32 | )
33 | ):
34 | x, size = embed(x)
35 | x = drop(x)
36 | for j, blk in enumerate(blocks):
37 | x = blk(x, size)
38 | if j == 0:
39 | x = pos_blk(x, size)
40 | if i < len(self.svt.depths) - 1:
41 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
42 | if return_feat:
43 | feat.append(x)
44 | if i == layer - 1:
45 | break
46 | if return_feat:
47 | return x, feat
48 | return x
49 |
50 | def compute_params(self, layer=2):
51 | num = 0
52 | for i, (embed, drop, blocks, pos_blk) in enumerate(
53 | zip(
54 | self.svt.patch_embeds,
55 | self.svt.pos_drops,
56 | self.svt.blocks,
57 | self.svt.pos_block,
58 | )
59 | ):
60 | for param in embed.parameters():
61 | num += np.prod(param.size())
62 |
63 | for param in drop.parameters():
64 | num += np.prod(param.size())
65 |
66 | for param in blocks.parameters():
67 | num += np.prod(param.size())
68 |
69 | for param in pos_blk.parameters():
70 | num += np.prod(param.size())
71 |
72 | if i == layer - 1:
73 | break
74 |
75 | for param in self.svt.head.parameters():
76 | num += np.prod(param.size())
77 |
78 | return num
79 |
80 |
81 | class twins_svt_large_context(nn.Module):
82 | def __init__(self, pretrained=True):
83 | super().__init__()
84 | self.svt = timm.create_model("twins_svt_large_context", pretrained=pretrained)
85 |
86 | def forward(self, x, data=None, layer=2):
87 | B = x.shape[0]
88 | for i, (embed, drop, blocks, pos_blk) in enumerate(
89 | zip(
90 | self.svt.patch_embeds,
91 | self.svt.pos_drops,
92 | self.svt.blocks,
93 | self.svt.pos_block,
94 | )
95 | ):
96 | x, size = embed(x)
97 | x = drop(x)
98 | for j, blk in enumerate(blocks):
99 | x = blk(x, size)
100 | if j == 0:
101 | x = pos_blk(x, size)
102 | if i < len(self.svt.depths) - 1:
103 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
104 |
105 | if i == layer - 1:
106 | break
107 |
108 | return x
109 |
110 |
111 | if __name__ == "__main__":
112 | m = twins_svt_large()
113 | input = torch.randn(2, 3, 400, 800)
114 | out = m.extract_feature(input)
115 | print(out.shape)
116 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/flowformer/core/__init__.py
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels - 1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2 * r + 1)
38 | dy = torch.linspace(-r, r, 2 * r + 1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 | corr = bilinear_sampler(corr, coords_lvl)
45 | corr = corr.view(batch, h1, w1, -1)
46 | out_pyramid.append(corr)
47 |
48 | out = torch.cat(out_pyramid, dim=-1)
49 | return out.permute(0, 3, 1, 2).contiguous().float()
50 |
51 | @staticmethod
52 | def corr(fmap1, fmap2):
53 | batch, dim, ht, wd = fmap1.shape
54 | fmap1 = fmap1.view(batch, dim, ht * wd)
55 | fmap2 = fmap2.view(batch, dim, ht * wd)
56 |
57 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
58 | corr = corr.view(batch, ht, wd, 1, ht, wd)
59 | return corr / torch.sqrt(torch.tensor(dim).float())
60 |
61 |
62 | class AlternateCorrBlock:
63 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
64 | self.num_levels = num_levels
65 | self.radius = radius
66 |
67 | self.pyramid = [(fmap1, fmap2)]
68 | for i in range(self.num_levels):
69 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
70 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
71 | self.pyramid.append((fmap1, fmap2))
72 |
73 | def __call__(self, coords):
74 | coords = coords.permute(0, 2, 3, 1)
75 | B, H, W, _ = coords.shape
76 | dim = self.pyramid[0][0].shape[1]
77 |
78 | corr_list = []
79 | for i in range(self.num_levels):
80 | r = self.radius
81 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
82 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
83 |
84 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
85 | (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
86 | corr_list.append(corr.squeeze(1))
87 |
88 | corr = torch.stack(corr_list, dim=1)
89 | corr = corr.reshape(B, -1, H, W)
90 | return corr / torch.sqrt(torch.tensor(dim).float())
91 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | MAX_FLOW = 400
4 |
5 |
6 | def sequence_loss(flow_preds, flow_gt, valid, cfg):
7 | """Loss function defined over sequence of flow predictions"""
8 |
9 | gamma = cfg.gamma
10 | max_flow = cfg.max_flow
11 | n_predictions = len(flow_preds)
12 | flow_loss = 0.0
13 | flow_gt_thresholds = [5, 10, 20]
14 |
15 | # exlude invalid pixels and extremely large diplacements
16 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
17 | valid = (valid >= 0.5) & (mag < max_flow)
18 |
19 | for i in range(n_predictions):
20 | i_weight = gamma ** (n_predictions - i - 1)
21 | i_loss = (flow_preds[i] - flow_gt).abs()
22 | flow_loss += i_weight * (valid[:, None] * i_loss).mean()
23 |
24 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
25 | epe = epe.view(-1)[valid.view(-1)]
26 |
27 | metrics = {
28 | "epe": epe.mean().item(),
29 | "1px": (epe < 1).float().mean().item(),
30 | "3px": (epe < 3).float().mean().item(),
31 | "5px": (epe < 5).float().mean().item(),
32 | }
33 |
34 | flow_gt_length = torch.sum(flow_gt**2, dim=1).sqrt()
35 | flow_gt_length = flow_gt_length.view(-1)[valid.view(-1)]
36 | for t in flow_gt_thresholds:
37 | e = epe[flow_gt_length < t]
38 | metrics.update({f"{t}-th-5px": (e < 5).float().mean().item()})
39 |
40 | return flow_loss, metrics
41 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import (
3 | MultiStepLR,
4 | CosineAnnealingLR,
5 | ExponentialLR,
6 | OneCycleLR,
7 | )
8 |
9 |
10 | def fetch_optimizer(model, cfg):
11 | """Create the optimizer and learning rate scheduler"""
12 | # optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
13 |
14 | # scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100,
15 | # pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
16 | optimizer = build_optimizer(model, cfg)
17 | scheduler = build_scheduler(cfg, optimizer)
18 |
19 | return optimizer, scheduler
20 |
21 |
22 | def build_optimizer(model, config):
23 | name = config.optimizer
24 | lr = config.canonical_lr
25 |
26 | if name == "adam":
27 | return torch.optim.Adam(
28 | model.parameters(),
29 | lr=lr,
30 | weight_decay=config.adam_decay,
31 | eps=config.epsilon,
32 | )
33 | elif name == "adamw":
34 | if hasattr(config, "twins_lr_factor"):
35 | factor = config.twins_lr_factor
36 | print("[Decrease lr of pre-trained model by factor {}]".format(factor))
37 | param_dicts = [
38 | {
39 | "params": [
40 | p
41 | for n, p in model.named_parameters()
42 | if "feat_encoder" not in n
43 | and "context_encoder" not in n
44 | and p.requires_grad
45 | ]
46 | },
47 | {
48 | "params": [
49 | p
50 | for n, p in model.named_parameters()
51 | if ("feat_encoder" in n or "context_encoder" in n)
52 | and p.requires_grad
53 | ],
54 | "lr": lr * factor,
55 | },
56 | ]
57 | full = [n for n, _ in model.named_parameters()]
58 | return torch.optim.AdamW(
59 | param_dicts, lr=lr, weight_decay=config.adamw_decay, eps=config.epsilon
60 | )
61 | else:
62 | return torch.optim.AdamW(
63 | model.parameters(),
64 | lr=lr,
65 | weight_decay=config.adamw_decay,
66 | eps=config.epsilon,
67 | )
68 | else:
69 | raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
70 |
71 |
72 | def build_scheduler(config, optimizer):
73 | """
74 | Returns:
75 | scheduler (dict):{
76 | 'scheduler': lr_scheduler,
77 | 'interval': 'step', # or 'epoch'
78 | }
79 | """
80 | # scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
81 | name = config.scheduler
82 | lr = config.canonical_lr
83 |
84 | if name == "OneCycleLR":
85 | # scheduler = OneCycleLR(optimizer, )
86 | if hasattr(config, "twins_lr_factor"):
87 | factor = config.twins_lr_factor
88 | scheduler = OneCycleLR(
89 | optimizer,
90 | [lr, lr * factor],
91 | config.num_steps + 100,
92 | pct_start=0.05,
93 | cycle_momentum=False,
94 | anneal_strategy=config.anneal_strategy,
95 | )
96 | else:
97 | scheduler = OneCycleLR(
98 | optimizer,
99 | lr,
100 | config.num_steps + 100,
101 | pct_start=0.05,
102 | cycle_momentum=False,
103 | anneal_strategy=config.anneal_strategy,
104 | )
105 | # elif name == 'MultiStepLR':
106 | # scheduler.update(
107 | # {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
108 | # elif name == 'CosineAnnealing':
109 | # scheduler = CosineAnnealingLR(optimizer, config.num_steps+100)
110 | # scheduler.update(
111 | # {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
112 | # elif name == 'ExponentialLR':
113 | # scheduler.update(
114 | # {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
115 | else:
116 | raise NotImplementedError()
117 |
118 | return scheduler
119 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/position_encoding.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | import math
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class PositionEncodingSine(nn.Module):
8 | """
9 | This is a sinusoidal position encoding that generalized to 2-dimensional images
10 | """
11 |
12 | def __init__(self, d_model, max_shape=(256, 256)):
13 | """
14 | Args:
15 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
16 | """
17 | super().__init__()
18 |
19 | pe = torch.zeros((d_model, *max_shape))
20 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
21 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
22 | div_term = torch.exp(
23 | torch.arange(0, d_model // 2, 2).float()
24 | * (-math.log(10000.0) / d_model // 2)
25 | )
26 | div_term = div_term[:, None, None] # [C//4, 1, 1]
27 | pe[0::4, :, :] = torch.sin(x_position * div_term)
28 | pe[1::4, :, :] = torch.cos(x_position * div_term)
29 | pe[2::4, :, :] = torch.sin(y_position * div_term)
30 | pe[3::4, :, :] = torch.cos(y_position * div_term)
31 |
32 | self.register_buffer("pe", pe.unsqueeze(0)) # [1, C, H, W]
33 |
34 | def forward(self, x):
35 | """
36 | Args:
37 | x: [N, C, H, W]
38 | """
39 | return x + self.pe[:, :, : x.size(2), : x.size(3)]
40 |
41 |
42 | class LinearPositionEncoding(nn.Module):
43 | """
44 | This is a sinusoidal position encoding that generalized to 2-dimensional images
45 | """
46 |
47 | def __init__(self, d_model, max_shape=(256, 256)):
48 | """
49 | Args:
50 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
51 | """
52 | super().__init__()
53 |
54 | pe = torch.zeros((d_model, *max_shape))
55 | y_position = (
56 | torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1
57 | ) / max_shape[0]
58 | x_position = (
59 | torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1
60 | ) / max_shape[1]
61 | div_term = torch.arange(0, d_model // 2, 2).float()
62 | div_term = div_term[:, None, None] # [C//4, 1, 1]
63 | pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi)
64 | pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi)
65 | pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi)
66 | pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi)
67 |
68 | self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
69 |
70 | def forward(self, x):
71 | """
72 | Args:
73 | x: [N, C, H, W]
74 | """
75 | # assert x.shape[2] == 80 and x.shape[3] == 80
76 |
77 | return x + self.pe[:, :, : x.size(2), : x.size(3)]
78 |
79 |
80 | class LearnedPositionEncoding(nn.Module):
81 | """
82 | This is a sinusoidal position encoding that generalized to 2-dimensional images
83 | """
84 |
85 | def __init__(self, d_model, max_shape=(80, 80)):
86 | """
87 | Args:
88 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
89 | """
90 | super().__init__()
91 |
92 | self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model))
93 |
94 | def forward(self, x):
95 | """
96 | Args:
97 | x: [N, C, H, W]
98 | """
99 | # assert x.shape[2] == 80 and x.shape[3] == 80
100 |
101 | return x + self.pe
102 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from update import BasicUpdateBlock, SmallUpdateBlock
7 | from extractor import BasicEncoder, SmallEncoder
8 | from corr import CorrBlock, AlternateCorrBlock
9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 |
19 | def __enter__(self):
20 | pass
21 |
22 | def __exit__(self, *args):
23 | pass
24 |
25 |
26 | class RAFT(nn.Module):
27 | def __init__(self, args):
28 | super(RAFT, self).__init__()
29 | self.args = args
30 |
31 | if args.small:
32 | self.hidden_dim = hdim = 96
33 | self.context_dim = cdim = 64
34 | args.corr_levels = 4
35 | args.corr_radius = 3
36 |
37 | else:
38 | self.hidden_dim = hdim = 128
39 | self.context_dim = cdim = 128
40 | args.corr_levels = 4
41 | args.corr_radius = 4
42 |
43 | if "dropout" not in self.args:
44 | self.args.dropout = 0
45 |
46 | if "alternate_corr" not in self.args:
47 | self.args.alternate_corr = False
48 |
49 | # feature network, context network, and update block
50 | if args.small:
51 | self.fnet = SmallEncoder(
52 | output_dim=128, norm_fn="instance", dropout=args.dropout
53 | )
54 | self.cnet = SmallEncoder(
55 | output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
56 | )
57 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
58 |
59 | else:
60 | self.fnet = BasicEncoder(
61 | output_dim=256, norm_fn="instance", dropout=args.dropout
62 | )
63 | self.cnet = BasicEncoder(
64 | output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
65 | )
66 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
67 |
68 | def freeze_bn(self):
69 | for m in self.modules():
70 | if isinstance(m, nn.BatchNorm2d):
71 | m.eval()
72 |
73 | def initialize_flow(self, img):
74 | """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
75 | N, C, H, W = img.shape
76 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device)
77 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device)
78 |
79 | # optical flow computed as difference: flow = coords1 - coords0
80 | return coords0, coords1
81 |
82 | def upsample_flow(self, flow, mask):
83 | """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
84 | N, _, H, W = flow.shape
85 | mask = mask.view(N, 1, 9, 8, 8, H, W)
86 | mask = torch.softmax(mask, dim=2)
87 |
88 | up_flow = F.unfold(8 * flow, [3, 3], padding=1)
89 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
90 |
91 | up_flow = torch.sum(mask * up_flow, dim=2)
92 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
93 | return up_flow.reshape(N, 2, 8 * H, 8 * W)
94 |
95 | def forward(
96 | self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False
97 | ):
98 | """Estimate optical flow between pair of frames"""
99 |
100 | image1 = 2 * (image1 / 255.0) - 1.0
101 | image2 = 2 * (image2 / 255.0) - 1.0
102 |
103 | image1 = image1.contiguous()
104 | image2 = image2.contiguous()
105 |
106 | hdim = self.hidden_dim
107 | cdim = self.context_dim
108 |
109 | # run the feature network
110 | with autocast(enabled=self.args.mixed_precision):
111 | fmap1, fmap2 = self.fnet([image1, image2])
112 |
113 | fmap1 = fmap1.float()
114 | fmap2 = fmap2.float()
115 | if self.args.alternate_corr:
116 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
117 | else:
118 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
119 |
120 | # run the context network
121 | with autocast(enabled=self.args.mixed_precision):
122 | cnet = self.cnet(image1)
123 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
124 | net = torch.tanh(net)
125 | inp = torch.relu(inp)
126 |
127 | coords0, coords1 = self.initialize_flow(image1)
128 |
129 | if flow_init is not None:
130 | coords1 = coords1 + flow_init
131 |
132 | flow_predictions = []
133 | for itr in range(iters):
134 | coords1 = coords1.detach()
135 | corr = corr_fn(coords1) # index correlation volume
136 |
137 | flow = coords1 - coords0
138 | with autocast(enabled=self.args.mixed_precision):
139 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
140 |
141 | # F(t+1) = F(t) + \Delta(t)
142 | coords1 = coords1 + delta_flow
143 |
144 | # upsample predictions
145 | if up_mask is None:
146 | flow_up = upflow8(coords1 - coords0)
147 | else:
148 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
149 |
150 | flow_predictions.append(flow_up)
151 |
152 | if test_mode:
153 | return coords1 - coords0, flow_up
154 |
155 | return flow_predictions
156 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/flowformer/core/utils/__init__.py
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 |
21 | def make_colorwheel():
22 | """
23 | Generates a color wheel for optical flow visualization as presented in:
24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
26 |
27 | Code follows the original C++ source code of Daniel Scharstein.
28 | Code follows the the Matlab source code of Deqing Sun.
29 |
30 | Returns:
31 | np.ndarray: Color wheel
32 | """
33 |
34 | RY = 15
35 | YG = 6
36 | GC = 4
37 | CB = 11
38 | BM = 13
39 | MR = 6
40 |
41 | ncols = RY + YG + GC + CB + BM + MR
42 | colorwheel = np.zeros((ncols, 3))
43 | col = 0
44 |
45 | # RY
46 | colorwheel[0:RY, 0] = 255
47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
48 | col = col + RY
49 | # YG
50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
51 | colorwheel[col : col + YG, 1] = 255
52 | col = col + YG
53 | # GC
54 | colorwheel[col : col + GC, 1] = 255
55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
56 | col = col + GC
57 | # CB
58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
59 | colorwheel[col : col + CB, 2] = 255
60 | col = col + CB
61 | # BM
62 | colorwheel[col : col + BM, 2] = 255
63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
64 | col = col + BM
65 | # MR
66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
67 | colorwheel[col : col + MR, 0] = 255
68 | return colorwheel
69 |
70 |
71 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
72 | """
73 | Applies the flow color wheel to (possibly clipped) flow components u and v.
74 |
75 | According to the C++ source code of Daniel Scharstein
76 | According to the Matlab source code of Deqing Sun
77 |
78 | Args:
79 | u (np.ndarray): Input horizontal flow of shape [H,W]
80 | v (np.ndarray): Input vertical flow of shape [H,W]
81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
82 |
83 | Returns:
84 | np.ndarray: Flow visualization image of shape [H,W,3]
85 | """
86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
87 | colorwheel = make_colorwheel() # shape [55x3]
88 | ncols = colorwheel.shape[0]
89 | rad = np.sqrt(np.square(u) + np.square(v))
90 | a = np.arctan2(-v, -u) / np.pi
91 | fk = (a + 1) / 2 * (ncols - 1)
92 | k0 = np.floor(fk).astype(np.int32)
93 | k1 = k0 + 1
94 | k1[k1 == ncols] = 0
95 | f = fk - k0
96 | for i in range(colorwheel.shape[1]):
97 | tmp = colorwheel[:, i]
98 | col0 = tmp[k0] / 255.0
99 | col1 = tmp[k1] / 255.0
100 | col = (1 - f) * col0 + f * col1
101 | idx = rad <= 1
102 | col[idx] = 1 - rad[idx] * (1 - col[idx])
103 | col[~idx] = col[~idx] * 0.75 # out of range
104 | # Note the 2-i => BGR instead of RGB
105 | ch_idx = 2 - i if convert_to_bgr else i
106 | flow_image[:, :, ch_idx] = np.floor(255 * col)
107 | return flow_image
108 |
109 |
110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, max_flow=None):
111 | """
112 | Expects a two dimensional flow image of shape.
113 |
114 | Args:
115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
118 |
119 | Returns:
120 | np.ndarray: Flow visualization image of shape [H,W,3]
121 | """
122 | assert flow_uv.ndim == 3, "input flow must have three dimensions"
123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
124 | if clip_flow is not None:
125 | flow_uv = np.clip(flow_uv, 0, clip_flow)
126 | u = flow_uv[:, :, 0]
127 | v = flow_uv[:, :, 1]
128 | if max_flow is None:
129 | rad = np.sqrt(np.square(u) + np.square(v))
130 | rad_max = np.max(rad)
131 | else:
132 | rad_max = max_flow
133 | epsilon = 1e-5
134 | u = u / (rad_max + epsilon)
135 | v = v / (rad_max + epsilon)
136 | return flow_uv_to_colors(u, v, convert_to_bgr)
137 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 |
8 | cv2.setNumThreads(0)
9 | cv2.ocl.setUseOpenCL(False)
10 |
11 | TAG_CHAR = np.array([202021.25], np.float32)
12 |
13 |
14 | def readFlow(fn):
15 | """Read .flo file in Middlebury format"""
16 | # Code adapted from:
17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
18 |
19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
20 | # print 'fn = %s'%(fn)
21 | with open(fn, "rb") as f:
22 | magic = np.fromfile(f, np.float32, count=1)
23 | if 202021.25 != magic:
24 | print("Magic number incorrect. Invalid .flo file")
25 | return None
26 | else:
27 | w = np.fromfile(f, np.int32, count=1)
28 | h = np.fromfile(f, np.int32, count=1)
29 | # print 'Reading %d x %d flo file\n' % (w, h)
30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
31 | # Reshape data into 3D array (columns, rows, bands)
32 | # The reshape here is for visualization, the original code is (w,h,2)
33 | return np.resize(data, (int(h), int(w), 2))
34 |
35 |
36 | def readPFM(file):
37 | file = open(file, "rb")
38 |
39 | color = None
40 | width = None
41 | height = None
42 | scale = None
43 | endian = None
44 |
45 | header = file.readline().rstrip()
46 | if header == b"PF":
47 | color = True
48 | elif header == b"Pf":
49 | color = False
50 | else:
51 | raise Exception("Not a PFM file.")
52 |
53 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
54 | if dim_match:
55 | width, height = map(int, dim_match.groups())
56 | else:
57 | raise Exception("Malformed PFM header.")
58 |
59 | scale = float(file.readline().rstrip())
60 | if scale < 0: # little-endian
61 | endian = "<"
62 | scale = -scale
63 | else:
64 | endian = ">" # big-endian
65 |
66 | data = np.fromfile(file, endian + "f")
67 | shape = (height, width, 3) if color else (height, width)
68 |
69 | data = np.reshape(data, shape)
70 | data = np.flipud(data)
71 | return data
72 |
73 |
74 | def writeFlow(filename, uv, v=None):
75 | """Write optical flow to file.
76 |
77 | If v is None, uv is assumed to contain both u and v channels,
78 | stacked in depth.
79 | Original code by Deqing Sun, adapted from Daniel Scharstein.
80 | """
81 | nBands = 2
82 |
83 | if v is None:
84 | assert uv.ndim == 3
85 | assert uv.shape[2] == 2
86 | u = uv[:, :, 0]
87 | v = uv[:, :, 1]
88 | else:
89 | u = uv
90 |
91 | assert u.shape == v.shape
92 | height, width = u.shape
93 | f = open(filename, "wb")
94 | # write the header
95 | f.write(TAG_CHAR)
96 | np.array(width).astype(np.int32).tofile(f)
97 | np.array(height).astype(np.int32).tofile(f)
98 | # arrange into matrix form
99 | tmp = np.zeros((height, width * nBands))
100 | tmp[:, np.arange(width) * 2] = u
101 | tmp[:, np.arange(width) * 2 + 1] = v
102 | tmp.astype(np.float32).tofile(f)
103 | f.close()
104 |
105 |
106 | def readFlowKITTI(filename):
107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
108 | flow = flow[:, :, ::-1].astype(np.float32)
109 | flow, valid = flow[:, :, :2], flow[:, :, 2]
110 | flow = (flow - 2**15) / 64.0
111 | return flow, valid
112 |
113 |
114 | def readDispKITTI(filename):
115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
116 | valid = disp > 0.0
117 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
118 | return flow, valid
119 |
120 |
121 | def writeFlowKITTI(filename, uv):
122 | uv = 64.0 * uv + 2**15
123 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
125 | cv2.imwrite(filename, uv[..., ::-1])
126 |
127 |
128 | def read_gen(file_name, pil=False):
129 | ext = splitext(file_name)[-1]
130 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
131 | return Image.open(file_name)
132 | elif ext == ".bin" or ext == ".raw":
133 | return np.load(file_name)
134 | elif ext == ".flo":
135 | return readFlow(file_name).astype(np.float32)
136 | elif ext == ".pfm":
137 | flow = readPFM(file_name).astype(np.float32)
138 | if len(flow.shape) == 2:
139 | return flow
140 | else:
141 | return flow[:, :, :-1]
142 | return []
143 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/utils/logger.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | from loguru import logger as loguru_logger
3 |
4 |
5 | class Logger:
6 | def __init__(self, model, scheduler, cfg):
7 | self.model = model
8 | self.scheduler = scheduler
9 | self.total_steps = 0
10 | self.running_loss = {}
11 | self.writer = None
12 | self.cfg = cfg
13 |
14 | def _print_training_status(self):
15 | metrics_data = [
16 | self.running_loss[k] / self.cfg.sum_freq
17 | for k in sorted(self.running_loss.keys())
18 | ]
19 | training_str = "[{:6d}, {}] ".format(
20 | self.total_steps + 1, self.scheduler.get_last_lr()
21 | )
22 | metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
23 |
24 | # print the training status
25 | loguru_logger.info(training_str + metrics_str)
26 |
27 | if self.writer is None:
28 | if self.cfg.log_dir is None:
29 | self.writer = SummaryWriter()
30 | else:
31 | self.writer = SummaryWriter(self.cfg.log_dir)
32 |
33 | for k in self.running_loss:
34 | self.writer.add_scalar(
35 | k, self.running_loss[k] / self.cfg.sum_freq, self.total_steps
36 | )
37 | self.running_loss[k] = 0.0
38 |
39 | def push(self, metrics):
40 | self.total_steps += 1
41 |
42 | for key in metrics:
43 | if key not in self.running_loss:
44 | self.running_loss[key] = 0.0
45 |
46 | self.running_loss[key] += metrics[key]
47 |
48 | if self.total_steps % self.cfg.sum_freq == self.cfg.sum_freq - 1:
49 | self._print_training_status()
50 | self.running_loss = {}
51 |
52 | def write_dict(self, results):
53 | if self.writer is None:
54 | self.writer = SummaryWriter()
55 |
56 | for key in results:
57 | self.writer.add_scalar(key, results[key], self.total_steps)
58 |
59 | def close(self):
60 | self.writer.close()
61 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/utils/misc.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import shutil
4 |
5 |
6 | def process_transformer_cfg(cfg):
7 | log_dir = ""
8 | if "critical_params" in cfg:
9 | critical_params = [cfg[key] for key in cfg.critical_params]
10 | for name, param in zip(cfg["critical_params"], critical_params):
11 | log_dir += "{:s}[{:s}]".format(name, str(param))
12 |
13 | return log_dir
14 |
15 |
16 | def process_cfg(cfg):
17 | log_dir = "logs/" + cfg.name + "/" + cfg.transformer + "/"
18 | critical_params = [cfg.trainer[key] for key in cfg.critical_params]
19 | for name, param in zip(cfg["critical_params"], critical_params):
20 | log_dir += "{:s}[{:s}]".format(name, str(param))
21 |
22 | log_dir += process_transformer_cfg(cfg[cfg.transformer])
23 |
24 | now = time.localtime()
25 | now_time = "{:02d}_{:02d}_{:02d}_{:02d}".format(
26 | now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min
27 | )
28 | log_dir += cfg.suffix + "(" + now_time + ")"
29 | cfg.log_dir = log_dir
30 | os.makedirs(log_dir)
31 |
32 | shutil.copytree("configs", f"{log_dir}/configs")
33 | shutil.copytree("core/FlowFormer", f"{log_dir}/FlowFormer")
34 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """Pads images such that dimensions are divisible by 8"""
9 |
10 | def __init__(self, dims, mode="sintel"):
11 | self.ht, self.wd = dims[-2:]
12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
14 | if mode == "sintel":
15 | self._pad = [
16 | pad_wd // 2,
17 | pad_wd - pad_wd // 2,
18 | pad_ht // 2,
19 | pad_ht - pad_ht // 2,
20 | ]
21 | elif mode == "kitti400":
22 | self._pad = [0, 0, 0, 400 - self.ht]
23 | else:
24 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
25 |
26 | def pad(self, *inputs):
27 | return [F.pad(x, self._pad, mode="replicate") for x in inputs]
28 |
29 | def unpad(self, x):
30 | ht, wd = x.shape[-2:]
31 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
32 | return x[..., c[0] : c[1], c[2] : c[3]]
33 |
34 |
35 | def forward_interpolate(flow):
36 | flow = flow.detach().cpu().numpy()
37 | dx, dy = flow[0], flow[1]
38 |
39 | ht, wd = dx.shape
40 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
41 |
42 | x1 = x0 + dx
43 | y1 = y0 + dy
44 |
45 | x1 = x1.reshape(-1)
46 | y1 = y1.reshape(-1)
47 | dx = dx.reshape(-1)
48 | dy = dy.reshape(-1)
49 |
50 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
51 | x1 = x1[valid]
52 | y1 = y1[valid]
53 | dx = dx[valid]
54 | dy = dy[valid]
55 |
56 | flow_x = interpolate.griddata(
57 | (x1, y1), dx, (x0, y0), method="nearest", fill_value=0
58 | )
59 |
60 | flow_y = interpolate.griddata(
61 | (x1, y1), dy, (x0, y0), method="nearest", fill_value=0
62 | )
63 |
64 | flow = np.stack([flow_x, flow_y], axis=0)
65 | return torch.from_numpy(flow).float()
66 |
67 |
68 | def bilinear_sampler(img, coords, mode="bilinear", mask=False):
69 | """Wrapper for grid_sample, uses pixel coordinates"""
70 | H, W = img.shape[-2:]
71 | xgrid, ygrid = coords.split([1, 1], dim=-1)
72 | xgrid = 2 * xgrid / (W - 1) - 1
73 | ygrid = 2 * ygrid / (H - 1) - 1
74 |
75 | grid = torch.cat([xgrid, ygrid], dim=-1)
76 | img = F.grid_sample(img, grid, align_corners=True)
77 |
78 | if mask:
79 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
80 | return img, mask.float()
81 |
82 | return img
83 |
84 |
85 | def indexing(img, coords, mask=False):
86 | """Wrapper for grid_sample, uses pixel coordinates"""
87 | """
88 | TODO: directly indexing features instead of sampling
89 | """
90 | H, W = img.shape[-2:]
91 | xgrid, ygrid = coords.split([1, 1], dim=-1)
92 | xgrid = 2 * xgrid / (W - 1) - 1
93 | ygrid = 2 * ygrid / (H - 1) - 1
94 |
95 | grid = torch.cat([xgrid, ygrid], dim=-1)
96 | img = F.grid_sample(img, grid, align_corners=True, mode="nearest")
97 |
98 | if mask:
99 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
100 | return img, mask.float()
101 |
102 | return img
103 |
104 |
105 | def coords_grid(batch, ht, wd):
106 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
107 | coords = torch.stack(coords[::-1], dim=0).float()
108 | return coords[None].repeat(batch, 1, 1, 1)
109 |
110 |
111 | def upflow8(flow, mode="bilinear"):
112 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
113 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
114 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/flowformer/run_train.sh:
--------------------------------------------------------------------------------
1 | mkdir -p checkpoints
2 | python -u train_FlowFormer.py --name chairs --stage chairs --validation chairs
3 | python -u train_FlowFormer.py --name things --stage things --validation sintel
4 | python -u train_FlowFormer.py --name sintel --stage sintel --validation sintel
5 | python -u train_FlowFormer.py --name kitti --stage kitti --validation kitti
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/gmflow/__init__.py:
--------------------------------------------------------------------------------
1 | from .gmflow import GMFlow
2 | import argparse
3 | import torch
4 |
5 |
6 | def initialize_GMFlow(model_path="pretrained_ckpt/gmflow_sintel_with_refinement.pkl", device="cuda"):
7 | """Initializes the RAFT model."""
8 |
9 | model = GMFlow()
10 | ckpt = torch.load(model_path, map_location="cpu")
11 |
12 | # def convert(param):
13 | # return {k.replace("module.", ""): v for k, v in param.items() if "module" in k}
14 | #
15 | # ckpt = convert(ckpt)
16 | model.load_state_dict(ckpt, strict=True)
17 | print("load gmflow from " + model_path)
18 |
19 | return model
20 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/gmflow/backbone.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .trident_conv import MultiScaleTridentConv
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
8 | ):
9 | super(ResidualBlock, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
12 | dilation=dilation, padding=dilation, stride=stride, bias=False)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
14 | dilation=dilation, padding=dilation, bias=False)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | self.norm1 = norm_layer(planes)
18 | self.norm2 = norm_layer(planes)
19 | if not stride == 1 or in_planes != planes:
20 | self.norm3 = norm_layer(planes)
21 |
22 | if stride == 1 and in_planes == planes:
23 | self.downsample = None
24 | else:
25 | self.downsample = nn.Sequential(
26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
27 |
28 | def forward(self, x):
29 | y = x
30 | y = self.relu(self.norm1(self.conv1(y)))
31 | y = self.relu(self.norm2(self.conv2(y)))
32 |
33 | if self.downsample is not None:
34 | x = self.downsample(x)
35 |
36 | return self.relu(x + y)
37 |
38 |
39 | class CNNEncoder(nn.Module):
40 | def __init__(self, output_dim=128,
41 | norm_layer=nn.InstanceNorm2d,
42 | num_output_scales=1,
43 | **kwargs,
44 | ):
45 | super(CNNEncoder, self).__init__()
46 | self.num_branch = num_output_scales
47 |
48 | feature_dims = [64, 96, 128]
49 |
50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
51 | self.norm1 = norm_layer(feature_dims[0])
52 | self.relu1 = nn.ReLU(inplace=True)
53 |
54 | self.in_planes = feature_dims[0]
55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
57 |
58 | # highest resolution 1/4 or 1/8
59 | stride = 2 if num_output_scales == 1 else 1
60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride,
61 | norm_layer=norm_layer,
62 | ) # 1/4 or 1/8
63 |
64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
65 |
66 | if self.num_branch > 1:
67 | if self.num_branch == 4:
68 | strides = (1, 2, 4, 8)
69 | elif self.num_branch == 3:
70 | strides = (1, 2, 4)
71 | elif self.num_branch == 2:
72 | strides = (1, 2)
73 | else:
74 | raise ValueError
75 |
76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
77 | kernel_size=3,
78 | strides=strides,
79 | paddings=1,
80 | num_branch=self.num_branch,
81 | )
82 |
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv2d):
85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
87 | if m.weight is not None:
88 | nn.init.constant_(m.weight, 1)
89 | if m.bias is not None:
90 | nn.init.constant_(m.bias, 0)
91 |
92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
95 |
96 | layers = (layer1, layer2)
97 |
98 | self.in_planes = dim
99 | return nn.Sequential(*layers)
100 |
101 | def forward(self, x):
102 | x = self.conv1(x)
103 | x = self.norm1(x)
104 | x = self.relu1(x)
105 |
106 | x = self.layer1(x) # 1/2
107 | x = self.layer2(x) # 1/4
108 | x = self.layer3(x) # 1/8 or 1/4
109 |
110 | x = self.conv2(x)
111 |
112 | if self.num_branch > 1:
113 | out = self.trident_conv([x] * self.num_branch) # high to low res
114 | else:
115 | out = [x]
116 |
117 | return out
118 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/gmflow/geometry.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def coords_grid(b, h, w, homogeneous=False, device=None):
6 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
7 |
8 | stacks = [x, y]
9 |
10 | if homogeneous:
11 | ones = torch.ones_like(x) # [H, W]
12 | stacks.append(ones)
13 |
14 | grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
15 |
16 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
17 |
18 | if device is not None:
19 | grid = grid.to(device)
20 |
21 | return grid
22 |
23 |
24 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
25 | assert device is not None
26 |
27 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
28 | torch.linspace(h_min, h_max, len_h, device=device)],
29 | )
30 | grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
31 |
32 | return grid
33 |
34 |
35 | def normalize_coords(coords, h, w):
36 | # coords: [B, H, W, 2]
37 | c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
38 | return (coords - c) / c # [-1, 1]
39 |
40 |
41 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
42 | # img: [B, C, H, W]
43 | # sample_coords: [B, 2, H, W] in image scale
44 | if sample_coords.size(1) != 2: # [B, H, W, 2]
45 | sample_coords = sample_coords.permute(0, 3, 1, 2)
46 |
47 | b, _, h, w = sample_coords.shape
48 |
49 | # Normalize to [-1, 1]
50 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
51 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
52 |
53 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
54 |
55 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
56 |
57 | if return_mask:
58 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
59 |
60 | return img, mask
61 |
62 | return img
63 |
64 |
65 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
66 | b, c, h, w = feature.size()
67 | assert flow.size(1) == 2
68 |
69 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
70 |
71 | return bilinear_sample(feature, grid, padding_mode=padding_mode,
72 | return_mask=mask)
73 |
74 |
75 | def forward_backward_consistency_check(fwd_flow, bwd_flow,
76 | alpha=0.01,
77 | beta=0.5
78 | ):
79 | # fwd_flow, bwd_flow: [B, 2, H, W]
80 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
81 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
82 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
83 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
84 |
85 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
86 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
87 |
88 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
89 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
90 |
91 | threshold = alpha * flow_mag + beta
92 |
93 | fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
94 | bwd_occ = (diff_bwd > threshold).float()
95 |
96 | return fwd_occ, bwd_occ
97 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/gmflow/matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from .geometry import coords_grid, generate_window_grid, normalize_coords
5 |
6 |
7 | def global_correlation_softmax(feature0, feature1,
8 | pred_bidir_flow=False,
9 | ):
10 | # global correlation
11 | b, c, h, w = feature0.shape
12 | feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
13 | feature1 = feature1.view(b, c, -1) # [B, C, H*W]
14 |
15 | correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
16 |
17 | # flow from softmax
18 | init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
19 | grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
20 |
21 | correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
22 |
23 | if pred_bidir_flow:
24 | correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
25 | init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
26 | grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
27 | b = b * 2
28 |
29 | prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
30 |
31 | correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
32 |
33 | # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
34 | flow = correspondence - init_grid
35 |
36 | return flow, prob
37 |
38 |
39 | def local_correlation_softmax(feature0, feature1, local_radius,
40 | padding_mode='zeros',
41 | ):
42 | b, c, h, w = feature0.size()
43 | coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
44 | coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
45 |
46 | local_h = 2 * local_radius + 1
47 | local_w = 2 * local_radius + 1
48 |
49 | window_grid = generate_window_grid(-local_radius, local_radius,
50 | -local_radius, local_radius,
51 | local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
52 | window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
53 | sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
54 |
55 | sample_coords_softmax = sample_coords
56 |
57 | # exclude coords that are out of image space
58 | valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
59 | valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
60 |
61 | valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
62 |
63 | # normalize coordinates to [-1, 1]
64 | sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
65 | window_feature = F.grid_sample(feature1, sample_coords_norm,
66 | padding_mode=padding_mode, align_corners=True
67 | ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
68 | feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
69 |
70 | corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
71 |
72 | # mask invalid locations
73 | corr[~valid] = -1e9
74 |
75 | prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
76 |
77 | correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
78 | b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
79 |
80 | flow = correspondence - coords_init
81 | match_prob = prob
82 |
83 | return flow, match_prob
84 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/gmflow/position.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | import math
7 |
8 |
9 | class PositionEmbeddingSine(nn.Module):
10 | """
11 | This is a more standard version of the position embedding, very similar to the one
12 | used by the Attention is all you need paper, generalized to work on images.
13 | """
14 |
15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
16 | super().__init__()
17 | self.num_pos_feats = num_pos_feats
18 | self.temperature = temperature
19 | self.normalize = normalize
20 | if scale is not None and normalize is False:
21 | raise ValueError("normalize should be True if scale is passed")
22 | if scale is None:
23 | scale = 2 * math.pi
24 | self.scale = scale
25 |
26 | def forward(self, x):
27 | # x = tensor_list.tensors # [B, C, H, W]
28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
29 | b, c, h, w = x.size()
30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
31 | y_embed = mask.cumsum(1, dtype=torch.float32)
32 | x_embed = mask.cumsum(2, dtype=torch.float32)
33 | if self.normalize:
34 | eps = 1e-6
35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
37 |
38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
40 |
41 | pos_x = x_embed[:, :, :, None] / dim_t
42 | pos_y = y_embed[:, :, :, None] / dim_t
43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
46 | return pos
47 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/gmflow/trident_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.nn.modules.utils import _pair
8 |
9 |
10 | class MultiScaleTridentConv(nn.Module):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | stride=1,
17 | strides=1,
18 | paddings=0,
19 | dilations=1,
20 | dilation=1,
21 | groups=1,
22 | num_branch=1,
23 | test_branch_idx=-1,
24 | bias=False,
25 | norm=None,
26 | activation=None,
27 | ):
28 | super(MultiScaleTridentConv, self).__init__()
29 | self.in_channels = in_channels
30 | self.out_channels = out_channels
31 | self.kernel_size = _pair(kernel_size)
32 | self.num_branch = num_branch
33 | self.stride = _pair(stride)
34 | self.groups = groups
35 | self.with_bias = bias
36 | self.dilation = dilation
37 | if isinstance(paddings, int):
38 | paddings = [paddings] * self.num_branch
39 | if isinstance(dilations, int):
40 | dilations = [dilations] * self.num_branch
41 | if isinstance(strides, int):
42 | strides = [strides] * self.num_branch
43 | self.paddings = [_pair(padding) for padding in paddings]
44 | self.dilations = [_pair(dilation) for dilation in dilations]
45 | self.strides = [_pair(stride) for stride in strides]
46 | self.test_branch_idx = test_branch_idx
47 | self.norm = norm
48 | self.activation = activation
49 |
50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
51 |
52 | self.weight = nn.Parameter(
53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
54 | )
55 | if bias:
56 | self.bias = nn.Parameter(torch.Tensor(out_channels))
57 | else:
58 | self.bias = None
59 |
60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
61 | if self.bias is not None:
62 | nn.init.constant_(self.bias, 0)
63 |
64 | def forward(self, inputs):
65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
66 | assert len(inputs) == num_branch
67 |
68 | if self.training or self.test_branch_idx == -1:
69 | outputs = [
70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
71 | for input, stride, padding in zip(inputs, self.strides, self.paddings)
72 | ]
73 | else:
74 | outputs = [
75 | F.conv2d(
76 | inputs[0],
77 | self.weight,
78 | self.bias,
79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
81 | self.dilation,
82 | self.groups,
83 | )
84 | ]
85 |
86 | if self.norm is not None:
87 | outputs = [self.norm(x) for x in outputs]
88 | if self.activation is not None:
89 | outputs = [self.activation(x) for x in outputs]
90 | return outputs
91 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/gmflow/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .position import PositionEmbeddingSine
3 |
4 |
5 | def split_feature(feature,
6 | num_splits=2,
7 | channel_last=False,
8 | ):
9 | if channel_last: # [B, H, W, C]
10 | b, h, w, c = feature.size()
11 | assert h % num_splits == 0 and w % num_splits == 0
12 |
13 | b_new = b * num_splits * num_splits
14 | h_new = h // num_splits
15 | w_new = w // num_splits
16 |
17 | feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
18 | ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
19 | else: # [B, C, H, W]
20 | b, c, h, w = feature.size()
21 | assert h % num_splits == 0 and w % num_splits == 0
22 |
23 | b_new = b * num_splits * num_splits
24 | h_new = h // num_splits
25 | w_new = w // num_splits
26 |
27 | feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
28 | ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
29 |
30 | return feature
31 |
32 |
33 | def merge_splits(splits,
34 | num_splits=2,
35 | channel_last=False,
36 | ):
37 | if channel_last: # [B*K*K, H/K, W/K, C]
38 | b, h, w, c = splits.size()
39 | new_b = b // num_splits // num_splits
40 |
41 | splits = splits.view(new_b, num_splits, num_splits, h, w, c)
42 | merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
43 | new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
44 | else: # [B*K*K, C, H/K, W/K]
45 | b, c, h, w = splits.size()
46 | new_b = b // num_splits // num_splits
47 |
48 | splits = splits.view(new_b, num_splits, num_splits, c, h, w)
49 | merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
50 | new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
51 |
52 | return merge
53 |
54 |
55 | def normalize_img(img0, img1):
56 | # loaded images are in [0, 255]
57 | # normalize by ImageNet mean and std
58 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
59 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
60 | img0 = (img0 / 255. - mean) / std
61 | img1 = (img1 / 255. - mean) / std
62 |
63 | return img0, img1
64 |
65 |
66 | def feature_add_position(feature0, feature1, attn_splits, feature_channels):
67 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
68 |
69 | if attn_splits > 1: # add position in splited window
70 | feature0_splits = split_feature(feature0, num_splits=attn_splits)
71 | feature1_splits = split_feature(feature1, num_splits=attn_splits)
72 |
73 | position = pos_enc(feature0_splits)
74 |
75 | feature0_splits = feature0_splits + position
76 | feature1_splits = feature1_splits + position
77 |
78 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
79 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
80 | else:
81 | position = pos_enc(feature0)
82 |
83 | feature0 = feature0 + position
84 | feature1 = feature1 + position
85 |
86 | return feature0, feature1
87 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/modules/__init__.py
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/modules/coord_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | class CoordSampler3D(nn.Module):
16 | def __init__(self, coord_range, t_coord_only=False):
17 | super().__init__()
18 | self.coord_range = coord_range
19 | self.t_coord_only = t_coord_only
20 |
21 | def shape2coordinate(
22 | self,
23 | batch_size,
24 | spatial_shape,
25 | t_ids,
26 | coord_range=(-1.0, 1.0),
27 | upsample_ratio=1,
28 | device=None,
29 | ):
30 | coords = []
31 | assert isinstance(t_ids, list)
32 | _coords = torch.tensor(t_ids, device=device) / 1.0
33 | coords.append(_coords.to(torch.float32))
34 | for num_s in spatial_shape:
35 | num_s = int(num_s * upsample_ratio)
36 | _coords = (0.5 + torch.arange(num_s, device=device)) / num_s
37 | _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
38 | coords.append(_coords)
39 | coords = torch.meshgrid(*coords, indexing="ij")
40 | coords = torch.stack(coords, dim=-1)
41 | ones_like_shape = (1,) * coords.ndim
42 | coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
43 | return coords # (B,T,H,W,3)
44 |
45 | def batchshape2coordinate(
46 | self,
47 | batch_size,
48 | spatial_shape,
49 | t_ids,
50 | coord_range=(-1.0, 1.0),
51 | upsample_ratio=1,
52 | device=None,
53 | ):
54 | coords = []
55 | _coords = torch.tensor(1, device=device)
56 | coords.append(_coords.to(torch.float32))
57 | for num_s in spatial_shape:
58 | num_s = int(num_s * upsample_ratio)
59 | _coords = (0.5 + torch.arange(num_s, device=device)) / num_s
60 | _coords = coord_range[0] + (coord_range[1] - coord_range[0]) * _coords
61 | coords.append(_coords)
62 | coords = torch.meshgrid(*coords, indexing="ij")
63 | coords = torch.stack(coords, dim=-1)
64 | ones_like_shape = (1,) * coords.ndim
65 | # Now coords b,1,h,w,3, coords[...,0]=1.
66 | coords = coords.unsqueeze(0).repeat(batch_size, *ones_like_shape)
67 | # assign per-sample timestep within the batch
68 | coords[..., :1] = coords[..., :1] * t_ids.reshape(-1, 1, 1, 1, 1)
69 | return coords
70 |
71 | def forward(
72 | self,
73 | batch_size,
74 | s_shape,
75 | t_ids,
76 | coord_range=None,
77 | upsample_ratio=1.0,
78 | device=None,
79 | ):
80 | coord_range = self.coord_range if coord_range is None else coord_range
81 | if isinstance(t_ids, list):
82 | coords = self.shape2coordinate(
83 | batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
84 | )
85 | elif isinstance(t_ids, torch.Tensor):
86 | coords = self.batchshape2coordinate(
87 | batch_size, s_shape, t_ids, coord_range, upsample_ratio, device
88 | )
89 | if self.t_coord_only:
90 | coords = coords[..., :1]
91 | return coords
92 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/modules/fi_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # raft: https://github.com/princeton-vl/RAFT
9 | # ema-vfi: https://github.com/MCG-NJU/EMA-VFI
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn.functional as F
14 |
15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16 | backwarp_tenGrid = {}
17 |
18 |
19 | def warp(tenInput, tenFlow):
20 | k = (str(tenFlow.device), str(tenFlow.size()))
21 | if k not in backwarp_tenGrid:
22 | tenHorizontal = (
23 | torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device)
24 | .view(1, 1, 1, tenFlow.shape[3])
25 | .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
26 | )
27 | tenVertical = (
28 | torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device)
29 | .view(1, 1, tenFlow.shape[2], 1)
30 | .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
31 | )
32 | backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
33 |
34 | tenFlow = torch.cat(
35 | [
36 | tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
37 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0),
38 | ],
39 | 1,
40 | )
41 |
42 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
43 | return torch.nn.functional.grid_sample(
44 | input=tenInput,
45 | grid=g,
46 | mode="bilinear",
47 | padding_mode="border",
48 | align_corners=True,
49 | )
50 |
51 |
52 | def normalize_flow(flows):
53 | flow_scaler = torch.max(torch.abs(flows).flatten(1), dim=-1)[0].reshape(
54 | -1, 1, 1, 1, 1
55 | )
56 | flows = flows / flow_scaler # [-1,1]
57 | # # Adapt to [0,1]
58 | flows = (flows + 1.0) / 2.0
59 | return flows, flow_scaler
60 |
61 |
62 | def unnormalize_flow(flows, flow_scaler):
63 | return (flows * 2.0 - 1.0) * flow_scaler
64 |
65 |
66 | def resize(x, scale_factor):
67 | return F.interpolate(
68 | x, scale_factor=scale_factor, mode="bilinear", align_corners=False
69 | )
70 |
71 |
72 | def coords_grid(batch, ht, wd):
73 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
74 | coords = torch.stack(coords[::-1], dim=0).float()
75 | return coords[None].repeat(batch, 1, 1, 1)
76 |
77 |
78 | def build_coord(img):
79 | N, C, H, W = img.shape
80 | coords = coords_grid(N, H // 8, W // 8)
81 | return coords
82 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/modules/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 |
8 | from torch import nn
9 | import torch
10 |
11 |
12 | # define siren layer & Siren model
13 | class Sine(nn.Module):
14 | """Sine activation with scaling.
15 |
16 | Args:
17 | w0 (float): Omega_0 parameter from SIREN paper.
18 | """
19 |
20 | def __init__(self, w0=1.0):
21 | super().__init__()
22 | self.w0 = w0
23 |
24 | def forward(self, x):
25 | return torch.sin(self.w0 * x)
26 |
27 |
28 | # Damping activation from http://arxiv.org/abs/2306.15242
29 | class Damping(nn.Module):
30 | """Sine activation with sublinear factor
31 |
32 | Args:
33 | w0 (float): Omega_0 parameter from SIREN paper.
34 | """
35 |
36 | def __init__(self, w0=1.0):
37 | super().__init__()
38 | self.w0 = w0
39 |
40 | def forward(self, x):
41 | x = torch.clamp(x, min=1e-30)
42 | return torch.sin(self.w0 * x) * torch.sqrt(x.abs())
43 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/modules/module_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | from typing import List, Optional
12 | from dataclasses import dataclass
13 | from omegaconf import MISSING
14 |
15 |
16 | @dataclass
17 | class HypoNetActivationConfig:
18 | type: str = "relu"
19 | siren_w0: Optional[float] = 30.0
20 |
21 |
22 | @dataclass
23 | class HypoNetInitConfig:
24 | weight_init_type: Optional[str] = "kaiming_uniform"
25 | bias_init_type: Optional[str] = "zero"
26 |
27 |
28 | @dataclass
29 | class HypoNetConfig:
30 | type: str = "mlp"
31 | n_layer: int = 5
32 | hidden_dim: List[int] = MISSING
33 | use_bias: bool = True
34 | input_dim: int = 2
35 | output_dim: int = 3
36 | output_bias: float = 0.5
37 | activation: HypoNetActivationConfig = HypoNetActivationConfig()
38 | initialization: HypoNetInitConfig = HypoNetInitConfig()
39 |
40 | normalize_weight: bool = True
41 | linear_interpo: bool = False
42 |
43 |
44 | @dataclass
45 | class CoordSamplerConfig:
46 | data_type: str = "image"
47 | t_coord_only: bool = False
48 | coord_range: List[float] = MISSING
49 | time_range: List[float] = MISSING
50 | train_strategy: Optional[str] = MISSING
51 | val_strategy: Optional[str] = MISSING
52 | patch_size: Optional[int] = MISSING
53 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/modules/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | import math
12 | import torch
13 | import torch.nn as nn
14 |
15 | from .layers import Sine, Damping
16 |
17 |
18 | def convert_int_to_list(size, len_list=2):
19 | if isinstance(size, int):
20 | return [size] * len_list
21 | else:
22 | assert len(size) == len_list
23 | return size
24 |
25 |
26 | def initialize_params(params, init_type, **kwargs):
27 | fan_in, fan_out = params.shape[0], params.shape[1]
28 | if init_type is None or init_type == "normal":
29 | nn.init.normal_(params)
30 | elif init_type == "kaiming_uniform":
31 | nn.init.kaiming_uniform_(params, a=math.sqrt(5))
32 | elif init_type == "uniform_fan_in":
33 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
34 | nn.init.uniform_(params, -bound, bound)
35 | elif init_type == "zero":
36 | nn.init.zeros_(params)
37 | elif "siren" == init_type:
38 | assert "siren_w0" in kwargs.keys() and "is_first" in kwargs.keys()
39 | w0 = kwargs["siren_w0"]
40 | if kwargs["is_first"]:
41 | w_std = 1 / fan_in
42 | else:
43 | w_std = math.sqrt(6.0 / fan_in) / w0
44 | nn.init.uniform_(params, -w_std, w_std)
45 | else:
46 | raise NotImplementedError
47 |
48 |
49 | def create_params_with_init(
50 | shape, init_type="normal", include_bias=False, bias_init_type="zero", **kwargs
51 | ):
52 | if not include_bias:
53 | params = torch.empty([shape[0], shape[1]])
54 | initialize_params(params, init_type, **kwargs)
55 | return params
56 | else:
57 | params = torch.empty([shape[0] - 1, shape[1]])
58 | bias = torch.empty([1, shape[1]])
59 |
60 | initialize_params(params, init_type, **kwargs)
61 | initialize_params(bias, bias_init_type, **kwargs)
62 | return torch.cat([params, bias], dim=0)
63 |
64 |
65 | def create_activation(config):
66 | if config.type == "relu":
67 | activation = nn.ReLU()
68 | elif config.type == "siren":
69 | activation = Sine(config.siren_w0)
70 | elif config.type == "silu":
71 | activation = nn.SiLU()
72 | elif config.type == "damp":
73 | activation = Damping(config.siren_w0)
74 | else:
75 | raise NotImplementedError
76 | return activation
77 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/raft/__init__.py:
--------------------------------------------------------------------------------
1 | from .raft import RAFT
2 | import argparse
3 | import torch
4 | from .extractor import BasicEncoder
5 |
6 |
7 | def initialize_RAFT(model_path="weights/raft-things.pth", device="cuda"):
8 | """Initializes the RAFT model."""
9 | args = argparse.ArgumentParser()
10 | args.raft_model = model_path
11 | args.small = False
12 | args.mixed_precision = False
13 | args.alternate_corr = False
14 | model = RAFT(args)
15 | ckpt = torch.load(args.raft_model, map_location="cpu")
16 |
17 | def convert(param):
18 | return {k.replace("module.", ""): v for k, v in param.items() if "module" in k}
19 |
20 | ckpt = convert(ckpt)
21 | model.load_state_dict(ckpt, strict=True)
22 | print("load raft from " + model_path)
23 |
24 | return model
25 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/raft/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/models/generalizable_INR/raft/utils/__init__.py
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/raft/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 |
21 | def make_colorwheel():
22 | """
23 | Generates a color wheel for optical flow visualization as presented in:
24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
26 |
27 | Code follows the original C++ source code of Daniel Scharstein.
28 | Code follows the the Matlab source code of Deqing Sun.
29 |
30 | Returns:
31 | np.ndarray: Color wheel
32 | """
33 |
34 | RY = 15
35 | YG = 6
36 | GC = 4
37 | CB = 11
38 | BM = 13
39 | MR = 6
40 |
41 | ncols = RY + YG + GC + CB + BM + MR
42 | colorwheel = np.zeros((ncols, 3))
43 | col = 0
44 |
45 | # RY
46 | colorwheel[0:RY, 0] = 255
47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
48 | col = col + RY
49 | # YG
50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
51 | colorwheel[col : col + YG, 1] = 255
52 | col = col + YG
53 | # GC
54 | colorwheel[col : col + GC, 1] = 255
55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
56 | col = col + GC
57 | # CB
58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
59 | colorwheel[col : col + CB, 2] = 255
60 | col = col + CB
61 | # BM
62 | colorwheel[col : col + BM, 2] = 255
63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
64 | col = col + BM
65 | # MR
66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
67 | colorwheel[col : col + MR, 0] = 255
68 | return colorwheel
69 |
70 |
71 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
72 | """
73 | Applies the flow color wheel to (possibly clipped) flow components u and v.
74 |
75 | According to the C++ source code of Daniel Scharstein
76 | According to the Matlab source code of Deqing Sun
77 |
78 | Args:
79 | u (np.ndarray): Input horizontal flow of shape [H,W]
80 | v (np.ndarray): Input vertical flow of shape [H,W]
81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
82 |
83 | Returns:
84 | np.ndarray: Flow visualization image of shape [H,W,3]
85 | """
86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
87 | colorwheel = make_colorwheel() # shape [55x3]
88 | ncols = colorwheel.shape[0]
89 | rad = np.sqrt(np.square(u) + np.square(v))
90 | a = np.arctan2(-v, -u) / np.pi
91 | fk = (a + 1) / 2 * (ncols - 1)
92 | k0 = np.floor(fk).astype(np.int32)
93 | k1 = k0 + 1
94 | k1[k1 == ncols] = 0
95 | f = fk - k0
96 | for i in range(colorwheel.shape[1]):
97 | tmp = colorwheel[:, i]
98 | col0 = tmp[k0] / 255.0
99 | col1 = tmp[k1] / 255.0
100 | col = (1 - f) * col0 + f * col1
101 | idx = rad <= 1
102 | col[idx] = 1 - rad[idx] * (1 - col[idx])
103 | col[~idx] = col[~idx] * 0.75 # out of range
104 | # Note the 2-i => BGR instead of RGB
105 | ch_idx = 2 - i if convert_to_bgr else i
106 | flow_image[:, :, ch_idx] = np.floor(255 * col)
107 | return flow_image
108 |
109 |
110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
111 | """
112 | Expects a two dimensional flow image of shape.
113 |
114 | Args:
115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
118 |
119 | Returns:
120 | np.ndarray: Flow visualization image of shape [H,W,3]
121 | """
122 | assert flow_uv.ndim == 3, "input flow must have three dimensions"
123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
124 | if clip_flow is not None:
125 | flow_uv = np.clip(flow_uv, 0, clip_flow)
126 | u = flow_uv[:, :, 0]
127 | v = flow_uv[:, :, 1]
128 | rad = np.sqrt(np.square(u) + np.square(v))
129 | rad_max = np.max(rad)
130 | epsilon = 1e-5
131 | u = u / (rad_max + epsilon)
132 | v = v / (rad_max + epsilon)
133 | return flow_uv_to_colors(u, v, convert_to_bgr)
134 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/raft/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 |
8 | cv2.setNumThreads(0)
9 | cv2.ocl.setUseOpenCL(False)
10 |
11 | TAG_CHAR = np.array([202021.25], np.float32)
12 |
13 |
14 | def readFlow(fn):
15 | """Read .flo file in Middlebury format"""
16 | # Code adapted from:
17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
18 |
19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
20 | # print 'fn = %s'%(fn)
21 | with open(fn, "rb") as f:
22 | magic = np.fromfile(f, np.float32, count=1)
23 | if 202021.25 != magic:
24 | print("Magic number incorrect. Invalid .flo file")
25 | return None
26 | else:
27 | w = np.fromfile(f, np.int32, count=1)
28 | h = np.fromfile(f, np.int32, count=1)
29 | # print 'Reading %d x %d flo file\n' % (w, h)
30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
31 | # Reshape data into 3D array (columns, rows, bands)
32 | # The reshape here is for visualization, the original code is (w,h,2)
33 | return np.resize(data, (int(h), int(w), 2))
34 |
35 |
36 | def readPFM(file):
37 | file = open(file, "rb")
38 |
39 | color = None
40 | width = None
41 | height = None
42 | scale = None
43 | endian = None
44 |
45 | header = file.readline().rstrip()
46 | if header == b"PF":
47 | color = True
48 | elif header == b"Pf":
49 | color = False
50 | else:
51 | raise Exception("Not a PFM file.")
52 |
53 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
54 | if dim_match:
55 | width, height = map(int, dim_match.groups())
56 | else:
57 | raise Exception("Malformed PFM header.")
58 |
59 | scale = float(file.readline().rstrip())
60 | if scale < 0: # little-endian
61 | endian = "<"
62 | scale = -scale
63 | else:
64 | endian = ">" # big-endian
65 |
66 | data = np.fromfile(file, endian + "f")
67 | shape = (height, width, 3) if color else (height, width)
68 |
69 | data = np.reshape(data, shape)
70 | data = np.flipud(data)
71 | return data
72 |
73 |
74 | def writeFlow(filename, uv, v=None):
75 | """Write optical flow to file.
76 |
77 | If v is None, uv is assumed to contain both u and v channels,
78 | stacked in depth.
79 | Original code by Deqing Sun, adapted from Daniel Scharstein.
80 | """
81 | nBands = 2
82 |
83 | if v is None:
84 | assert uv.ndim == 3
85 | assert uv.shape[2] == 2
86 | u = uv[:, :, 0]
87 | v = uv[:, :, 1]
88 | else:
89 | u = uv
90 |
91 | assert u.shape == v.shape
92 | height, width = u.shape
93 | f = open(filename, "wb")
94 | # write the header
95 | f.write(TAG_CHAR)
96 | np.array(width).astype(np.int32).tofile(f)
97 | np.array(height).astype(np.int32).tofile(f)
98 | # arrange into matrix form
99 | tmp = np.zeros((height, width * nBands))
100 | tmp[:, np.arange(width) * 2] = u
101 | tmp[:, np.arange(width) * 2 + 1] = v
102 | tmp.astype(np.float32).tofile(f)
103 | f.close()
104 |
105 |
106 | def readFlowKITTI(filename):
107 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
108 | flow = flow[:, :, ::-1].astype(np.float32)
109 | flow, valid = flow[:, :, :2], flow[:, :, 2]
110 | flow = (flow - 2**15) / 64.0
111 | return flow, valid
112 |
113 |
114 | def readDispKITTI(filename):
115 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
116 | valid = disp > 0.0
117 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
118 | return flow, valid
119 |
120 |
121 | def writeFlowKITTI(filename, uv):
122 | uv = 64.0 * uv + 2**15
123 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
124 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
125 | cv2.imwrite(filename, uv[..., ::-1])
126 |
127 |
128 | def read_gen(file_name, pil=False):
129 | ext = splitext(file_name)[-1]
130 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
131 | return Image.open(file_name)
132 | elif ext == ".bin" or ext == ".raw":
133 | return np.load(file_name)
134 | elif ext == ".flo":
135 | return readFlow(file_name).astype(np.float32)
136 | elif ext == ".pfm":
137 | flow = readPFM(file_name).astype(np.float32)
138 | if len(flow.shape) == 2:
139 | return flow
140 | else:
141 | return flow[:, :, :-1]
142 | return []
143 |
--------------------------------------------------------------------------------
/models/gimm/src/models/generalizable_INR/raft/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """Pads images such that dimensions are divisible by 8"""
9 |
10 | def __init__(self, dims, mode="sintel"):
11 | self.ht, self.wd = dims[-2:]
12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
14 | if mode == "sintel":
15 | self._pad = [
16 | pad_wd // 2,
17 | pad_wd - pad_wd // 2,
18 | pad_ht // 2,
19 | pad_ht - pad_ht // 2,
20 | ]
21 | else:
22 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
23 |
24 | def pad(self, *inputs):
25 | return [F.pad(x, self._pad, mode="replicate") for x in inputs]
26 |
27 | def unpad(self, x):
28 | ht, wd = x.shape[-2:]
29 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
30 | return x[..., c[0] : c[1], c[2] : c[3]]
31 |
32 |
33 | def forward_interpolate(flow):
34 | flow = flow.detach().cpu().numpy()
35 | dx, dy = flow[0], flow[1]
36 |
37 | ht, wd = dx.shape
38 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
39 |
40 | x1 = x0 + dx
41 | y1 = y0 + dy
42 |
43 | x1 = x1.reshape(-1)
44 | y1 = y1.reshape(-1)
45 | dx = dx.reshape(-1)
46 | dy = dy.reshape(-1)
47 |
48 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
49 | x1 = x1[valid]
50 | y1 = y1[valid]
51 | dx = dx[valid]
52 | dy = dy[valid]
53 |
54 | flow_x = interpolate.griddata(
55 | (x1, y1), dx, (x0, y0), method="nearest", fill_value=0
56 | )
57 |
58 | flow_y = interpolate.griddata(
59 | (x1, y1), dy, (x0, y0), method="nearest", fill_value=0
60 | )
61 |
62 | flow = np.stack([flow_x, flow_y], axis=0)
63 | return torch.from_numpy(flow).float()
64 |
65 |
66 | def bilinear_sampler(img, coords, mode="bilinear", mask=False):
67 | """Wrapper for grid_sample, uses pixel coordinates"""
68 | H, W = img.shape[-2:]
69 | xgrid, ygrid = coords.split([1, 1], dim=-1)
70 | xgrid = 2 * xgrid / (W - 1) - 1
71 | ygrid = 2 * ygrid / (H - 1) - 1
72 |
73 | grid = torch.cat([xgrid, ygrid], dim=-1)
74 | img = F.grid_sample(img, grid, align_corners=True)
75 |
76 | if mask:
77 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
78 | return img, mask.float()
79 |
80 | return img
81 |
82 |
83 | def coords_grid(batch, ht, wd, device):
84 | coords = torch.meshgrid(
85 | torch.arange(ht, device=device), torch.arange(wd, device=device)
86 | )
87 | coords = torch.stack(coords[::-1], dim=0).float()
88 | return coords[None].repeat(batch, 1, 1, 1)
89 |
90 |
91 | def upflow8(flow, mode="bilinear"):
92 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
93 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
94 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/utils/__init__.py
--------------------------------------------------------------------------------
/models/gimm/src/utils/accumulator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | import torch
12 | import models.gimm.src.utils.dist as dist_utils
13 |
14 |
15 | class AccmStageINR:
16 | def __init__(
17 | self,
18 | scalar_metric_names,
19 | vector_metric_names=(),
20 | vector_metric_lengths=(),
21 | device="cpu",
22 | ):
23 | self.device = device
24 |
25 | assert len(vector_metric_lengths) == len(vector_metric_names)
26 |
27 | self.scalar_metric_names = scalar_metric_names
28 | self.vector_metric_names = vector_metric_names
29 | self.metrics_sum = {}
30 |
31 | for n in self.scalar_metric_names:
32 | self.metrics_sum[n] = torch.zeros(1, device=self.device)
33 |
34 | for n, length in zip(self.vector_metric_names, vector_metric_lengths):
35 | self.metrics_sum[n] = torch.zeros(length, device=self.device)
36 |
37 | self.counter = 0
38 | self.Summary.scalar_metric_names = self.scalar_metric_names
39 | self.Summary.vector_metric_names = self.vector_metric_names
40 |
41 | @torch.no_grad()
42 | def update(self, metrics_to_add, count=None, sync=False, distenv=None):
43 | # we assume that value is simultaneously None (or not None) for every process
44 | metrics_to_add = {
45 | name: value for (name, value) in metrics_to_add.items() if value is not None
46 | }
47 |
48 | if sync:
49 | for name, value in metrics_to_add.items():
50 | gathered_value = dist_utils.all_gather_cat(distenv, value.unsqueeze(0))
51 | gathered_value = gathered_value.sum(dim=0).detach()
52 | metrics_to_add[name] = gathered_value
53 |
54 | for name, value in metrics_to_add.items():
55 | if name not in self.metrics_sum:
56 | raise KeyError(f"unexpected metric name: {name}")
57 | self.metrics_sum[name] += value
58 |
59 | self.counter += count if not sync else count * distenv.world_size
60 |
61 | @torch.no_grad()
62 | def get_summary(self, n_samples=None):
63 | n_samples = n_samples if n_samples else self.counter
64 | return self.Summary({k: v / n_samples for k, v in self.metrics_sum.items()})
65 |
66 | class Summary:
67 | scalar_metric_names = ()
68 | vector_metric_names = ()
69 |
70 | def __init__(self, metrics):
71 | for key, value in metrics.items():
72 | self[key] = value
73 |
74 | def print_line(self):
75 | reprs = []
76 | for k in self.scalar_metric_names:
77 | v = self[k]
78 | repr = f"{k}: {v.item():.4f}"
79 | reprs.append(repr)
80 |
81 | for k in self.vector_metric_names:
82 | v = self[k]
83 | array_repr = ",".join([f"{v_i.item():.4f}" for v_i in v])
84 | repr = f"{k}: [{array_repr}]"
85 | reprs.append(repr)
86 |
87 | return ", ".join(reprs)
88 |
89 | def tb_like(self):
90 | tb_summary = {}
91 | for k in self.scalar_metric_names:
92 | v = self[k]
93 | tb_summary[f"loss/{k}"] = v
94 |
95 | for k in self.vector_metric_names:
96 | v = self[k]
97 | for i, v_i in enumerate(v):
98 | tb_summary[f"loss/{k}_{i}"] = v_i
99 |
100 | return tb_summary
101 |
102 | def __getitem__(self, item):
103 | return getattr(self, item)
104 |
105 | def __setitem__(self, key, value):
106 | setattr(self, key, value)
107 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | from omegaconf import OmegaConf
12 | from easydict import EasyDict as edict
13 | import yaml
14 |
15 | from models.gimm.src.models.generalizable_INR.configs import GIMMConfig, GIMMVFIConfig
16 | import os.path as osp
17 |
18 |
19 | def easydict_to_dict(obj):
20 | if not isinstance(obj, edict):
21 | return obj
22 | else:
23 | return {k: easydict_to_dict(v) for k, v in obj.items()}
24 |
25 |
26 | def load_config(config_path):
27 | with open(config_path) as f:
28 | config = yaml.load(f, Loader=yaml.FullLoader)
29 | config = easydict_to_dict(config)
30 | config = OmegaConf.create(config)
31 | return config
32 |
33 |
34 | def augment_arch_defaults(arch_config):
35 | if arch_config.type == "gimm":
36 | arch_defaults = GIMMConfig.create(arch_config)
37 | elif arch_config.type == "gimmvfi":
38 | arch_defaults = GIMMVFIConfig.create(arch_config)
39 | elif arch_config.type == "gimmvfi_f" or arch_config.type == "gimmvfi_r":
40 | arch_defaults = GIMMVFIConfig.create(arch_config)
41 | else:
42 | raise ValueError(f"{arch_config.type} is not implemented for default arguments")
43 |
44 | return OmegaConf.merge(arch_defaults, arch_config)
45 |
46 |
47 | def augment_optimizer_defaults(optim_config):
48 | defaults = OmegaConf.create(
49 | {
50 | "type": "adamW",
51 | "max_gn": None,
52 | "warmup": {
53 | "mode": "linear",
54 | "start_from_zero": (True if optim_config.warmup.epoch > 0 else False),
55 | },
56 | }
57 | )
58 | return OmegaConf.merge(defaults, optim_config)
59 |
60 |
61 | def augment_defaults(config):
62 | defaults = OmegaConf.create(
63 | {
64 | "arch": augment_arch_defaults(config.arch),
65 | "dataset": {
66 | "transforms": {"type": None},
67 | },
68 | "optimizer": augment_optimizer_defaults(config.optimizer),
69 | "experiment": {
70 | "test_freq": 10,
71 | "amp": False,
72 | },
73 | }
74 | )
75 |
76 | if "inr" in config.arch.type or "gimm" in config.arch.type:
77 | subsample_defaults = OmegaConf.create({"type": None, "ratio": 1.0})
78 | loss_defaults = OmegaConf.create(
79 | {
80 | "loss": {
81 | "type": "mse",
82 | "subsample": subsample_defaults,
83 | "coord_noise": None,
84 | }
85 | }
86 | )
87 | defaults = OmegaConf.merge(defaults, loss_defaults)
88 | config = OmegaConf.merge(defaults, config)
89 | return config
90 |
91 |
92 | def augment_dist_defaults(config, distenv):
93 | config = config.copy()
94 | local_batch_size = config.experiment.batch_size
95 | world_batch_size = distenv.world_size * local_batch_size
96 | total_batch_size = config.experiment.get("total_batch_size", world_batch_size)
97 |
98 | if total_batch_size % world_batch_size != 0:
99 | raise ValueError("total batch size must be divisible by world batch size")
100 | else:
101 | grad_accm_steps = total_batch_size // world_batch_size
102 |
103 | config.optimizer.grad_accm_steps = grad_accm_steps
104 | config.experiment.total_batch_size = total_batch_size
105 | return config
106 |
107 |
108 | def config_setup(args, distenv, config_path, extra_args=()):
109 | if not osp.isfile(config_path):
110 | config_path = args.model_config
111 | if args.eval:
112 | config = load_config(config_path)
113 | config = augment_defaults(config)
114 | if hasattr(args, "test_batch_size"):
115 | config.experiment.batch_size = args.test_batch_size
116 | if not hasattr(config, "seed"):
117 | config.seed = args.seed
118 |
119 | elif args.resume:
120 | config = load_config(config_path)
121 | if distenv.world_size != config.runtime.distenv.world_size:
122 | raise ValueError("world_size not identical to the resuming config")
123 | config.runtime = {"args": vars(args), "distenv": distenv}
124 |
125 | else: # training
126 | config_path = args.model_config
127 | config = load_config(config_path)
128 |
129 | extra_config = OmegaConf.from_dotlist(extra_args)
130 | config = OmegaConf.merge(config, extra_config)
131 |
132 | config = augment_defaults(config)
133 | config = augment_dist_defaults(config, distenv)
134 |
135 | config.seed = args.seed
136 | config.runtime = {
137 | "args": vars(args),
138 | "extra_config": extra_config,
139 | "distenv": distenv,
140 | }
141 |
142 | return config
143 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/dist.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | from dataclasses import dataclass
12 | import datetime
13 | import os
14 |
15 | import torch
16 | import torch.distributed as dist
17 | from torch.nn.parallel import DistributedDataParallel
18 |
19 |
20 | @dataclass
21 | class DistEnv:
22 | world_size: int
23 | world_rank: int
24 | local_rank: int
25 | num_gpus: int
26 | master: bool
27 | device_name: str
28 |
29 |
30 | def initialize(args, logger=None):
31 | args.rank = int(os.environ.get("RANK", 0))
32 | args.world_size = int(os.environ.get("WORLD_SIZE", 1))
33 | args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
34 |
35 | if args.world_size > 1:
36 | os.environ["RANK"] = str(args.rank)
37 | os.environ["WORLD_SIZE"] = str(args.world_size)
38 | os.environ["LOCAL_RANK"] = str(args.local_rank)
39 |
40 | print(f"[dist] Distributed: wait dist process group:{args.local_rank}")
41 | dist.init_process_group(
42 | backend=args.dist_backend,
43 | init_method="env://",
44 | world_size=args.world_size,
45 | timeout=datetime.timedelta(0, args.timeout),
46 | )
47 | assert args.world_size == dist.get_world_size()
48 | print(
49 | f"""[dist] Distributed: success device:{args.local_rank}, """,
50 | f"""{dist.get_rank()}/{dist.get_world_size()}""",
51 | )
52 | distenv = DistEnv(
53 | world_size=dist.get_world_size(),
54 | world_rank=dist.get_rank(),
55 | local_rank=args.local_rank,
56 | num_gpus=1,
57 | master=(dist.get_rank() == 0),
58 | device_name=torch.cuda.get_device_name(),
59 | )
60 | else:
61 | print("[dist] Single processed")
62 | distenv = DistEnv(
63 | 1, 0, 0, torch.cuda.device_count(), True, torch.cuda.get_device_name()
64 | )
65 |
66 | print(f"[dist] {distenv}")
67 |
68 | if logger is not None:
69 | logger.info(distenv)
70 |
71 | return distenv
72 |
73 |
74 | def dataparallel_and_sync(
75 | distenv, model, find_unused_parameters=False, static_graph=False
76 | ):
77 | if dist.is_initialized():
78 | model = DistributedDataParallel(
79 | model,
80 | device_ids=[distenv.local_rank],
81 | output_device=distenv.local_rank,
82 | find_unused_parameters=find_unused_parameters,
83 | # Available only with PyTorch 1.11 or above.
84 | # When set to ``True``, DDP knows the trained graph is static.
85 | # Especially, this enables activation checkpointing multiple times
86 | # which was not supported in the previous versions.
87 | # See the docstring of DistributedDataParallel for more details.
88 | static_graph=static_graph,
89 | )
90 | for _, param in model.state_dict().items():
91 | dist.broadcast(param, 0)
92 |
93 | dist.barrier()
94 | else:
95 | model = torch.nn.DataParallel(model)
96 | torch.cuda.synchronize()
97 |
98 | return model
99 |
100 |
101 | def param_sync(param):
102 | dist.broadcast(param, 0)
103 | dist.barrier()
104 | torch.cuda.synchronize()
105 |
106 |
107 | @torch.no_grad()
108 | def all_gather_cat(distenv, tensor, dim=0):
109 | if distenv.world_size == 1:
110 | return tensor
111 |
112 | g_tensor = [torch.ones_like(tensor) for _ in range(distenv.world_size)]
113 | dist.all_gather(g_tensor, tensor)
114 | g_tensor = torch.cat(g_tensor, dim=dim)
115 |
116 | return g_tensor
117 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 |
21 | def make_colorwheel():
22 | """
23 | Generates a color wheel for optical flow visualization as presented in:
24 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
25 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
26 |
27 | Code follows the original C++ source code of Daniel Scharstein.
28 | Code follows the the Matlab source code of Deqing Sun.
29 |
30 | Returns:
31 | np.ndarray: Color wheel
32 | """
33 |
34 | RY = 15
35 | YG = 6
36 | GC = 4
37 | CB = 11
38 | BM = 13
39 | MR = 6
40 |
41 | ncols = RY + YG + GC + CB + BM + MR
42 | colorwheel = np.zeros((ncols, 3))
43 | col = 0
44 |
45 | # RY
46 | colorwheel[0:RY, 0] = 255
47 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
48 | col = col + RY
49 | # YG
50 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
51 | colorwheel[col : col + YG, 1] = 255
52 | col = col + YG
53 | # GC
54 | colorwheel[col : col + GC, 1] = 255
55 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
56 | col = col + GC
57 | # CB
58 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
59 | colorwheel[col : col + CB, 2] = 255
60 | col = col + CB
61 | # BM
62 | colorwheel[col : col + BM, 2] = 255
63 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
64 | col = col + BM
65 | # MR
66 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
67 | colorwheel[col : col + MR, 0] = 255
68 | return colorwheel
69 |
70 |
71 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
72 | """
73 | Applies the flow color wheel to (possibly clipped) flow components u and v.
74 |
75 | According to the C++ source code of Daniel Scharstein
76 | According to the Matlab source code of Deqing Sun
77 |
78 | Args:
79 | u (np.ndarray): Input horizontal flow of shape [H,W]
80 | v (np.ndarray): Input vertical flow of shape [H,W]
81 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
82 |
83 | Returns:
84 | np.ndarray: Flow visualization image of shape [H,W,3]
85 | """
86 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
87 | colorwheel = make_colorwheel() # shape [55x3]
88 | ncols = colorwheel.shape[0]
89 | rad = np.sqrt(np.square(u) + np.square(v))
90 | a = np.arctan2(-v, -u) / np.pi
91 | fk = (a + 1) / 2 * (ncols - 1)
92 | k0 = np.floor(fk).astype(np.int32)
93 | k1 = k0 + 1
94 | k1[k1 == ncols] = 0
95 | f = fk - k0
96 | for i in range(colorwheel.shape[1]):
97 | tmp = colorwheel[:, i]
98 | col0 = tmp[k0] / 255.0
99 | col1 = tmp[k1] / 255.0
100 | col = (1 - f) * col0 + f * col1
101 | idx = rad <= 1
102 | col[idx] = 1 - rad[idx] * (1 - col[idx])
103 | col[~idx] = col[~idx] * 0.75 # out of range
104 | # Note the 2-i => BGR instead of RGB
105 | ch_idx = 2 - i if convert_to_bgr else i
106 | flow_image[:, :, ch_idx] = np.floor(255 * col)
107 | return flow_image
108 |
109 |
110 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, max_flow=None):
111 | """
112 | Expects a two dimensional flow image of shape.
113 |
114 | Args:
115 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
116 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
117 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
118 |
119 | Returns:
120 | np.ndarray: Flow visualization image of shape [H,W,3]
121 | """
122 | assert flow_uv.ndim == 3, "input flow must have three dimensions"
123 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
124 | if clip_flow is not None:
125 | flow_uv = np.clip(flow_uv, 0, clip_flow)
126 | u = flow_uv[:, :, 0]
127 | v = flow_uv[:, :, 1]
128 | if max_flow is None:
129 | rad = np.sqrt(np.square(u) + np.square(v))
130 | rad_max = np.max(rad)
131 | else:
132 | rad_max = max_flow
133 | epsilon = 1e-5
134 | u = u / (rad_max + epsilon)
135 | v = v / (rad_max + epsilon)
136 | return flow_uv_to_colors(u, v, convert_to_bgr)
137 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # FlowFormer: https://github.com/drinkingcoder/FlowFormer-Official
9 | # --------------------------------------------------------
10 |
11 | import numpy as np
12 | from PIL import Image
13 | from os.path import *
14 | import re
15 |
16 | import cv2
17 |
18 | cv2.setNumThreads(0)
19 | cv2.ocl.setUseOpenCL(False)
20 |
21 | TAG_CHAR = np.array([202021.25], np.float32)
22 |
23 |
24 | def readFlow(fn):
25 | """Read .flo file in Middlebury format"""
26 | # Code adapted from:
27 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
28 |
29 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
30 | # print 'fn = %s'%(fn)
31 | with open(fn, "rb") as f:
32 | magic = np.fromfile(f, np.float32, count=1)
33 | if 202021.25 != magic:
34 | print("Magic number incorrect. Invalid .flo file")
35 | return None
36 | else:
37 | w = np.fromfile(f, np.int32, count=1)
38 | h = np.fromfile(f, np.int32, count=1)
39 | # print 'Reading %d x %d flo file\n' % (w, h)
40 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
41 | # Reshape data into 3D array (columns, rows, bands)
42 | # The reshape here is for visualization, the original code is (w,h,2)
43 | return np.resize(data, (int(h), int(w), 2))
44 |
45 |
46 | def readPFM(file):
47 | file = open(file, "rb")
48 |
49 | color = None
50 | width = None
51 | height = None
52 | scale = None
53 | endian = None
54 |
55 | header = file.readline().rstrip()
56 | if header == b"PF":
57 | color = True
58 | elif header == b"Pf":
59 | color = False
60 | else:
61 | raise Exception("Not a PFM file.")
62 |
63 | dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
64 | if dim_match:
65 | width, height = map(int, dim_match.groups())
66 | else:
67 | raise Exception("Malformed PFM header.")
68 |
69 | scale = float(file.readline().rstrip())
70 | if scale < 0: # little-endian
71 | endian = "<"
72 | scale = -scale
73 | else:
74 | endian = ">" # big-endian
75 |
76 | data = np.fromfile(file, endian + "f")
77 | shape = (height, width, 3) if color else (height, width)
78 |
79 | data = np.reshape(data, shape)
80 | data = np.flipud(data)
81 | return data
82 |
83 |
84 | def writeFlow(filename, uv, v=None):
85 | """Write optical flow to file.
86 |
87 | If v is None, uv is assumed to contain both u and v channels,
88 | stacked in depth.
89 | Original code by Deqing Sun, adapted from Daniel Scharstein.
90 | """
91 | nBands = 2
92 |
93 | if v is None:
94 | assert uv.ndim == 3
95 | assert uv.shape[2] == 2
96 | u = uv[:, :, 0]
97 | v = uv[:, :, 1]
98 | else:
99 | u = uv
100 |
101 | assert u.shape == v.shape
102 | height, width = u.shape
103 | f = open(filename, "wb")
104 | # write the header
105 | f.write(TAG_CHAR)
106 | np.array(width).astype(np.int32).tofile(f)
107 | np.array(height).astype(np.int32).tofile(f)
108 | # arrange into matrix form
109 | tmp = np.zeros((height, width * nBands))
110 | tmp[:, np.arange(width) * 2] = u
111 | tmp[:, np.arange(width) * 2 + 1] = v
112 | tmp.astype(np.float32).tofile(f)
113 | f.close()
114 |
115 |
116 | def readFlowKITTI(filename):
117 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
118 | flow = flow[:, :, ::-1].astype(np.float32)
119 | flow, valid = flow[:, :, :2], flow[:, :, 2]
120 | flow = (flow - 2**15) / 64.0
121 | return flow, valid
122 |
123 |
124 | def readDispKITTI(filename):
125 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
126 | valid = disp > 0.0
127 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
128 | return flow, valid
129 |
130 |
131 | def writeFlowKITTI(filename, uv):
132 | uv = 64.0 * uv + 2**15
133 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
134 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
135 | cv2.imwrite(filename, uv[..., ::-1])
136 |
137 |
138 | def read_gen(file_name, pil=False):
139 | ext = splitext(file_name)[-1]
140 | if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
141 | return Image.open(file_name)
142 | elif ext == ".bin" or ext == ".raw":
143 | return np.load(file_name)
144 | elif ext == ".flo":
145 | return readFlow(file_name).astype(np.float32)
146 | elif ext == ".pfm":
147 | flow = readPFM(file_name).astype(np.float32)
148 | if len(flow.shape) == 2:
149 | return flow
150 | else:
151 | return flow[:, :, :-1]
152 | return []
153 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # EMA-VFI: https://github.com/MCG-NJU/EMA-VFI
9 | # IFRNet: https://github.com/ltkong218/IFRNet
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import numpy as np
16 |
17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 |
19 |
20 | ## Laploss
21 | def gauss_kernel(channels=3):
22 | kernel = torch.tensor(
23 | [
24 | [1.0, 4.0, 6.0, 4.0, 1],
25 | [4.0, 16.0, 24.0, 16.0, 4.0],
26 | [6.0, 24.0, 36.0, 24.0, 6.0],
27 | [4.0, 16.0, 24.0, 16.0, 4.0],
28 | [1.0, 4.0, 6.0, 4.0, 1.0],
29 | ]
30 | )
31 | kernel /= 256.0
32 | kernel = kernel.repeat(channels, 1, 1, 1)
33 | kernel = kernel.to(device)
34 | return kernel
35 |
36 |
37 | def downsample(x):
38 | return x[:, :, ::2, ::2]
39 |
40 |
41 | def upsample(x):
42 | cc = torch.cat(
43 | [x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)],
44 | dim=3,
45 | )
46 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
47 | cc = cc.permute(0, 1, 3, 2)
48 | cc = torch.cat(
49 | [
50 | cc,
51 | torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device),
52 | ],
53 | dim=3,
54 | )
55 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
56 | x_up = cc.permute(0, 1, 3, 2)
57 | return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
58 |
59 |
60 | def conv_gauss(img, kernel):
61 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode="reflect")
62 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
63 | return out
64 |
65 |
66 | def laplacian_pyramid(img, kernel, max_levels=3):
67 | current = img
68 | pyr = []
69 | for level in range(max_levels):
70 | filtered = conv_gauss(current, kernel)
71 | down = downsample(filtered)
72 | up = upsample(down)
73 | diff = current - up
74 | pyr.append(diff)
75 | current = down
76 | return pyr
77 |
78 |
79 | class LapLoss(torch.nn.Module):
80 | def __init__(self, max_levels=5, channels=3):
81 | super(LapLoss, self).__init__()
82 | self.max_levels = max_levels
83 | self.gauss_kernel = gauss_kernel(channels=channels)
84 |
85 | def forward(self, input, target):
86 | pyr_input = laplacian_pyramid(
87 | img=input, kernel=self.gauss_kernel, max_levels=self.max_levels
88 | )
89 | pyr_target = laplacian_pyramid(
90 | img=target, kernel=self.gauss_kernel, max_levels=self.max_levels
91 | )
92 | return sum(
93 | torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)
94 | )
95 |
96 |
97 | class Ternary(nn.Module):
98 | def __init__(self, patch_size=7):
99 | super(Ternary, self).__init__()
100 | self.patch_size = patch_size
101 | out_channels = patch_size * patch_size
102 | self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
103 | self.w = np.transpose(self.w, (3, 2, 0, 1))
104 | self.w = torch.tensor(self.w).float().to(device)
105 |
106 | def transform(self, tensor):
107 | tensor_ = tensor.mean(dim=1, keepdim=True)
108 | patches = F.conv2d(tensor_, self.w, padding=self.patch_size // 2, bias=None)
109 | loc_diff = patches - tensor_
110 | loc_diff_norm = loc_diff / torch.sqrt(0.81 + loc_diff**2)
111 | return loc_diff_norm
112 |
113 | def valid_mask(self, tensor):
114 | padding = self.patch_size // 2
115 | b, c, h, w = tensor.size()
116 | inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor)
117 | mask = F.pad(inner, [padding] * 4)
118 | return mask
119 |
120 | def forward(self, x, y): # pred,gt
121 | loc_diff_x = self.transform(x)
122 | loc_diff_y = self.transform(y)
123 | diff = loc_diff_x - loc_diff_y.detach()
124 | dist = (diff**2 / (0.1 + diff**2)).mean(dim=1, keepdim=True)
125 | mask = self.valid_mask(x)
126 | loss = (dist * mask).mean()
127 | return loss
128 |
129 |
130 | class Charbonnier_L1(nn.Module):
131 | def __init__(self):
132 | super(Charbonnier_L1, self).__init__()
133 |
134 | def forward(self, pred, gt, mask=None):
135 | diff = pred - gt
136 | if mask is None:
137 | loss = ((diff**2 + 1e-6) ** 0.5).mean()
138 | else:
139 | loss = (((diff**2 + 1e-6) ** 0.5) * mask).mean() / (mask.mean() + 1e-9)
140 | return loss
141 |
142 |
143 | class Charbonnier_Ada(nn.Module):
144 | def __init__(self):
145 | super(Charbonnier_Ada, self).__init__()
146 |
147 | def forward(self, diff, weight):
148 | alpha = weight / 2
149 | epsilon = 10 ** (-(10 * weight - 1) / 3)
150 | loss = ((diff**2 + epsilon**2) ** alpha).mean()
151 | return loss
152 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/lpips/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gimm/src/utils/lpips/alex.pth
--------------------------------------------------------------------------------
/models/gimm/src/utils/profiler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | class Profiler:
12 | opts_model_size = {"trainable-only", "transformer-block-only"}
13 |
14 | def __init__(self, logger):
15 | self._logger = logger
16 |
17 | def get_model_size(self, model, opt=None):
18 | if opt is None:
19 | self._logger.info(
20 | "[OPTION: ALL] #parameters: %.4fM",
21 | sum(p.numel() for p in model.parameters()) / 1e6,
22 | )
23 | else:
24 | assert (
25 | opt in self.opts_model_size
26 | ), f"{opt} is not in {self.opts_model_size}"
27 |
28 | if opt == "trainable-only":
29 | self._logger.info(
30 | "[OPTION: %s] #parameters: %.4fM",
31 | opt,
32 | sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6,
33 | )
34 | else:
35 | if hasattr(model, "blocks"):
36 | self._logger.info(
37 | "[OPTION: %s] #parameters: %.4fM",
38 | opt,
39 | sum(p.numel() for p in model.blocks.parameters()) / 1e6,
40 | )
41 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 | from datetime import datetime
12 | import logging
13 | import inspect
14 | import os
15 | import shutil
16 | from pathlib import Path
17 |
18 | from omegaconf import OmegaConf
19 |
20 | from .writer import Writer
21 | from .config import config_setup
22 | from .dist import initialize as dist_init
23 |
24 |
25 | def logger_setup(log_path, eval=False):
26 | log_fname = os.path.join(log_path, "val.log" if eval else "train.log")
27 |
28 | for hdlr in logging.root.handlers:
29 | logging.root.removeHandler(hdlr)
30 |
31 | SMOKE_TEST = bool(os.environ.get("SMOKE_TEST", 0))
32 |
33 | logging.basicConfig(
34 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
35 | datefmt="%m/%d/%Y %H:%M:%S",
36 | level=(logging.DEBUG if SMOKE_TEST else logging.INFO),
37 | handlers=[logging.FileHandler(log_fname), logging.StreamHandler()],
38 | )
39 | main_filename, *_ = inspect.getframeinfo(inspect.currentframe().f_back.f_back)
40 |
41 | logger = logging.getLogger(Path(main_filename).name)
42 | writer = Writer(log_path)
43 |
44 | return logger, writer
45 |
46 |
47 | def setup(args, extra_args=()):
48 | """
49 | meaning of args.result_path:
50 | - if args.eval, directory where the model is
51 | - if args.resume, no meaning
52 | - otherwise, path to store the logs
53 |
54 | Returns:
55 | config, logger, writer
56 | """
57 |
58 | distenv = dist_init(args)
59 |
60 | args.result_path = Path(args.result_path).absolute().as_posix()
61 | args.model_config = Path(args.model_config).absolute().resolve().as_posix()
62 |
63 | now = datetime.now().strftime("%d%m%Y_%H%M%S")
64 |
65 | if args.eval:
66 | config_path = Path(args.result_path).joinpath("config.yaml")
67 | log_path = Path(args.result_path).joinpath("val", now)
68 |
69 | elif args.resume:
70 | load_path = Path(args.load_path)
71 | if not load_path.is_file():
72 | raise ValueError("load_path must be a valid filename")
73 |
74 | config_path = load_path.parent.joinpath("config.yaml").absolute()
75 | log_path = load_path.parent.parent.joinpath(now)
76 |
77 | else:
78 | config_path = Path(args.model_config).absolute()
79 | task_name = config_path.stem
80 | if args.postfix:
81 | task_name += f"__{args.postfix}"
82 | log_path = Path(args.result_path).joinpath(task_name, now)
83 |
84 | config = config_setup(args, distenv, config_path, extra_args=extra_args)
85 | config.result_path = log_path.absolute().resolve().as_posix()
86 |
87 | if distenv.master:
88 | if not log_path.exists():
89 | os.makedirs(log_path)
90 | logger, writer = logger_setup(log_path)
91 | logger.info(distenv)
92 | logger.info(f"log_path: {log_path}")
93 | logger.info("\n" + OmegaConf.to_yaml(config))
94 | OmegaConf.save(config, log_path.joinpath("config.yaml"))
95 |
96 | src_dir = Path(os.getcwd()).joinpath("src")
97 | shutil.copytree(src_dir, log_path.joinpath("src"))
98 | logger.info(f"source copied to {log_path}/src")
99 | else:
100 | logger, writer, log_path = None, None, None
101 |
102 | return config, logger, writer
103 |
104 |
105 | def single_setup(args, extra_args=(), train=True):
106 | assert args.eval
107 | args.model_config = Path(args.model_config).absolute().resolve().as_posix()
108 | config_path = args.model_config
109 | config = config_setup(args, None, config_path, extra_args=extra_args)
110 | return config
111 |
--------------------------------------------------------------------------------
/models/gimm/src/utils/writer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ginr-ipc: https://github.com/kakaobrain/ginr-ipc
9 | # --------------------------------------------------------
10 |
11 |
12 | import os
13 | from torch.utils.tensorboard import SummaryWriter
14 |
15 |
16 | class Writer:
17 | def __init__(self, result_path):
18 | self.result_path = result_path
19 |
20 | self.writer_trn = SummaryWriter(os.path.join(result_path, "train"))
21 | self.writer_val = SummaryWriter(os.path.join(result_path, "valid"))
22 | self.writer_val_ema = SummaryWriter(os.path.join(result_path, "valid_ema"))
23 |
24 | def _get_writer(self, mode):
25 | if mode == "train":
26 | writer = self.writer_trn
27 | elif mode == "valid":
28 | writer = self.writer_val
29 | elif mode == "valid_ema":
30 | writer = self.writer_val_ema
31 | else:
32 | raise ValueError(f"{mode} is not valid..")
33 |
34 | return writer
35 |
36 | def add_scalar(self, tag, scalar, mode, epoch=0):
37 | writer = self._get_writer(mode)
38 | writer.add_scalar(tag, scalar, epoch)
39 |
40 | def add_image(self, tag, image, mode, epoch=0):
41 | writer = self._get_writer(mode)
42 | writer.add_image(tag, image, epoch)
43 |
44 | def add_text(self, tag, text, mode, epoch=0):
45 | writer = self._get_writer(mode)
46 | writer.add_text(tag, text, epoch)
47 |
48 | def add_audio(self, tag, audio, mode, sampling_rate=16000, epoch=0):
49 | writer = self._get_writer(mode)
50 | writer.add_audio(tag, audio, epoch, sampling_rate)
51 |
52 | def close(self):
53 | self.writer_trn.close()
54 | self.writer_val.close()
55 | self.writer_val_ema.close()
56 |
--------------------------------------------------------------------------------
/models/gmflow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/routineLife1/MultiPassDedup/fc724a0a99d4818366677b102049289126b61744/models/gmflow/__init__.py
--------------------------------------------------------------------------------
/models/gmflow/backbone.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from models.gmflow.trident_conv import MultiScaleTridentConv
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
8 | ):
9 | super(ResidualBlock, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
12 | dilation=dilation, padding=dilation, stride=stride, bias=False)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
14 | dilation=dilation, padding=dilation, bias=False)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | self.norm1 = norm_layer(planes)
18 | self.norm2 = norm_layer(planes)
19 | if not stride == 1 or in_planes != planes:
20 | self.norm3 = norm_layer(planes)
21 |
22 | if stride == 1 and in_planes == planes:
23 | self.downsample = None
24 | else:
25 | self.downsample = nn.Sequential(
26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
27 |
28 | def forward(self, x):
29 | y = x
30 | y = self.relu(self.norm1(self.conv1(y)))
31 | y = self.relu(self.norm2(self.conv2(y)))
32 |
33 | if self.downsample is not None:
34 | x = self.downsample(x)
35 |
36 | return self.relu(x + y)
37 |
38 |
39 | class CNNEncoder(nn.Module):
40 | def __init__(self, output_dim=128,
41 | norm_layer=nn.InstanceNorm2d,
42 | num_output_scales=1,
43 | **kwargs,
44 | ):
45 | super(CNNEncoder, self).__init__()
46 | self.num_branch = num_output_scales
47 |
48 | feature_dims = [64, 96, 128]
49 |
50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
51 | self.norm1 = norm_layer(feature_dims[0])
52 | self.relu1 = nn.ReLU(inplace=True)
53 |
54 | self.in_planes = feature_dims[0]
55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
57 |
58 | # highest resolution 1/4 or 1/8
59 | stride = 2 if num_output_scales == 1 else 1
60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride,
61 | norm_layer=norm_layer,
62 | ) # 1/4 or 1/8
63 |
64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
65 |
66 | if self.num_branch > 1:
67 | if self.num_branch == 4:
68 | strides = (1, 2, 4, 8)
69 | elif self.num_branch == 3:
70 | strides = (1, 2, 4)
71 | elif self.num_branch == 2:
72 | strides = (1, 2)
73 | else:
74 | raise ValueError
75 |
76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
77 | kernel_size=3,
78 | strides=strides,
79 | paddings=1,
80 | num_branch=self.num_branch,
81 | )
82 |
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv2d):
85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
87 | if m.weight is not None:
88 | nn.init.constant_(m.weight, 1)
89 | if m.bias is not None:
90 | nn.init.constant_(m.bias, 0)
91 |
92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
95 |
96 | layers = (layer1, layer2)
97 |
98 | self.in_planes = dim
99 | return nn.Sequential(*layers)
100 |
101 | def forward(self, x):
102 | x = self.conv1(x)
103 | x = self.norm1(x)
104 | x = self.relu1(x)
105 |
106 | x = self.layer1(x) # 1/2
107 | x = self.layer2(x) # 1/4
108 | x = self.layer3(x) # 1/8 or 1/4
109 |
110 | x = self.conv2(x)
111 |
112 | if self.num_branch > 1:
113 | out = self.trident_conv([x] * self.num_branch) # high to low res
114 | else:
115 | out = [x]
116 |
117 | return out
118 |
--------------------------------------------------------------------------------
/models/gmflow/geometry.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | coords_grid_cache = {}
4 |
5 | def coords_grid(b, h, w, homogeneous=False, device=None, dtype: torch.dtype=torch.float32):
6 | k = (str(device), str((b, h, w)))
7 | if k in coords_grid_cache:
8 | return coords_grid_cache[k]
9 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
10 |
11 | stacks = [x, y]
12 |
13 | if homogeneous:
14 | ones = torch.ones_like(x) # [H, W]
15 | stacks.append(ones)
16 |
17 | grid = torch.stack(stacks, dim=0) # [2, H, W] or [3, H, W]
18 |
19 | grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
20 |
21 | if device is not None:
22 | grid = grid.to(device, dtype=dtype)
23 | coords_grid_cache[k] = grid
24 | return grid
25 |
26 | window_grid_cache = {}
27 |
28 | def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None, dtype=torch.float32):
29 | assert device is not None
30 | k = (str(device), str((h_min, h_max, w_min, w_max, len_h, len_w)))
31 | if k in window_grid_cache:
32 | return window_grid_cache[k]
33 | x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
34 | torch.linspace(h_min, h_max, len_h, device=device)],
35 | )
36 | grid = torch.stack((x, y), -1).transpose(0, 1).to(device, dtype=dtype) # [H, W, 2]
37 | window_grid_cache[k] = grid
38 | return grid
39 |
40 | normalize_coords_cache = {}
41 |
42 | def normalize_coords(coords, h, w):
43 | # coords: [B, H, W, 2]
44 | k = (str(coords.device), str((h, w)))
45 | if k in normalize_coords_cache:
46 | c = normalize_coords_cache[k]
47 | else:
48 | c = torch.tensor([(w - 1) / 2., (h - 1) / 2.], dtype=coords.dtype, device=coords.device)
49 | normalize_coords_cache[k] = c
50 | return (coords - c) / c # [-1, 1]
51 |
52 |
53 | def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
54 | # img: [B, C, H, W]
55 | # sample_coords: [B, 2, H, W] in image scale
56 | if sample_coords.size(1) != 2: # [B, H, W, 2]
57 | sample_coords = sample_coords.permute(0, 3, 1, 2)
58 |
59 | b, _, h, w = sample_coords.shape
60 |
61 | # Normalize to [-1, 1]
62 | x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
63 | y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
64 |
65 | grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
66 |
67 | img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
68 |
69 | if return_mask:
70 | mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
71 |
72 | return img, mask
73 |
74 | return img
75 |
76 |
77 | def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
78 | b, c, h, w = feature.size()
79 | assert flow.size(1) == 2
80 |
81 | grid = coords_grid(b, h, w, device=flow.device, dtype=flow.dtype) + flow # [B, 2, H, W]
82 |
83 | return bilinear_sample(feature, grid, padding_mode=padding_mode,
84 | return_mask=mask)
85 |
86 |
87 | def forward_backward_consistency_check(fwd_flow, bwd_flow,
88 | alpha=0.01,
89 | beta=0.5
90 | ):
91 | # fwd_flow, bwd_flow: [B, 2, H, W]
92 | # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
93 | assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
94 | assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
95 | flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
96 |
97 | warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
98 | warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
99 |
100 | diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
101 | diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
102 |
103 | threshold = alpha * flow_mag + beta
104 |
105 | fwd_occ = (diff_fwd > threshold).to(fwd_flow) # [B, H, W]
106 | bwd_occ = (diff_bwd > threshold).to(bwd_flow)
107 |
108 | return fwd_occ, bwd_occ
109 |
--------------------------------------------------------------------------------
/models/gmflow/matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from models.gmflow.geometry import coords_grid, generate_window_grid, normalize_coords
5 |
6 |
7 | def global_correlation_softmax(feature0, feature1,
8 | pred_bidir_flow=False,
9 | ):
10 | # global correlation
11 | b, c, h, w = feature0.shape
12 | feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
13 | feature1 = feature1.view(b, c, -1) # [B, C, H*W]
14 |
15 | correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
16 |
17 | # flow from softmax
18 | init_grid = coords_grid(b, h, w, device=correlation.device, dtype=feature0.dtype) # [B, 2, H, W]
19 | grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
20 |
21 | correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
22 |
23 | if pred_bidir_flow:
24 | correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
25 | init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
26 | grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
27 | b = b * 2
28 |
29 | prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
30 |
31 | correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
32 |
33 | # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
34 | flow = correspondence - init_grid
35 |
36 | return flow, prob
37 |
38 |
39 | def local_correlation_softmax(feature0, feature1, local_radius,
40 | padding_mode='zeros',
41 | ):
42 | b, c, h, w = feature0.size()
43 | coords_init = coords_grid(b, h, w, device=feature0.device, dtype=feature0.dtype) # [B, 2, H, W]
44 | coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
45 |
46 | local_h = 2 * local_radius + 1
47 | local_w = 2 * local_radius + 1
48 |
49 | window_grid = generate_window_grid(-local_radius, local_radius,
50 | -local_radius, local_radius,
51 | local_h, local_w, device=feature0.device, dtype=feature0.dtype) # [2R+1, 2R+1, 2]
52 | window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
53 | sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
54 |
55 | sample_coords_softmax = sample_coords
56 |
57 | # exclude coords that are out of image space
58 | valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
59 | valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
60 |
61 | valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
62 |
63 | # normalize coordinates to [-1, 1]
64 | sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
65 | window_feature = F.grid_sample(feature1.contiguous(), sample_coords_norm.contiguous(),
66 | padding_mode=padding_mode, align_corners=True
67 | ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
68 | feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
69 |
70 | corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
71 |
72 | # mask invalid locations
73 | corr[~valid] = -1e4
74 |
75 | prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
76 |
77 | correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
78 | b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
79 |
80 | flow = correspondence - coords_init
81 | match_prob = prob
82 |
83 | return flow, match_prob
84 |
--------------------------------------------------------------------------------
/models/gmflow/position.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from models.utils.tools import get_ones_tensor_size
10 |
11 | tensor_cache = dict()
12 |
13 | class PositionEmbeddingSine(nn.Module):
14 | """
15 | This is a more standard version of the position embedding, very similar to the one
16 | used by the Attention is all you need paper, generalized to work on images.
17 | """
18 |
19 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
20 | super().__init__()
21 | self.num_pos_feats = num_pos_feats
22 | self.temperature = temperature
23 | self.normalize = normalize
24 | if scale is not None and normalize is False:
25 | raise ValueError("normalize should be True if scale is passed")
26 | if scale is None:
27 | scale = 2 * math.pi
28 | self.scale = scale
29 |
30 | def forward(self, x):
31 | # x = tensor_list.tensors # [B, C, H, W]
32 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
33 | b, c, h, w = x.size()
34 | mask = get_ones_tensor_size((b, h, w), device=x.device, dtype=x.dtype) # [B, H, W]
35 | y_embed = mask.cumsum(1)
36 | x_embed = mask.cumsum(2)
37 | if self.normalize:
38 | eps = 1e-6
39 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
40 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
41 |
42 | if 'dim_t' not in tensor_cache:
43 | dim_t = torch.arange(self.num_pos_feats, device=x.device, dtype=x.dtype)
44 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
45 | tensor_cache['dim_t'] = dim_t
46 | else:
47 | dim_t = tensor_cache['dim_t']
48 |
49 | pos_x = x_embed[:, :, :, None] / dim_t
50 | pos_y = y_embed[:, :, :, None] / dim_t
51 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
52 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
53 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
54 | return pos
55 |
--------------------------------------------------------------------------------
/models/gmflow/trident_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.nn.modules.utils import _pair
8 |
9 |
10 | class MultiScaleTridentConv(nn.Module):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | stride=1,
17 | strides=1,
18 | paddings=0,
19 | dilations=1,
20 | dilation=1,
21 | groups=1,
22 | num_branch=1,
23 | test_branch_idx=-1,
24 | bias=False,
25 | norm=None,
26 | activation=None,
27 | ):
28 | super(MultiScaleTridentConv, self).__init__()
29 | self.in_channels = in_channels
30 | self.out_channels = out_channels
31 | self.kernel_size = _pair(kernel_size)
32 | self.num_branch = num_branch
33 | self.stride = _pair(stride)
34 | self.groups = groups
35 | self.with_bias = bias
36 | self.dilation = dilation
37 | if isinstance(paddings, int):
38 | paddings = [paddings] * self.num_branch
39 | if isinstance(dilations, int):
40 | dilations = [dilations] * self.num_branch
41 | if isinstance(strides, int):
42 | strides = [strides] * self.num_branch
43 | self.paddings = [_pair(padding) for padding in paddings]
44 | self.dilations = [_pair(dilation) for dilation in dilations]
45 | self.strides = [_pair(stride) for stride in strides]
46 | self.test_branch_idx = test_branch_idx
47 | self.norm = norm
48 | self.activation = activation
49 |
50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
51 |
52 | self.weight = nn.Parameter(
53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
54 | )
55 | if bias:
56 | self.bias = nn.Parameter(torch.Tensor(out_channels))
57 | else:
58 | self.bias = None
59 |
60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
61 | if self.bias is not None:
62 | nn.init.constant_(self.bias, 0)
63 |
64 | def forward(self, inputs):
65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
66 | assert len(inputs) == num_branch
67 |
68 | if self.training or self.test_branch_idx == -1:
69 | outputs = [
70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
71 | for input, stride, padding in zip(inputs, self.strides, self.paddings)
72 | ]
73 | else:
74 | outputs = [
75 | F.conv2d(
76 | inputs[0],
77 | self.weight,
78 | self.bias,
79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
81 | self.dilation,
82 | self.groups,
83 | )
84 | ]
85 |
86 | if self.norm is not None:
87 | outputs = [self.norm(x) for x in outputs]
88 | if self.activation is not None:
89 | outputs = [self.activation(x) for x in outputs]
90 | return outputs
91 |
--------------------------------------------------------------------------------
/models/gmflow/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.gmflow.position import PositionEmbeddingSine
3 |
4 |
5 | def split_feature(feature,
6 | num_splits=2,
7 | channel_last=False,
8 | ):
9 | if channel_last: # [B, H, W, C]
10 | b, h, w, c = feature.size()
11 | assert h % num_splits == 0 and w % num_splits == 0
12 |
13 | b_new = b * num_splits * num_splits
14 | h_new = h // num_splits
15 | w_new = w // num_splits
16 |
17 | feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
18 | ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
19 | else: # [B, C, H, W]
20 | b, c, h, w = feature.size()
21 | assert h % num_splits == 0 and w % num_splits == 0
22 |
23 | b_new = b * num_splits * num_splits
24 | h_new = h // num_splits
25 | w_new = w // num_splits
26 |
27 | feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
28 | ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
29 |
30 | return feature
31 |
32 |
33 | def merge_splits(splits,
34 | num_splits=2,
35 | channel_last=False,
36 | ):
37 | if channel_last: # [B*K*K, H/K, W/K, C]
38 | b, h, w, c = splits.size()
39 | new_b = b // num_splits // num_splits
40 |
41 | splits = splits.view(new_b, num_splits, num_splits, h, w, c)
42 | merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
43 | new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
44 | else: # [B*K*K, C, H/K, W/K]
45 | b, c, h, w = splits.size()
46 | new_b = b // num_splits // num_splits
47 |
48 | splits = splits.view(new_b, num_splits, num_splits, c, h, w)
49 | merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
50 | new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
51 |
52 | return merge
53 |
54 |
55 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
56 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
57 | is_mean_std_loaded = False
58 |
59 |
60 | def normalize_img(img0, img1):
61 | # loaded images are in [0, 255]
62 | # normalize by ImageNet mean and std
63 | global mean, std, is_mean_std_loaded
64 | if not is_mean_std_loaded:
65 | mean = mean.to(img0)
66 | std = std.to(img0)
67 | is_mean_std_loaded = True
68 | img0 = (img0 - mean) / std
69 | img1 = (img1 - mean) / std
70 |
71 | return img0, img1
72 |
73 |
74 | def feature_add_position(feature0, feature1, attn_splits, feature_channels):
75 | pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
76 |
77 | if attn_splits > 1: # add position in splited window
78 | feature0_splits = split_feature(feature0, num_splits=attn_splits)
79 | feature1_splits = split_feature(feature1, num_splits=attn_splits)
80 |
81 | position = pos_enc(feature0_splits)
82 |
83 | feature0_splits = feature0_splits + position
84 | feature1_splits = feature1_splits + position
85 |
86 | feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
87 | feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
88 | else:
89 | position = pos_enc(feature0)
90 |
91 | feature0 = feature0 + position
92 | feature1 = feature1 + position
93 |
94 | return feature0, feature1
95 |
--------------------------------------------------------------------------------
/models/model_pg104/FeatureNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FeatureNet(nn.Module):
7 | """The quadratic model"""
8 | def __init__(self):
9 | super(FeatureNet, self).__init__()
10 | self.block1 = nn.Sequential(
11 | nn.PReLU(),
12 | nn.Conv2d(3, 64, 3, 2, 1),
13 | nn.PReLU(),
14 | nn.Conv2d(64, 64, 3, 1, 1),
15 | )
16 | self.block2 = nn.Sequential(
17 | nn.PReLU(),
18 | nn.Conv2d(64, 128, 3, 2, 1),
19 | nn.PReLU(),
20 | nn.Conv2d(128, 128, 3, 1, 1),
21 | )
22 | self.block3 = nn.Sequential(
23 | nn.PReLU(),
24 | nn.Conv2d(128, 192, 3, 2, 1),
25 | nn.PReLU(),
26 | nn.Conv2d(192, 192, 3, 1, 1),
27 | )
28 |
29 | def forward(self, x):
30 | x1 = self.block1(x)
31 | x2 = self.block2(x1)
32 | x3 = self.block3(x2)
33 |
34 | return x1, x2, x3
--------------------------------------------------------------------------------
/models/model_pg104/IFNet_HDv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.model_pg104.warplayer import warp
6 |
7 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
8 | return nn.Sequential(
9 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
10 | padding=padding, dilation=dilation, bias=True),
11 | nn.LeakyReLU(0.2, True)
12 | )
13 |
14 | class ResConv(nn.Module):
15 | def __init__(self, c, dilation=1):
16 | super(ResConv, self).__init__()
17 | self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
18 | self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
19 | self.relu = nn.LeakyReLU(0.2, True)
20 |
21 | def forward(self, x):
22 | return self.relu(self.conv(x) * self.beta + x)
23 |
24 | class IFBlock(nn.Module):
25 | def __init__(self, in_planes, c=64):
26 | super(IFBlock, self).__init__()
27 | self.conv0 = nn.Sequential(
28 | conv(in_planes, c//2, 3, 2, 1),
29 | conv(c//2, c, 3, 2, 1),
30 | )
31 | self.convblock = nn.Sequential(
32 | ResConv(c),
33 | ResConv(c),
34 | ResConv(c),
35 | ResConv(c),
36 | ResConv(c),
37 | ResConv(c),
38 | ResConv(c),
39 | ResConv(c),
40 | )
41 | self.lastconv = nn.Sequential(
42 | nn.ConvTranspose2d(c, 4*6, 4, 2, 1),
43 | nn.PixelShuffle(2)
44 | )
45 |
46 | def forward(self, x, flow=None, scale=1):
47 | x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
48 | if flow is not None:
49 | flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
50 | x = torch.cat((x, flow), 1)
51 | feat = self.conv0(x)
52 | feat = self.convblock(feat)
53 | tmp = self.lastconv(feat)
54 | tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
55 | flow = tmp[:, :4] * scale
56 | mask = tmp[:, 4:5]
57 | return flow, mask
58 |
59 | class IFNet(nn.Module):
60 | def __init__(self):
61 | super(IFNet, self).__init__()
62 | self.block0 = IFBlock(7, c=192)
63 | self.block1 = IFBlock(8+4, c=128)
64 | self.block2 = IFBlock(8+4, c=96)
65 | self.block3 = IFBlock(8+4, c=64)
66 |
67 | def forward( self, x, timestep=0.5, scale_list=[8, 4, 2, 1], training=False, fastmode=True, ensemble=False):
68 | if training == False:
69 | channel = x.shape[1] // 2
70 | img0 = x[:, :channel]
71 | img1 = x[:, channel:]
72 | if not torch.is_tensor(timestep):
73 | timestep = (x[:, :1].clone() * 0 + 1) * timestep
74 | else:
75 | timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
76 | flow = None
77 | block = [self.block0, self.block1, self.block2, self.block3]
78 | for i in range(4):
79 | if flow is None:
80 | flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], timestep), 1), None, scale=scale_list[i])
81 | if ensemble:
82 | f1, m1 = block[i](torch.cat((img1[:, :3], img0[:, :3], 1-timestep), 1), None, scale=scale_list[i])
83 | flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
84 | mask = (mask + (-m1)) / 2
85 | else:
86 | f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], timestep, mask), 1), flow, scale=scale_list[i])
87 | if ensemble:
88 | f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
89 | f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
90 | m0 = (m0 + (-m1)) / 2
91 | flow = flow + f0
92 | mask = mask + m0
93 | warped_img0 = warp(img0, flow[:, :2])
94 | warped_img1 = warp(img1, flow[:, 2:4])
95 | mask = torch.sigmoid(mask)
96 | merged = warped_img0 * mask + warped_img1 * (1 - mask)
97 | return merged
98 |
--------------------------------------------------------------------------------
/models/model_pg104/MetricNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.gmflow.geometry import forward_backward_consistency_check
6 |
7 |
8 | backwarp_tenGrid = {}
9 |
10 | def backwarp(tenIn, tenflow):
11 | if str(tenflow.shape) not in backwarp_tenGrid:
12 | tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenflow.shape[3], dtype=tenflow.dtype, device=tenflow.device).view(1, 1, 1, -1).repeat(1, 1, tenflow.shape[2], 1)
13 | tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenflow.shape[2], dtype=tenflow.dtype, device=tenflow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenflow.shape[3])
14 |
15 | backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1).cuda()
16 | # end
17 |
18 | tenflow = torch.cat([tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1)
19 |
20 | return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True)
21 |
22 |
23 | class MetricNet(nn.Module):
24 | def __init__(self):
25 | super(MetricNet, self).__init__()
26 | self.metric_in = nn.Conv2d(14, 64, 3, 1, 1)
27 | self.metric_net1 = nn.Sequential(
28 | nn.PReLU(),
29 | nn.Conv2d(64, 64, 3, 1, 1)
30 | )
31 | self.metric_net2 = nn.Sequential(
32 | nn.PReLU(),
33 | nn.Conv2d(64, 64, 3, 1, 1)
34 | )
35 | self.metric_net3 = nn.Sequential(
36 | nn.PReLU(),
37 | nn.Conv2d(64, 64, 3, 1, 1)
38 | )
39 | self.metric_out = nn.Sequential(
40 | nn.PReLU(),
41 | nn.Conv2d(64, 2, 3, 1, 1),
42 | nn.Tanh()
43 | )
44 |
45 | def forward(self, img0, img1, flow01, flow10):
46 | metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction='none').mean([1], True)
47 | metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction='none').mean([1], True)
48 |
49 | fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10)
50 |
51 | flow01 = torch.cat([flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0), flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0)], 1)
52 | flow10 = torch.cat([flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0), flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0)], 1)
53 |
54 | img = torch.cat((img0, img1), 1)
55 | metric = torch.cat((-metric0, -metric1), 1)
56 | flow = torch.cat((flow01, flow10), 1)
57 | occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1)
58 |
59 | feat = self.metric_in(torch.cat((img, metric, flow, occ), 1))
60 | feat = self.metric_net1(feat) + feat
61 | feat = self.metric_net2(feat) + feat
62 | feat = self.metric_net3(feat) + feat
63 | metric = self.metric_out(feat) * 10
64 |
65 | return metric[:, :1], metric[:, 1:2]
66 |
--------------------------------------------------------------------------------
/models/model_pg104/warplayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5 | backwarp_tenGrid = {}
6 |
7 |
8 | def warp(tenInput, tenFlow):
9 | k = (str(tenFlow.device), str(tenFlow.size()))
10 | if k not in backwarp_tenGrid:
11 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
12 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
13 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
14 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
15 | backwarp_tenGrid[k] = torch.cat(
16 | [tenHorizontal, tenVertical], 1).to(device)
17 |
18 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
19 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
20 |
21 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
22 | return torch.nn.functional.grid_sample(input=tenInput, grid=g.to(tenInput.dtype), mode='bilinear', padding_mode='border', align_corners=True)
--------------------------------------------------------------------------------
/models/vfi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.IFNet_HDv3 import IFNet
3 | from models.gimm.src.utils.setup import single_setup
4 | from models.gimm.src.models import create_model
5 | from models.model_pg104.GMFSS import Model as GMFSS
6 | import argparse
7 | from models.utils.tools import *
8 |
9 |
10 | class VFI:
11 | def __init__(self, model_type='rife', weights='weights', scale=1.0,
12 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
13 | if model_type == 'rife':
14 | model = IFNet()
15 | model.load_state_dict(convert(torch.load(f'{weights}/rife48.pkl')))
16 | elif model_type == 'gmfss':
17 | model = GMFSS()
18 | model.load_model(f'{weights}/train_log_pg104', -1)
19 | else:
20 | args = argparse.Namespace(
21 | model_config=r"models/gimm/configs/gimmvfi/gimmvfi_r_arb.yaml",
22 | load_path=f"{weights}/gimmvfi_r_arb_lpips.pt",
23 | ds_factor=scale,
24 | eval=True,
25 | seed=0
26 | )
27 | config = single_setup(args)
28 | model, _ = create_model(config.arch)
29 |
30 | # Checkpoint loading
31 | if "ours" in args.load_path:
32 | ckpt = torch.load(args.load_path, map_location="cpu")
33 |
34 | def convert_gimm(param):
35 | return {
36 | k.replace("module.feature_bone", "frame_encoder"): v
37 | for k, v in param.items()
38 | if "feature_bone" in k
39 | }
40 |
41 | ckpt = convert_gimm(ckpt)
42 | model.load_state_dict(ckpt, strict=False)
43 | else:
44 | ckpt = torch.load(args.load_path, map_location="cpu")
45 | model.load_state_dict(ckpt["state_dict"], strict=False)
46 |
47 | model.eval()
48 | if model_type == 'gmfss':
49 | model.device()
50 | else:
51 | model.to(device)
52 |
53 | self.model = model
54 | self.model_type = model_type
55 | base_pads = {
56 | 'gimm': 64,
57 | 'rife': 64,
58 | 'gmfss': 128,
59 | }
60 | self.pad_size = base_pads[model_type] / scale
61 | self.device = device
62 | self.saved_result = {}
63 | self.scale = scale
64 |
65 | @torch.inference_mode()
66 | def gen_ts_frame(self, x, y, ts):
67 | _outputs = list()
68 | head = [x] if 0 in ts else []
69 | tail = [y] if 1 in ts else []
70 | if 0 in ts:
71 | ts.remove(0)
72 | if 1 in ts:
73 | ts.remove(1)
74 | with torch.autocast(str(self.device)):
75 | _reuse_things = self.model.reuse(x, y, self.scale) if self.model_type == 'gmfss' else None
76 | if self.model_type in ['rife', 'gmfss']:
77 | for t in ts:
78 | if self.model_type == 'rife':
79 | scale_list = [8 / self.scale, 4 / self.scale, 2 / self.scale, 1 / self.scale]
80 | _out = self.model(torch.cat((x, y), dim=1), t, scale_list)
81 | elif self.model_type == 'gmfss':
82 | _out = self.model.inference(x, y, _reuse_things, t)
83 | _outputs.append(_out)
84 | elif self.model_type == 'gimm':
85 | xs = torch.cat((x.unsqueeze(2), y.unsqueeze(2)), dim=2).to(
86 | self.device, non_blocking=True
87 | )
88 | self.model.zero_grad()
89 | coord_inputs = [
90 | (
91 | self.model.sample_coord_input(
92 | xs.shape[0],
93 | xs.shape[-2:],
94 | [t],
95 | device=xs.device,
96 | upsample_ratio=self.scale,
97 | ),
98 | None,
99 | )
100 | for t in ts
101 | ]
102 | timesteps = [
103 | t * torch.ones(xs.shape[0]).to(xs.device).to(torch.float)
104 | for t in ts
105 | ]
106 | all_outputs = self.model(xs, coord_inputs, t=timesteps, ds_factor=self.scale)
107 |
108 | _outputs = all_outputs["imgt_pred"]
109 |
110 | _outputs = head + _outputs + tail
111 |
112 | return _outputs
113 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ffmpeg-python>=0.2.0
2 | numpy>=1.16, <=1.23.5
3 | torch>=2.5.1
4 | torchvision>=0.20.1
5 | tqdm>=4.35.0
6 | opencv-python>=4.1.2
7 | cupy-cuda11x
8 | decorator==5.1.1
9 | easydict==1.9
10 | einops==0.6.1
11 | imageio==2.31.2
12 | importlib-metadata==6.7.0
13 | omegaconf==2.3.0
14 | Pillow==9.5.0
15 | protobuf==3.20.0
16 | safetensors==0.4.2
17 | scikit-image==0.19.3
18 | scikit-learn==0.24.0
19 | scipy==1.7.3
20 | tensorboard==2.3.0
21 | tensorboard-plugin-wit==1.8.1
22 | timm==0.4.12
23 | typing_extensions==4.7.1
24 | yacs==0.1.6
--------------------------------------------------------------------------------