├── .gitignore
├── LICENSE
├── README
├── README.md
├── __init__.py
├── configs
├── create_configs.py
├── flow_config.py
├── inj_config.py
├── pivot_config.py
├── rave_config.py
├── sca_config.py
└── temporal_config.py
├── constants.py
├── example_workflows
├── vv_sd15_example_workflow.json
└── vv_sdxl_example_workflow.json
├── modules
├── vv_attention.py
├── vv_block.py
├── vv_resnet.py
└── vv_unet.py
├── nodes
├── config_flow_attn_node.py
├── config_flow_inj_node.py
├── flow_config_node.py
├── get_flow_node.py
├── get_raft_flow_node.py
├── inj_config_node.py
├── pivot_config_node.py
├── rave_config_node.py
├── sca_config_node.py
├── temporal_config_node.py
├── vv_apply_node.py
├── vv_sampler_node.py
└── vv_unsampler_node.py
├── requirements.txt
├── sea_raft
├── .gitignore
├── LICENSE
├── README.md
├── assets
│ └── visualization.png
├── chairs_split.txt
├── config
│ ├── eval
│ │ ├── kitti-L.json
│ │ ├── kitti-M.json
│ │ ├── kitti-S.json
│ │ ├── sintel-L.json
│ │ ├── sintel-M.json
│ │ ├── sintel-S.json
│ │ ├── spring-L.json
│ │ ├── spring-M.json
│ │ └── spring-S.json
│ ├── parser.py
│ └── train
│ │ ├── Tartan-C-T-TSKH-kitti432x960-M.json
│ │ ├── Tartan-C-T-TSKH-kitti432x960-S.json
│ │ ├── Tartan-C-T-TSKH-spring540x960-M.json
│ │ ├── Tartan-C-T-TSKH-spring540x960-S.json
│ │ ├── Tartan-C-T-TSKH432x960-M.json
│ │ ├── Tartan-C-T-TSKH432x960-S.json
│ │ ├── Tartan-C-T432x960-M.json
│ │ ├── Tartan-C-T432x960-S.json
│ │ ├── Tartan-C368x496-M.json
│ │ ├── Tartan-C368x496-S.json
│ │ ├── Tartan480x640-M.json
│ │ └── Tartan480x640-S.json
├── core
│ ├── __init__.py
│ ├── corr.py
│ ├── extractor.py
│ ├── layer.py
│ ├── loss.py
│ ├── raft.py
│ ├── update.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── augmentor.py
│ │ ├── flow_transforms.py
│ │ ├── flow_viz.py
│ │ ├── frame_utils.py
│ │ └── utils.py
├── custom.py
├── custom
│ ├── flow.jpg
│ ├── heatmap.jpg
│ ├── image1.jpg
│ └── image2.jpg
├── ddp_utils.py
├── demo.py
├── eval_ptlflow.py
├── evaluate.py
├── profile_ptlflow.py
├── profiler.py
├── requirements.txt
├── scripts
│ ├── eval.sh
│ ├── submission.sh
│ └── train.sh
├── submission.py
└── train.py
├── unimatch
├── DATASETS.md
├── LICENSE
├── MODEL_ZOO.md
├── README.md
├── conda_environment.yml
├── dataloader
│ ├── __init__.py
│ ├── depth
│ │ ├── __init__.py
│ │ ├── augmentation.py
│ │ ├── datasets.py
│ │ ├── download_demon_test.sh
│ │ ├── download_demon_train.sh
│ │ ├── prepare_demon_test.py
│ │ ├── prepare_demon_train.py
│ │ ├── scannet_banet_test_pairs.txt
│ │ └── scannet_banet_train_pairs.txt
│ ├── flow
│ │ ├── __init__.py
│ │ ├── chairs_split.txt
│ │ ├── datasets.py
│ │ └── transforms.py
│ └── stereo
│ │ ├── __init__.py
│ │ ├── datasets.py
│ │ └── transforms.py
├── demo
│ ├── depth-scannet
│ │ ├── color
│ │ │ ├── 0048.png
│ │ │ ├── 0054.png
│ │ │ ├── 0060.png
│ │ │ └── 0066.png
│ │ ├── intrinsic
│ │ │ └── intrinsic_depth.txt
│ │ └── pose
│ │ │ ├── 0048.txt
│ │ │ ├── 0054.txt
│ │ │ ├── 0060.txt
│ │ │ └── 0066.txt
│ ├── flow-davis
│ │ ├── 00000.jpg
│ │ ├── 00001.jpg
│ │ └── 00002.jpg
│ ├── kitti.mp4
│ └── stereo-middlebury
│ │ ├── im0.png
│ │ └── im1.png
├── evaluate_depth.py
├── evaluate_flow.py
├── evaluate_stereo.py
├── loss
│ ├── __init__.py
│ ├── depth_loss.py
│ ├── flow_loss.py
│ └── stereo_metric.py
├── main_depth.py
├── main_flow.py
├── main_stereo.py
├── pip_install.sh
├── scripts
│ ├── gmdepth_demo.sh
│ ├── gmdepth_evaluate.sh
│ ├── gmdepth_scale1_regrefine1_train.sh
│ ├── gmdepth_scale1_train.sh
│ ├── gmflow_demo.sh
│ ├── gmflow_evaluate.sh
│ ├── gmflow_scale1_train.sh
│ ├── gmflow_scale2_regrefine6_train.sh
│ ├── gmflow_scale2_train.sh
│ ├── gmflow_submission.sh
│ ├── gmstereo_demo.sh
│ ├── gmstereo_evaluate.sh
│ ├── gmstereo_scale1_train.sh
│ ├── gmstereo_scale2_regrefine3_train.sh
│ ├── gmstereo_scale2_train.sh
│ └── gmstereo_submission.sh
├── unimatch
│ ├── __init__.py
│ ├── attention.py
│ ├── backbone.py
│ ├── geometry.py
│ ├── matching.py
│ ├── position.py
│ ├── reg_refine.py
│ ├── transformer.py
│ ├── trident_conv.py
│ ├── unimatch.py
│ └── utils.py
└── utils
│ ├── dist_utils.py
│ ├── file_io.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ ├── logger.py
│ ├── misc.py
│ ├── utils.py
│ └── visualization.py
├── utils
├── attention_utils.py
├── batching_utils.py
├── feature_utils.py
├── flow_utils.py
├── flow_viz.py
├── module_utils.py
├── noise_utils.py
├── rave_utils.py
└── sampler_utils.py
└── vv_defaults.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
3 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes.vv_apply_node import ApplyVVModel
2 | from .nodes.get_flow_node import FlowGetFlowNode
3 | from .nodes.vv_unsampler_node import VVUnsamplerSamplerNode
4 | from .nodes.vv_sampler_node import VVSamplerSamplerNode
5 | from .nodes.inj_config_node import InjectionConfigNode
6 | from .nodes.pivot_config_node import PivotConfigNode
7 | from .nodes.rave_config_node import RaveConfigNode
8 | from .nodes.sca_config_node import SCAConfigNode
9 | from .nodes.flow_config_node import FlowConfigNode
10 | from .nodes.temporal_config_node import TemporalConfigNode
11 | from .nodes.get_raft_flow_node import GetRaftFlowNode
12 |
13 |
14 | NODE_CLASS_MAPPINGS = {
15 | "ApplyVVModel": ApplyVVModel,
16 | "FlowGetFlow": FlowGetFlowNode,
17 | "VVUnsamplerSampler": VVUnsamplerSamplerNode,
18 | "VVSamplerSampler": VVSamplerSamplerNode,
19 | "InjectionConfig": InjectionConfigNode,
20 | "FlowConfig": FlowConfigNode,
21 | "PivotConfig": PivotConfigNode,
22 | "RaveConfig": RaveConfigNode,
23 | "SCAConfig": SCAConfigNode,
24 | "TemporalConfig": TemporalConfigNode,
25 | "GetRaftFlow": GetRaftFlowNode,
26 | }
27 |
28 | NODE_DISPLAY_NAME_MAPPINGS = {
29 | "ApplyVVModel": "VV] Apply Model",
30 | "FlowGetFlow": "VV] Get Flow",
31 | "VVUnsamplerSampler": "VV] Unsampler",
32 | "VVSamplerSampler": "VV] Sampler",
33 | "InjectionConfig": "VV] Injection Config",
34 | "FlowConfig": "VV] Flow Attn Config",
35 | "PivotConfig": "VV] Pivot Attn Config",
36 | "RaveConfig": "VV] Rave Attn Config",
37 | "SCAConfig": "VV] Sparse Casual Attn Config",
38 | "TemporalConfig": "VV] Temporal Attn Config",
39 | "VVGetRaftFlow": "VV] Get Raft Flow"
40 | }
41 |
--------------------------------------------------------------------------------
/configs/create_configs.py:
--------------------------------------------------------------------------------
1 | from ..configs.rave_config import RaveAttentionConfig
2 | from ..configs.pivot_config import PivotAttentionConfig
3 | from ..configs.sca_config import SparseCasualAttentionConfig
4 | from ..configs.flow_config import FlowAttentionConfig
5 | from ..configs.inj_config import InjectionConfig
6 | from ..configs.temporal_config import TemporalAttentionConfig
7 | from ..vv_defaults import SD1_MAPS, SDXL_MAPS, SD1_ATTN_INJ_DEFAULTS, SD1_RES_INJ_DEFAULTS, SD1_FLOW_MAP, SDXL_ATTN_INJ_DEFAULTS, SDXL_RES_INJ_DEFAULTS, SDXL_FLOW_MAP
8 |
9 |
10 | def create_configs(
11 | model,
12 | inj_config=None,
13 | flow_config=None,
14 | pivot_config=None,
15 | rave_config=None,
16 | sca_config=None,
17 | temporal_config=None,
18 | ):
19 | model_type = 'SD1.5' #TODO
20 | if model_type == 'SD1.5':
21 | if inj_config is not None:
22 | inj_config = InjectionConfig(
23 | SD1_ATTN_INJ_DEFAULTS,
24 | SD1_RES_INJ_DEFAULTS,
25 | inj_config['attn_injections'],
26 | inj_config['res_injections'],
27 | inj_config['attn_save_steps'],
28 | inj_config['res_save_steps'],
29 | inj_config['attn_sigmas'],
30 | inj_config['res_sigmas']
31 | )
32 | if flow_config is not None:
33 | flow_config = FlowAttentionConfig(
34 | SD1_MAPS[flow_config['targets']],
35 | flow_config['flow'],
36 | flow_config['start_percent'],
37 | flow_config['end_percent']
38 | )
39 | if rave_config is not None:
40 | rave_config = RaveAttentionConfig(
41 | SD1_MAPS[rave_config['targets']],
42 | rave_config['grid_size'],
43 | rave_config['seed'],
44 | rave_config['start_percent'],
45 | rave_config['end_percent'])
46 | if pivot_config is not None:
47 | pivot_config = PivotAttentionConfig(
48 | SD1_MAPS[pivot_config['targets']],
49 | pivot_config['batch_size'],
50 | pivot_config['seed'],
51 | pivot_config['start_percent'],
52 | pivot_config['end_percent'])
53 | if sca_config is not None:
54 | sca_config = SparseCasualAttentionConfig(
55 | SD1_MAPS[sca_config['targets']],
56 | sca_config['direction'],
57 | sca_config['start_percent'],
58 | sca_config['end_percent'])
59 | if temporal_config is not None:
60 | temporal_config = TemporalAttentionConfig(
61 | temporal_config['start_percent'],
62 | temporal_config['end_percent'],
63 | )
64 | else:
65 | if inj_config is not None:
66 | inj_config = InjectionConfig(
67 | SDXL_ATTN_INJ_DEFAULTS,
68 | SDXL_RES_INJ_DEFAULTS,
69 | inj_config['attn_injections'],
70 | inj_config['res_injections'],
71 | inj_config['attn_save_steps'],
72 | inj_config['res_save_steps'],
73 | inj_config['attn_sigmas'],
74 | inj_config['res_sigmas']
75 | )
76 | if flow_config is not None:
77 | flow_config = FlowAttentionConfig(
78 | SDXL_MAPS[flow_config['targets']],
79 | flow_config['flow'],
80 | flow_config['start_percent'],
81 | flow_config['end_percent']
82 | )
83 | if rave_config is not None:
84 | rave_config = RaveAttentionConfig(
85 | SDXL_MAPS[rave_config['targets']],
86 | rave_config['grid_size'],
87 | rave_config['seed'],
88 | rave_config['start_percent'],
89 | rave_config['end_percent'])
90 | if pivot_config is not None:
91 | pivot_config = PivotAttentionConfig(
92 | SDXL_MAPS[pivot_config['targets']],
93 | pivot_config['batch_size'],
94 | pivot_config['seed'],
95 | pivot_config['start_percent'],
96 | pivot_config['end_percent'])
97 | if sca_config is not None:
98 | sca_config = SparseCasualAttentionConfig(
99 | SDXL_MAPS[sca_config['targets']],
100 | sca_config['direction'],
101 | sca_config['start_percent'],
102 | sca_config['end_percent'])
103 | if temporal_config is not None:
104 | temporal_config = TemporalAttentionConfig(
105 | temporal_config['start_percent'],
106 | temporal_config['end_percent'],
107 | )
108 | return {
109 | 'INJ_CONFIG': inj_config,
110 | 'FLOW_CONFIG': flow_config,
111 | 'RAVE_CONFIG': rave_config,
112 | 'PIVOT_CONFIG': pivot_config,
113 | 'SCA_CONFIG': sca_config,
114 | 'TEMPORAL_CONFIG': temporal_config
115 | }
--------------------------------------------------------------------------------
/configs/flow_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict
3 |
4 |
5 | @dataclass
6 | class FlowAttentionConfig:
7 | targets: set[tuple[str, int]]
8 | flow: Dict[Any, Any]
9 | start_percent: float
10 | end_percent: float
11 |
--------------------------------------------------------------------------------
/configs/inj_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, List, Union
3 |
4 | import torch
5 |
6 |
7 | @dataclass
8 | class InjectionConfig:
9 | attn_map: set[tuple[str, int]]
10 | res_map: set[tuple[str, int]]
11 | attn_injections: Dict[float, Dict[tuple[str, int], Union[torch.Tensor, List[torch.Tensor]]]]
12 | res_injections: Dict[float, Dict[tuple[str, int], Union[torch.Tensor, List[torch.Tensor]]]]
13 | attn_save_steps: int
14 | res_save_steps: int
15 | attn_sigmas: List[float]
16 | res_sigmas: List[float]
17 |
--------------------------------------------------------------------------------
/configs/pivot_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class PivotAttentionConfig:
6 | targets: set[tuple[str, int]]
7 | batch_size: int
8 | seed: int
9 | start_percent: float
10 | end_percent: float
--------------------------------------------------------------------------------
/configs/rave_config.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | from dataclasses import dataclass
5 |
6 |
7 | @dataclass
8 | class RaveAttentionConfig:
9 | targets: set[tuple[str, int]]
10 | grid_size: int
11 | seed: int
12 | start_percent: float
13 | end_percent: float
14 |
15 |
--------------------------------------------------------------------------------
/configs/sca_config.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | from dataclasses import dataclass
5 |
6 |
7 | class SCADirection:
8 | PREVIOUS = 'PREVIOUS'
9 | NEXT = 'NEXT'
10 | BOTH = 'BOTH'
11 |
12 |
13 | @dataclass
14 | class SparseCasualAttentionConfig:
15 | targets: set[tuple[str, int]]
16 | direction: SCADirection
17 | start_percent: float
18 | end_percent: float
19 |
20 |
--------------------------------------------------------------------------------
/configs/temporal_config.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | from dataclasses import dataclass
5 |
6 |
7 | @dataclass
8 | class TemporalAttentionConfig:
9 | start_percent: float
10 | end_percent: float
11 |
12 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 |
2 | class ModelOptionKey:
3 | INJECTION_TYPE = "INJECTION_KEY"
4 | FLOW = "FLOW"
5 | STEP = "STEP"
6 | BANK = "BANK"
7 |
8 |
9 | class InjectionType:
10 | UNSAMPLING = "UNSAMPLING"
11 | SAMPLING = "SAMPLING"
12 | REFERENCE = "REFERENCE"
--------------------------------------------------------------------------------
/modules/vv_block.py:
--------------------------------------------------------------------------------
1 | from comfy.ldm.modules.attention import BasicTransformerBlock
2 | from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
3 |
4 | from ..utils.module_utils import isinstance_str
5 |
6 |
7 |
8 | class VVTransformerBlock(BasicTransformerBlock):
9 | def forward(self, x, context=None, transformer_options={}):
10 | extra_options = {}
11 | block = transformer_options.get("block", None)
12 | block_index = transformer_options.get("block_index", 0)
13 | transformer_patches = {}
14 | transformer_patches_replace = {}
15 |
16 | for k in transformer_options:
17 | if k == "patches":
18 | transformer_patches = transformer_options[k]
19 | elif k == "patches_replace":
20 | transformer_patches_replace = transformer_options[k]
21 | else:
22 | extra_options[k] = transformer_options[k]
23 |
24 | extra_options["n_heads"] = self.n_heads
25 | extra_options["dim_head"] = self.d_head
26 | extra_options["attn_precision"] = self.attn_precision
27 |
28 | if self.ff_in:
29 | x_skip = x
30 | x = self.ff_in(self.norm_in(x))
31 | if self.is_res:
32 | x += x_skip
33 |
34 | n = self.norm1(x)
35 |
36 | if self.disable_self_attn:
37 | context_attn1 = context
38 | else:
39 | context_attn1 = None
40 | value_attn1 = None
41 |
42 | if "attn1_patch" in transformer_patches:
43 | patch = transformer_patches["attn1_patch"]
44 | if context_attn1 is None:
45 | context_attn1 = n
46 | value_attn1 = context_attn1
47 | for p in patch:
48 | n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
49 |
50 | if block is not None:
51 | transformer_block = (block[0], block[1], block_index)
52 | else:
53 | transformer_block = None
54 | attn1_replace_patch = transformer_patches_replace.get("attn1", {})
55 | block_attn1 = transformer_block
56 | if block_attn1 not in attn1_replace_patch:
57 | block_attn1 = block
58 |
59 | if block_attn1 in attn1_replace_patch:
60 | if context_attn1 is None:
61 | context_attn1 = n
62 | value_attn1 = n
63 | n = self.attn1.to_q(n)
64 | context_attn1 = self.attn1.to_k(context_attn1)
65 | value_attn1 = self.attn1.to_v(value_attn1)
66 | n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
67 | n = self.attn1.to_out(n)
68 | else:
69 | n = self.attn1(n, extra_options=extra_options)
70 |
71 | if "attn1_output_patch" in transformer_patches:
72 | patch = transformer_patches["attn1_output_patch"]
73 | for p in patch:
74 | n = p(n, extra_options)
75 |
76 | x += n
77 | if "middle_patch" in transformer_patches:
78 | patch = transformer_patches["middle_patch"]
79 | for p in patch:
80 | x = p(x, extra_options)
81 |
82 | if self.attn2 is not None:
83 | n = self.norm2(x)
84 | if self.switch_temporal_ca_to_sa:
85 | context_attn2 = n
86 | else:
87 | context_attn2 = context
88 | value_attn2 = None
89 | if "attn2_patch" in transformer_patches:
90 | patch = transformer_patches["attn2_patch"]
91 | value_attn2 = context_attn2
92 | for p in patch:
93 | n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
94 |
95 | attn2_replace_patch = transformer_patches_replace.get("attn2", {})
96 | block_attn2 = transformer_block
97 | if block_attn2 not in attn2_replace_patch:
98 | block_attn2 = block
99 |
100 | if block_attn2 in attn2_replace_patch:
101 | if value_attn2 is None:
102 | value_attn2 = context_attn2
103 | n = self.attn2.to_q(n)
104 | context_attn2 = self.attn2.to_k(context_attn2)
105 | value_attn2 = self.attn2.to_v(value_attn2)
106 | n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
107 | n = self.attn2.to_out(n)
108 | else:
109 | n = self.attn2(n, context=context_attn2, value=value_attn2)
110 |
111 | if "attn2_output_patch" in transformer_patches:
112 | patch = transformer_patches["attn2_output_patch"]
113 | for p in patch:
114 | n = p(n, extra_options)
115 |
116 | x += n
117 | if self.is_res:
118 | x_skip = x
119 | x = self.ff(self.norm3(x))
120 | if self.is_res:
121 | x += x_skip
122 |
123 | return x
124 |
125 |
126 |
127 | def _get_block_modules(module):
128 | blocks = list(filter(lambda x: isinstance_str(x[1], 'BasicTransformerBlock'), module.named_modules()))
129 | return [block for _, block in blocks]
130 |
131 |
132 | def inject_vv_block(diffusion_model: UNetModel):
133 | blocks = _get_block_modules(diffusion_model)
134 |
135 | for block in blocks:
136 | block.__class__ = VVTransformerBlock
137 |
--------------------------------------------------------------------------------
/modules/vv_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock as ComfyResBlock, UNetModel
4 |
5 | from ..utils.module_utils import isinstance_str
6 |
7 |
8 | class ResBlock(ComfyResBlock):
9 | def init_module(self, block, idx):
10 | self.block = block
11 | self.idx = idx
12 | self.block_idx = (block, idx)
13 |
14 | def forward(self, x, emb, extra_options):
15 | if self.updown:
16 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
17 | h = in_rest(x)
18 | h = self.h_upd(h)
19 | x = self.x_upd(x)
20 | h = in_conv(h)
21 | else:
22 | h = self.in_layers(x)
23 |
24 | emb_out = None
25 | if not self.skip_t_emb:
26 | emb_out = self.emb_layers(emb).type(h.dtype)
27 | while len(emb_out.shape) < len(h.shape):
28 | emb_out = emb_out[..., None]
29 | if self.use_scale_shift_norm:
30 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
31 | h = out_norm(h)
32 | if emb_out is not None:
33 | scale, shift = torch.chunk(emb_out, 2, dim=1)
34 | h *= (1 + scale)
35 | h += shift
36 | h = out_rest(h)
37 | else:
38 | if emb_out is not None:
39 | if self.exchange_temb_dims:
40 | emb_out = emb_out.movedim(1, 2)
41 | h = h + emb_out
42 | h = self.out_layers(h)
43 |
44 | ad_params = extra_options.get('ad_params', {})
45 | sub_idxs = ad_params.get('sub_idxs', [0])
46 | if sub_idxs is None:
47 | sub_idx = 0
48 | else:
49 | sub_idx = sub_idxs[0]
50 |
51 | res_inj_steps = extra_options.get('RES_INJECTION_STEPS', 0)
52 | step = extra_options.get('STEP', 999)
53 | inj_config = extra_options.get('INJ_CONFIG', None)
54 | if inj_config and self.block_idx in inj_config.res_map:
55 | if extra_options['INJECTION_KEY'] == 'SAMPLING' and step < res_inj_steps and step < inj_config.res_save_steps:
56 | if extra_options['INJECTION_KEY'] == 'SAMPLING':
57 | len_cond = len(extra_options['cond_or_uncond'])
58 | sigma_key = inj_config.res_sigmas[step]
59 | res_inj = inj_config.res_injections[sigma_key][self.block_idx]
60 | if sub_idxs is not None:
61 | h = res_inj[sub_idxs].to(x.device)
62 | else:
63 | h = res_inj.to(x.device)
64 | if len_cond > 1:
65 | h = torch.cat([h]*len_cond)
66 | elif extra_options['INJECTION_KEY'] == 'UNSAMPLING' and step < inj_config.res_save_steps:
67 | overlap = extra_options.get('OVERLAP', None)
68 | sigma_key = inj_config.res_sigmas[step]
69 | res_inj = inj_config.res_injections[sigma_key][self.block_idx]
70 | if overlap == 0 or sub_idx == 0:
71 | res_inj.append(h.clone().detach().cpu())
72 | else:
73 | res_inj.append(h[overlap:].clone().detach().cpu())
74 |
75 | return self.skip_connection(x) + h
76 |
77 |
78 |
79 | def _get_resnet_modules(module):
80 | blocks = list(filter(lambda x: isinstance_str(x[1], 'ResBlock'), module.named_modules()))
81 | return [block for _, block in blocks]
82 |
83 |
84 | def inject_vv_resblock(diffusion_model: UNetModel):
85 | input = _get_resnet_modules(diffusion_model.input_blocks)
86 | middle = _get_resnet_modules(diffusion_model.middle_block)
87 | output = _get_resnet_modules(diffusion_model.output_blocks)
88 |
89 | for idx, resnet in enumerate(input):
90 | resnet.__class__ = ResBlock
91 | resnet.init_module('input', idx)
92 |
93 | for idx, resnet in enumerate(middle):
94 | resnet.__class__ = ResBlock
95 | resnet.init_module('middle', idx)
96 |
97 | for idx, resnet in enumerate(output):
98 | resnet.__class__ = ResBlock
99 | resnet.init_module('output', idx)
--------------------------------------------------------------------------------
/modules/vv_unet.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import List
3 | import torch
4 |
5 | from comfy.ldm.modules.attention import SpatialTransformer
6 | from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepBlock, UNetModel, Upsample, apply_control
7 | from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
8 |
9 |
10 | def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
11 | for layer in ts:
12 | if isinstance(layer, TimestepBlock):
13 | x = layer(x, emb, transformer_options)
14 | elif isinstance(layer, SpatialTransformer):
15 | x = layer(x, context, transformer_options)
16 | if "transformer_index" in transformer_options:
17 | transformer_options["transformer_index"] += 1
18 | elif isinstance(layer, Upsample):
19 | x = layer(x, output_shape=output_shape)
20 | elif 'raunet' in str(layer.__class__):
21 | x = layer(x, output_shape=output_shape, transformer_options=transformer_options)
22 | elif 'Temporal' in str(layer.__class__):
23 | temporal_config = transformer_options.get('TEMPORAL_CONFIG', None)
24 | step_percent = transformer_options.get('STEP_PERCENT', 999)
25 | if temporal_config is None or (temporal_config.start_percent <= step_percent <= temporal_config.end_percent):
26 | x = layer(x, context)
27 | else:
28 | x = layer(x)
29 | return x
30 |
31 |
32 | class VVUNetModel(UNetModel):
33 | def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
34 | return self._forward_sample(x, timesteps, context, y, control, transformer_options=transformer_options)
35 |
36 | def _forward_sample(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
37 | transformer_options["original_shape"] = list(x.shape)
38 | transformer_options["transformer_index"] = 0
39 | transformer_patches = transformer_options.get("patches", {})
40 |
41 | hs = []
42 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
43 | emb = self.time_embed(t_emb)
44 |
45 | if self.num_classes is not None:
46 | assert y.shape[0] == x.shape[0]
47 | emb = emb + self.label_emb(y)
48 |
49 | h = x
50 | for id, module in enumerate(self.input_blocks):
51 | transformer_options["block"] = ("input", id)
52 | h = forward_timestep_embed(module, h, emb, context, transformer_options)
53 | if control is not None:
54 | h = apply_control(h, control, 'input')
55 | if "input_block_patch" in transformer_patches:
56 | patch = transformer_patches["input_block_patch"]
57 | for p in patch:
58 | h = p(h, transformer_options)
59 |
60 | hs.append(h)
61 | if "input_block_patch_after_skip" in transformer_patches:
62 | patch = transformer_patches["input_block_patch_after_skip"]
63 | for p in patch:
64 | h = p(h, transformer_options)
65 |
66 | transformer_options["block"] = ("middle", 0)
67 | h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
68 | h = apply_control(h, control, 'middle')
69 |
70 | for id, module in enumerate(self.output_blocks):
71 | transformer_options["block"] = ("output", id)
72 | hsp = hs.pop()
73 | if control is not None:
74 | hsp = apply_control(hsp, control, 'output')
75 |
76 | if "output_block_patch" in transformer_patches:
77 | patch = transformer_patches["output_block_patch"]
78 | for p in patch:
79 | h, hsp = p(h, hsp, transformer_options)
80 |
81 | h = torch.cat([h, hsp], dim=1)
82 | del hsp
83 | if len(hs) > 0:
84 | output_shape = hs[-1].shape
85 | else:
86 | output_shape = None
87 | h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
88 |
89 | h = h.type(x.dtype)
90 | return self.out(h)
91 |
92 |
93 | def inject_unet(diffusion_model):
94 | diffusion_model.__class__ = VVUNetModel
95 |
--------------------------------------------------------------------------------
/nodes/config_flow_attn_node.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class ConfigFlowAttnFlowSD1Node:
4 | @classmethod
5 | def INPUT_TYPES(s):
6 | base = {"required": {
7 | }}
8 | for i in range(8):
9 | base['required'][f'input_{i}'] = ("BOOLEAN", { "default": i < 2 })
10 |
11 | base['required'][f'middle_0'] = ("BOOLEAN", { "default": False })
12 |
13 | for i in range(9):
14 | base['required'][f'output_{i}'] = ("BOOLEAN", { "default": i > 6 })
15 |
16 | return base
17 | RETURN_TYPES = ("ATTN_MAP",)
18 | FUNCTION = "apply"
19 |
20 | CATEGORY = "vv"
21 |
22 | def apply(self, **kwargs):
23 |
24 | attention_map = set()
25 | for key, value in kwargs.items():
26 | if value:
27 | block, idx = key.split('_')
28 | attention_map.add((block, int(idx)))
29 |
30 | return (attention_map, )
31 |
32 |
33 | SDXL_DEFAULT_ALL_ATTNS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35"
34 |
35 | class ConfigFlowAttnFlowSDXLNode:
36 | @classmethod
37 | def INPUT_TYPES(s):
38 | base = {"required": {
39 | "input_attns": ("STRING", {"multiline": True, "default": SDXL_DEFAULT_ALL_ATTNS }),
40 | "output_attns": ("STRING", {"multiline": True, "default": SDXL_DEFAULT_ALL_ATTNS }),
41 | }}
42 | return base
43 | RETURN_TYPES = ("ATTN_MAP",)
44 | FUNCTION = "apply"
45 |
46 | CATEGORY = "flow"
47 |
48 | def apply(self, input_attns, output_attns):
49 |
50 | attention_map = set()
51 | if input_attns != '' and input_attns is not None:
52 | for idx in input_attns.split(','):
53 | idx = idx.strip()
54 | if idx is '':
55 | continue
56 | attention_map.add(('input', int(idx)))
57 | if input_attns != '' and input_attns is not None:
58 | for idx in output_attns.split(','):
59 | idx = idx.strip()
60 | if idx is '':
61 | continue
62 | attention_map.add(('output', int(idx)))
63 |
64 | return (attention_map, )
65 |
66 |
67 | class ConfigFlowAttnCrossFrameSDXLNode:
68 | @classmethod
69 | def INPUT_TYPES(s):
70 | base = {"required": {
71 | "input_attns": ("STRING", {"multiline": True, "default": SDXL_DEFAULT_ALL_ATTNS }),
72 | "output_attns": ("STRING", {"multiline": True, "default": SDXL_DEFAULT_ALL_ATTNS }),
73 | }}
74 | return base
75 | RETURN_TYPES = ("ATTN_MAP",)
76 | FUNCTION = "apply"
77 |
78 | CATEGORY = "flow"
79 |
80 | def apply(self, input_attns, output_attns):
81 |
82 | attention_map = set()
83 | if input_attns != '' and input_attns is not None:
84 | for idx in input_attns.split(','):
85 | idx = idx.strip()
86 | if idx == '' or idx is None:
87 | continue
88 | attention_map.add(('input', int(idx)))
89 | if output_attns != '' and output_attns is not None:
90 | for idx in output_attns.split(','):
91 | idx = idx.strip()
92 | if idx == '' or idx is None:
93 | continue
94 | attention_map.add(('output', int(idx)))
95 |
96 | return (attention_map, )
--------------------------------------------------------------------------------
/nodes/config_flow_inj_node.py:
--------------------------------------------------------------------------------
1 | SD1_ATTN_DEFAULTS = set([1,2,3,4,5,6])
2 | SD1_RES_DEFAULTS = set([3,4,6])
3 |
4 |
5 | class ConfigFlowAttnInjSD1Node:
6 | @classmethod
7 | def INPUT_TYPES(s):
8 | base = {"required": {
9 | }}
10 |
11 | for i in range(9):
12 | base['required'][f'output_{i}'] = ("BOOLEAN", { "default": i in SD1_ATTN_DEFAULTS })
13 |
14 | return base
15 | RETURN_TYPES = ("ATTN_INJ_MAP",)
16 | FUNCTION = "apply"
17 |
18 | CATEGORY = "vv"
19 |
20 | def apply(self, **kwargs):
21 |
22 | attention_map = set()
23 | for key, value in kwargs.items():
24 | if value:
25 | block, idx = key.split('_')
26 | attention_map.add((block, int(idx)))
27 |
28 | return (attention_map, )
29 |
30 |
31 | # 0-35
32 | class ConfigFlowAttnInjSDXLNode:
33 | @classmethod
34 | def INPUT_TYPES(s):
35 | base = {"required": {
36 | "attns": ("STRING", {"multiline": True }),
37 | }}
38 | return base
39 | RETURN_TYPES = ("ATTN_INJ_MAP",)
40 | FUNCTION = "apply"
41 |
42 | CATEGORY = "flow"
43 |
44 | def apply(self, attns):
45 |
46 | attention_map = set()
47 | for idx in attns.split(','):
48 | idx = int(idx.strip())
49 | attention_map.add(('output', int(idx)))
50 |
51 | return (attention_map, )
52 |
53 |
54 | class ConfigFlowResInjSD1Node:
55 | @classmethod
56 | def INPUT_TYPES(s):
57 | base = {"required": {
58 | }}
59 |
60 | for i in range(9):
61 | base['required'][f'output_{i}'] = ("BOOLEAN", { "default": i in SD1_RES_DEFAULTS })
62 |
63 | return base
64 | RETURN_TYPES = ("RES_INJ_MAP",)
65 | FUNCTION = "apply"
66 |
67 | CATEGORY = "flow"
68 |
69 | def apply(self, **kwargs):
70 |
71 | resnet_map = set()
72 | for key, value in kwargs.items():
73 | if value:
74 | block, idx = key.split('_')
75 | resnet_map.add((block, int(idx)))
76 |
77 | return (resnet_map, )
78 |
79 |
80 | class ConfigFlowResInjSDXLNode:
81 | @classmethod
82 | def INPUT_TYPES(s):
83 | base = {"required": {
84 | }}
85 |
86 | for i in range(9):
87 | base['required'][f'output_{i}'] = ("BOOLEAN", { "default": i in SD1_RES_DEFAULTS })
88 |
89 | return base
90 | RETURN_TYPES = ("RES_INJ_MAP",)
91 | FUNCTION = "apply"
92 |
93 | CATEGORY = "flow"
94 |
95 | def apply(self, **kwargs):
96 |
97 | resnet_map = set()
98 | for key, value in kwargs.items():
99 | if value:
100 | block, idx = key.split('_')
101 | resnet_map.add((block, int(idx)))
102 |
103 | return (resnet_map, )
--------------------------------------------------------------------------------
/nodes/flow_config_node.py:
--------------------------------------------------------------------------------
1 |
2 | class FlowConfigNode:
3 | @classmethod
4 | def INPUT_TYPES(s):
5 | return {"required": {
6 | "flow": ("FLOW",),
7 | "targets": (['full', 'inner', 'outer', 'none'],),
8 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
9 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
10 | }}
11 | RETURN_TYPES = ("FLOW_CONFIG",)
12 | FUNCTION = "build"
13 |
14 | CATEGORY = "vv/configs"
15 |
16 | def build(self, flow, targets, start_percent, end_percent):
17 | return ({ 'targets': targets, 'flow': flow, 'start_percent': start_percent, 'end_percent': end_percent},)
--------------------------------------------------------------------------------
/nodes/inj_config_node.py:
--------------------------------------------------------------------------------
1 |
2 | def find_closest_index(lst, target):
3 | return min(range(len(lst)), key=lambda i: abs(lst[i] - target))
4 |
5 |
6 | class InjectionConfigNode:
7 | @classmethod
8 | def INPUT_TYPES(s):
9 | return {"required": {
10 | "unsampler_sigmas": ("SIGMAS",),
11 | "sampler_sigmas": ("SIGMAS",),
12 | "save_attn_steps": ("INT", {"default": 3, "min": 0, "max": 999, "step": 1}),
13 | "save_res_steps": ("INT", {"default": 3, "min": 0, "max": 999, "step": 1}),
14 | }}
15 | RETURN_TYPES = ("INJ_CONFIG",)
16 | FUNCTION = "create"
17 |
18 | CATEGORY = "vv/configs"
19 |
20 | def create(self, unsampler_sigmas, sampler_sigmas, save_attn_steps, save_res_steps):
21 | attn_injections = {}
22 | attn_sigmas = []
23 | for i in range(save_attn_steps):
24 | sampler_sigma = sampler_sigmas[i].item()
25 | unsampler_idx = find_closest_index(unsampler_sigmas, sampler_sigma)
26 | unsampler_sigma = unsampler_sigmas[unsampler_idx].item()
27 | attn_injections[unsampler_sigma] = {}
28 | attn_sigmas.append(unsampler_sigma)
29 |
30 | res_injections = {}
31 | res_sigmas = []
32 | for i in range(save_res_steps):
33 | sampler_sigma = sampler_sigmas[i].item()
34 | unsampler_idx = find_closest_index(unsampler_sigmas, sampler_sigma)
35 | unsampler_sigma = unsampler_sigmas[unsampler_idx].item()
36 | res_injections[unsampler_sigma] = {}
37 | res_sigmas.append(unsampler_sigma)
38 |
39 | config = {
40 | 'attn_injections': attn_injections,
41 | 'res_injections': res_injections,
42 | 'attn_save_steps': save_attn_steps,
43 | 'res_save_steps': save_res_steps,
44 | 'attn_sigmas': attn_sigmas,
45 | 'res_sigmas': res_sigmas
46 | }
47 | return (config, )
48 |
--------------------------------------------------------------------------------
/nodes/pivot_config_node.py:
--------------------------------------------------------------------------------
1 | from ..vv_defaults import MAP_TYPES
2 |
3 | class PivotConfigNode:
4 | @classmethod
5 | def INPUT_TYPES(s):
6 | return {"required": {
7 | "batch_size": ("INT", {"default": 3, "min": 2, "max": 9, "step": 1}),
8 | "targets": (MAP_TYPES,),
9 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
10 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
11 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
12 | }}
13 | RETURN_TYPES = ("PIVOT_CONFIG",)
14 | FUNCTION = "build"
15 |
16 | CATEGORY = "vv/configs"
17 |
18 | def build(self, batch_size, targets, seed, start_percent, end_percent):
19 | return ({ 'batch_size': batch_size, 'targets': targets, 'seed': seed, 'start_percent': start_percent, 'end_percent': end_percent},)
20 |
--------------------------------------------------------------------------------
/nodes/rave_config_node.py:
--------------------------------------------------------------------------------
1 | from ..vv_defaults import MAP_TYPES
2 |
3 |
4 | class RaveConfigNode:
5 | @classmethod
6 | def INPUT_TYPES(s):
7 | return {"required": {
8 | "grid_size": ("INT", {"default": 3, "min": 2, "max": 9, "step": 1}),
9 | "targets": (MAP_TYPES,),
10 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
11 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
12 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
13 | }}
14 | RETURN_TYPES = ("RAVE_CONFIG",)
15 | FUNCTION = "build"
16 |
17 | CATEGORY = "vv/configs"
18 |
19 | def build(self, grid_size, targets, seed, start_percent, end_percent):
20 | return ({ 'grid_size': grid_size, 'targets': targets, 'seed': seed, 'start_percent': start_percent, 'end_percent': end_percent},)
--------------------------------------------------------------------------------
/nodes/sca_config_node.py:
--------------------------------------------------------------------------------
1 | from ..vv_defaults import MAP_TYPES
2 |
3 |
4 | class SCAConfigNode:
5 | @classmethod
6 | def INPUT_TYPES(s):
7 | return {"required": {
8 | "direction": (['PREVIOUS', 'NEXT', 'BOTH'],),
9 | "targets": (MAP_TYPES,),
10 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
11 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
12 | }}
13 | RETURN_TYPES = ("SCA_CONFIG",)
14 | FUNCTION = "build"
15 |
16 | CATEGORY = "vv/configs"
17 |
18 | def build(self, direction, targets, start_percent, end_percent):
19 |
20 | return ({ 'direction': direction, 'targets': targets, 'start_percent': start_percent, 'end_percent': end_percent},)
--------------------------------------------------------------------------------
/nodes/temporal_config_node.py:
--------------------------------------------------------------------------------
1 |
2 | class TemporalConfigNode:
3 | @classmethod
4 | def INPUT_TYPES(s):
5 | return {"required": {
6 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
7 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
8 | }}
9 | RETURN_TYPES = ("TEMPORAL_CONFIG",)
10 | FUNCTION = "build"
11 |
12 | CATEGORY = "vv/configs"
13 |
14 | def build(self, start_percent, end_percent):
15 |
16 | return ({ 'start_percent': start_percent, 'end_percent': end_percent},)
--------------------------------------------------------------------------------
/nodes/vv_apply_node.py:
--------------------------------------------------------------------------------
1 | from ..modules.vv_attention import inject_vv_atn
2 | from ..modules.vv_block import inject_vv_block
3 | from ..modules.vv_resnet import inject_vv_resblock
4 | from ..modules.vv_unet import inject_unet
5 |
6 |
7 | class ApplyVVModel:
8 | @classmethod
9 | def INPUT_TYPES(s):
10 | return {"required": {
11 | "model": ("MODEL",),
12 | }}
13 | RETURN_TYPES = ("MODEL",)
14 | FUNCTION = "apply"
15 |
16 | CATEGORY = "vv"
17 |
18 | def apply(self, model):
19 | inject_vv_atn(model.model.diffusion_model)
20 | inject_vv_block(model.model.diffusion_model)
21 | inject_vv_resblock(model.model.diffusion_model)
22 | inject_unet(model.model.diffusion_model)
23 | return (model, )
24 |
--------------------------------------------------------------------------------
/nodes/vv_sampler_node.py:
--------------------------------------------------------------------------------
1 | import comfy.samplers
2 | from comfy.samplers import KSAMPLER
3 |
4 | from ..utils.sampler_utils import get_sampler_fn, create_sampler
5 |
6 |
7 |
8 | class VVSamplerSamplerNode:
9 | @classmethod
10 | def INPUT_TYPES(s):
11 | return {"required": {
12 | "sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
13 | "attn_injection_steps": ("INT", {"default": 3, "min": 0, "max": 999, "step": 1}),
14 | "res_injection_steps": ("INT", {"default": 3, "min": 0, "max": 999, "step": 1}),
15 | "overlap": ("INT", {"default": 5, "min": 0, "max": 999, "step": 1}),
16 | }, "optional": {
17 | "flow_config": ("FLOW_CONFIG",),
18 | "inj_config": ("INJ_CONFIG",),
19 | "sca_config": ('SCA_CONFIG',),
20 | "pivot_config": ('PIVOT_CONFIG',),
21 | "rave_config": ('RAVE_CONFIG',),
22 | "temporal_config": ('TEMPORAL_CONFIG',),
23 | "sampler": ("SAMPLER",),
24 | }}
25 | RETURN_TYPES = ("SAMPLER",)
26 | FUNCTION = "build"
27 |
28 | CATEGORY = "vv/sampling"
29 |
30 | def build(self,
31 | sampler_name,
32 | attn_injection_steps,
33 | res_injection_steps,
34 | overlap,
35 | flow_config=None,
36 | inj_config=None,
37 | sca_config=None,
38 | pivot_config=None,
39 | rave_config=None,
40 | temporal_config=None,
41 | sampler=None):
42 |
43 | sampler_fn = get_sampler_fn(sampler_name)
44 | sampler_fn = create_sampler(sampler_fn, attn_injection_steps, res_injection_steps, overlap,
45 | flow_config, inj_config,pivot_config, rave_config, sca_config,
46 | temporal_config)
47 |
48 | if sampler is None:
49 | sampler = KSAMPLER(sampler_fn)
50 | else:
51 | sampler.sampler_function = sampler_fn
52 |
53 | return (sampler, )
--------------------------------------------------------------------------------
/nodes/vv_unsampler_node.py:
--------------------------------------------------------------------------------
1 | import comfy.samplers
2 | from comfy.samplers import KSAMPLER
3 |
4 | from ..utils.sampler_utils import get_sampler_fn, create_unsampler
5 |
6 |
7 | class VVUnsamplerSamplerNode:
8 | @classmethod
9 | def INPUT_TYPES(s):
10 | return {"required": {
11 | "sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
12 | "overlap": ("INT", {"default": 5, "min": 0, "max": 999, "step": 1}),
13 | }, "optional": {
14 | "flow_config": ("FLOW_CONFIG",),
15 | "inj_config": ("INJ_CONFIG",),
16 | "sampler": ("SAMPLER",)
17 | }}
18 | RETURN_TYPES = ("SAMPLER",)
19 | FUNCTION = "build"
20 |
21 | CATEGORY = "vv/sampling"
22 |
23 | def build(self,
24 | sampler_name,
25 | overlap,
26 | flow_config=None,
27 | inj_config=None,
28 | sampler=None):
29 |
30 | sampler_fn = get_sampler_fn(sampler_name)
31 | sampler_fn = create_unsampler(sampler_fn, overlap, flow_config, inj_config)
32 |
33 | if sampler is None:
34 | sampler = KSAMPLER(sampler_fn)
35 | else:
36 | sampler.sampler_function = sampler_fn
37 |
38 | return (sampler, )
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einshape
2 | timm
--------------------------------------------------------------------------------
/sea_raft/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | dist
3 | datasets
4 | pytorch_env
5 | models
6 | ckpt*
7 | build
8 | demo
9 | runs
10 | results
11 | */__pycache__/*
12 | checkpoints
13 | weights
14 | *.ckpt
15 | *.pth
16 | *.ppm
17 | *.flo
18 | *.egg-info
19 | *.npy
20 | *.npz
--------------------------------------------------------------------------------
/sea_raft/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Princeton Vision & Learning Lab
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/sea_raft/README.md:
--------------------------------------------------------------------------------
1 | # [ECCV24] SEA-RAFT
2 |
3 | We introduce SEA-RAFT, a more simple, efficient, and accurate [RAFT](https://github.com/princeton-vl/RAFT) for optical flow. Compared with RAFT, SEA-RAFT is trained with a new loss (mixture of Laplace). It directly regresses an initial flow for faster convergence in iterative refinements and introduces rigid-motion pre-training to improve generalization. SEA-RAFT achieves state-of-the-art accuracy on the [Spring benchmark](https://spring-benchmark.org/) with a 3.69 endpoint-error (EPE) and a 0.36 1-pixel outlier rate (1px), representing 22.9\% and 17.8\% error reduction from best-published results. In addition, SEA-RAFT obtains the best cross-dataset generalization on KITTI and Spring. With its high efficiency, SEA-RAFT operates at least 2.3x faster than existing methods while maintaining competitive performance.
4 |
5 |
6 |
7 | If you find SEA-RAFT useful for your work, please consider citing our academic paper:
8 |
9 |
14 |
15 | Yihan Wang,
16 | Lahav Lipson,
17 | Jia Deng
18 |
19 |
20 | ```
21 | @article{wang2024sea,
22 | title={SEA-RAFT: Simple, Efficient, Accurate RAFT for Optical Flow},
23 | author={Wang, Yihan and Lipson, Lahav and Deng, Jia},
24 | journal={arXiv preprint arXiv:2405.14793},
25 | year={2024}
26 | }
27 | ```
28 |
29 | ## Requirements
30 | Our code is developed with pytorch 2.2.0, CUDA 12.2 and python 3.10.
31 | ```Shell
32 | conda create --name SEA-RAFT python=3.10.13
33 | conda activate SEA-RAFT
34 | pip install -r requirements.txt
35 | ```
36 |
37 | ## Model Zoo
38 | Please download the models from [google drive](https://drive.google.com/drive/folders/1YLovlvUW94vciWvTyLf-p3uWscbOQRWW?usp=sharing) and put them into the `models` folder.
39 |
40 | ## Custom Usage
41 |
42 | We provide an example in `custom.py`. By default, this file will take two RGB images as the input and provide visualizations of the optical flow and the uncertainty.
43 | ```Shell
44 | python custom.py --cfg config/eval/spring-M.json --model models/Tartan-C-T-TSKH-spring540x960-M.pth
45 | ```
46 |
47 | ## Datasets
48 | To evaluate/train SEA-RAFT, you will need to download the required datasets: [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs), [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html), [Sintel](http://sintel.is.tue.mpg.de/), [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow), [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/), [TartanAir](https://theairlab.org/tartanair-dataset/), and [Spring](https://spring-benchmark.org/).
49 |
50 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder. Please check [RAFT](https://github.com/princeton-vl/RAFT) for more details.
51 |
52 | ```Shell
53 | ├── datasets
54 | ├── Sintel
55 | ├── KITTI
56 | ├── FlyingChairs/FlyingChairs_release
57 | ├── FlyingThings3D
58 | ├── HD1K
59 | ├── spring
60 | ├── test
61 | ├── train
62 | ├── val
63 | ├── tartanair
64 | ```
65 |
66 | ## Training, Evaluation, and Submission
67 |
68 | Please refer to [scripts/train.sh](scripts/train.sh), [scripts/eval.sh](scripts/eval.sh), and [scripts/submission.sh](scripts/submission.sh) for more details.
69 |
70 | ## Acknowledgements
71 |
72 | This project relies on code from existing repositories: [RAFT](https://github.com/princeton-vl/RAFT), [unimatch](https://github.com/autonomousvision/unimatch/tree/master), [Flowformer](https://github.com/drinkingcoder/FlowFormer-Official), [ptlflow](https://github.com/hmorimitsu/ptlflow), and [LoFTR](https://github.com/zju3dv/LoFTR). We thank the original authors for their excellent work.
73 |
--------------------------------------------------------------------------------
/sea_raft/assets/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/assets/visualization.png
--------------------------------------------------------------------------------
/sea_raft/config/eval/kitti-L.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "kitti-L",
3 | "dataset": "kitti",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 12,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 16,
20 | "epsilon": 1e-8,
21 | "lr": 1e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 10000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/eval/kitti-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "kitti-M",
3 | "dataset": "kitti",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 16,
20 | "epsilon": 1e-8,
21 | "lr": 1e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 10000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/eval/kitti-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "kitti-S",
3 | "dataset": "kitti",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 16,
20 | "epsilon": 1e-8,
21 | "lr": 1e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 10000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/eval/sintel-L.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "sintel-M",
3 | "dataset": "sintel",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 12,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 300000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/eval/sintel-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "sintel-M",
3 | "dataset": "sintel",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 300000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/eval/sintel-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "sintel-S",
3 | "dataset": "sintel",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 300000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/eval/spring-L.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "spring-L",
3 | "dataset": "spring",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 12,
16 |
17 | "image_size": [540, 960],
18 | "scale": -1,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 120000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/eval/spring-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "spring-M",
3 | "dataset": "spring",
4 | "gpus": [
5 | 0,
6 | 1,
7 | 2,
8 | 3,
9 | 4,
10 | 5,
11 | 6,
12 | 7
13 | ],
14 | "use_var": true,
15 | "var_min": 0,
16 | "var_max": 10,
17 | "pretrain": "resnet34",
18 | "initial_dim": 64,
19 | "block_dims": [
20 | 64,
21 | 128,
22 | 256
23 | ],
24 | "radius": 4,
25 | "dim": 128,
26 | "num_blocks": 2,
27 | "iters": 4,
28 | "image_size": [
29 | 540,
30 | 960
31 | ],
32 | "scale": -1,
33 | "batch_size": 32,
34 | "epsilon": 1e-8,
35 | "lr": 4e-4,
36 | "wdecay": 1e-5,
37 | "dropout": 0,
38 | "clip": 1.0,
39 | "gamma": 0.85,
40 | "num_steps": 120000,
41 | "restore_ckpt": null,
42 | "coarse_config": null
43 | }
--------------------------------------------------------------------------------
/sea_raft/config/eval/spring-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "spring-S",
3 | "dataset": "spring",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [540, 960],
18 | "scale": -1,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 120000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/parser.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import json
3 | from typing import Any, List
4 |
5 | def get_config_data(json_path):
6 | with open(json_path, 'r') as f:
7 | data = json.load(f)
8 | return data
9 |
10 |
11 | def get_model_config(config_path):
12 | args = get_config_data(config_path)
13 | model_args = ModelArgs(**args)
14 | return model_args
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C-T-TSKH-kitti432x960-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C-T-TSKH-kitti432x960-M",
3 | "dataset": "kitti",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 16,
20 | "epsilon": 1e-8,
21 | "lr": 1e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 10000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C-T-TSKH-kitti432x960-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C-T-TSKH-kitti432x960-S",
3 | "dataset": "kitti",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 16,
20 | "epsilon": 1e-8,
21 | "lr": 1e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 10000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C-T-TSKH-spring540x960-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C-T-TSKH-spring540x960-M",
3 | "dataset": "spring",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [540, 960],
18 | "scale": -1,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 120000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C-T-TSKH-spring540x960-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C-T-TSKH-spring540x960-S",
3 | "dataset": "spring",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [540, 960],
18 | "scale": -1,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 120000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C-T-TSKH432x960-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C-T-TSKH432x960-M",
3 | "dataset": "TSKH",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 300000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C-T-TSKH432x960-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C-T-TSKH432x960-S",
3 | "dataset": "TSKH",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 300000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C-T432x960-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C-T432x960-M",
3 | "dataset": "things",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 120000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C-T432x960-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C-T432x960-S",
3 | "dataset": "things",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [432, 960],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 120000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C368x496-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C368x496-M",
3 | "dataset": "chairs",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7, 8],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [368, 496],
18 | "scale": 0,
19 | "batch_size": 16,
20 | "epsilon": 1e-8,
21 | "lr": 2.5e-4,
22 | "wdecay": 1e-4,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.8,
26 | "num_steps": 100000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan-C368x496-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan-C368x496-S",
3 | "dataset": "chairs",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7, 8],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [368, 496],
18 | "scale": 0,
19 | "batch_size": 16,
20 | "epsilon": 1e-8,
21 | "lr": 2.5e-4,
22 | "wdecay": 1e-4,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.8,
26 | "num_steps": 100000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan480x640-M.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan480x640-M",
3 | "dataset": "TartanAir",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet34",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [480, 640],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 300000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/config/train/Tartan480x640-S.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Tartan480x640-S",
3 | "dataset": "TartanAir",
4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7],
5 |
6 | "use_var": true,
7 | "var_min": 0,
8 | "var_max": 10,
9 | "pretrain": "resnet18",
10 | "initial_dim": 64,
11 | "block_dims": [64, 128, 256],
12 | "radius": 4,
13 | "dim": 128,
14 | "num_blocks": 2,
15 | "iters": 4,
16 |
17 | "image_size": [480, 640],
18 | "scale": 0,
19 | "batch_size": 32,
20 | "epsilon": 1e-8,
21 | "lr": 4e-4,
22 | "wdecay": 1e-5,
23 | "dropout": 0,
24 | "clip": 1.0,
25 | "gamma": 0.85,
26 | "num_steps": 300000,
27 |
28 | "restore_ckpt": null,
29 | "coarse_config": null
30 | }
--------------------------------------------------------------------------------
/sea_raft/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/core/__init__.py
--------------------------------------------------------------------------------
/sea_raft/core/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .layer import BasicBlock, conv1x1, conv3x3
5 |
6 | class ResNetFPN(nn.Module):
7 | """
8 | ResNet18, output resolution is 1/8.
9 | Each block has 2 layers.
10 | """
11 | def __init__(self, args, input_dim=3, output_dim=256, ratio=1.0, norm_layer=nn.BatchNorm2d, init_weight=False):
12 | super().__init__()
13 | # Config
14 | block = BasicBlock
15 | block_dims = args.block_dims
16 | initial_dim = args.initial_dim
17 | self.init_weight = init_weight
18 | self.input_dim = input_dim
19 | # Class Variable
20 | self.in_planes = initial_dim
21 | for i in range(len(block_dims)):
22 | block_dims[i] = int(block_dims[i] * ratio)
23 | # Networks
24 | self.conv1 = nn.Conv2d(input_dim, initial_dim, kernel_size=7, stride=2, padding=3)
25 | self.bn1 = norm_layer(initial_dim)
26 | self.relu = nn.ReLU(inplace=True)
27 | if args.pretrain == 'resnet34':
28 | n_block = [3, 4, 6]
29 | elif args.pretrain == 'resnet18':
30 | n_block = [2, 2, 2]
31 | else:
32 | raise NotImplementedError
33 | self.layer1 = self._make_layer(block, block_dims[0], stride=1, norm_layer=norm_layer, num=n_block[0]) # 1/2
34 | self.layer2 = self._make_layer(block, block_dims[1], stride=2, norm_layer=norm_layer, num=n_block[1]) # 1/4
35 | self.layer3 = self._make_layer(block, block_dims[2], stride=2, norm_layer=norm_layer, num=n_block[2]) # 1/8
36 | self.final_conv = conv1x1(block_dims[2], output_dim)
37 | self._init_weights(args)
38 |
39 | def _init_weights(self, args):
40 |
41 | for m in self.modules():
42 | if isinstance(m, nn.Conv2d):
43 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
44 | if m.bias is not None:
45 | nn.init.constant_(m.bias, 0)
46 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
47 | if m.weight is not None:
48 | nn.init.constant_(m.weight, 1)
49 | if m.bias is not None:
50 | nn.init.constant_(m.bias, 0)
51 |
52 | if self.init_weight:
53 | from torchvision.models import resnet18, ResNet18_Weights, resnet34, ResNet34_Weights
54 | if args.pretrain == 'resnet18':
55 | pretrained_dict = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).state_dict()
56 | else:
57 | pretrained_dict = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1).state_dict()
58 | model_dict = self.state_dict()
59 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
60 | if self.input_dim == 6:
61 | for k, v in pretrained_dict.items():
62 | if k == 'conv1.weight':
63 | pretrained_dict[k] = torch.cat((v, v), dim=1)
64 | model_dict.update(pretrained_dict)
65 | self.load_state_dict(model_dict, strict=False)
66 |
67 |
68 | def _make_layer(self, block, dim, stride=1, norm_layer=nn.BatchNorm2d, num=2):
69 | layers = []
70 | layers.append(block(self.in_planes, dim, stride=stride, norm_layer=norm_layer))
71 | for i in range(num - 1):
72 | layers.append(block(dim, dim, stride=1, norm_layer=norm_layer))
73 | self.in_planes = dim
74 | return nn.Sequential(*layers)
75 |
76 | def forward(self, x):
77 | # ResNet Backbone
78 | x = self.relu(self.bn1(self.conv1(x)))
79 | for i in range(len(self.layer1)):
80 | x = self.layer1[i](x)
81 | for i in range(len(self.layer2)):
82 | x = self.layer2[i](x)
83 | for i in range(len(self.layer3)):
84 | x = self.layer3[i](x)
85 | # Output
86 | output = self.final_conv(x)
87 | return output
--------------------------------------------------------------------------------
/sea_raft/core/layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import torch
6 | import math
7 | from torch.nn import Module, Dropout
8 |
9 | ### Gradient Clipping and Zeroing Operations ###
10 |
11 | GRAD_CLIP = 0.1
12 |
13 | class GradClip(torch.autograd.Function):
14 | @staticmethod
15 | def forward(ctx, x):
16 | return x
17 |
18 | @staticmethod
19 | def backward(ctx, grad_x):
20 | grad_x = torch.where(torch.isnan(grad_x), torch.zeros_like(grad_x), grad_x)
21 | return grad_x.clamp(min=-0.01, max=0.01)
22 |
23 | class GradientClip(nn.Module):
24 | def __init__(self):
25 | super(GradientClip, self).__init__()
26 |
27 | def forward(self, x):
28 | return GradClip.apply(x)
29 |
30 | def _make_divisible(v, divisor, min_value=None):
31 | if min_value is None:
32 | min_value = divisor
33 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
34 | # Make sure that round down does not go down by more than 10%.
35 | if new_v < 0.9 * v:
36 | new_v += divisor
37 | return new_v
38 |
39 | class ConvNextBlock(nn.Module):
40 | r""" ConvNeXt Block. There are two equivalent implementations:
41 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
42 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
43 | We use (2) as we find it slightly faster in PyTorch
44 |
45 | Args:
46 | dim (int): Number of input channels.
47 | drop_path (float): Stochastic depth rate. Default: 0.0
48 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
49 | """
50 | def __init__(self, dim, output_dim, layer_scale_init_value=1e-6):
51 | super().__init__()
52 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
53 | self.norm = LayerNorm(dim, eps=1e-6)
54 | self.pwconv1 = nn.Linear(dim, 4 * output_dim) # pointwise/1x1 convs, implemented with linear layers
55 | self.act = nn.GELU()
56 | self.pwconv2 = nn.Linear(4 * output_dim, dim)
57 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
58 | requires_grad=True) if layer_scale_init_value > 0 else None
59 | self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0)
60 |
61 | def forward(self, x):
62 | input = x
63 | x = self.dwconv(x)
64 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
65 | x = self.norm(x)
66 | x = self.pwconv1(x)
67 | x = self.act(x)
68 | x = self.pwconv2(x)
69 | if self.gamma is not None:
70 | x = self.gamma * x
71 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
72 | x = self.final(input + x)
73 | return x
74 |
75 | class LayerNorm(nn.Module):
76 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
77 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
78 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
79 | with shape (batch_size, channels, height, width).
80 | """
81 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
82 | super().__init__()
83 | self.weight = nn.Parameter(torch.ones(normalized_shape))
84 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
85 | self.eps = eps
86 | self.data_format = data_format
87 | if self.data_format not in ["channels_last", "channels_first"]:
88 | raise NotImplementedError
89 | self.normalized_shape = (normalized_shape, )
90 |
91 | def forward(self, x):
92 | if self.data_format == "channels_last":
93 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
94 | elif self.data_format == "channels_first":
95 | u = x.mean(1, keepdim=True)
96 | s = (x - u).pow(2).mean(1, keepdim=True)
97 | x = (x - u) / torch.sqrt(s + self.eps)
98 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
99 | return x
100 |
101 | def conv1x1(in_planes, out_planes, stride=1):
102 | """1x1 convolution without padding"""
103 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)
104 |
105 |
106 | def conv3x3(in_planes, out_planes, stride=1):
107 | """3x3 convolution with padding"""
108 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)
109 |
110 | class BasicBlock(nn.Module):
111 | def __init__(self, in_planes, planes, stride=1, norm_layer=nn.BatchNorm2d):
112 | super().__init__()
113 |
114 | # self.sparse = sparse
115 | self.conv1 = conv3x3(in_planes, planes, stride)
116 | self.conv2 = conv3x3(planes, planes)
117 | self.bn1 = norm_layer(planes)
118 | self.bn2 = norm_layer(planes)
119 | self.relu = nn.ReLU(inplace=True)
120 | if stride == 1 and in_planes == planes:
121 | self.downsample = None
122 | else:
123 | self.bn3 = norm_layer(planes)
124 | self.downsample = nn.Sequential(
125 | conv1x1(in_planes, planes, stride=stride),
126 | self.bn3
127 | )
128 |
129 | def forward(self, x):
130 | y = x
131 | y = self.relu(self.bn1(self.conv1(y)))
132 | y = self.relu(self.bn2(self.conv2(y)))
133 | if self.downsample is not None:
134 | x = self.downsample(x)
135 | return self.relu(x+y)
--------------------------------------------------------------------------------
/sea_raft/core/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | # exclude extremly large displacements
7 | MAX_FLOW = 400
8 | SUM_FREQ = 100
9 | VAL_FREQ = 5000
10 |
11 | def sequence_loss(output, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
12 | """ Loss function defined over sequence of flow predictions """
13 | n_predictions = len(output['flow'])
14 | flow_loss = 0.0
15 | # exlude invalid pixels and extremely large diplacements
16 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
17 | valid = (valid >= 0.5) & (mag < max_flow)
18 | for i in range(n_predictions):
19 | i_weight = gamma ** (n_predictions - i - 1)
20 | loss_i = output['nf'][i]
21 | final_mask = (~torch.isnan(loss_i.detach())) & (~torch.isinf(loss_i.detach())) & valid[:, None]
22 | flow_loss += i_weight * ((final_mask * loss_i).sum() / final_mask.sum())
23 |
24 | return flow_loss
--------------------------------------------------------------------------------
/sea_raft/core/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .layer import ConvNextBlock
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256, output_dim=4):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class BasicMotionEncoder(nn.Module):
17 | def __init__(self, args, dim=128):
18 | super(BasicMotionEncoder, self).__init__()
19 | cor_planes = args.corr_channel
20 | self.convc1 = nn.Conv2d(cor_planes, dim*2, 1, padding=0)
21 | self.convc2 = nn.Conv2d(dim*2, dim+dim//2, 3, padding=1)
22 | self.convf1 = nn.Conv2d(2, dim, 7, padding=3)
23 | self.convf2 = nn.Conv2d(dim, dim//2, 3, padding=1)
24 | self.conv = nn.Conv2d(dim*2, dim-2, 3, padding=1)
25 |
26 | def forward(self, flow, corr):
27 | cor = F.relu(self.convc1(corr))
28 | cor = F.relu(self.convc2(cor))
29 | flo = F.relu(self.convf1(flow))
30 | flo = F.relu(self.convf2(flo))
31 |
32 | cor_flo = torch.cat([cor, flo], dim=1)
33 | out = F.relu(self.conv(cor_flo))
34 | return torch.cat([out, flow], dim=1)
35 |
36 | class BasicUpdateBlock(nn.Module):
37 | def __init__(self, args, hdim=128, cdim=128):
38 | #net: hdim, inp: cdim
39 | super(BasicUpdateBlock, self).__init__()
40 | self.args = args
41 | self.encoder = BasicMotionEncoder(args, dim=cdim)
42 | self.refine = []
43 | for i in range(args.num_blocks):
44 | self.refine.append(ConvNextBlock(2*cdim+hdim, hdim))
45 | self.refine = nn.ModuleList(self.refine)
46 |
47 | def forward(self, net, inp, corr, flow, upsample=True):
48 | motion_features = self.encoder(flow, corr)
49 | inp = torch.cat([inp, motion_features], dim=1)
50 | for blk in self.refine:
51 | net = blk(torch.cat([net, inp], dim=1))
52 | return net
--------------------------------------------------------------------------------
/sea_raft/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/core/utils/__init__.py
--------------------------------------------------------------------------------
/sea_raft/core/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | rad = np.sqrt(np.square(u) + np.square(v))
128 | rad_max = np.max(rad)
129 | epsilon = 1e-5
130 | u = u / (rad_max + epsilon)
131 | v = v / (rad_max + epsilon)
132 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/sea_raft/core/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 | import h5py
6 |
7 | import cv2
8 | cv2.setNumThreads(0)
9 | cv2.ocl.setUseOpenCL(False)
10 |
11 | TAG_CHAR = np.array([202021.25], np.float32)
12 |
13 | def readFlow(fn):
14 | """ Read .flo file in Middlebury format"""
15 | # Code adapted from:
16 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
17 |
18 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
19 | # print 'fn = %s'%(fn)
20 | with open(fn, 'rb') as f:
21 | magic = np.fromfile(f, np.float32, count=1)
22 | if 202021.25 != magic:
23 | print('Magic number incorrect. Invalid .flo file')
24 | return None
25 | else:
26 | w = np.fromfile(f, np.int32, count=1)
27 | h = np.fromfile(f, np.int32, count=1)
28 | # print 'Reading %d x %d flo file\n' % (w, h)
29 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
30 | # Reshape data into 3D array (columns, rows, bands)
31 | # The reshape here is for visualization, the original code is (w,h,2)
32 | return np.resize(data, (int(h), int(w), 2))
33 |
34 | def readPFM(file):
35 | file = open(file, 'rb')
36 |
37 | color = None
38 | width = None
39 | height = None
40 | scale = None
41 | endian = None
42 |
43 | header = file.readline().rstrip()
44 | if header == b'PF':
45 | color = True
46 | elif header == b'Pf':
47 | color = False
48 | else:
49 | raise Exception('Not a PFM file.')
50 |
51 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
52 | if dim_match:
53 | width, height = map(int, dim_match.groups())
54 | else:
55 | raise Exception('Malformed PFM header.')
56 |
57 | scale = float(file.readline().rstrip())
58 | if scale < 0: # little-endian
59 | endian = '<'
60 | scale = -scale
61 | else:
62 | endian = '>' # big-endian
63 |
64 | data = np.fromfile(file, endian + 'f')
65 | shape = (height, width, 3) if color else (height, width)
66 |
67 | data = np.reshape(data, shape)
68 | data = np.flipud(data)
69 | return data
70 |
71 | def writeFlow(filename,uv,v=None):
72 | """ Write optical flow to file.
73 |
74 | If v is None, uv is assumed to contain both u and v channels,
75 | stacked in depth.
76 | Original code by Deqing Sun, adapted from Daniel Scharstein.
77 | """
78 | nBands = 2
79 |
80 | if v is None:
81 | assert(uv.ndim == 3)
82 | assert(uv.shape[2] == 2)
83 | u = uv[:,:,0]
84 | v = uv[:,:,1]
85 | else:
86 | u = uv
87 |
88 | assert(u.shape == v.shape)
89 | height,width = u.shape
90 | f = open(filename,'wb')
91 | # write the header
92 | f.write(TAG_CHAR)
93 | np.array(width).astype(np.int32).tofile(f)
94 | np.array(height).astype(np.int32).tofile(f)
95 | # arrange into matrix form
96 | tmp = np.zeros((height, width*nBands))
97 | tmp[:,np.arange(width)*2] = u
98 | tmp[:,np.arange(width)*2 + 1] = v
99 | tmp.astype(np.float32).tofile(f)
100 | f.close()
101 |
102 |
103 | def readFlowKITTI(filename):
104 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
105 | flow = flow[:,:,::-1].astype(np.float32)
106 | flow, valid = flow[:, :, :2], flow[:, :, 2]
107 | flow = (flow - 2**15) / 64.0
108 | return flow, valid
109 |
110 | def readDispKITTI(filename):
111 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
112 | valid = disp > 0.0
113 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
114 | return flow, valid
115 |
116 |
117 | def writeFlowKITTI(filename, uv):
118 | uv = 64.0 * uv + 2**15
119 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
120 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
121 | cv2.imwrite(filename, uv[..., ::-1])
122 |
123 | def readFlo5Flow(filename):
124 | with h5py.File(filename, "r") as f:
125 | if "flow" not in f.keys():
126 | raise IOError(f"File {filename} does not have a 'flow' key. Is this a valid flo5 file?")
127 | return f["flow"][()]
128 |
129 | def writeFlo5File(flow, filename):
130 | with h5py.File(filename, "w") as f:
131 | f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
132 |
133 | def read_gen(file_name, pil=False):
134 | ext = splitext(file_name)[-1]
135 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
136 | return Image.open(file_name)
137 | elif ext == '.bin' or ext == '.raw':
138 | return np.load(file_name)
139 | elif ext == '.flo':
140 | return readFlow(file_name).astype(np.float32)
141 | elif ext == '.pfm':
142 | flow = readPFM(file_name).astype(np.float32)
143 | if len(flow.shape) == 2:
144 | return flow
145 | else:
146 | return flow[:, :, :-1]
147 | elif ext == '.flo5':
148 | return readFlo5Flow(file_name)
149 | return []
--------------------------------------------------------------------------------
/sea_raft/core/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 | def load_ckpt(model, path):
7 | """ Load checkpoint """
8 | state_dict = torch.load(path, map_location=torch.device('cpu'))
9 | model.load_state_dict(state_dict, strict=False)
10 |
11 | def resize_data(img1, img2, flow, factor=1.0):
12 | _, _, h, w = img1.shape
13 | h = int(h * factor)
14 | w = int(w * factor)
15 | img1 = F.interpolate(img1, (h, w), mode='area')
16 | img2 = F.interpolate(img2, (h, w), mode='area')
17 | flow = F.interpolate(flow, (h, w), mode='area') * factor
18 | return img1, img2, flow
19 |
20 | class InputPadder:
21 | """ Pads images such that dimensions are divisible by 8 """
22 | def __init__(self, dims, mode='sintel'):
23 | self.ht, self.wd = dims[-2:]
24 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
25 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
26 | if mode == 'sintel':
27 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
28 | else:
29 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
30 |
31 | def pad(self, *inputs):
32 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
33 |
34 | def unpad(self, x):
35 | ht, wd = x.shape[-2:]
36 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
37 | return x[..., c[0]:c[1], c[2]:c[3]]
38 |
39 | def forward_interpolate(flow):
40 | flow = flow.detach().cpu().numpy()
41 | dx, dy = flow[0], flow[1]
42 |
43 | ht, wd = dx.shape
44 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
45 |
46 | x1 = x0 + dx
47 | y1 = y0 + dy
48 |
49 | x1 = x1.reshape(-1)
50 | y1 = y1.reshape(-1)
51 | dx = dx.reshape(-1)
52 | dy = dy.reshape(-1)
53 |
54 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
55 | x1 = x1[valid]
56 | y1 = y1[valid]
57 | dx = dx[valid]
58 | dy = dy[valid]
59 |
60 | flow_x = interpolate.griddata(
61 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
62 |
63 | flow_y = interpolate.griddata(
64 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
65 |
66 | flow = np.stack([flow_x, flow_y], axis=0)
67 | return torch.from_numpy(flow).float()
68 |
69 |
70 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
71 | """ Wrapper for grid_sample, uses pixel coordinates """
72 | H, W = img.shape[-2:]
73 | xgrid, ygrid = coords.split([1,1], dim=-1)
74 | xgrid = 2*xgrid/(W-1) - 1
75 | ygrid = 2*ygrid/(H-1) - 1
76 |
77 | grid = torch.cat([xgrid, ygrid], dim=-1)
78 | img = F.grid_sample(img, grid, align_corners=True)
79 |
80 | if mask:
81 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
82 | return img, mask.float()
83 |
84 | return img
85 |
86 | def coords_grid(batch, ht, wd, device):
87 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
88 | coords = torch.stack(coords[::-1], dim=0).float()
89 | return coords[None].repeat(batch, 1, 1, 1)
90 |
91 |
92 | def upflow8(flow, mode='bilinear'):
93 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
94 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
95 |
96 | def transform(T, p):
97 | assert T.shape == (4,4)
98 | return np.einsum('H W j, i j -> H W i', p, T[:3,:3]) + T[:3, 3]
99 |
100 | def from_homog(x):
101 | return x[...,:-1] / x[...,[-1]]
102 |
103 | def reproject(depth1, pose1, pose2, K1, K2):
104 | H, W = depth1.shape
105 | x, y = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')
106 | img_1_coords = np.stack((x, y, np.ones_like(x)), axis=-1).astype(np.float64)
107 | cam1_coords = np.einsum('H W, H W j, i j -> H W i', depth1, img_1_coords, np.linalg.inv(K1))
108 | rel_pose = np.linalg.inv(pose2) @ pose1
109 | cam2_coords = transform(rel_pose, cam1_coords)
110 | return from_homog(np.einsum('H W j, i j -> H W i', cam2_coords, K2))
111 |
112 | def induced_flow(depth0, depth1, data):
113 | H, W = depth0.shape
114 | coords1 = reproject(depth0, data['T0'], data['T1'], data['K0'], data['K1'])
115 | x, y = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')
116 | coords0 = np.stack([x, y], axis=-1)
117 | flow_01 = coords1 - coords0
118 |
119 | H, W = depth1.shape
120 | coords1 = reproject(depth1, data['T1'], data['T0'], data['K1'], data['K0'])
121 | x, y = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')
122 | coords0 = np.stack([x, y], axis=-1)
123 | flow_10 = coords1 - coords0
124 |
125 | return flow_01, flow_10
126 |
127 | def check_cycle_consistency(flow_01, flow_10):
128 | flow_01 = torch.from_numpy(flow_01).permute(2, 0, 1)[None]
129 | flow_10 = torch.from_numpy(flow_10).permute(2, 0, 1)[None]
130 | H, W = flow_01.shape[-2:]
131 | coords = coords_grid(1, H, W, flow_01.device)
132 | coords1 = coords + flow_01
133 | flow_reprojected = bilinear_sampler(flow_10, coords1.permute(0, 2, 3, 1))
134 | cycle = flow_reprojected + flow_01
135 | cycle = torch.norm(cycle, dim=1)
136 | mask = (cycle < 0.1 * min(H, W)).float()
137 | return mask[0].numpy()
--------------------------------------------------------------------------------
/sea_raft/custom.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, List
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from .core.raft import RAFT
7 |
8 |
9 | def get_model_config():
10 | return {
11 | "name": "spring-M",
12 | "dataset": "spring",
13 | "gpus": [
14 | 0,
15 | 1,
16 | 2,
17 | 3,
18 | 4,
19 | 5,
20 | 6,
21 | 7
22 | ],
23 | "use_var": True,
24 | "var_min": 0,
25 | "var_max": 10,
26 | "pretrain": "resnet34",
27 | "initial_dim": 64,
28 | "block_dims": [
29 | 64,
30 | 128,
31 | 256
32 | ],
33 | "radius": 4,
34 | "dim": 128,
35 | "num_blocks": 2,
36 | "iters": 4,
37 | "image_size": [
38 | 540,
39 | 960
40 | ],
41 | "scale": -1,
42 | "batch_size": 32,
43 | "epsilon": 1e-8,
44 | "lr": 4e-4,
45 | "wdecay": 1e-5,
46 | "dropout": 0,
47 | "clip": 1.0,
48 | "gamma": 0.85,
49 | "num_steps": 120000,
50 | "restore_ckpt": None,
51 | "coarse_config": None
52 | }
53 |
54 |
55 | @dataclass
56 | class ModelArgs:
57 | name: str = None
58 | dataset: str = None
59 | gpus: List[int] = None
60 | use_var: bool = None
61 | var_min: int = None
62 | var_max: int = None
63 | pretrain: str = None
64 | initial_dim: int = None
65 | block_dims: List[int] = None
66 | radius: int = None
67 | dim: int = None
68 | num_blocks: int = None
69 | iters: int = None
70 | image_size: List[int] = None
71 | scale: int = None
72 | batch_size: int = None
73 | epsilon: float = None
74 | lr: float = None
75 | wdecay: float = None
76 | dropout: float = None
77 | clip: float = None
78 | gamma: float = None
79 | num_steps: int = None
80 | restore_ckpt: Any = None
81 | coarse_config: Any = None
82 |
83 |
84 | def forward_flow(args, model, image1, image2):
85 | output = model(image1, image2, iters=args.iters, test_mode=True)
86 | flow_final = output['flow'][-1]
87 | info_final = output['info'][-1]
88 | return flow_final, info_final
89 |
90 |
91 | @torch.no_grad
92 | def calc_flow(args, model, image1, image2):
93 | img1 = F.interpolate(image1, scale_factor=2 ** args.scale, mode='bilinear', align_corners=False)
94 | img2 = F.interpolate(image2, scale_factor=2 ** args.scale, mode='bilinear', align_corners=False)
95 | H, W = img1.shape[2:]
96 | flow, info = forward_flow(args, model, img1, img2)
97 | flow_down = F.interpolate(flow, scale_factor=0.5 ** args.scale, mode='bilinear', align_corners=False) * (0.5 ** args.scale)
98 | info_down = F.interpolate(info, scale_factor=0.5 ** args.scale, mode='area')
99 | return flow_down, info_down
100 |
101 |
102 | class RaftWrapper(torch.nn.Module):
103 | def __init__(self, model, model_args):
104 | super(RaftWrapper, self).__init__()
105 | self.model = model
106 | self.args = model_args
107 |
108 | def forward(self, image1, image2):
109 | flow_down, info_down = calc_flow(self.args, self.model, image1, image2)
110 | return flow_down, info_down
111 |
112 | def eval(self):
113 | self.model.eval()
114 |
115 |
116 | def get_model(checkpoint_path):
117 | model_args = ModelArgs(**get_model_config())
118 | model = RAFT(model_args)
119 | state_dict = torch.load(checkpoint_path)
120 | model.load_state_dict(state_dict)
121 |
122 | return RaftWrapper(model, model_args)
123 |
124 |
--------------------------------------------------------------------------------
/sea_raft/custom/flow.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/custom/flow.jpg
--------------------------------------------------------------------------------
/sea_raft/custom/heatmap.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/custom/heatmap.jpg
--------------------------------------------------------------------------------
/sea_raft/custom/image1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/custom/image1.jpg
--------------------------------------------------------------------------------
/sea_raft/custom/image2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/custom/image2.jpg
--------------------------------------------------------------------------------
/sea_raft/ddp_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from time import sleep
3 | import numpy as np
4 | import torch
5 | import torch.distributed as dist
6 | import torch.multiprocessing as mp
7 | from datetime import timedelta
8 | import random
9 |
10 | def init_fn(worker_id):
11 | np.random.seed(987)
12 | random.seed(987)
13 |
14 | def process_group_initialized():
15 | try:
16 | dist.get_world_size()
17 | return True
18 | except:
19 | return False
20 |
21 | def calc_num_workers():
22 | try:
23 | world_size = dist.get_world_size()
24 | except:
25 | world_size = 1
26 | return len(os.sched_getaffinity(0)) // world_size
27 |
28 | def setup_ddp(rank, world_size):
29 | dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
30 | torch.manual_seed(987)
31 | torch.cuda.set_device(rank)
32 |
33 | def init_ddp():
34 | os.environ['MASTER_ADDR'] = 'localhost'
35 | os.environ['MASTER_PORT'] = str(11451 + np.random.randint(100))
36 | world_size = torch.cuda.device_count()
37 | assert world_size > 0, "You need a GPU!"
38 | smp = mp.get_context('spawn')
39 | return smp, world_size
40 |
41 | def wait_for_world(state: mp.Queue, world_size):
42 | state.put(1)
43 | while state.qsize() < world_size:
44 | pass
45 | for _ in range(world_size):
46 | state.get()
--------------------------------------------------------------------------------
/sea_raft/eval_ptlflow.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 | import argparse
4 | import os
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.data as data
11 |
12 | import datasets
13 | from raft import RAFT
14 | from tqdm import tqdm
15 |
16 | from utils import flow_viz
17 | from utils import frame_utils
18 | from utils.utils import resize_data, load_ckpt
19 |
20 | import ptlflow
21 | from ptlflow.utils import flow_utils
22 |
23 | def forward_flow(model, image1, image2, scale=0, mode='downsample'):
24 | if mode == 'downsample':
25 | dlt = 0 # avoid edge effects
26 | image1 = image1 / 255.
27 | image2 = image2 / 255.
28 | img1 = F.interpolate(image1, scale_factor=2 ** scale, mode='bilinear', align_corners=False)
29 | img2 = F.interpolate(image2, scale_factor=2 ** scale, mode='bilinear', align_corners=False)
30 | img1 = F.pad(img1, (dlt, dlt, dlt, dlt), "constant", 0)
31 | img2 = F.pad(img2, (dlt, dlt, dlt, dlt), "constant", 0)
32 | H, W = img1.shape[2:]
33 | inputs = {"images": torch.stack([img1, img2], dim=1)}
34 | predictions = model(inputs)
35 | flow = predictions['flows'][:, 0]
36 | flow = flow[..., dlt: H-dlt, dlt: W-dlt]
37 | flow = F.interpolate(flow, scale_factor=0.5 ** scale, mode='bilinear', align_corners=False) * (0.5 ** scale)
38 | else:
39 | raise NotImplementedError
40 | return flow
41 |
42 | @torch.no_grad()
43 | def validate_spring(model, mode='downsample'):
44 | """ Peform validation using the Spring (val) split """
45 | val_dataset = datasets.SpringFlowDataset(split='val') + datasets.SpringFlowDataset(split='train')
46 | val_loader = data.DataLoader(val_dataset, batch_size=4,
47 | pin_memory=False, shuffle=False, num_workers=16, drop_last=False)
48 |
49 | epe_list = np.array([], dtype=np.float32)
50 | px1_list = np.array([], dtype=np.float32)
51 | px3_list = np.array([], dtype=np.float32)
52 | px5_list = np.array([], dtype=np.float32)
53 | for i_batch, data_blob in enumerate(val_loader):
54 | image1, image2, flow_gt, valid = [x.cuda(non_blocking=True) for x in data_blob]
55 | flow = forward_flow(model, image1, image2, scale=-1, mode=mode)
56 | epe = torch.sum((flow - flow_gt)**2, dim=1).sqrt()
57 | px1 = (epe < 1.0).float().mean(dim=[1, 2]).cpu().numpy()
58 | px3 = (epe < 3.0).float().mean(dim=[1, 2]).cpu().numpy()
59 | px5 = (epe < 5.0).float().mean(dim=[1, 2]).cpu().numpy()
60 | epe = epe.mean(dim=[1, 2]).cpu().numpy()
61 | epe_list = np.append(epe_list, epe)
62 | px1_list = np.append(px1_list, px1)
63 | px3_list = np.append(px3_list, px3)
64 | px5_list = np.append(px5_list, px5)
65 |
66 | epe = np.mean(epe_list)
67 | px1 = np.mean(px1_list)
68 | px3 = np.mean(px3_list)
69 | px5 = np.mean(px5_list)
70 |
71 | print(f"Validation Spring EPE: {epe}, 1px: {100 * (1 - px1)}")
72 |
73 | @torch.no_grad()
74 | def validate_middlebury(model, mode='downsample'):
75 | """ Peform validation using the Middlebury (public) split """
76 | val_dataset = datasets.Middlebury()
77 | val_loader = data.DataLoader(val_dataset, batch_size=1,
78 | pin_memory=False, shuffle=False, num_workers=16, drop_last=False)
79 | epe_list = np.array([], dtype=np.float32)
80 | num_valid_pixels = 0
81 | out_valid_pixels = 0
82 | for i_batch, data_blob in enumerate(val_loader):
83 | image1, image2, flow_gt, valid_gt = [x.cuda(non_blocking=True) for x in data_blob]
84 | flow = forward_flow(model, image1, image2, scale=0, mode=mode)
85 | epe = torch.sum((flow - flow_gt)**2, dim=1).sqrt()
86 | mag = torch.sum(flow_gt**2, dim=1).sqrt()
87 | val = valid_gt >= 0.5
88 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
89 | for b in range(out.shape[0]):
90 | epe_list = np.append(epe_list, epe[b][val[b]].mean().cpu().numpy())
91 | out_valid_pixels += out[b][val[b]].sum().cpu().numpy()
92 | num_valid_pixels += val[b].sum().cpu().numpy()
93 |
94 | epe = np.mean(epe_list)
95 | f1 = 100 * out_valid_pixels / num_valid_pixels
96 | print("Validation middlebury: %f, %f" % (epe, f1))
97 |
98 | def eval(args):
99 |
100 | # Get an initialized model from PTLFlow
101 | device = torch.device('cuda')
102 | model = ptlflow.get_model(args.model, 'mixed').to(device)
103 | if "use_tile_input" in model.args:
104 | model.args.use_tile_input = False
105 | model.eval()
106 | print(args.model)
107 | with torch.no_grad():
108 | try:
109 | validate_middlebury(model, mode='downsample')
110 | except:
111 | print('Middlebury validation failed')
112 | validate_spring(model, mode='downsample')
113 |
114 | def main():
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument('--model', help='experiment configure file name', required=True, type=str)
117 | args = parser.parse_args()
118 | eval(args)
119 |
120 | if __name__ == '__main__':
121 | main()
--------------------------------------------------------------------------------
/sea_raft/profile_ptlflow.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 | import argparse
4 | import os
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.data as data
11 |
12 | import datasets
13 | from raft import RAFT
14 | from tqdm import tqdm
15 |
16 | from utils import flow_viz
17 | from utils import frame_utils
18 | from utils.profile import profile_model
19 | from utils.utils import resize_data, load_ckpt
20 |
21 | import ptlflow
22 | from ptlflow.utils import flow_utils
23 |
24 | @torch.no_grad()
25 | def eval(args):
26 | # Get an initialized model from PTLFlow
27 | model = ptlflow.get_model(args.model, 'mixed').cuda()
28 | if "use_tile_input" in model.args:
29 | model.args.use_tile_input = False
30 | model.eval()
31 | h, w = 540, 960
32 | inputs = {"images": torch.zeros(1, 2, 3, h, w).cuda()}
33 | with torch.profiler.profile(
34 | activities=[
35 | torch.profiler.ProfilerActivity.CUDA,
36 | torch.profiler.ProfilerActivity.CPU
37 | ],
38 | with_flops=True) as prof:
39 | output = model(inputs)
40 | events = prof.events()
41 | print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total', row_limit=5))
42 | forward_MACs = sum([int(evt.flops) for evt in events])
43 | print("forward MACs: ", forward_MACs / 2 / 1e9, "G")
44 |
45 | def main():
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('--model', help='experiment configure file name', required=True, type=str)
48 | args = parser.parse_args()
49 | eval(args)
50 |
51 | if __name__ == '__main__':
52 | main()
--------------------------------------------------------------------------------
/sea_raft/profiler.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 | import argparse
4 | import torch
5 | from config.parser import parse_args
6 | from raft import RAFT
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str)
11 | args = parse_args(parser)
12 | model = RAFT(args)
13 | model.eval()
14 | h, w = [540, 960]
15 | input = torch.zeros(1, 3, h, w)
16 | model = model.cuda()
17 | input = input.cuda()
18 | with torch.profiler.profile(
19 | activities=[
20 | torch.profiler.ProfilerActivity.CPU,
21 | torch.profiler.ProfilerActivity.CUDA
22 | ],
23 | with_flops=True) as prof:
24 | output = model(input, input, iters=args.iters, test_mode=True)
25 |
26 | print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total', row_limit=5))
27 | events = prof.events()
28 | forward_MACs = sum([int(evt.flops) for evt in events])
29 | print("forward MACs: ", forward_MACs / 2 / 1e9, "G")
30 |
31 | if __name__ == '__main__':
32 | main()
--------------------------------------------------------------------------------
/sea_raft/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | torchaudio
4 | numpy
5 | matplotlib
6 | scipy
7 | opencv-python
8 | tensorboard
9 | h5py
10 | tqdm
11 | einops
--------------------------------------------------------------------------------
/sea_raft/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | # evaluate on sintel
2 | python evaluate.py --cfg config/eval/sintel-M.json --model models/Tartan-C-T-TSKH432x960-M.pth
3 | # evaluate on kitti
4 | python evaluate.py --cfg config/eval/kitti-M.json --model models/Tartan-C-T-TSKH-kitti432x960-M.pth
5 | # evaluate on spring
6 | python evaluate.py --cfg config/eval/spring-M.json --model models/Tartan-C-T-TSKH-spring540x960-M.pth
--------------------------------------------------------------------------------
/sea_raft/scripts/submission.sh:
--------------------------------------------------------------------------------
1 | # submit to sintel public leaderboard
2 | python submission.py --cfg config/eval/sintel-M.json --model models/Tartan-C-T-TSKH432x960-M.pth
3 | # submit to kitti public leaderboard
4 | python submission.py --cfg config/eval/kitti-M.json --model models/Tartan-C-T-TSKH-kitti432x960-M.pth
5 | # submit to spring public leaderboard
6 | python submission.py --cfg config/eval/spring-M.json --model models/Tartan-C-T-TSKH-spring540x960-M.pth
--------------------------------------------------------------------------------
/sea_raft/scripts/train.sh:
--------------------------------------------------------------------------------
1 | # The training scripts are tested on 8 NVIDIA L40 GPUs
2 |
3 | # Stage0: Stereo pretraining on TartanAir
4 | python -u train.py --cfg config/train/Tartan480x640-M.json
5 | # Stage1: FlyingChairs, please update your ckpts path (from Stage0 results) in the config file (restore_ckpt)
6 | python -u train.py --cfg config/train/Tartan-C368x496-M.json
7 | # Stage2: FlyingThings3D, please update your ckpts path (from Stage1 results) in the config file (restore_ckpt)
8 | python -u train.py --cfg config/train/Tartan-C-T432x960-M.json
9 | # Stage3: FlyingThings + Sintel + KITTI + HD1K, please update your ckpts path (from Stage2 results) in the config file (restore_ckpt)
10 | # The ckpts from this stage are used for sintel submission
11 | python -u train.py --cfg config/train/Tartan-C-T-TSKH432x960-M.json
12 |
13 | # Stage4 (finetune for KITTI submission): KITTI, please update your ckpts path (from Stage3 results) in the config file (restore_ckpt)
14 | python -u train.py --cfg config/train/Tartan-C-T-TSKH-kitti432x960-M.json
15 |
16 | # Stage5 (finetune for Spring submission): Spring, please update your ckpts path (from Stage3 results) in the config file (restore_ckpt
17 | python -u train.py --cfg config/train/Tartan-C-T-TSKH-spring540x960-M.json
18 |
--------------------------------------------------------------------------------
/sea_raft/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('core')
3 |
4 | import argparse
5 | import numpy as np
6 |
7 | from config.parser import parse_args
8 |
9 | import torch
10 | import torch.optim as optim
11 |
12 | from raft import RAFT
13 | from datasets import fetch_dataloader
14 | from utils.utils import load_ckpt
15 | from loss import sequence_loss
16 | from ddp_utils import *
17 |
18 | os.system("export KMP_INIT_AT_FORK=FALSE")
19 |
20 | def fetch_optimizer(args, model):
21 | """ Create the optimizer and learning rate scheduler """
22 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
23 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps + 100,
24 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')
25 |
26 | return optimizer, scheduler
27 |
28 | def train(args, rank=0, world_size=1, use_ddp=False):
29 | """ Full training loop """
30 | device_id = rank
31 | model = RAFT(args).to(device_id)
32 | if args.restore_ckpt is not None:
33 | load_ckpt(model, args.restore_ckpt)
34 | print(f"restore ckpt from {args.restore_ckpt}")
35 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # there might not be any, actually
36 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], static_graph=True)
37 |
38 | model.train()
39 | train_loader = fetch_dataloader(args, rank=rank, world_size=world_size, use_ddp=use_ddp)
40 | optimizer, scheduler = fetch_optimizer(args, model)
41 | total_steps = 0
42 | VAL_FREQ = 10000
43 | epoch = 0
44 | should_keep_training = True
45 | # torch.autograd.set_detect_anomaly(True)
46 | while should_keep_training:
47 | # shuffle sampler
48 | train_loader.sampler.set_epoch(epoch)
49 | epoch += 1
50 | for i_batch, data_blob in enumerate(train_loader):
51 | optimizer.zero_grad()
52 | image1, image2, flow, valid = [x.cuda(non_blocking=True) for x in data_blob]
53 | output = model(image1, image2, flow_gt=flow, iters=args.iters)
54 | loss = sequence_loss(output, flow, valid, args.gamma)
55 | loss.backward()
56 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
57 | optimizer.step()
58 | scheduler.step()
59 |
60 | if total_steps % VAL_FREQ == VAL_FREQ - 1 and rank == 0:
61 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name)
62 | torch.save(model.module.state_dict(), PATH)
63 |
64 | if total_steps > args.num_steps:
65 | should_keep_training = False
66 | break
67 |
68 | total_steps += 1
69 |
70 | PATH = 'checkpoints/%s.pth' % args.name
71 | if rank == 0:
72 | torch.save(model.module.state_dict(), PATH)
73 |
74 | return PATH
75 |
76 | def main(rank, world_size, args, use_ddp):
77 |
78 | if use_ddp:
79 | print(f"Using DDP [{rank=} {world_size=}]")
80 | setup_ddp(rank, world_size)
81 |
82 | train(args, rank=rank, world_size=world_size, use_ddp=use_ddp)
83 |
84 | if __name__ == '__main__':
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str)
87 | args = parse_args(parser)
88 | args.name += f"_exp{str(np.random.randint(100))}"
89 | smp, world_size = init_ddp()
90 | if world_size > 1:
91 | spwn_ctx = mp.spawn(main, nprocs=world_size, args=(world_size, args, True), join=False)
92 | spwn_ctx.join()
93 | else:
94 | main(0, 1, args, False)
95 | print("Done!")
--------------------------------------------------------------------------------
/unimatch/DATASETS.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 |
4 |
5 | ## Optical Flow
6 |
7 | The datasets used to train and evaluate our GMFlow model are as follows:
8 |
9 | - [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
10 | - [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
11 | - [Sintel](http://sintel.is.tue.mpg.de/)
12 | - [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/)
13 | - [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow)
14 | - [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/)
15 |
16 | By default the dataloader [dataloader/flow/datasets.py](dataloader/flow/datasets.py) assumes the datasets are located in the `datasets` directory.
17 |
18 | It is recommended to symlink your dataset root to `datasets`:
19 |
20 | ```
21 | ln -s $YOUR_DATASET_ROOT datasets
22 | ```
23 |
24 | Otherwise, you may need to change the corresponding paths in [dataloader/flow/datasets.py](dataloader/flow/datasets.py).
25 |
26 |
27 |
28 | ## Stereo Matching
29 |
30 | The datasets used to train and evaluate our GMStereo model are as follows:
31 |
32 | - [Scene Flow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html)
33 | - [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/)
34 | - [KITTI](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo)
35 | - [TartanAir](https://github.com/castacks/tartanair_tools)
36 | - [Falling Things](https://research.nvidia.com/publication/2018-06_Falling-Things)
37 | - [HR-VS](https://drive.google.com/file/d/1SgEIrH_IQTKJOToUwR1rx4-237sThUqX/view)
38 | - [CREStereo Dataset](https://github.com/megvii-research/CREStereo/blob/master/dataset_download.sh)
39 | - [InStereo2K](https://github.com/YuhuaXu/StereoDataset)
40 | - [Middlebury](https://vision.middlebury.edu/stereo/data/)
41 | - [Sintel Stereo](http://sintel.is.tue.mpg.de/stereo)
42 | - [ETH3D](https://www.eth3d.net/datasets#low-res-two-view-training-data)
43 |
44 | By default the dataloader [dataloader/stereo/datasets.py](dataloader/stereo/datasets.py) assumes the datasets are located in the `datasets` directory.
45 |
46 | It is recommended to symlink your dataset root to `datasets`:
47 |
48 | ```
49 | ln -s $YOUR_DATASET_ROOT datasets
50 | ```
51 |
52 | Otherwise, you may need to change the corresponding paths in [dataloader/stereo/datasets.py](dataloader/flow/datasets.py).
53 |
54 |
55 |
56 | ## Depth Estimation
57 |
58 | The datasets used to train and evaluate our GMDepth model are as follows:
59 |
60 | - [DeMoN](https://github.com/lmb-freiburg/demon)
61 | - [ScanNet](http://www.scan-net.org/)
62 |
63 | We support downloading and extracting the DeMoN dataset in our code: [dataloader/depth/download_demon_train.sh](dataloader/depth/download_demon_train.sh), [dataloader/depth/download_demon_test.sh](dataloader/depth/download_demon_test.sh), [dataloader/depth/prepare_demon_train.sh](dataloader/depth/prepare_demon_train.sh) and [dataloader/depth/prepare_demon_test.sh](dataloader/depth/prepare_demon_test.sh).
64 |
65 | By default the dataloader [dataloader/depth/datasets.py](dataloader/depth/datasets.py) assumes the datasets are located in the `datasets` directory.
66 |
67 | It is recommended to symlink your dataset root to `datasets`:
68 |
69 | ```
70 | ln -s $YOUR_DATASET_ROOT datasets
71 | ```
72 |
73 | Otherwise, you may need to change the corresponding paths in [dataloader/depth/datasets.py](dataloader/depth/datasets.py).
74 |
75 |
--------------------------------------------------------------------------------
/unimatch/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 autonomousvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/unimatch/conda_environment.yml:
--------------------------------------------------------------------------------
1 | name: unimatch
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1
7 | - _openmp_mutex=5.1
8 | - blas=1.0
9 | - brotli=1.0.9
10 | - brotli-bin=1.0.9
11 | - bzip2=1.0.8
12 | - ca-certificates=2022.10.11
13 | - certifi=2022.9.24
14 | - cloudpickle=2.0.0
15 | - cudatoolkit=10.2.89
16 | - cycler=0.11.0
17 | - cytoolz=0.12.0
18 | - dask-core=2022.7.0
19 | - dbus=1.13.18
20 | - expat=2.4.9
21 | - ffmpeg=4.3
22 | - fftw=3.3.9
23 | - fontconfig=2.13.1
24 | - fonttools=4.25.0
25 | - freetype=2.12.1
26 | - fsspec=2022.10.0
27 | - giflib=5.2.1
28 | - glib=2.69.1
29 | - gmp=6.2.1
30 | - gnutls=3.6.15
31 | - gst-plugins-base=1.14.0
32 | - gstreamer=1.14.0
33 | - icu=58.2
34 | - imageio=2.9.0
35 | - intel-openmp=2021.4.0
36 | - jpeg=9b
37 | - kiwisolver=1.4.2
38 | - lame=3.100
39 | - lcms2=2.12
40 | - ld_impl_linux-64=2.38
41 | - libbrotlicommon=1.0.9
42 | - libbrotlidec=1.0.9
43 | - libbrotlienc=1.0.9
44 | - libffi=3.3
45 | - libgcc-ng=11.2.0
46 | - libgfortran-ng=11.2.0
47 | - libgfortran5=11.2.0
48 | - libgomp=11.2.0
49 | - libiconv=1.16
50 | - libidn2=2.3.2
51 | - libpng=1.6.37
52 | - libstdcxx-ng=11.2.0
53 | - libtasn1=4.16.0
54 | - libtiff=4.1.0
55 | - libunistring=0.9.10
56 | - libuuid=1.41.5
57 | - libuv=1.40.0
58 | - libwebp=1.2.0
59 | - libxcb=1.15
60 | - libxml2=2.9.14
61 | - locket=1.0.0
62 | - lz4-c=1.9.3
63 | - matplotlib=3.5.1
64 | - matplotlib-base=3.5.1
65 | - mkl=2021.4.0
66 | - mkl-service=2.4.0
67 | - mkl_fft=1.3.1
68 | - mkl_random=1.2.2
69 | - munkres=1.1.4
70 | - ncurses=6.3
71 | - nettle=3.7.3
72 | - networkx=2.8.4
73 | - ninja=1.10.2
74 | - ninja-base=1.10.2
75 | - numpy=1.19.2
76 | - numpy-base=1.19.2
77 | - openh264=2.1.1
78 | - openssl=1.1.1s
79 | - packaging=21.3
80 | - partd=1.2.0
81 | - pcre=8.45
82 | - pillow=9.0.1
83 | - pip=22.2.2
84 | - pyparsing=3.0.9
85 | - pyqt=5.9.2
86 | - python=3.8.13
87 | - python-dateutil=2.8.2
88 | - pytorch=1.9.0
89 | - pywavelets=1.3.0
90 | - pyyaml=6.0
91 | - qt=5.9.7
92 | - readline=8.2
93 | - scikit-image=0.19.2
94 | - scipy=1.9.3
95 | - sip=4.19.13
96 | - six=1.16.0
97 | - sqlite=3.39.3
98 | - tifffile=2020.10.1
99 | - tk=8.6.12
100 | - toolz=0.12.0
101 | - torchvision=0.10.0
102 | - tornado=6.2
103 | - typing_extensions=4.3.0
104 | - wheel=0.37.1
105 | - xz=5.2.6
106 | - yaml=0.2.5
107 | - zlib=1.2.13
108 | - zstd=1.4.9
109 | - pip:
110 | - absl-py==1.3.0
111 | - cachetools==5.2.0
112 | - charset-normalizer==2.1.1
113 | - google-auth==2.14.1
114 | - google-auth-oauthlib==0.4.6
115 | - grpcio==1.50.0
116 | - h5py==3.7.0
117 | - idna==3.4
118 | - imageio-ffmpeg==0.4.7
119 | - importlib-metadata==5.0.0
120 | - joblib==1.2.0
121 | - lz4==4.0.2
122 | - markdown==3.4.1
123 | - markupsafe==2.1.1
124 | - oauthlib==3.2.2
125 | - opencv-python==4.6.0.66
126 | - path==16.5.0
127 | - protobuf==3.19.6
128 | - pyasn1==0.4.8
129 | - pyasn1-modules==0.2.8
130 | - requests==2.28.1
131 | - requests-oauthlib==1.3.1
132 | - rsa==4.9
133 | - setuptools==59.5.0
134 | - tensorboard==2.9.1
135 | - tensorboard-data-server==0.6.1
136 | - tensorboard-plugin-wit==1.8.1
137 | - tqdm==4.64.1
138 | - urllib3==1.26.12
139 | - werkzeug==2.2.2
140 | - zipp==3.10.0
141 |
--------------------------------------------------------------------------------
/unimatch/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/dataloader/__init__.py
--------------------------------------------------------------------------------
/unimatch/dataloader/depth/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/dataloader/depth/__init__.py
--------------------------------------------------------------------------------
/unimatch/dataloader/depth/download_demon_test.sh:
--------------------------------------------------------------------------------
1 | # Source from https://github.com/lmb-freiburg/demon
2 | #!/bin/bash
3 | clear
4 | cat << EOF
5 | ================================================================================
6 | The test datasets are provided for research purposes only.
7 | Some of the test datasets build upon other publicly available data.
8 | Make sure to cite the respective original source of the data if you use the
9 | provided files for your research.
10 | * sun3d_test.h5 is based on the SUN3D dataset http://sun3d.cs.princeton.edu/
11 | J. Xiao, A. Owens, and A. Torralba, “SUN3D: A Database of Big Spaces Reconstructed Using SfM and Object Labels,” in 2013 IEEE International Conference on Computer Vision (ICCV), 2013, pp. 1625–1632.
12 |
13 | * rgbd_test.h5 is based on the RGBD SLAM benchmark http://vision.in.tum.de/data/datasets/rgbd-dataset (licensed under CC-BY 3.0)
14 |
15 | J. Sturm, N. Engelhard, F. Endres, W. Burgard, and D. Cremers, “A benchmark for the evaluation of RGB-D SLAM systems,” in 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, 2012, pp. 573–580.
16 | * scenes11_test.h5 uses objects from shapenet https://www.shapenet.org/
17 |
18 | A. X. Chang et al., “ShapeNet: An Information-Rich 3D Model Repository,” arXiv:1512.03012 [cs], Dec. 2015.
19 | * mvs_test.h5 contains scenes from https://colmap.github.io/datasets.html
20 |
21 | J. L. Schönberger and J. M. Frahm, “Structure-from-Motion Revisited,” in 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 4104–4113.
22 | J. L. Schönberger, E. Zheng, J.-M. Frahm, and M. Pollefeys, “Pixelwise View Selection for Unstructured Multi-View Stereo,” in Computer Vision – ECCV 2016, 2016, pp. 501–518.
23 | * nyu2_test.h5 is based on http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html
24 |
25 | N. Silberman, D. Hoiem, P. Kohli, and R. Fergus, “Indoor Segmentation and Support Inference from RGBD Images,” in Computer Vision – ECCV 2012, 2012, pp. 746–760.
26 | ================================================================================
27 | type Y to start the download.
28 | EOF
29 |
30 | read -s -n 1 answer
31 | if [ "$answer" != "Y" -a "$answer" != "y" ]; then
32 | exit 0
33 | fi
34 | echo
35 |
36 | datasets=(sun3d rgbd mvs scenes11)
37 |
38 | OLD_PWD="$PWD"
39 | DESTINATION=testdata
40 | mkdir $DESTINATION
41 | cd $DESTINATION
42 |
43 | for ds in ${datasets[@]}; do
44 | if [ -e "${ds}_test.h5" ]; then
45 | echo "${ds}_test.h5 already exists, skipping ${ds}"
46 | else
47 | wget "https://lmb.informatik.uni-freiburg.de/data/demon/testdata/${ds}_test.tgz"
48 | tar -xvf "${ds}_test.tgz"
49 | fi
50 | done
51 |
52 | cd "$OLD_PWD"
--------------------------------------------------------------------------------
/unimatch/dataloader/depth/download_demon_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | clear
3 | cat << EOF
4 |
5 | ================================================================================
6 |
7 |
8 | The train datasets are provided for research purposes only.
9 |
10 | Some of the test datasets build upon other publicly available data.
11 | Make sure to cite the respective original source of the data if you use the
12 | provided files for your research.
13 |
14 | * sun3d_train.h5 is based on the SUN3D dataset http://sun3d.cs.princeton.edu/
15 |
16 | J. Xiao, A. Owens, and A. Torralba, “SUN3D: A Database of Big Spaces Reconstructed Using SfM and Object Labels,” in 2013 IEEE International Conference on Computer Vision (ICCV), 2013, pp. 1625–1632.
17 |
18 |
19 |
20 |
21 | * rgbd_bugfix_train.h5 is based on the RGBD SLAM benchmark http://vision.in.tum.de/data/datasets/rgbd-dataset (licensed under CC-BY 3.0)
22 |
23 | J. Sturm, N. Engelhard, F. Endres, W. Burgard, and D. Cremers, “A benchmark for the evaluation of RGB-D SLAM systems,” in 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, 2012, pp. 573–580.
24 |
25 |
26 |
27 | * scenes11_train.h5 uses objects from shapenet https://www.shapenet.org/
28 |
29 | A. X. Chang et al., “ShapeNet: An Information-Rich 3D Model Repository,” arXiv:1512.03012 [cs], Dec. 2015.
30 |
31 |
32 |
33 | * mvs_train.h5 contains the Citywall and Achteck-Turm scenes from MVE (Multi-View Environment) http://www.gcc.tu-darmstadt.de/home/proj/mve/
34 |
35 | S. Fuhrmann, F. Langguth, and M. Goesele, “MVE: A Multi-view Reconstruction Environment,” in Proceedings of the Eurographics Workshop on Graphics and Cultural Heritage, Aire-la-Ville, Switzerland, Switzerland, 2014, pp. 11–18.
36 |
37 |
38 |
39 | ================================================================================
40 |
41 | type Y to start the download.
42 |
43 | EOF
44 |
45 | read -s -n 1 answer
46 | if [ "$answer" != "Y" -a "$answer" != "y" ]; then
47 | exit 0
48 | fi
49 | echo
50 |
51 | datasets=(sun3d rgbd mvs scenes11)
52 |
53 | OLD_PWD="$PWD"
54 | DESTINATION=traindata
55 | mkdir $DESTINATION
56 | cd $DESTINATION
57 |
58 | if [ ! -e "README_traindata" ]; then
59 | wget --no-check-certificate "https://lmb.informatik.uni-freiburg.de/data/demon/traindata/README_traindata"
60 | fi
61 |
62 | for ds in ${datasets[@]}; do
63 | if [ -e "${ds}_train.h5" ]; then
64 | echo "${ds}_train.h5 already exists, skipping ${ds}"
65 | else
66 | wget --no-check-certificate "https://lmb.informatik.uni-freiburg.de/data/demon/traindata/${ds}_train.tgz"
67 | tar -xvf "${ds}_train.tgz"
68 | fi
69 | done
70 |
71 | cd "$OLD_PWD"
--------------------------------------------------------------------------------
/unimatch/dataloader/depth/prepare_demon_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import subprocess
4 |
5 | from joblib import Parallel, delayed
6 | import numpy as np
7 | import imageio
8 |
9 | imageio.plugins.freeimage.download()
10 | from imageio.plugins import freeimage
11 | import h5py
12 | from lz4.block import decompress
13 | import scipy.misc
14 |
15 | import cv2
16 |
17 | from path import Path
18 |
19 | path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
20 |
21 |
22 | def dump_example(dataset_name):
23 | print("Converting {:}.h5 ...".format(dataset_name))
24 | file = h5py.File(os.path.join(path, "testdata", "{:}.h5".format(dataset_name)), "r")
25 |
26 | for (seq_idx, seq_name) in enumerate(file):
27 | if dataset_name == 'scenes11_test':
28 | scale = 0.4
29 | else:
30 | scale = 1
31 |
32 | print("Processing sequence {:d}/{:d}".format(seq_idx, len(file)))
33 | dump_dir = os.path.join(path, 'test', dataset_name + "_" + "{:05d}".format(seq_idx))
34 | if not os.path.isdir(dump_dir):
35 | os.mkdir(dump_dir)
36 | dump_dir = Path(dump_dir)
37 | sequence = file[seq_name]["frames"]["t0"]
38 | poses = []
39 | for (f_idx, f_name) in enumerate(sequence):
40 | frame = sequence[f_name]
41 | for dt_type in frame:
42 | dataset = frame[dt_type]
43 | img = dataset[...]
44 | if dt_type == "camera":
45 | if f_idx == 0:
46 | intrinsics = np.array([[img[0], 0, img[3]], [0, img[1], img[4]], [0, 0, 1]])
47 | pose = np.array(
48 | [[img[5], img[8], img[11], img[14] * scale], [img[6], img[9], img[12], img[15] * scale],
49 | [img[7], img[10], img[13], img[16] * scale]])
50 | poses.append(pose.tolist())
51 | elif dt_type == "depth":
52 | dimension = dataset.attrs["extents"]
53 | depth = np.array(np.frombuffer(decompress(img.tobytes(), dimension[0] * dimension[1] * 2),
54 | dtype=np.float16)).astype(np.float32)
55 | depth = depth.reshape(dimension[0], dimension[1]) * scale
56 |
57 | dump_depth_file = dump_dir / '{:04d}.npy'.format(f_idx)
58 | np.save(dump_depth_file, depth)
59 | elif dt_type == "image":
60 | img = imageio.imread(img.tobytes())
61 | dump_img_file = dump_dir / '{:04d}.jpg'.format(f_idx)
62 | imageio.imsave(dump_img_file, img)
63 |
64 | dump_cam_file = dump_dir / 'cam.txt'
65 | np.savetxt(dump_cam_file, intrinsics)
66 | poses_file = dump_dir / 'poses.txt'
67 | np.savetxt(poses_file, np.array(poses).reshape(-1, 12), fmt='%.6e')
68 |
69 | if len(dump_dir.files('*.jpg')) < 2:
70 | dump_dir.rmtree()
71 |
72 |
73 | def preparedata():
74 | num_threads = 1
75 | SUB_DATASET_NAMES = (["rgbd_test", "scenes11_test", "sun3d_test"])
76 |
77 | dump_root = os.path.join(path, 'test')
78 | if not os.path.isdir(dump_root):
79 | os.mkdir(dump_root)
80 |
81 | if num_threads == 1:
82 | for scene in SUB_DATASET_NAMES:
83 | dump_example(scene)
84 | else:
85 | Parallel(n_jobs=num_threads)(delayed(dump_example)(scene) for scene in SUB_DATASET_NAMES)
86 |
87 | dump_root = Path(dump_root)
88 | subdirs = dump_root.dirs()
89 | subdirs = [subdir.basename() for subdir in subdirs]
90 | subdirs = sorted(subdirs)
91 | with open(dump_root / 'test.txt', 'w') as tf:
92 | for subdir in subdirs:
93 | tf.write('{}\n'.format(subdir))
94 |
95 | print("Finished Converting Data.")
96 |
97 |
98 | if __name__ == "__main__":
99 | preparedata()
100 |
--------------------------------------------------------------------------------
/unimatch/dataloader/depth/prepare_demon_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import subprocess
4 |
5 | from joblib import Parallel, delayed
6 | import numpy as np
7 | import imageio
8 |
9 | imageio.plugins.freeimage.download()
10 | from imageio.plugins import freeimage
11 | import h5py
12 | from lz4.block import decompress
13 | import scipy.misc
14 | import cv2
15 |
16 | from path import Path
17 |
18 | path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
19 |
20 |
21 | def dump_example(dataset_name):
22 | print("Converting {:}.h5 ...".format(dataset_name))
23 | file = h5py.File(os.path.join(path, "traindata", "{:}.h5".format(dataset_name)), "r")
24 |
25 | for (seq_idx, seq_name) in enumerate(file):
26 | if dataset_name == 'scenes11_train':
27 | scale = 0.4
28 | else:
29 | scale = 1
30 |
31 | if ((dataset_name == 'sun3d_train_1.6m_to_infm' and seq_idx == 7) or \
32 | (dataset_name == 'sun3d_train_0.4m_to_0.8m' and seq_idx == 15) or \
33 | (dataset_name == 'scenes11_train' and (
34 | seq_idx == 2758 or seq_idx == 4691 or seq_idx == 7023 or seq_idx == 11157 or seq_idx == 17168 or seq_idx == 19595))):
35 | continue # Skip error files
36 |
37 | print("Processing sequence {:d}/{:d}".format(seq_idx, len(file)))
38 | dump_dir = os.path.join(path, '../train', dataset_name + "_" + "{:05d}".format(seq_idx))
39 | if not os.path.isdir(dump_dir):
40 | os.mkdir(dump_dir)
41 | dump_dir = Path(dump_dir)
42 | sequence = file[seq_name]["frames"]["t0"]
43 | poses = []
44 | for (f_idx, f_name) in enumerate(sequence):
45 | frame = sequence[f_name]
46 | for dt_type in frame:
47 | dataset = frame[dt_type]
48 | img = dataset[...]
49 | if dt_type == "camera":
50 | if f_idx == 0:
51 | intrinsics = np.array([[img[0], 0, img[3]], [0, img[1], img[4]], [0, 0, 1]])
52 | pose = np.array(
53 | [[img[5], img[8], img[11], img[14] * scale], [img[6], img[9], img[12], img[15] * scale],
54 | [img[7], img[10], img[13], img[16] * scale]])
55 | poses.append(pose.tolist())
56 | elif dt_type == "depth":
57 | dimension = dataset.attrs["extents"]
58 | depth = np.array(np.frombuffer(decompress(img.tobytes(), dimension[0] * dimension[1] * 2),
59 | dtype=np.float16)).astype(np.float32)
60 | depth = depth.reshape(dimension[0], dimension[1]) * scale
61 |
62 | dump_depth_file = dump_dir / '{:04d}.npy'.format(f_idx)
63 | np.save(dump_depth_file, depth)
64 | elif dt_type == "image":
65 | img = imageio.imread(img.tobytes())
66 | dump_img_file = dump_dir / '{:04d}.jpg'.format(f_idx)
67 | imageio.imsave(dump_img_file, img)
68 |
69 | dump_cam_file = dump_dir / 'cam.txt'
70 | np.savetxt(dump_cam_file, intrinsics)
71 | poses_file = dump_dir / 'poses.txt'
72 | np.savetxt(poses_file, np.array(poses).reshape(-1, 12), fmt='%.6e')
73 |
74 | if len(dump_dir.files('*.jpg')) < 2:
75 | dump_dir.rmtree()
76 |
77 |
78 | def preparedata():
79 | num_threads = 1
80 | SUB_DATASET_NAMES = ([
81 | "rgbd_10_to_20_3d_train", "rgbd_10_to_20_handheld_train", "rgbd_10_to_20_simple_train",
82 | "rgbd_20_to_inf_3d_train", "rgbd_20_to_inf_handheld_train", "rgbd_20_to_inf_simple_train",
83 | "sun3d_train_0.01m_to_0.1m", "sun3d_train_0.1m_to_0.2m", "sun3d_train_0.2m_to_0.4m", "sun3d_train_0.4m_to_0.8m",
84 | "sun3d_train_0.8m_to_1.6m", "sun3d_train_1.6m_to_infm",
85 | "scenes11_train",
86 | ])
87 |
88 | dump_root = os.path.join(path, 'train')
89 | if not os.path.isdir(dump_root):
90 | os.mkdir(dump_root)
91 |
92 | if num_threads == 1:
93 | for scene in SUB_DATASET_NAMES:
94 | dump_example(scene)
95 | else:
96 | Parallel(n_jobs=num_threads)(delayed(dump_example)(scene) for scene in SUB_DATASET_NAMES)
97 |
98 | np.random.seed(8964)
99 | dump_root = Path(dump_root)
100 | subdirs = dump_root.dirs()
101 | canonic_prefixes = set([subdir.basename()[:-2] for subdir in subdirs])
102 | with open(dump_root / 'train.txt', 'w') as tf:
103 | with open(dump_root / 'val.txt', 'w') as vf:
104 | for pr in canonic_prefixes:
105 | corresponding_dirs = dump_root.dirs('{}*'.format(pr))
106 | if np.random.random() < 0.1:
107 | for s in corresponding_dirs:
108 | vf.write('{}\n'.format(s.name))
109 | else:
110 | for s in corresponding_dirs:
111 | tf.write('{}\n'.format(s.name))
112 |
113 | print("Finished Converting Data.")
114 |
115 |
116 | if __name__ == "__main__":
117 | preparedata()
118 |
--------------------------------------------------------------------------------
/unimatch/dataloader/flow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/dataloader/flow/__init__.py
--------------------------------------------------------------------------------
/unimatch/dataloader/stereo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/dataloader/stereo/__init__.py
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/color/0048.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/depth-scannet/color/0048.png
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/color/0054.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/depth-scannet/color/0054.png
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/color/0060.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/depth-scannet/color/0060.png
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/color/0066.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/depth-scannet/color/0066.png
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/intrinsic/intrinsic_depth.txt:
--------------------------------------------------------------------------------
1 | 577.590698 0.000000 318.905426 0.000000
2 | 0.000000 578.729797 242.683609 0.000000
3 | 0.000000 0.000000 1.000000 0.000000
4 | 0.000000 0.000000 0.000000 1.000000
5 |
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/pose/0048.txt:
--------------------------------------------------------------------------------
1 | 0.703694 -0.367391 0.608144 1.896290
2 | -0.708482 -0.427345 0.561630 2.467417
3 | 0.053549 -0.826075 -0.561010 1.399475
4 | 0.000000 0.000000 0.000000 1.000000
5 |
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/pose/0054.txt:
--------------------------------------------------------------------------------
1 | 0.750884 -0.329503 0.572363 1.915689
2 | -0.658300 -0.443024 0.608580 2.368469
3 | 0.053042 -0.833760 -0.549572 1.413484
4 | 0.000000 0.000000 0.000000 1.000000
5 |
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/pose/0060.txt:
--------------------------------------------------------------------------------
1 | 0.776779 -0.282017 0.563098 1.923212
2 | -0.625761 -0.446388 0.639656 2.259246
3 | 0.070966 -0.849237 -0.523221 1.407526
4 | 0.000000 0.000000 0.000000 1.000000
5 |
--------------------------------------------------------------------------------
/unimatch/demo/depth-scannet/pose/0066.txt:
--------------------------------------------------------------------------------
1 | 0.794580 -0.275900 0.540852 1.915812
2 | -0.604505 -0.442681 0.662273 2.192295
3 | 0.056703 -0.853178 -0.518529 1.423975
4 | 0.000000 0.000000 0.000000 1.000000
5 |
--------------------------------------------------------------------------------
/unimatch/demo/flow-davis/00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/flow-davis/00000.jpg
--------------------------------------------------------------------------------
/unimatch/demo/flow-davis/00001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/flow-davis/00001.jpg
--------------------------------------------------------------------------------
/unimatch/demo/flow-davis/00002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/flow-davis/00002.jpg
--------------------------------------------------------------------------------
/unimatch/demo/kitti.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/kitti.mp4
--------------------------------------------------------------------------------
/unimatch/demo/stereo-middlebury/im0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/stereo-middlebury/im0.png
--------------------------------------------------------------------------------
/unimatch/demo/stereo-middlebury/im1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/stereo-middlebury/im1.png
--------------------------------------------------------------------------------
/unimatch/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/loss/__init__.py
--------------------------------------------------------------------------------
/unimatch/loss/depth_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def compute_errors(gt, pred):
7 | """Computation of error metrics between predicted and ground truth depths
8 | """
9 | thresh = np.maximum((gt / pred), (pred / gt))
10 | a1 = (thresh < 1.25).mean()
11 | a2 = (thresh < 1.25 ** 2).mean()
12 | a3 = (thresh < 1.25 ** 3).mean()
13 |
14 | rmse = (gt - pred) ** 2
15 | rmse = np.sqrt(rmse.mean())
16 |
17 | rmse_log = (np.log(gt) - np.log(pred)) ** 2
18 | rmse_log = np.sqrt(rmse_log.mean())
19 |
20 | abs_rel = np.mean(np.abs(gt - pred) / gt)
21 |
22 | sq_rel = np.mean(((gt - pred) ** 2) / gt)
23 |
24 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
25 |
26 |
27 | def get_depth_grad_loss(depth_pred, depth_gt, valid, inverse_depth_loss=True):
28 | # default is on inverse depth
29 | # both: [B, H, W]
30 | assert depth_pred.dim() == 3 and depth_gt.dim() == 3 and valid.dim() == 3
31 |
32 | valid = valid > 0.5
33 | valid_x = valid[:, :, :-1] & valid[:, :, 1:]
34 | valid_y = valid[:, :-1, :] & valid[:, 1:, :]
35 |
36 | if valid_x.max() < 0.5 or valid_y.max() < 0.5: # no valid pixel
37 | return 0.
38 |
39 | if inverse_depth_loss:
40 | grad_pred_x = torch.abs(1. / depth_pred[:, :, :-1][valid_x] - 1. / depth_pred[:, :, 1:][valid_x])
41 | grad_pred_y = torch.abs(1. / depth_pred[:, :-1, :][valid_y] - 1. / depth_pred[:, 1:, :][valid_y])
42 |
43 | grad_gt_x = torch.abs(1. / depth_gt[:, :, :-1][valid_x] - 1. / depth_gt[:, :, 1:][valid_x])
44 | grad_gt_y = torch.abs(1. / depth_gt[:, :-1, :][valid_y] - 1. / depth_gt[:, 1:, :][valid_y])
45 | else:
46 | grad_pred_x = torch.abs((depth_pred[:, :, :-1] - depth_pred[:, :, 1:])[valid_x])
47 | grad_pred_y = torch.abs((depth_pred[:, :-1, :] - depth_pred[:, 1:, :])[valid_y])
48 |
49 | grad_gt_x = torch.abs((depth_gt[:, :, :-1] - depth_gt[:, :, 1:])[valid_x])
50 | grad_gt_y = torch.abs((depth_gt[:, :-1, :] - depth_gt[:, 1:, :])[valid_y])
51 |
52 | loss_grad_x = torch.abs(grad_pred_x - grad_gt_x).mean()
53 | loss_grad_y = torch.abs(grad_pred_y - grad_gt_y).mean()
54 |
55 | return loss_grad_x + loss_grad_y
56 |
57 |
58 | def depth_grad_loss_func(depth_preds, depth_gt, valid,
59 | inverse_depth_loss=True,
60 | gamma=0.9):
61 | num = len(depth_preds)
62 | loss = 0.
63 |
64 | for i in range(num):
65 | weight = gamma ** (num - i - 1)
66 | loss += weight * get_depth_grad_loss(depth_preds[i], depth_gt, valid,
67 | inverse_depth_loss=inverse_depth_loss)
68 |
69 | return loss
70 |
71 |
72 | def depth_loss_func(depth_preds, depth_gt, valid, gamma=0.9,
73 | ):
74 | """ loss function defined over multiple depth predictions """
75 |
76 | n_predictions = len(depth_preds)
77 | depth_loss = 0.0
78 |
79 | for i in range(n_predictions):
80 | i_weight = gamma ** (n_predictions - i - 1)
81 |
82 | # inverse depth loss
83 | valid_bool = valid > 0.5
84 | if valid_bool.max() < 0.5: # no valid pixel
85 | i_loss = 0.
86 | else:
87 | i_loss = (1. / depth_preds[i][valid_bool] - 1. / depth_gt[valid_bool]).abs().mean()
88 |
89 | depth_loss += i_weight * i_loss
90 |
91 | return depth_loss
92 |
--------------------------------------------------------------------------------
/unimatch/loss/flow_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def flow_loss_func(flow_preds, flow_gt, valid,
5 | gamma=0.9,
6 | max_flow=400,
7 | **kwargs,
8 | ):
9 | n_predictions = len(flow_preds)
10 | flow_loss = 0.0
11 |
12 | # exlude invalid pixels and extremely large diplacements
13 | mag = torch.sum(flow_gt ** 2, dim=1).sqrt() # [B, H, W]
14 | valid = (valid >= 0.5) & (mag < max_flow)
15 |
16 | for i in range(n_predictions):
17 | i_weight = gamma ** (n_predictions - i - 1)
18 |
19 | i_loss = (flow_preds[i] - flow_gt).abs()
20 |
21 | flow_loss += i_weight * (valid[:, None] * i_loss).mean()
22 |
23 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt()
24 |
25 | if valid.max() < 0.5:
26 | pass
27 |
28 | epe = epe.view(-1)[valid.view(-1)]
29 |
30 | metrics = {
31 | 'epe': epe.mean().item(),
32 | '1px': (epe > 1).float().mean().item(),
33 | '3px': (epe > 3).float().mean().item(),
34 | '5px': (epe > 5).float().mean().item(),
35 | }
36 |
37 | return flow_loss, metrics
38 |
--------------------------------------------------------------------------------
/unimatch/loss/stereo_metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def epe_metric(d_est, d_gt, mask, use_np=False):
6 | d_est, d_gt = d_est[mask], d_gt[mask]
7 | if use_np:
8 | epe = np.mean(np.abs(d_est - d_gt))
9 | else:
10 | epe = torch.mean(torch.abs(d_est - d_gt))
11 |
12 | return epe
13 |
14 |
15 | def d1_metric(d_est, d_gt, mask, use_np=False):
16 | d_est, d_gt = d_est[mask], d_gt[mask]
17 | if use_np:
18 | e = np.abs(d_gt - d_est)
19 | else:
20 | e = torch.abs(d_gt - d_est)
21 | err_mask = (e > 3) & (e / d_gt > 0.05)
22 |
23 | if use_np:
24 | mean = np.mean(err_mask.astype('float'))
25 | else:
26 | mean = torch.mean(err_mask.float())
27 |
28 | return mean
29 |
30 |
31 | def bad_pixel_metric(d_est, d_gt, mask,
32 | abs_threshold=10,
33 | rel_threshold=0.1,
34 | use_np=False):
35 | d_est, d_gt = d_est[mask], d_gt[mask]
36 | if use_np:
37 | e = np.abs(d_gt - d_est)
38 | else:
39 | e = torch.abs(d_gt - d_est)
40 |
41 | err_mask = (e > abs_threshold) & (e / torch.maximum(d_gt, torch.ones_like(d_gt)) > rel_threshold)
42 |
43 | if use_np:
44 | mean = np.mean(err_mask.astype('float'))
45 | else:
46 | mean = torch.mean(err_mask.float())
47 |
48 | return mean
49 |
50 |
51 | def thres_metric(d_est, d_gt, mask, thres, use_np=False):
52 | assert isinstance(thres, (int, float))
53 | d_est, d_gt = d_est[mask], d_gt[mask]
54 | if use_np:
55 | e = np.abs(d_gt - d_est)
56 | else:
57 | e = torch.abs(d_gt - d_est)
58 | err_mask = e > thres
59 |
60 | if use_np:
61 | mean = np.mean(err_mask.astype('float'))
62 | else:
63 | mean = torch.mean(err_mask.float())
64 |
65 | return mean
66 |
--------------------------------------------------------------------------------
/unimatch/pip_install.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
5 |
6 | pip install imageio==2.9.0 imageio-ffmpeg matplotlib opencv-python pillow scikit-image scipy tensorboard==2.9.1 setuptools==59.5.0
7 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmdepth_demo.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | # gmdepth-scale1-regrefine1
5 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \
6 | --inference_dir demo/depth-scannet \
7 | --output_path output/gmdepth-scale1-regrefine1-scannet \
8 | --resume pretrained/gmdepth-scale1-regrefine1-resumeflowthings-scannet-90325722.pth \
9 | --reg_refine \
10 | --num_reg_refine 1
11 |
12 | # --pred_bidir_depth
13 |
14 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmdepth_evaluate.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | # gmdepth-scale1
5 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \
6 | --eval \
7 | --resume pretrained/gmdepth-scale1-resumeflowthings-demon-a2fe127b.pth \
8 | --val_dataset demon \
9 | --demon_split scenes11
10 |
11 |
12 | # gmdepth-scale1-regrefine1, this is our final model
13 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \
14 | --eval \
15 | --resume pretrained/gmdepth-scale1-regrefine1-resumeflowthings-demon-7c23f230.pth \
16 | --val_dataset demon \
17 | --demon_split scenes11 \
18 | --reg_refine \
19 | --num_reg_refine 1
20 |
21 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmdepth_scale1_regrefine1_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # GMDepth (1/8 feature only), with additional 1 local regression refinement at 1/8 resolution
4 |
5 | # number of gpus for training, please set according to your hardware
6 | # trained on 8x 40GB A100 gpus
7 | NUM_GPUS=8
8 |
9 |
10 | # scannet
11 | CHECKPOINT_DIR=checkpoints_depth/scannet-gmdepth-scale1-regrefine1-resumeflowthings && \
12 | mkdir -p ${CHECKPOINT_DIR} && \
13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \
14 | --launcher pytorch \
15 | --checkpoint_dir ${CHECKPOINT_DIR} \
16 | --resume pretrained/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth \
17 | --no_resume_optimizer \
18 | --dataset scannet \
19 | --val_dataset scannet \
20 | --image_size 480 640 \
21 | --batch_size 64 \
22 | --lr 4e-4 \
23 | --reg_refine \
24 | --num_reg_refine 1 \
25 | --summary_freq 100 \
26 | --val_freq 5000 \
27 | --save_ckpt_freq 5000 \
28 | --num_steps 100000 \
29 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
30 |
31 |
32 | # demon
33 | CHECKPOINT_DIR=checkpoints_depth/demon-gmdepth-scale1-regrefine1-resumeflowthings && \
34 | mkdir -p ${CHECKPOINT_DIR} && \
35 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \
36 | --launcher pytorch \
37 | --checkpoint_dir ${CHECKPOINT_DIR} \
38 | --resume pretrained/gmdepth-scale1-resumeflowthings-demon-a2fe127b.pth \
39 | --no_resume_optimizer \
40 | --dataset demon \
41 | --val_dataset demon \
42 | --demon_split rgbd \
43 | --image_size 448 576 \
44 | --batch_size 64 \
45 | --lr 4e-4 \
46 | --reg_refine \
47 | --num_reg_refine 1 \
48 | --summary_freq 100 \
49 | --val_freq 5000 \
50 | --save_ckpt_freq 5000 \
51 | --num_steps 100000 \
52 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
53 |
54 |
55 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmdepth_scale1_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # basic GMDepth without any refinement (1/8 feature only)
4 |
5 | # number of gpus for training, please set according to your hardware
6 | # trained on 8x 40GB A100 gpus
7 | NUM_GPUS=8
8 |
9 |
10 | # scannet (our final model is trained for 100K steps, for ablation, we train for 50K)
11 | # resume flow things model (our ablations are trained from random init)
12 | CHECKPOINT_DIR=checkpoints_depth/scannet-gmdepth-scale1-resumeflowthings && \
13 | mkdir -p ${CHECKPOINT_DIR} && \
14 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \
15 | --launcher pytorch \
16 | --checkpoint_dir ${CHECKPOINT_DIR} \
17 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \
18 | --no_resume_optimizer \
19 | --dataset scannet \
20 | --val_dataset scannet \
21 | --image_size 480 640 \
22 | --batch_size 80 \
23 | --lr 4e-4 \
24 | --summary_freq 100 \
25 | --val_freq 5000 \
26 | --save_ckpt_freq 5000 \
27 | --num_steps 100000 \
28 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
29 |
30 |
31 | # demon, resume flow things model
32 | CHECKPOINT_DIR=checkpoints_depth/demon-gmdepth-scale1-resumeflowthings && \
33 | mkdir -p ${CHECKPOINT_DIR} && \
34 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \
35 | --launcher pytorch \
36 | --checkpoint_dir ${CHECKPOINT_DIR} \
37 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \
38 | --no_resume_optimizer \
39 | --dataset demon \
40 | --val_dataset demon \
41 | --demon_split rgbd \
42 | --image_size 448 576 \
43 | --batch_size 80 \
44 | --lr 4e-4 \
45 | --summary_freq 100 \
46 | --val_freq 5000 \
47 | --save_ckpt_freq 5000 \
48 | --num_steps 100000 \
49 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
50 |
51 |
52 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmflow_demo.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | # gmflow-scale2-regrefine6, inference on image dir
5 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
6 | --inference_dir demo/flow-davis \
7 | --resume pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth \
8 | --output_path output/gmflow-scale2-regrefine6-davis \
9 | --padding_factor 32 \
10 | --upsample_factor 4 \
11 | --num_scales 2 \
12 | --attn_splits_list 2 8 \
13 | --corr_radius_list -1 4 \
14 | --prop_radius_list -1 1 \
15 | --reg_refine \
16 | --num_reg_refine 6
17 |
18 |
19 | # gmflow-scale2-regrefine6, inference on video, save as video
20 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
21 | --inference_video demo/kitti.mp4 \
22 | --resume pretrained/gmflow-scale2-regrefine6-kitti15-25b554d7.pth \
23 | --output_path output/kitti \
24 | --padding_factor 32 \
25 | --upsample_factor 4 \
26 | --num_scales 2 \
27 | --attn_splits_list 2 8 \
28 | --corr_radius_list -1 4 \
29 | --prop_radius_list -1 1 \
30 | --reg_refine \
31 | --num_reg_refine 6 \
32 | --save_video \
33 | --concat_flow_img
34 |
35 |
36 |
37 | # gmflow-scale1, inference on image dir
38 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
39 | --inference_dir demo/flow-davis \
40 | --resume pretrained/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth \
41 | --output_path output/gmflow-scale1-davis
42 |
43 | # optional predict bidirection flow and forward-backward consistency check
44 | #--pred_bidir_flow
45 | #--fwd_bwd_check
46 |
47 |
48 | # gmflow-scale2, inference on image dir
49 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
50 | --inference_dir demo/flow-davis \
51 | --resume pretrained/gmflow-scale2-mixdata-train320x576-9ff1c094.pth \
52 | --output_path output/gmflow-scale2-davis \
53 | --padding_factor 32 \
54 | --upsample_factor 4 \
55 | --num_scales 2 \
56 | --attn_splits_list 2 8 \
57 | --corr_radius_list -1 4 \
58 | --prop_radius_list -1 1
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmflow_evaluate.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | # gmflow-scale1
5 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
6 | --eval \
7 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \
8 | --val_dataset sintel \
9 | --with_speed_metric
10 |
11 |
12 | # gmflow-scale2
13 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
14 | --eval \
15 | --resume pretrained/gmflow-scale2-things-36579974.pth \
16 | --val_dataset kitti \
17 | --padding_factor 32 \
18 | --upsample_factor 4 \
19 | --num_scales 2 \
20 | --attn_splits_list 2 8 \
21 | --corr_radius_list -1 4 \
22 | --prop_radius_list -1 1 \
23 | --with_speed_metric
24 |
25 |
26 | # gmflow-scale2-regrefine6
27 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
28 | --eval \
29 | --resume pretrained/gmflow-scale2-regrefine6-things-776ed612.pth \
30 | --val_dataset kitti \
31 | --padding_factor 32 \
32 | --upsample_factor 4 \
33 | --num_scales 2 \
34 | --attn_splits_list 2 8 \
35 | --corr_radius_list -1 4 \
36 | --prop_radius_list -1 1 \
37 | --reg_refine \
38 | --num_reg_refine 6 \
39 | --with_speed_metric
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmflow_scale1_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # basic GMFlow without any refinement (1/8 feature only)
4 |
5 | # number of gpus for training, please set according to your hardware
6 | # can be trained on 4x 16GB V100 or 2x 32GB V100 or 2x 40GB A100 gpus
7 | NUM_GPUS=4
8 |
9 | # chairs
10 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale1 && \
11 | mkdir -p ${CHECKPOINT_DIR} && \
12 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
13 | --launcher pytorch \
14 | --checkpoint_dir ${CHECKPOINT_DIR} \
15 | --stage chairs \
16 | --batch_size 16 \
17 | --val_dataset chairs sintel kitti \
18 | --lr 4e-4 \
19 | --image_size 384 512 \
20 | --padding_factor 16 \
21 | --upsample_factor 8 \
22 | --with_speed_metric \
23 | --val_freq 10000 \
24 | --save_ckpt_freq 10000 \
25 | --num_steps 100000 \
26 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
27 |
28 | # things (our final model is trained for 800K iterations, for ablation study, you can train for 200K)
29 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale1 && \
30 | mkdir -p ${CHECKPOINT_DIR} && \
31 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
32 | --launcher pytorch \
33 | --checkpoint_dir ${CHECKPOINT_DIR} \
34 | --resume checkpoints_flow/chairs-gmflow-scale1/step_100000.pth \
35 | --stage things \
36 | --batch_size 8 \
37 | --val_dataset things sintel kitti \
38 | --lr 2e-4 \
39 | --image_size 384 768 \
40 | --padding_factor 16 \
41 | --upsample_factor 8 \
42 | --with_speed_metric \
43 | --val_freq 40000 \
44 | --save_ckpt_freq 50000 \
45 | --num_steps 800000 \
46 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
47 |
48 | # a final note: if your training is terminated unexpectedly, you can resume from the latest checkpoint
49 | # an example: resume chairs training
50 | # CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale1 && \
51 | # mkdir -p ${CHECKPOINT_DIR} && \
52 | # python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
53 | # --launcher pytorch \
54 | # --checkpoint_dir ${CHECKPOINT_DIR} \
55 | # --resume checkpoints_flow/chairs-gmflow-scale1/checkpoint_latest.pth \
56 | # --stage chairs \
57 | # --batch_size 16 \
58 | # --val_dataset chairs sintel kitti \
59 | # --lr 4e-4 \
60 | # --image_size 384 512 \
61 | # --padding_factor 16 \
62 | # --upsample_factor 8 \
63 | # --with_speed_metric \
64 | # --val_freq 10000 \
65 | # --save_ckpt_freq 10000 \
66 | # --num_steps 100000 \
67 | # 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
68 |
69 |
70 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmflow_scale2_regrefine6_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features)
4 | # with additional 6 local regression refinements
5 |
6 | # number of gpus for training, please set according to your hardware
7 | # can be trained on 8x 32G V100 or 8x 40GB A100 gpus
8 | NUM_GPUS=8
9 |
10 | # chairs, resume from scale2 model
11 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale2-regrefine6 && \
12 | mkdir -p ${CHECKPOINT_DIR} && \
13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
14 | --launcher pytorch \
15 | --checkpoint_dir ${CHECKPOINT_DIR} \
16 | --resume pretrained/gmflow-scale2-chairs-020cc9be.pth \
17 | --no_resume_optimizer \
18 | --stage chairs \
19 | --batch_size 16 \
20 | --val_dataset chairs sintel kitti \
21 | --lr 4e-4 \
22 | --image_size 384 512 \
23 | --padding_factor 32 \
24 | --upsample_factor 4 \
25 | --num_scales 2 \
26 | --attn_splits_list 2 8 \
27 | --corr_radius_list -1 4 \
28 | --prop_radius_list -1 1 \
29 | --reg_refine \
30 | --num_reg_refine 6 \
31 | --with_speed_metric \
32 | --val_freq 10000 \
33 | --save_ckpt_freq 10000 \
34 | --num_steps 100000 \
35 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
36 |
37 | # things
38 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale2-regrefine6 && \
39 | mkdir -p ${CHECKPOINT_DIR} && \
40 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
41 | --launcher pytorch \
42 | --checkpoint_dir ${CHECKPOINT_DIR} \
43 | --resume checkpoints_flow/chairs-gmflow-scale2-regrefine6/step_100000.pth \
44 | --stage things \
45 | --batch_size 8 \
46 | --val_dataset things sintel kitti \
47 | --lr 2e-4 \
48 | --image_size 384 768 \
49 | --padding_factor 32 \
50 | --upsample_factor 4 \
51 | --num_scales 2 \
52 | --attn_splits_list 2 8 \
53 | --corr_radius_list -1 4 \
54 | --prop_radius_list -1 1 \
55 | --reg_refine \
56 | --num_reg_refine 6 \
57 | --with_speed_metric \
58 | --val_freq 40000 \
59 | --save_ckpt_freq 50000 \
60 | --num_steps 800000 \
61 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
62 |
63 | # sintel, resume from things model
64 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2-regrefine6 && \
65 | mkdir -p ${CHECKPOINT_DIR} && \
66 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
67 | --launcher pytorch \
68 | --checkpoint_dir ${CHECKPOINT_DIR} \
69 | --resume checkpoints_flow/things-gmflow-scale2-regrefine6/step_800000.pth \
70 | --stage sintel \
71 | --batch_size 8 \
72 | --val_dataset sintel kitti \
73 | --lr 2e-4 \
74 | --image_size 320 896 \
75 | --padding_factor 32 \
76 | --upsample_factor 4 \
77 | --num_scales 2 \
78 | --attn_splits_list 2 8 \
79 | --corr_radius_list -1 4 \
80 | --prop_radius_list -1 1 \
81 | --reg_refine \
82 | --num_reg_refine 6 \
83 | --with_speed_metric \
84 | --val_freq 20000 \
85 | --save_ckpt_freq 20000 \
86 | --num_steps 200000 \
87 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
88 |
89 |
90 | # sintel finetune, resume from sintel model, this is our final model for sintel benchmark submission
91 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2-regrefine6-ft && \
92 | mkdir -p ${CHECKPOINT_DIR} && \
93 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
94 | --launcher pytorch \
95 | --checkpoint_dir ${CHECKPOINT_DIR} \
96 | --resume checkpoints_flow/sintel-gmflow-scale2-regrefine6/step_200000.pth \
97 | --stage sintel_ft \
98 | --batch_size 8 \
99 | --val_dataset sintel \
100 | --lr 1e-4 \
101 | --image_size 416 1024 \
102 | --padding_factor 32 \
103 | --upsample_factor 4 \
104 | --num_scales 2 \
105 | --attn_splits_list 2 8 \
106 | --corr_radius_list -1 4 \
107 | --prop_radius_list -1 1 \
108 | --reg_refine \
109 | --num_reg_refine 6 \
110 | --with_speed_metric \
111 | --val_freq 1000 \
112 | --save_ckpt_freq 1000 \
113 | --num_steps 5000 \
114 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
115 |
116 |
117 | # vkitti2, resume from things model
118 | CHECKPOINT_DIR=checkpoints_flow/vkitti2-gmflow-scale2-regrefine6 && \
119 | mkdir -p ${CHECKPOINT_DIR} && \
120 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
121 | --launcher pytorch \
122 | --checkpoint_dir ${CHECKPOINT_DIR} \
123 | --resume checkpoints_flow/things-gmflow-scale2-regrefine6/step_800000.pth \
124 | --stage vkitti2 \
125 | --batch_size 16 \
126 | --val_dataset kitti \
127 | --lr 2e-4 \
128 | --image_size 320 832 \
129 | --padding_factor 32 \
130 | --upsample_factor 4 \
131 | --num_scales 2 \
132 | --attn_splits_list 2 8 \
133 | --corr_radius_list -1 4 \
134 | --prop_radius_list -1 1 \
135 | --reg_refine \
136 | --num_reg_refine 6 \
137 | --with_speed_metric \
138 | --val_freq 10000 \
139 | --save_ckpt_freq 10000 \
140 | --num_steps 40000 \
141 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
142 |
143 |
144 | # kitti, resume from vkitti2 model, this is our final model for kitti benchmark submission
145 | CHECKPOINT_DIR=checkpoints_flow/kitti-gmflow-scale2-regrefine6 && \
146 | mkdir -p ${CHECKPOINT_DIR} && \
147 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
148 | --launcher pytorch \
149 | --checkpoint_dir ${CHECKPOINT_DIR} \
150 | --resume checkpoints_flow/vkitti2-gmflow-scale2-regrefine6/step_040000.pth \
151 | --stage kitti_mix \
152 | --batch_size 8 \
153 | --val_dataset kitti \
154 | --lr 2e-4 \
155 | --image_size 352 1216 \
156 | --padding_factor 32 \
157 | --upsample_factor 4 \
158 | --num_scales 2 \
159 | --attn_splits_list 2 8 \
160 | --corr_radius_list -1 4 \
161 | --prop_radius_list -1 1 \
162 | --reg_refine \
163 | --num_reg_refine 6 \
164 | --with_speed_metric \
165 | --val_freq 5000 \
166 | --save_ckpt_freq 10000 \
167 | --num_steps 30000 \
168 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
169 |
170 |
171 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmflow_scale2_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features)
4 |
5 | # number of gpus for training, please set according to your hardware
6 | # can be trained on 4x 32G V100 or 4x 40GB A100 or 8x 16G V100 gpus
7 | NUM_GPUS=4
8 |
9 | # chairs
10 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale2 && \
11 | mkdir -p ${CHECKPOINT_DIR} && \
12 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
13 | --launcher pytorch \
14 | --checkpoint_dir ${CHECKPOINT_DIR} \
15 | --stage chairs \
16 | --batch_size 16 \
17 | --val_dataset chairs sintel kitti \
18 | --lr 4e-4 \
19 | --image_size 384 512 \
20 | --padding_factor 32 \
21 | --upsample_factor 4 \
22 | --num_scales 2 \
23 | --attn_splits_list 2 8 \
24 | --corr_radius_list -1 4 \
25 | --prop_radius_list -1 1 \
26 | --with_speed_metric \
27 | --val_freq 10000 \
28 | --save_ckpt_freq 10000 \
29 | --num_steps 100000 \
30 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
31 |
32 | # things (our final model is trained for 800K iterations, for ablation study, you can train for 200K)
33 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale2 && \
34 | mkdir -p ${CHECKPOINT_DIR} && \
35 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
36 | --launcher pytorch \
37 | --checkpoint_dir ${CHECKPOINT_DIR} \
38 | --resume checkpoints_flow/chairs-gmflow-scale2/step_100000.pth \
39 | --stage things \
40 | --batch_size 8 \
41 | --val_dataset things sintel kitti \
42 | --lr 2e-4 \
43 | --image_size 384 768 \
44 | --padding_factor 32 \
45 | --upsample_factor 4 \
46 | --num_scales 2 \
47 | --attn_splits_list 2 8 \
48 | --corr_radius_list -1 4 \
49 | --prop_radius_list -1 1 \
50 | --with_speed_metric \
51 | --val_freq 40000 \
52 | --save_ckpt_freq 50000 \
53 | --num_steps 800000 \
54 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
55 |
56 | # sintel
57 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2 && \
58 | mkdir -p ${CHECKPOINT_DIR} && \
59 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
60 | --launcher pytorch \
61 | --checkpoint_dir ${CHECKPOINT_DIR} \
62 | --resume checkpoints_flow/things-gmflow-scale2/step_800000.pth \
63 | --stage sintel \
64 | --batch_size 8 \
65 | --val_dataset sintel kitti \
66 | --lr 2e-4 \
67 | --image_size 320 896 \
68 | --padding_factor 32 \
69 | --upsample_factor 4 \
70 | --num_scales 2 \
71 | --attn_splits_list 2 8 \
72 | --corr_radius_list -1 4 \
73 | --prop_radius_list -1 1 \
74 | --with_speed_metric \
75 | --val_freq 20000 \
76 | --save_ckpt_freq 20000 \
77 | --num_steps 200000 \
78 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
79 |
80 | # kitti
81 | CHECKPOINT_DIR=checkpoints_flow/kitti-gmflow-scale2 && \
82 | mkdir -p ${CHECKPOINT_DIR} && \
83 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \
84 | --launcher pytorch \
85 | --checkpoint_dir ${CHECKPOINT_DIR} \
86 | --resume checkpoints_flow/sintel-gmflow-scale2/step_200000.pth \
87 | --stage kitti \
88 | --batch_size 8 \
89 | --val_dataset kitti \
90 | --lr 2e-4 \
91 | --image_size 320 1152 \
92 | --padding_factor 32 \
93 | --upsample_factor 4 \
94 | --num_scales 2 \
95 | --attn_splits_list 2 8 \
96 | --corr_radius_list -1 4 \
97 | --prop_radius_list -1 1 \
98 | --with_speed_metric \
99 | --val_freq 10000 \
100 | --save_ckpt_freq 10000 \
101 | --num_steps 100000 \
102 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
103 |
104 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmflow_submission.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | # generate prediction results for submission on sintel and kitti online servers
5 |
6 |
7 | # submission to sintel
8 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
9 | --submission \
10 | --output_path submission/sintel-gmflow-scale2-regrefine6-sintelft \
11 | --val_dataset sintel \
12 | --resume pretrained/gmflow-scale2-regrefine6-sintelft-6e39e2b9.pth \
13 | --inference_size 416 1024 \
14 | --padding_factor 32 \
15 | --upsample_factor 4 \
16 | --num_scales 2 \
17 | --attn_splits_list 2 8 \
18 | --corr_radius_list -1 4 \
19 | --prop_radius_list -1 1 \
20 | --reg_refine \
21 | --num_reg_refine 6
22 |
23 |
24 | # you can also visualize the predictions before submission
25 | #CUDA_VISIBLE_DEVICES=0 python main_flow.py \
26 | #--submission \
27 | #--output_path submission/sintel-gmflow-scale2-regrefine6-sintelft-vis \
28 | #--val_dataset sintel \
29 | #--resume pretrained/gmflow-scale2-regrefine6-sintelft-6e39e2b9.pth \
30 | #--inference_size 416 1024 \
31 | #--save_vis_flow \
32 | #--no_save_flo \
33 | #--padding_factor 32 \
34 | #--upsample_factor 4 \
35 | #--num_scales 2 \
36 | #--attn_splits_list 2 8 \
37 | #--corr_radius_list -1 4 \
38 | #--prop_radius_list -1 1 \
39 | #--reg_refine \
40 | #--num_reg_refine 6
41 |
42 |
43 | # submission to kitti
44 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \
45 | --submission \
46 | --output_path submission/kitti-gmflow-scale2-regrefine6 \
47 | --val_dataset kitti \
48 | --resume pretrained/gmflow-scale2-regrefine6-kitti15-25b554d7.pth \
49 | --inference_size 352 1216 \
50 | --padding_factor 32 \
51 | --upsample_factor 4 \
52 | --num_scales 2 \
53 | --attn_splits_list 2 8 \
54 | --corr_radius_list -1 4 \
55 | --prop_radius_list -1 1 \
56 | --reg_refine \
57 | --num_reg_refine 6
58 |
59 |
60 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmstereo_demo.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | # gmstereo-scale2-regrefine3 model
5 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \
6 | --inference_dir demo/stereo-middlebury \
7 | --inference_size 1024 1536 \
8 | --output_path output/gmstereo-scale2-regrefine3-middlebury \
9 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-middleburyfthighres-a82bec03.pth \
10 | --padding_factor 32 \
11 | --upsample_factor 4 \
12 | --num_scales 2 \
13 | --attn_type self_swin2d_cross_swin1d \
14 | --attn_splits_list 2 8 \
15 | --corr_radius_list -1 4 \
16 | --prop_radius_list -1 1 \
17 | --reg_refine \
18 | --num_reg_refine 3
19 |
20 | # optionally predict both left and right disparities
21 | #--pred_bidir_disp
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmstereo_evaluate.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | # gmstereo-scale1
5 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \
6 | --eval \
7 | --resume pretrained/gmstereo-scale1-resumeflowthings-sceneflow-16e38788.pth \
8 | --val_dataset kitti15
9 |
10 |
11 | # gmstereo-scale2
12 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \
13 | --eval \
14 | --resume pretrained/gmstereo-scale2-resumeflowthings-sceneflow-48020649.pth \
15 | --val_dataset kitti15 \
16 | --padding_factor 32 \
17 | --upsample_factor 4 \
18 | --num_scales 2 \
19 | --attn_type self_swin2d_cross_swin1d \
20 | --attn_splits_list 2 8 \
21 | --corr_radius_list -1 4 \
22 | --prop_radius_list -1 1
23 |
24 |
25 | # gmstereo-scale2-regrefine3
26 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \
27 | --eval \
28 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-sceneflow-f724fee6.pth \
29 | --val_dataset kitti15 \
30 | --padding_factor 32 \
31 | --upsample_factor 4 \
32 | --num_scales 2 \
33 | --attn_type self_swin2d_cross_swin1d \
34 | --attn_splits_list 2 8 \
35 | --corr_radius_list -1 4 \
36 | --prop_radius_list -1 1 \
37 | --reg_refine \
38 | --num_reg_refine 3
39 |
40 |
41 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmstereo_scale1_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # basic GMStereo without any refinement (1/8 feature only)
4 |
5 | # number of gpus for training, please set according to your hardware
6 | # trained on 8x 40GB A100 gpus
7 | NUM_GPUS=8
8 |
9 | # sceneflow (our final model is trained for 100K steps, for ablation, we train for 50K)
10 | # resume flow things model (our ablations are trained from random init)
11 | CHECKPOINT_DIR=checkpoints_stereo/sceneflow-gmstereo-scale1-resumeflowthings && \
12 | mkdir -p ${CHECKPOINT_DIR} && \
13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \
14 | --launcher pytorch \
15 | --checkpoint_dir ${CHECKPOINT_DIR} \
16 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \
17 | --no_resume_optimizer \
18 | --stage sceneflow \
19 | --batch_size 64 \
20 | --val_dataset things kitti15 \
21 | --img_height 384 \
22 | --img_width 768 \
23 | --padding_factor 16 \
24 | --upsample_factor 8 \
25 | --attn_type self_swin2d_cross_1d \
26 | --summary_freq 1000 \
27 | --val_freq 10000 \
28 | --save_ckpt_freq 1000 \
29 | --save_latest_ckpt_freq 1000 \
30 | --num_steps 100000 \
31 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmstereo_scale2_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features)
4 |
5 | # number of gpus for training, please set according to your hardware
6 | # trained on 8x 40GB A100 gpus
7 | NUM_GPUS=8
8 |
9 | # sceneflow
10 | # resume flow things model
11 | CHECKPOINT_DIR=checkpoints_stereo/sceneflow-gmstereo-scale2-resumeflowthings && \
12 | mkdir -p ${CHECKPOINT_DIR} && \
13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \
14 | --launcher pytorch \
15 | --checkpoint_dir ${CHECKPOINT_DIR} \
16 | --resume pretrained/gmflow-scale2-things-36579974.pth \
17 | --no_resume_optimizer \
18 | --stage sceneflow \
19 | --batch_size 32 \
20 | --val_dataset things kitti15 \
21 | --img_height 384 \
22 | --img_width 768 \
23 | --padding_factor 32 \
24 | --upsample_factor 4 \
25 | --num_scales 2 \
26 | --attn_type self_swin2d_cross_swin1d \
27 | --attn_splits_list 2 8 \
28 | --corr_radius_list -1 4 \
29 | --prop_radius_list -1 1 \
30 | --summary_freq 100 \
31 | --val_freq 10000 \
32 | --save_ckpt_freq 1000 \
33 | --save_latest_ckpt_freq 1000 \
34 | --num_steps 100000 \
35 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/unimatch/scripts/gmstereo_submission.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # generate prediction results for submission on kitti, middlebury and eth3d online servers
4 |
5 |
6 | # submission to kitti
7 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \
8 | --submission \
9 | --val_dataset kitti15 \
10 | --inference_size 352 1216 \
11 | --output_path submission/kitti-gmstereo-scale2-regrefine3 \
12 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-kitti15-04487ebf.pth \
13 | --padding_factor 32 \
14 | --upsample_factor 4 \
15 | --num_scales 2 \
16 | --attn_type self_swin2d_cross_swin1d \
17 | --attn_splits_list 2 8 \
18 | --corr_radius_list -1 4 \
19 | --prop_radius_list -1 1 \
20 | --reg_refine \
21 | --num_reg_refine 3
22 |
23 |
24 | # submission to middlebury
25 | # set --eth_submission_mode to train and test to generate results on both train and test sets
26 | # use --save_vis_disp to visualize disparity
27 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \
28 | --submission \
29 | --val_dataset middlebury \
30 | --middlebury_resolution F \
31 | --middlebury_submission_mode test \
32 | --inference_size 1024 1536 \
33 | --output_path submission/middlebury-test-gmstereo-scale2-regrefine3 \
34 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-middleburyfthighres-a82bec03.pth \
35 | --padding_factor 32 \
36 | --upsample_factor 4 \
37 | --num_scales 2 \
38 | --attn_type self_swin2d_cross_swin1d \
39 | --attn_splits_list 2 8 \
40 | --corr_radius_list -1 4 \
41 | --prop_radius_list -1 1 \
42 | --reg_refine \
43 | --num_reg_refine 3
44 |
45 |
46 | # submission to eth3d
47 | # set --eth_submission_mode to train and test to generate results on both train and test sets
48 | # use --save_vis_disp to visualize disparity
49 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \
50 | --submission \
51 | --eth_submission_mode test \
52 | --val_dataset eth3d \
53 | --inference_size 512 768 \
54 | --output_path submission/eth3d-test-gmstereo-scale2-regrefine3 \
55 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-eth3dft-46effc13.pth \
56 | --padding_factor 32 \
57 | --upsample_factor 4 \
58 | --num_scales 2 \
59 | --attn_type self_swin2d_cross_swin1d \
60 | --attn_splits_list 2 8 \
61 | --corr_radius_list -1 4 \
62 | --prop_radius_list -1 1 \
63 | --reg_refine \
64 | --num_reg_refine 3
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/unimatch/unimatch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/unimatch/__init__.py
--------------------------------------------------------------------------------
/unimatch/unimatch/backbone.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .trident_conv import MultiScaleTridentConv
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
8 | ):
9 | super(ResidualBlock, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
12 | dilation=dilation, padding=dilation, stride=stride, bias=False)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
14 | dilation=dilation, padding=dilation, bias=False)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | self.norm1 = norm_layer(planes)
18 | self.norm2 = norm_layer(planes)
19 | if not stride == 1 or in_planes != planes:
20 | self.norm3 = norm_layer(planes)
21 |
22 | if stride == 1 and in_planes == planes:
23 | self.downsample = None
24 | else:
25 | self.downsample = nn.Sequential(
26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
27 |
28 | def forward(self, x):
29 | y = x
30 | y = self.relu(self.norm1(self.conv1(y)))
31 | y = self.relu(self.norm2(self.conv2(y)))
32 |
33 | if self.downsample is not None:
34 | x = self.downsample(x)
35 |
36 | return self.relu(x + y)
37 |
38 |
39 | class CNNEncoder(nn.Module):
40 | def __init__(self, output_dim=128,
41 | norm_layer=nn.InstanceNorm2d,
42 | num_output_scales=1,
43 | **kwargs,
44 | ):
45 | super(CNNEncoder, self).__init__()
46 | self.num_branch = num_output_scales
47 |
48 | feature_dims = [64, 96, 128]
49 |
50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
51 | self.norm1 = norm_layer(feature_dims[0])
52 | self.relu1 = nn.ReLU(inplace=True)
53 |
54 | self.in_planes = feature_dims[0]
55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
57 |
58 | # highest resolution 1/4 or 1/8
59 | stride = 2 if num_output_scales == 1 else 1
60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride,
61 | norm_layer=norm_layer,
62 | ) # 1/4 or 1/8
63 |
64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
65 |
66 | if self.num_branch > 1:
67 | if self.num_branch == 4:
68 | strides = (1, 2, 4, 8)
69 | elif self.num_branch == 3:
70 | strides = (1, 2, 4)
71 | elif self.num_branch == 2:
72 | strides = (1, 2)
73 | else:
74 | raise ValueError
75 |
76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
77 | kernel_size=3,
78 | strides=strides,
79 | paddings=1,
80 | num_branch=self.num_branch,
81 | )
82 |
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv2d):
85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
87 | if m.weight is not None:
88 | nn.init.constant_(m.weight, 1)
89 | if m.bias is not None:
90 | nn.init.constant_(m.bias, 0)
91 |
92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
95 |
96 | layers = (layer1, layer2)
97 |
98 | self.in_planes = dim
99 | return nn.Sequential(*layers)
100 |
101 | def forward(self, x):
102 | x = self.conv1(x)
103 | x = self.norm1(x)
104 | x = self.relu1(x)
105 |
106 | x = self.layer1(x) # 1/2
107 | x = self.layer2(x) # 1/4
108 | x = self.layer3(x) # 1/8 or 1/4
109 |
110 | x = self.conv2(x)
111 |
112 | if self.num_branch > 1:
113 | out = self.trident_conv([x] * self.num_branch) # high to low res
114 | else:
115 | out = [x]
116 |
117 | return out
118 |
--------------------------------------------------------------------------------
/unimatch/unimatch/position.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | import math
7 |
8 |
9 | class PositionEmbeddingSine(nn.Module):
10 | """
11 | This is a more standard version of the position embedding, very similar to the one
12 | used by the Attention is all you need paper, generalized to work on images.
13 | """
14 |
15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
16 | super().__init__()
17 | self.num_pos_feats = num_pos_feats
18 | self.temperature = temperature
19 | self.normalize = normalize
20 | if scale is not None and normalize is False:
21 | raise ValueError("normalize should be True if scale is passed")
22 | if scale is None:
23 | scale = 2 * math.pi
24 | self.scale = scale
25 |
26 | def forward(self, x):
27 | # x = tensor_list.tensors # [B, C, H, W]
28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
29 | b, c, h, w = x.size()
30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
31 | y_embed = mask.cumsum(1, dtype=torch.float32)
32 | x_embed = mask.cumsum(2, dtype=torch.float32)
33 | if self.normalize:
34 | eps = 1e-6
35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
37 |
38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
40 |
41 | pos_x = x_embed[:, :, :, None] / dim_t
42 | pos_y = y_embed[:, :, :, None] / dim_t
43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
46 | return pos
47 |
--------------------------------------------------------------------------------
/unimatch/unimatch/reg_refine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256,
8 | out_dim=2,
9 | ):
10 | super(FlowHead, self).__init__()
11 |
12 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
13 | self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1)
14 | self.relu = nn.ReLU(inplace=True)
15 |
16 | def forward(self, x):
17 | out = self.conv2(self.relu(self.conv1(x)))
18 |
19 | return out
20 |
21 |
22 | class SepConvGRU(nn.Module):
23 | def __init__(self, hidden_dim=128, input_dim=192 + 128,
24 | kernel_size=5,
25 | ):
26 | padding = (kernel_size - 1) // 2
27 |
28 | super(SepConvGRU, self).__init__()
29 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
30 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
31 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
32 |
33 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
34 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
35 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
36 |
37 | def forward(self, h, x):
38 | # horizontal
39 | hx = torch.cat([h, x], dim=1)
40 | z = torch.sigmoid(self.convz1(hx))
41 | r = torch.sigmoid(self.convr1(hx))
42 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
43 | h = (1 - z) * h + z * q
44 |
45 | # vertical
46 | hx = torch.cat([h, x], dim=1)
47 | z = torch.sigmoid(self.convz2(hx))
48 | r = torch.sigmoid(self.convr2(hx))
49 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
50 | h = (1 - z) * h + z * q
51 |
52 | return h
53 |
54 |
55 | class BasicMotionEncoder(nn.Module):
56 | def __init__(self, corr_channels=324,
57 | flow_channels=2,
58 | ):
59 | super(BasicMotionEncoder, self).__init__()
60 |
61 | self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0)
62 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
63 | self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3)
64 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
65 | self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1)
66 |
67 | def forward(self, flow, corr):
68 | cor = F.relu(self.convc1(corr))
69 | cor = F.relu(self.convc2(cor))
70 | flo = F.relu(self.convf1(flow))
71 | flo = F.relu(self.convf2(flo))
72 |
73 | cor_flo = torch.cat([cor, flo], dim=1)
74 | out = F.relu(self.conv(cor_flo))
75 | return torch.cat([out, flow], dim=1)
76 |
77 |
78 | class BasicUpdateBlock(nn.Module):
79 | def __init__(self, corr_channels=324,
80 | hidden_dim=128,
81 | context_dim=128,
82 | downsample_factor=8,
83 | flow_dim=2,
84 | bilinear_up=False,
85 | ):
86 | super(BasicUpdateBlock, self).__init__()
87 |
88 | self.encoder = BasicMotionEncoder(corr_channels=corr_channels,
89 | flow_channels=flow_dim,
90 | )
91 |
92 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim)
93 |
94 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256,
95 | out_dim=flow_dim,
96 | )
97 |
98 | if bilinear_up:
99 | self.mask = None
100 | else:
101 | self.mask = nn.Sequential(
102 | nn.Conv2d(hidden_dim, 256, 3, padding=1),
103 | nn.ReLU(inplace=True),
104 | nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0))
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 |
109 | inp = torch.cat([inp, motion_features], dim=1)
110 |
111 | net = self.gru(net, inp)
112 | delta_flow = self.flow_head(net)
113 |
114 | if self.mask is not None:
115 | mask = self.mask(net)
116 | else:
117 | mask = None
118 |
119 | return net, mask, delta_flow
120 |
--------------------------------------------------------------------------------
/unimatch/unimatch/trident_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
3 |
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.nn.modules.utils import _pair
8 |
9 |
10 | class MultiScaleTridentConv(nn.Module):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | stride=1,
17 | strides=1,
18 | paddings=0,
19 | dilations=1,
20 | dilation=1,
21 | groups=1,
22 | num_branch=1,
23 | test_branch_idx=-1,
24 | bias=False,
25 | norm=None,
26 | activation=None,
27 | ):
28 | super(MultiScaleTridentConv, self).__init__()
29 | self.in_channels = in_channels
30 | self.out_channels = out_channels
31 | self.kernel_size = _pair(kernel_size)
32 | self.num_branch = num_branch
33 | self.stride = _pair(stride)
34 | self.groups = groups
35 | self.with_bias = bias
36 | self.dilation = dilation
37 | if isinstance(paddings, int):
38 | paddings = [paddings] * self.num_branch
39 | if isinstance(dilations, int):
40 | dilations = [dilations] * self.num_branch
41 | if isinstance(strides, int):
42 | strides = [strides] * self.num_branch
43 | self.paddings = [_pair(padding) for padding in paddings]
44 | self.dilations = [_pair(dilation) for dilation in dilations]
45 | self.strides = [_pair(stride) for stride in strides]
46 | self.test_branch_idx = test_branch_idx
47 | self.norm = norm
48 | self.activation = activation
49 |
50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
51 |
52 | self.weight = nn.Parameter(
53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
54 | )
55 | if bias:
56 | self.bias = nn.Parameter(torch.Tensor(out_channels))
57 | else:
58 | self.bias = None
59 |
60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
61 | if self.bias is not None:
62 | nn.init.constant_(self.bias, 0)
63 |
64 | def forward(self, inputs):
65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
66 | assert len(inputs) == num_branch
67 |
68 | if self.training or self.test_branch_idx == -1:
69 | outputs = [
70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
71 | for input, stride, padding in zip(inputs, self.strides, self.paddings)
72 | ]
73 | else:
74 | outputs = [
75 | F.conv2d(
76 | inputs[0],
77 | self.weight,
78 | self.bias,
79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
81 | self.dilation,
82 | self.groups,
83 | )
84 | ]
85 |
86 | if self.norm is not None:
87 | outputs = [self.norm(x) for x in outputs]
88 | if self.activation is not None:
89 | outputs = [self.activation(x) for x in outputs]
90 | return outputs
91 |
--------------------------------------------------------------------------------
/unimatch/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | # https://github.com/open-mmlab/mmcv/blob/7540cf73ac7e5d1e14d0ffbd9b6759e83929ecfc/mmcv/runner/dist_utils.py
3 |
4 | import os
5 | import subprocess
6 |
7 | import torch
8 | import torch.multiprocessing as mp
9 | from torch import distributed as dist
10 |
11 |
12 | def init_dist(launcher, backend='nccl', **kwargs):
13 | if mp.get_start_method(allow_none=True) is None:
14 | mp.set_start_method('spawn')
15 | if launcher == 'pytorch':
16 | _init_dist_pytorch(backend, **kwargs)
17 | elif launcher == 'mpi':
18 | _init_dist_mpi(backend, **kwargs)
19 | elif launcher == 'slurm':
20 | _init_dist_slurm(backend, **kwargs)
21 | else:
22 | raise ValueError(f'Invalid launcher type: {launcher}')
23 |
24 |
25 | def _init_dist_pytorch(backend, **kwargs):
26 | # TODO: use local_rank instead of rank % num_gpus
27 | rank = int(os.environ['RANK'])
28 | num_gpus = torch.cuda.device_count()
29 | torch.cuda.set_device(rank % num_gpus)
30 | dist.init_process_group(backend=backend, **kwargs)
31 |
32 |
33 | def _init_dist_mpi(backend, **kwargs):
34 | # TODO: use local_rank instead of rank % num_gpus
35 | rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
36 | num_gpus = torch.cuda.device_count()
37 | torch.cuda.set_device(rank % num_gpus)
38 | dist.init_process_group(backend=backend, **kwargs)
39 |
40 |
41 | def _init_dist_slurm(backend, port=None):
42 | """Initialize slurm distributed training environment.
43 | If argument ``port`` is not specified, then the master port will be system
44 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
45 | environment variable, then a default port ``29500`` will be used.
46 | Args:
47 | backend (str): Backend of torch.distributed.
48 | port (int, optional): Master port. Defaults to None.
49 | """
50 | proc_id = int(os.environ['SLURM_PROCID'])
51 | ntasks = int(os.environ['SLURM_NTASKS'])
52 | node_list = os.environ['SLURM_NODELIST']
53 | num_gpus = torch.cuda.device_count()
54 | torch.cuda.set_device(proc_id % num_gpus)
55 | addr = subprocess.getoutput(
56 | f'scontrol show hostname {node_list} | head -n1')
57 | # specify master port
58 | if port is not None:
59 | os.environ['MASTER_PORT'] = str(port)
60 | elif 'MASTER_PORT' in os.environ:
61 | pass # use MASTER_PORT in the environment variable
62 | else:
63 | # 29500 is torch.distributed default port
64 | os.environ['MASTER_PORT'] = '29500'
65 | # use MASTER_ADDR in the environment variable if it already exists
66 | if 'MASTER_ADDR' not in os.environ:
67 | os.environ['MASTER_ADDR'] = addr
68 | os.environ['WORLD_SIZE'] = str(ntasks)
69 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
70 | os.environ['RANK'] = str(proc_id)
71 | dist.init_process_group(backend=backend)
72 |
73 |
74 | def get_dist_info():
75 | # if (TORCH_VERSION != 'parrots'
76 | # and digit_version(TORCH_VERSION) < digit_version('1.0')):
77 | # initialized = dist._initialized
78 | # else:
79 | if dist.is_available():
80 | initialized = dist.is_initialized()
81 | else:
82 | initialized = False
83 | if initialized:
84 | rank = dist.get_rank()
85 | world_size = dist.get_world_size()
86 | else:
87 | rank = 0
88 | world_size = 1
89 | return rank, world_size
90 |
91 |
92 | # from DETR repo
93 | def setup_for_distributed(is_master):
94 | """
95 | This function disables printing when not in master process
96 | """
97 | import builtins as __builtin__
98 | builtin_print = __builtin__.print
99 |
100 | def print(*args, **kwargs):
101 | force = kwargs.pop('force', False)
102 | if is_master or force:
103 | builtin_print(*args, **kwargs)
104 |
105 | __builtin__.print = print
106 |
--------------------------------------------------------------------------------
/unimatch/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 | import cv2
6 |
7 | TAG_CHAR = np.array([202021.25], np.float32)
8 |
9 |
10 | def readFlow(fn):
11 | """ Read .flo file in Middlebury format"""
12 | # Code adapted from:
13 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
14 |
15 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
16 | # print 'fn = %s'%(fn)
17 | with open(fn, 'rb') as f:
18 | magic = np.fromfile(f, np.float32, count=1)
19 | if 202021.25 != magic:
20 | print('Magic number incorrect. Invalid .flo file')
21 | return None
22 | else:
23 | w = np.fromfile(f, np.int32, count=1)
24 | h = np.fromfile(f, np.int32, count=1)
25 | # print 'Reading %d x %d flo file\n' % (w, h)
26 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
27 | # Reshape testdata into 3D array (columns, rows, bands)
28 | # The reshape here is for visualization, the original code is (w,h,2)
29 | return np.resize(data, (int(h), int(w), 2))
30 |
31 |
32 | def readPFM(file):
33 | file = open(file, 'rb')
34 |
35 | color = None
36 | width = None
37 | height = None
38 | scale = None
39 | endian = None
40 |
41 | header = file.readline().rstrip()
42 | if header == b'PF':
43 | color = True
44 | elif header == b'Pf':
45 | color = False
46 | else:
47 | raise Exception('Not a PFM file.')
48 |
49 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
50 | if dim_match:
51 | width, height = map(int, dim_match.groups())
52 | else:
53 | raise Exception('Malformed PFM header.')
54 |
55 | scale = float(file.readline().rstrip())
56 | if scale < 0: # little-endian
57 | endian = '<'
58 | scale = -scale
59 | else:
60 | endian = '>' # big-endian
61 |
62 | data = np.fromfile(file, endian + 'f')
63 | shape = (height, width, 3) if color else (height, width)
64 |
65 | data = np.reshape(data, shape)
66 | data = np.flipud(data)
67 | return data
68 |
69 |
70 | def writeFlow(filename, uv, v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert (uv.ndim == 3)
81 | assert (uv.shape[2] == 2)
82 | u = uv[:, :, 0]
83 | v = uv[:, :, 1]
84 | else:
85 | u = uv
86 |
87 | assert (u.shape == v.shape)
88 | height, width = u.shape
89 | f = open(filename, 'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width * nBands))
96 | tmp[:, np.arange(width) * 2] = u
97 | tmp[:, np.arange(width) * 2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
104 | flow = flow[:, :, ::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2 ** 15) / 64.0
107 | return flow, valid
108 |
109 |
110 | def readDispKITTI(filename):
111 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
112 | valid = disp > 0.0
113 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
114 | return flow, valid
115 |
116 |
117 | def writeFlowKITTI(filename, uv):
118 | uv = 64.0 * uv + 2 ** 15
119 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
120 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
121 | cv2.imwrite(filename, uv[..., ::-1])
122 |
123 |
124 | def read_gen(file_name, pil=False):
125 | ext = splitext(file_name)[-1]
126 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
127 | return Image.open(file_name)
128 | elif ext == '.bin' or ext == '.raw':
129 | return np.load(file_name)
130 | elif ext == '.flo':
131 | return readFlow(file_name).astype(np.float32)
132 | elif ext == '.pfm':
133 | flow = readPFM(file_name).astype(np.float32)
134 | if len(flow.shape) == 2:
135 | return flow
136 | else:
137 | return flow[:, :, :-1]
138 | return []
139 |
140 |
141 | def read_vkitti2_flow(filename):
142 | # In R, flow along x-axis normalized by image width and quantized to [0;2^16 – 1]
143 | # In G, flow along x-axis normalized by image width and quantized to [0;2^16 – 1]
144 | # B = 0 for invalid flow (e.g., sky pixels)
145 | bgr = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
146 | h, w, _c = bgr.shape
147 | assert bgr.dtype == np.uint16 and _c == 3
148 | # b == invalid flow flag == 0 for sky or other invalid flow
149 | invalid = bgr[:, :, 0] == 0
150 | # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1]
151 | out_flow = 2.0 / (2 ** 16 - 1.0) * bgr[:, :, 2:0:-1].astype('f4') - 1 # [H, W, 2]
152 | out_flow[..., 0] *= (w - 1)
153 | out_flow[..., 1] *= (h - 1)
154 |
155 | out_flow[invalid] = 0.000001 # invalid as very small value to add supervison on the sky
156 | valid = (np.logical_or(invalid, ~invalid)).astype(np.float32)
157 |
158 | return out_flow, valid
159 |
--------------------------------------------------------------------------------
/unimatch/utils/logger.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils.flow_viz import flow_tensor_to_image
4 | from .visualization import viz_depth_tensor
5 |
6 |
7 | class Logger:
8 | def __init__(self, lr_scheduler,
9 | summary_writer,
10 | summary_freq=100,
11 | start_step=0,
12 | img_mean=None,
13 | img_std=None,
14 | ):
15 | self.lr_scheduler = lr_scheduler
16 | self.total_steps = start_step
17 | self.running_loss = {}
18 | self.summary_writer = summary_writer
19 | self.summary_freq = summary_freq
20 |
21 | self.img_mean = img_mean
22 | self.img_std = img_std
23 |
24 | def print_training_status(self, mode='train', is_depth=False):
25 | if is_depth:
26 | print('step: %06d \t loss: %.3f' % (self.total_steps, self.running_loss['total_loss'] / self.summary_freq))
27 | else:
28 | print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq))
29 |
30 | for k in self.running_loss:
31 | self.summary_writer.add_scalar(mode + '/' + k,
32 | self.running_loss[k] / self.summary_freq, self.total_steps)
33 | self.running_loss[k] = 0.0
34 |
35 | def lr_summary(self):
36 | lr = self.lr_scheduler.get_last_lr()[0]
37 | self.summary_writer.add_scalar('lr', lr, self.total_steps)
38 |
39 | def add_image_summary(self, img1, img2, flow_preds=None, flow_gt=None, mode='train',
40 | is_depth=False,
41 | ):
42 | if self.total_steps % self.summary_freq == 0:
43 | if is_depth:
44 | img1 = self.unnormalize_image(img1.detach().cpu()) # [3, H, W], range [0, 1]
45 | img2 = self.unnormalize_image(img2.detach().cpu())
46 |
47 | concat = torch.cat((img1, img2), dim=-1) # [3, H, W*2]
48 |
49 | self.summary_writer.add_image(mode + '/img', concat, self.total_steps)
50 | else:
51 | img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1)
52 | img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard
53 |
54 | flow_pred = flow_tensor_to_image(flow_preds[-1][0])
55 | forward_flow_gt = flow_tensor_to_image(flow_gt[0])
56 | flow_concat = torch.cat((torch.from_numpy(flow_pred),
57 | torch.from_numpy(forward_flow_gt)), dim=-1)
58 |
59 | concat = torch.cat((img_concat, flow_concat), dim=-2)
60 |
61 | self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps)
62 |
63 | def add_depth_summary(self, depth_pred, depth_gt, mode='train'):
64 | # assert depth_pred.dim() == 2 # [H, W]
65 | if self.total_steps % self.summary_freq == 0 or 'val' in mode:
66 | pred_viz = viz_depth_tensor(depth_pred.detach().cpu()) # [3, H, W]
67 | gt_viz = viz_depth_tensor(depth_gt.detach().cpu())
68 |
69 | concat = torch.cat((pred_viz, gt_viz), dim=-1) # [3, H, W*2]
70 |
71 | self.summary_writer.add_image(mode + '/depth_pred_gt', concat, self.total_steps)
72 |
73 | def unnormalize_image(self, img):
74 | # img: [3, H, W], used for visualizing image
75 | mean = torch.tensor(self.img_mean).view(3, 1, 1).type_as(img)
76 | std = torch.tensor(self.img_std).view(3, 1, 1).type_as(img)
77 |
78 | out = img * std + mean
79 |
80 | return out
81 |
82 | def push(self, metrics, mode='train', is_depth=False, ):
83 | self.total_steps += 1
84 |
85 | self.lr_summary()
86 |
87 | for key in metrics:
88 | if key not in self.running_loss:
89 | self.running_loss[key] = 0.0
90 |
91 | self.running_loss[key] += metrics[key]
92 |
93 | if self.total_steps % self.summary_freq == 0:
94 | self.print_training_status(mode, is_depth=is_depth)
95 | self.running_loss = {}
96 |
97 | def write_dict(self, results):
98 | for key in results:
99 | tag = key.split('_')[0]
100 | tag = tag + '/' + key
101 | self.summary_writer.add_scalar(tag, results[key], self.total_steps)
102 |
103 | def close(self):
104 | self.summary_writer.close()
105 |
--------------------------------------------------------------------------------
/unimatch/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 |
5 |
6 | def read_text_lines(filepath):
7 | with open(filepath, 'r') as f:
8 | lines = f.readlines()
9 | lines = [l.rstrip() for l in lines]
10 | return lines
11 |
12 |
13 | def check_path(path):
14 | if not os.path.exists(path):
15 | os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing
16 |
17 |
18 | def save_command(save_path, filename='command_train.txt'):
19 | check_path(save_path)
20 | command = sys.argv
21 | save_file = os.path.join(save_path, filename)
22 | # Save all training commands when resuming training
23 | with open(save_file, 'a') as f:
24 | f.write(' '.join(command))
25 | f.write('\n\n')
26 |
27 |
28 | def save_args(args, filename='args.json'):
29 | args_dict = vars(args)
30 | check_path(args.checkpoint_dir)
31 | save_path = os.path.join(args.checkpoint_dir, filename)
32 |
33 | # save all training args when resuming training
34 | with open(save_path, 'a') as f:
35 | json.dump(args_dict, f, indent=4, sort_keys=False)
36 | f.write('\n\n')
37 |
--------------------------------------------------------------------------------
/unimatch/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import numpy as np
4 | import torchvision.utils as vutils
5 | import cv2
6 | from matplotlib.cm import get_cmap
7 | import matplotlib as mpl
8 | import matplotlib.cm as cm
9 |
10 |
11 | def vis_disparity(disp):
12 | disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
13 | disp_vis = disp_vis.astype("uint8")
14 | disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
15 |
16 | return disp_vis
17 |
18 |
19 | def gen_error_colormap():
20 | cols = np.array(
21 | [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149],
22 | [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180],
23 | [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209],
24 | [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233],
25 | [1.5 / 3.0, 3 / 3.0, 224, 243, 248],
26 | [3 / 3.0, 6 / 3.0, 254, 224, 144],
27 | [6 / 3.0, 12 / 3.0, 253, 174, 97],
28 | [12 / 3.0, 24 / 3.0, 244, 109, 67],
29 | [24 / 3.0, 48 / 3.0, 215, 48, 39],
30 | [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32)
31 | cols[:, 2: 5] /= 255.
32 | return cols
33 |
34 |
35 | def disp_error_img(D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1):
36 | D_gt_np = D_gt_tensor.detach().cpu().numpy()
37 | D_est_np = D_est_tensor.detach().cpu().numpy()
38 | B, H, W = D_gt_np.shape
39 | # valid mask
40 | mask = D_gt_np > 0
41 | # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5%
42 | error = np.abs(D_gt_np - D_est_np)
43 | error[np.logical_not(mask)] = 0
44 | error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres)
45 | # get colormap
46 | cols = gen_error_colormap()
47 | # create error image
48 | error_image = np.zeros([B, H, W, 3], dtype=np.float32)
49 | for i in range(cols.shape[0]):
50 | error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:]
51 | # TODO: imdilate
52 | # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius));
53 | error_image[np.logical_not(mask)] = 0.
54 | # show color tag in the top-left cornor of the image
55 | for i in range(cols.shape[0]):
56 | distance = 20
57 | error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:]
58 |
59 | return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2])))
60 |
61 |
62 | def save_images(logger, mode_tag, images_dict, global_step):
63 | images_dict = tensor2numpy(images_dict)
64 | for tag, values in images_dict.items():
65 | if not isinstance(values, list) and not isinstance(values, tuple):
66 | values = [values]
67 | for idx, value in enumerate(values):
68 | if len(value.shape) == 3:
69 | value = value[:, np.newaxis, :, :]
70 | value = value[:1]
71 | value = torch.from_numpy(value)
72 |
73 | image_name = '{}/{}'.format(mode_tag, tag)
74 | if len(values) > 1:
75 | image_name = image_name + "_" + str(idx)
76 | logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True),
77 | global_step)
78 |
79 |
80 | def tensor2numpy(var_dict):
81 | for key, vars in var_dict.items():
82 | if isinstance(vars, np.ndarray):
83 | var_dict[key] = vars
84 | elif isinstance(vars, torch.Tensor):
85 | var_dict[key] = vars.data.cpu().numpy()
86 | else:
87 | raise NotImplementedError("invalid input type for tensor2numpy")
88 |
89 | return var_dict
90 |
91 |
92 | def viz_depth_tensor(disp, return_numpy=False, colormap='plasma'):
93 | # visualize inverse depth
94 | assert isinstance(disp, torch.Tensor)
95 |
96 | disp = disp.numpy()
97 | vmax = np.percentile(disp, 95)
98 | normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax)
99 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap)
100 | colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3]
101 |
102 | if return_numpy:
103 | return colormapped_im
104 |
105 | viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W]
106 |
107 | return viz
108 |
--------------------------------------------------------------------------------
/utils/attention_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 |
7 |
8 | def reshape_heads_to_batch_dim3(tensor, head_size):
9 | batch_size1, batch_size2, seq_len, dim = tensor.shape
10 | tensor = tensor.reshape(batch_size1, batch_size2,
11 | seq_len, head_size, dim // head_size)
12 | tensor = tensor.permute(0, 3, 1, 2, 4)
13 | return tensor
14 |
15 |
16 | def trajectory_hidden_states(query,
17 | key,
18 | value,
19 | trajectories,
20 | use_old_qk,
21 | extra_options,
22 | n_heads):
23 | if not use_old_qk:
24 | query = value
25 | key = value
26 | # TODO: Hardcoded for SD1.5
27 | _,_, oh, ow = extra_options['original_shape']
28 | height = 64 # int(value.shape[1]**0.5)
29 | width = height
30 | cond_size = len(extra_options['cond_or_uncond'])
31 | video_length = len(query) // cond_size
32 |
33 | sub_idxs = extra_options.get('ad_params', {}).get('sub_idxs', None)
34 | idx = 0
35 | if sub_idxs is not None:
36 | idx = sub_idxs[0]
37 |
38 | traj_window = trajectories['trajectory_windows'][idx]
39 | if f'traj{height}' not in traj_window:
40 | return value
41 | trajs = traj_window[f'traj{height}']
42 | traj_mask = traj_window[f'mask{height}']
43 |
44 | start = -video_length+1
45 | end = trajs.shape[2]
46 |
47 | traj_key_sequence_inds = torch.cat(
48 | [trajs[:, :, 0, :].unsqueeze(-2), trajs[:, :, start:end, :]], dim=-2)
49 | traj_mask = torch.cat([traj_mask[:, :, 0].unsqueeze(-1),
50 | traj_mask[:, :, start:end]], dim=-1)
51 |
52 | t_inds = traj_key_sequence_inds[:, :, :, 0]
53 | x_inds = traj_key_sequence_inds[:, :, :, 1]
54 | y_inds = traj_key_sequence_inds[:, :, :, 2]
55 |
56 | query_tempo = query.unsqueeze(-2)
57 | _key = rearrange(key, '(b f) (h w) d -> b f h w d',
58 | b=cond_size, h=height)
59 | _value = rearrange(value, '(b f) (h w) d -> b f h w d',
60 | b=cond_size, h=height)
61 | key_tempo = _key[:, t_inds, x_inds, y_inds]
62 | value_tempo = _value[:, t_inds, x_inds, y_inds]
63 | key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d')
64 | value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d')
65 |
66 | traj_mask = rearrange(torch.stack(
67 | [traj_mask] * cond_size), 'b f n l -> (b f) n l')
68 | traj_mask = traj_mask[:, None].repeat(
69 | 1, n_heads, 1, 1).unsqueeze(-2)
70 | attn_bias = torch.zeros_like(
71 | traj_mask, dtype=key_tempo.dtype, device=query.device) # regular zeros_like
72 | attn_bias[~traj_mask] = -torch.inf
73 |
74 | # flow attention
75 | query_tempo = reshape_heads_to_batch_dim3(query_tempo, n_heads)
76 | key_tempo = reshape_heads_to_batch_dim3(key_tempo, n_heads)
77 | value_tempo = reshape_heads_to_batch_dim3(value_tempo, n_heads)
78 |
79 | attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt(
80 | query_tempo.size(-1)) + attn_bias
81 | attn_matrix2 = F.softmax(attn_matrix2, dim=-1)
82 | out = (attn_matrix2@value_tempo).squeeze(-2)
83 |
84 | hidden_states = rearrange(out, 'b k r d -> b r (k d)')
85 |
86 | return hidden_states
87 |
88 |
89 | def select_pivot_indexes(length, batch_size, seed=None):
90 | # Create a new Random object with the given seed
91 | rng = random.Random(seed)
92 |
93 | # Use the seeded Random object to generate the random index
94 | rnd_idx = rng.randint(0, batch_size - 1)
95 |
96 | return [min(i, length-1) for i in range(rnd_idx, length, batch_size)] + [length-1]
--------------------------------------------------------------------------------
/utils/batching_utils.py:
--------------------------------------------------------------------------------
1 |
2 | # Adjusted from ADE: https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved
3 |
4 | def create_windows_static_standard(num_frames, context_length, overlap):
5 | windows = []
6 | if num_frames <= context_length or context_length == 0:
7 | windows.append(list(range(num_frames)))
8 | return windows
9 | # always return the same set of windows
10 | delta = context_length - overlap
11 | for start_idx in range(0, num_frames, delta):
12 | # if past the end of frames, move start_idx back to allow same context_length
13 | ending = start_idx + context_length
14 | if ending >= num_frames:
15 | final_delta = ending - num_frames
16 | final_start_idx = start_idx - final_delta
17 | windows.append(
18 | list(range(final_start_idx, final_start_idx + context_length)))
19 | break
20 | windows.append(list(range(start_idx, start_idx + context_length)))
21 | return windows
22 |
--------------------------------------------------------------------------------
/utils/module_utils.py:
--------------------------------------------------------------------------------
1 | def isinstance_str(x: object, cls_name: str):
2 | for _cls in x.__class__.__mro__:
3 | if _cls.__name__ == cls_name:
4 | return True
5 |
6 | return False
7 |
--------------------------------------------------------------------------------
/utils/noise_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_alphacumprod(sigma):
5 | return 1 / ((sigma * sigma) + 1)
6 |
7 |
8 | def add_noise(src_latent, noise, sigma):
9 | alphas_cumprod = get_alphacumprod(sigma)
10 |
11 | sqrt_alpha_prod = alphas_cumprod ** 0.5
12 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
13 | while len(sqrt_alpha_prod.shape) < len(src_latent.shape):
14 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
15 |
16 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5
17 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
18 | while len(sqrt_one_minus_alpha_prod.shape) < len(src_latent.shape):
19 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
20 |
21 | noisy_samples = sqrt_alpha_prod * src_latent + sqrt_one_minus_alpha_prod * noise
22 | return noisy_samples
23 |
24 |
25 | def add_noise_test(latents, sigma, noise=None):
26 | alpha_cumprod = 1/ ((sigma * sigma) + 1)
27 | sqrt_alpha_prod = alpha_cumprod ** 0.5
28 | sqrt_one_minus_alpha_prod = (1 - alpha_cumprod) ** 0.5
29 | if noise is None:
30 | generator = torch.Generator(device='cuda')
31 | # generator.manual_seed(0)
32 | noise = torch.empty_like(latents).normal_(generator=generator).to(latents.device)
33 |
34 | return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise
--------------------------------------------------------------------------------
/vv_defaults.py:
--------------------------------------------------------------------------------
1 | SD1_ATTN_INJ_DEFAULTS = set([])
2 | for idx in [1,2,3,4,5,6]:
3 | SD1_ATTN_INJ_DEFAULTS.add(('output', idx))
4 |
5 | SD1_RES_INJ_DEFAULTS = set()
6 | for idx in [3,4,6]:
7 | SD1_RES_INJ_DEFAULTS.add(('output', idx))
8 |
9 | SD1_FLOW_MAP = set([('input', 0), ('input', 1), ('output', 6), ('output', 7), ('output', 8)])
10 |
11 | SD1_OUTER_MAP = SD1_FLOW_MAP
12 | SD1_INNER_MAP = set([])
13 | for i in range(2, 12):
14 | SD1_INNER_MAP.add(('input', i))
15 | for i in range(6):
16 | SD1_INNER_MAP.add(('output', i))
17 |
18 | SD_FULL_MAP = set([])
19 | for i in range(40):
20 | SD_FULL_MAP.add(('input', i))
21 | SD_FULL_MAP.add(('output', i))
22 |
23 | SD1_INPUT_MAP = set([])
24 | SD1_OUTPUT_MAP = set([])
25 | for i in range(12):
26 | SD1_INPUT_MAP.add(('input', i))
27 | SD1_OUTPUT_MAP.add(('output', i))
28 |
29 | # TODO get actual adj upsampler indexes
30 | SDXL_ATTN_INJ_DEFAULTS = set([])
31 | for idx in [10, 15, 20, 25, 30]:
32 | SDXL_ATTN_INJ_DEFAULTS.add(('output', idx))
33 |
34 | SDXL_RES_INJ_DEFAULTS = set([])
35 | for idx in [1,2,4]:
36 | SDXL_RES_INJ_DEFAULTS.add(('output', idx))
37 |
38 | SDXL_FLOW_MAP = set([])
39 | for idx in range(36):
40 | SDXL_FLOW_MAP.add(('input', idx))
41 | SDXL_FLOW_MAP.add(('output', idx))
42 |
43 | # These are approximate
44 | SDXL_OUTER_MAP = set([])
45 | for i in range(30, 40):
46 | SDXL_OUTER_MAP.add(('output', i))
47 | for i in range(5):
48 | SDXL_OUTER_MAP.add(('input', i))
49 |
50 | SDXL_INNER_MAP =set([])
51 | for i in range(30):
52 | SDXL_INNER_MAP.add(('output', i))
53 | for i in range(5, 40):
54 | SDXL_INNER_MAP.add(('input', i))
55 |
56 |
57 | SDXL_INPUT_MAP = set([])
58 | SDXL_OUTPUT_MAP = set([])
59 | for i in range(40):
60 | SDXL_INPUT_MAP.add(('input', i))
61 | SDXL_OUTER_MAP.add(('output', i))
62 |
63 | MAP_TYPES = ['none', 'inner', 'outer', 'full', 'input', 'output']
64 |
65 | SD1_MAPS = {
66 | 'none': set(),
67 | 'inner': SD1_INNER_MAP,
68 | 'outer': SD1_OUTER_MAP,
69 | 'full': SD_FULL_MAP,
70 | 'output': SD1_OUTPUT_MAP,
71 | 'input': SD1_INPUT_MAP,
72 | }
73 |
74 | SDXL_MAPS = {
75 | 'none': set(),
76 | 'inner': SDXL_INNER_MAP,
77 | 'outer': SDXL_OUTER_MAP,
78 | 'full': SD_FULL_MAP,
79 | 'output': SDXL_OUTPUT_MAP,
80 | 'input': SDXL_INPUT_MAP,
81 | }
82 |
--------------------------------------------------------------------------------